From eb0d6696d7a2e0a532816b29b4eb47e042d8e8e7 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 22:35:08 +0800 Subject: [PATCH 01/34] feat(chatdb): add ChatDB agent for text-to-SQL interactions - Add CHAT_DB agent type to AgentType enum - Create ChatDBPage with split layout (DataSourcePanel + ChatDBChatPane) - Create ChatDBViewModel for managing data sources and chat - Create DataSourcePanel for managing database connections - Create DataSourceConfigDialog for adding/editing data sources - Create ChatDBChatPane for chat interactions - Add Database and Schema icons to AutoDevComposeIcons - Update AgentInterfaceRouter to route to ChatDBPage - Update when expressions in AgentChatInterface, TopBarMenuDesktop, TopBarMenuMobile Closes #508 --- .../kotlin/cc/unitmesh/agent/AgentType.kt | 9 + .../components/header/IdeaAgentTabsHeader.kt | 2 + .../idea/toolwindow/IdeaAgentViewModel.kt | 1 + .../idea/toolwindow/IdeaComposeIcons.kt | 55 +++ .../unitmesh/devins/ui/compose/AutoDevApp.kt | 5 +- .../ui/compose/agent/AgentChatInterface.kt | 5 +- .../ui/compose/agent/AgentInterfaceRouter.kt | 29 ++ .../ui/compose/agent/chatdb/ChatDBPage.kt | 169 ++++++++++ .../compose/agent/chatdb/ChatDBViewModel.kt | 273 +++++++++++++++ .../agent/chatdb/components/ChatDBChatPane.kt | 313 ++++++++++++++++++ .../components/DataSourceConfigDialog.kt | 222 +++++++++++++ .../chatdb/components/DataSourcePanel.kt | 312 +++++++++++++++++ .../agent/chatdb/model/DataSourceModels.kt | 168 ++++++++++ .../ui/compose/chat/TopBarMenuDesktop.kt | 1 + .../ui/compose/chat/TopBarMenuMobile.kt | 2 + .../ui/compose/icons/AutoDevComposeIcons.kt | 4 + .../devins/ui/compose/state/DesktopUiState.kt | 22 ++ 17 files changed, 1589 insertions(+), 3 deletions(-) create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt index 81d5d381bc..11fee4b743 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt @@ -7,6 +7,8 @@ package cc.unitmesh.agent * - LOCAL: Simple local chat mode without heavy tooling * - CODING: Local coding agent with full tool access (file system, shell, etc.) * - CODE_REVIEW: Dedicated code review agent with git integration + * - KNOWLEDGE: Document reader mode for AI-native document reading + * - CHAT_DB: Database chat mode for text-to-SQL interactions * - REMOTE: Remote agent connected to mpp-server */ enum class AgentType { @@ -30,6 +32,11 @@ enum class AgentType { */ KNOWLEDGE, + /** + * Database chat mode - text-to-SQL agent for database queries + */ + CHAT_DB, + /** * Remote agent mode - connects to remote mpp-server for distributed execution */ @@ -40,6 +47,7 @@ enum class AgentType { CODING -> "Agentic" CODE_REVIEW -> "Review" KNOWLEDGE -> "Knowledge" + CHAT_DB -> "ChatDB" REMOTE -> "Remote" } @@ -51,6 +59,7 @@ enum class AgentType { "coding" -> CODING "codereview" -> CODE_REVIEW "documentreader", "documents" -> KNOWLEDGE + "chatdb", "database" -> CHAT_DB else -> LOCAL_CHAT } } diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt index d7e0b5ba43..9a2357db57 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt @@ -256,6 +256,7 @@ private fun getAgentTypeColor(type: AgentType): Color = when (type) { AgentType.CODING -> IdeaAutoDevColors.Blue.c400 AgentType.CODE_REVIEW -> IdeaAutoDevColors.Indigo.c400 AgentType.KNOWLEDGE -> IdeaAutoDevColors.Green.c400 + AgentType.CHAT_DB -> IdeaAutoDevColors.Cyan.c400 AgentType.REMOTE -> IdeaAutoDevColors.Amber.c400 AgentType.LOCAL_CHAT -> JewelTheme.globalColors.text.normal } @@ -267,6 +268,7 @@ private fun getAgentTypeIcon(type: AgentType): ImageVector = when (type) { AgentType.CODING -> IdeaComposeIcons.Code AgentType.CODE_REVIEW -> IdeaComposeIcons.Review AgentType.KNOWLEDGE -> IdeaComposeIcons.Book + AgentType.CHAT_DB -> IdeaComposeIcons.Database AgentType.REMOTE -> IdeaComposeIcons.Cloud AgentType.LOCAL_CHAT -> IdeaComposeIcons.Chat } diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt index 1b11f4149f..1c2b02edfd 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt @@ -249,6 +249,7 @@ class IdeaAgentViewModel( AgentType.CODING -> "Coding" AgentType.CODE_REVIEW -> "CodeReview" AgentType.KNOWLEDGE -> "Documents" + AgentType.CHAT_DB -> "ChatDB" } AutoDevConfigWrapper.saveAgentTypePreference(typeString) diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt index 730f95be57..d19ad56cd0 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt @@ -1738,5 +1738,60 @@ object IdeaComposeIcons { }.build() } + /** + * Database icon (cylinder shape representing database storage) + */ + val Database: ImageVector by lazy { + ImageVector.Builder( + name = "Database", + defaultWidth = 24.dp, + defaultHeight = 24.dp, + viewportWidth = 24f, + viewportHeight = 24f + ).apply { + path( + fill = SolidColor(Color.Black) + ) { + // Database cylinder icon (Storage icon) + moveTo(2f, 20f) + horizontalLineToRelative(20f) + verticalLineToRelative(-4f) + horizontalLineTo(2f) + verticalLineToRelative(4f) + close() + moveTo(4f, 17f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + horizontalLineTo(4f) + verticalLineToRelative(-2f) + close() + moveTo(2f, 4f) + verticalLineToRelative(4f) + horizontalLineToRelative(20f) + verticalLineTo(4f) + horizontalLineTo(2f) + close() + moveTo(6f, 7f) + horizontalLineTo(4f) + verticalLineTo(5f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + close() + moveTo(2f, 14f) + horizontalLineToRelative(20f) + verticalLineToRelative(-4f) + horizontalLineTo(2f) + verticalLineToRelative(4f) + close() + moveTo(4f, 11f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + horizontalLineTo(4f) + verticalLineToRelative(-2f) + close() + } + }.build() + } + } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt index 420dea03f7..ec30fbc1d0 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt @@ -175,7 +175,10 @@ private fun AutoDevContent( val typeString = when (type) { AgentType.REMOTE -> "Remote" AgentType.LOCAL_CHAT -> "Local" - else -> "Local" + AgentType.CODING -> "Coding" + AgentType.CODE_REVIEW -> "CodeReview" + AgentType.KNOWLEDGE -> "Documents" + AgentType.CHAT_DB -> "ChatDB" } AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt index 88d2b2d900..23097e3b23 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt @@ -198,8 +198,9 @@ fun AgentChatInterface( } AgentType.CODE_REVIEW, - AgentType.KNOWLEDGE -> { - // CODE_REVIEW and DOCUMENT_READER have their own full-page interfaces + AgentType.KNOWLEDGE, + AgentType.CHAT_DB -> { + // CODE_REVIEW, DOCUMENT_READER and CHAT_DB have their own full-page interfaces // They should not reach here - handled by AgentInterfaceRouter } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt index 97f13119c9..8855769002 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt @@ -1,10 +1,13 @@ package cc.unitmesh.devins.ui.compose.agent +import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import cc.unitmesh.agent.AgentType +import cc.unitmesh.devins.ui.compose.agent.chatdb.ChatDBPage import cc.unitmesh.devins.ui.compose.agent.codereview.CodeReviewPage import cc.unitmesh.devins.ui.remote.RemoteAgentChatInterface +import cc.unitmesh.devins.workspace.Workspace import cc.unitmesh.llm.KoogLLMService /** @@ -53,9 +56,35 @@ fun AgentInterfaceRouter( onProjectChange: (String) -> Unit = {}, onGitUrlChange: (String) -> Unit = {}, onNotification: (String, String) -> Unit = { _, _ -> }, + workspace: Workspace? = null, modifier: Modifier = Modifier ) { when (selectedAgentType) { + AgentType.CHAT_DB -> { + if (workspace != null) { + ChatDBPage( + workspace = workspace, + llmService = llmService, + modifier = modifier, + onBack = { + onAgentTypeChange(AgentType.CODING) + }, + onNotification = onNotification + ) + } else { + // Show placeholder when workspace is not available + androidx.compose.foundation.layout.Box( + modifier = modifier.fillMaxSize(), + contentAlignment = androidx.compose.ui.Alignment.Center + ) { + androidx.compose.material3.Text( + text = "Please select a workspace to use ChatDB", + style = androidx.compose.material3.MaterialTheme.typography.bodyLarge + ) + } + } + } + AgentType.KNOWLEDGE -> { cc.unitmesh.devins.ui.compose.document.DocumentReaderPage( modifier = modifier, diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt new file mode 100644 index 0000000000..f557da592d --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -0,0 +1,169 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb + +import androidx.compose.foundation.layout.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.compose.agent.chatdb.components.* +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.ConnectionStatus +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons +import cc.unitmesh.devins.workspace.Workspace +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.flow.collectLatest + +/** + * ChatDB Page - Main page for text-to-SQL agent + * + * Left side: Data source management panel + * Right side: Chat area for natural language to SQL queries + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ChatDBPage( + workspace: Workspace, + llmService: KoogLLMService?, + modifier: Modifier = Modifier, + onBack: () -> Unit, + onNotification: (String, String) -> Unit = { _, _ -> } +) { + val viewModel = remember { ChatDBViewModel(workspace) } + val state = viewModel.state + + // Collect notifications + LaunchedEffect(viewModel) { + viewModel.notificationEvent.collectLatest { (title, message) -> + onNotification(title, message) + } + } + + // Cleanup on dispose + DisposableEffect(viewModel) { + onDispose { + viewModel.dispose() + } + } + + Scaffold( + topBar = { + TopAppBar( + title = { Text("ChatDB") }, + navigationIcon = { + IconButton(onClick = onBack) { + Icon(AutoDevComposeIcons.ArrowBack, contentDescription = "Back") + } + }, + actions = { + // Schema info button when connected + if (state.connectionStatus is ConnectionStatus.Connected) { + var showSchemaDialog by remember { mutableStateOf(false) } + IconButton(onClick = { showSchemaDialog = true }) { + Icon(AutoDevComposeIcons.Schema, contentDescription = "View Schema") + } + if (showSchemaDialog) { + SchemaInfoDialog( + schema = viewModel.getSchema(), + onDismiss = { showSchemaDialog = false } + ) + } + } + } + ) + }, + modifier = modifier + ) { paddingValues -> + Row( + modifier = Modifier + .fillMaxSize() + .padding(paddingValues) + ) { + // Left panel - Data source management + DataSourcePanel( + dataSources = state.filteredDataSources, + selectedDataSourceId = state.selectedDataSourceId, + connectionStatus = state.connectionStatus, + filterQuery = state.filterQuery, + onFilterChange = viewModel::setFilterQuery, + onSelectDataSource = viewModel::selectDataSource, + onAddClick = viewModel::openAddDialog, + onEditClick = viewModel::openEditDialog, + onDeleteClick = viewModel::deleteDataSource, + onConnectClick = viewModel::connect, + onDisconnectClick = viewModel::disconnect, + modifier = Modifier.width(280.dp) + ) + + VerticalDivider() + + // Right panel - Chat area + ChatDBChatPane( + renderer = viewModel.renderer, + connectionStatus = state.connectionStatus, + schema = viewModel.getSchema(), + isGenerating = viewModel.isGenerating, + onSendMessage = viewModel::sendMessage, + onStopGeneration = viewModel::stopGeneration, + modifier = Modifier.weight(1f) + ) + } + + // Config dialog + if (state.isConfigDialogOpen) { + DataSourceConfigDialog( + existingConfig = state.editingDataSource, + onDismiss = viewModel::closeConfigDialog, + onSave = { config -> + if (state.editingDataSource != null) { + viewModel.updateDataSource(config) + } else { + viewModel.addDataSource(config) + } + } + ) + } + } +} + +@Composable +private fun SchemaInfoDialog( + schema: cc.unitmesh.agent.database.DatabaseSchema?, + onDismiss: () -> Unit +) { + AlertDialog( + onDismissRequest = onDismiss, + title = { Text("Database Schema") }, + text = { + if (schema != null) { + Column { + Text( + text = "${schema.tables.size} tables", + style = MaterialTheme.typography.labelMedium + ) + Spacer(modifier = Modifier.height(8.dp)) + schema.tables.take(10).forEach { table -> + Text( + text = "โ€ข ${table.name} (${table.columns.size} columns)", + style = MaterialTheme.typography.bodySmall + ) + } + if (schema.tables.size > 10) { + Text( + text = "... and ${schema.tables.size - 10} more", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } else { + Text("No schema available") + } + }, + confirmButton = { + TextButton(onClick = onDismiss) { + Text("Close") + } + } + ) +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt new file mode 100644 index 0000000000..c49301d4ad --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -0,0 +1,273 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb + +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.config.ConfigManager +import cc.unitmesh.devins.ui.compose.agent.ComposeRenderer +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.workspace.Workspace +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.asSharedFlow +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * ViewModel for ChatDB Page + * + * Manages data sources, database connections, and chat interactions for text-to-SQL. + */ +class ChatDBViewModel( + private val workspace: Workspace +) { + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + + val renderer = ComposeRenderer() + + // LLM Service + private var llmService: KoogLLMService? = null + private var currentExecutionJob: Job? = null + + // Database connection + private var currentConnection: DatabaseConnection? = null + private var currentSchema: DatabaseSchema? = null + + // UI State + var state by mutableStateOf(ChatDBState()) + private set + + var isGenerating by mutableStateOf(false) + private set + + // Notifications + private val _notificationEvent = MutableSharedFlow>() + val notificationEvent = _notificationEvent.asSharedFlow() + + init { + initializeLLMService() + loadDataSources() + } + + private fun initializeLLMService() { + scope.launch { + try { + val configWrapper = ConfigManager.load() + val modelConfig = configWrapper.getActiveModelConfig() + if (modelConfig != null && modelConfig.isValid()) { + llmService = KoogLLMService.create(modelConfig) + } + } catch (e: Exception) { + println("[ChatDB] Failed to initialize LLM service: ${e.message}") + } + } + } + + private fun loadDataSources() { + // TODO: Load from persistent storage + // For now, use empty list + state = state.copy(dataSources = emptyList()) + } + + @OptIn(ExperimentalUuidApi::class) + fun addDataSource(config: DataSourceConfig) { + val newConfig = config.copy( + id = if (config.id.isBlank()) Uuid.random().toString() else config.id, + createdAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds(), + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + state = state.copy( + dataSources = state.dataSources + newConfig, + isConfigDialogOpen = false, + editingDataSource = null + ) + saveDataSources() + } + + fun updateDataSource(config: DataSourceConfig) { + val updated = config.copy( + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + state = state.copy( + dataSources = state.dataSources.map { if (it.id == config.id) updated else it }, + isConfigDialogOpen = false, + editingDataSource = null + ) + saveDataSources() + } + + fun deleteDataSource(id: String) { + state = state.copy( + dataSources = state.dataSources.filter { it.id != id }, + selectedDataSourceId = if (state.selectedDataSourceId == id) null else state.selectedDataSourceId + ) + if (state.selectedDataSourceId == id) { + disconnect() + } + saveDataSources() + } + + private fun saveDataSources() { + // TODO: Persist to storage + } + + fun selectDataSource(id: String) { + if (state.selectedDataSourceId == id) return + + disconnect() + state = state.copy(selectedDataSourceId = id) + } + + fun setFilterQuery(query: String) { + state = state.copy(filterQuery = query) + } + + fun openAddDialog() { + state = state.copy(isConfigDialogOpen = true, editingDataSource = null) + } + + fun openEditDialog(config: DataSourceConfig) { + state = state.copy(isConfigDialogOpen = true, editingDataSource = config) + } + + fun closeConfigDialog() { + state = state.copy(isConfigDialogOpen = false, editingDataSource = null) + } + + fun connect() { + val dataSource = state.selectedDataSource ?: return + + scope.launch { + state = state.copy(connectionStatus = ConnectionStatus.Connecting) + + try { + val connection = createDatabaseConnection(dataSource.toDatabaseConfig()) + if (connection.isConnected()) { + currentConnection = connection + currentSchema = connection.getSchema() + state = state.copy(connectionStatus = ConnectionStatus.Connected) + _notificationEvent.emit("Connected" to "Successfully connected to ${dataSource.name}") + } else { + state = state.copy(connectionStatus = ConnectionStatus.Error("Failed to connect")) + } + } catch (e: Exception) { + state = state.copy(connectionStatus = ConnectionStatus.Error(e.message ?: "Unknown error")) + _notificationEvent.emit("Connection Failed" to (e.message ?: "Unknown error")) + } + } + } + + fun disconnect() { + scope.launch { + try { + currentConnection?.close() + } catch (e: Exception) { + println("[ChatDB] Error closing connection: ${e.message}") + } + currentConnection = null + currentSchema = null + state = state.copy(connectionStatus = ConnectionStatus.Disconnected) + } + } + + fun sendMessage(text: String) { + if (isGenerating || text.isBlank()) return + + currentExecutionJob = scope.launch { + isGenerating = true + renderer.addUserMessage(text) + + try { + val service = llmService + if (service == null) { + renderer.renderError("LLM service not initialized. Please configure your model settings.") + isGenerating = false + return@launch + } + + val schemaContext = currentSchema?.getDescription() ?: "No database connected" + val systemPrompt = buildSystemPrompt(schemaContext) + + val response = StringBuilder() + renderer.renderLLMResponseStart() + + service.streamPrompt( + userPrompt = "$systemPrompt\n\nUser: $text", + compileDevIns = false + ).collect { chunk -> + response.append(chunk) + renderer.renderLLMResponseChunk(chunk) + } + + renderer.renderLLMResponseEnd() + + // Try to extract and execute SQL if present + extractAndExecuteSQL(response.toString()) + + } catch (e: CancellationException) { + renderer.forceStop() + renderer.renderError("Generation cancelled") + } catch (e: Exception) { + renderer.renderError("Error: ${e.message}") + } finally { + isGenerating = false + currentExecutionJob = null + } + } + } + + private fun buildSystemPrompt(schemaContext: String): String { + return """You are a helpful SQL assistant. You help users write SQL queries based on their natural language questions. + +## Database Schema +$schemaContext + +## Instructions +1. Analyze the user's question and understand what data they need +2. Generate the appropriate SQL query +3. Wrap SQL queries in ```sql code blocks +4. Explain your query briefly + +## Rules +- Only generate SELECT queries for safety +- Always use proper table and column names from the schema +- If you're unsure about the schema, ask for clarification +""" + } + + private suspend fun extractAndExecuteSQL(response: String) { + val sqlPattern = Regex("```sql\\n([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = sqlPattern.find(response) + + if (match != null && currentConnection != null) { + val sql = match.groupValues[1].trim() + try { + val result = currentConnection!!.executeQuery(sql) + // Display query result as a new message + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk("\n\n**Query Result:**\n```\n${result.toTableString()}\n```") + renderer.renderLLMResponseEnd() + } catch (e: Exception) { + renderer.renderError("Query Error: ${e.message}") + } + } + } + + fun stopGeneration() { + currentExecutionJob?.cancel() + isGenerating = false + } + + fun getSchema(): DatabaseSchema? = currentSchema + + fun dispose() { + stopGeneration() + disconnect() + scope.cancel() + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt new file mode 100644 index 0000000000..4bcfbf6492 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt @@ -0,0 +1,313 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.BorderStroke +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.BasicTextField +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.SolidColor +import androidx.compose.ui.input.key.* +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.TextFieldValue +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import cc.unitmesh.agent.Platform +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.devins.ui.compose.agent.AgentMessageList +import cc.unitmesh.devins.ui.compose.agent.ComposeRenderer +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.ConnectionStatus +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Chat pane for ChatDB - right side chat area for text-to-SQL + */ +@Composable +fun ChatDBChatPane( + renderer: ComposeRenderer, + connectionStatus: ConnectionStatus, + schema: DatabaseSchema?, + isGenerating: Boolean, + onSendMessage: (String) -> Unit, + onStopGeneration: () -> Unit, + modifier: Modifier = Modifier +) { + Column(modifier = modifier.fillMaxSize()) { + // Connection status banner + ConnectionStatusBanner( + connectionStatus = connectionStatus, + schema = schema + ) + + // Message list + Box(modifier = Modifier.weight(1f)) { + AgentMessageList( + renderer = renderer, + modifier = Modifier.fillMaxSize() + ) + + // Welcome message when no messages + if (renderer.timeline.isEmpty()) { + WelcomeMessage( + isConnected = connectionStatus is ConnectionStatus.Connected, + schema = schema, + onQuickQuery = onSendMessage, + modifier = Modifier.fillMaxSize() + ) + } + } + + HorizontalDivider() + + // Input area + ChatInputArea( + isGenerating = isGenerating, + isConnected = connectionStatus is ConnectionStatus.Connected, + onSendMessage = onSendMessage, + onStopGeneration = onStopGeneration + ) + } +} + +@Composable +private fun ConnectionStatusBanner( + connectionStatus: ConnectionStatus, + schema: DatabaseSchema? +) { + when (connectionStatus) { + is ConnectionStatus.Connected -> { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CheckCircle, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.primary + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = "Connected", + style = MaterialTheme.typography.labelLarge, + color = MaterialTheme.colorScheme.primary + ) + if (schema != null) { + Text( + text = "${schema.tables.size} tables available", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } + } + is ConnectionStatus.Disconnected -> { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "Not connected. Select a data source and connect.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + else -> { /* Connecting or Error handled elsewhere */ } + } +} + +@Composable +private fun WelcomeMessage( + isConnected: Boolean, + schema: DatabaseSchema?, + onQuickQuery: (String) -> Unit, + modifier: Modifier = Modifier +) { + Column( + modifier = modifier + .fillMaxSize() + .padding(24.dp), + verticalArrangement = Arrangement.Center, + horizontalAlignment = Alignment.CenterHorizontally + ) { + Icon( + AutoDevComposeIcons.Database, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.primary.copy(alpha = 0.6f) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "ChatDB - Text to SQL", + style = MaterialTheme.typography.headlineSmall, + color = MaterialTheme.colorScheme.onSurface + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = if (isConnected) { + "Ask questions about your data in natural language" + } else { + "Connect to a database to start querying" + }, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + if (isConnected && schema != null) { + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = "Try asking:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + val suggestions = listOf( + "Show me all tables in the database", + "What are the columns in the ${schema.tables.firstOrNull()?.name ?: "users"} table?", + "Count the total number of records" + ) + + suggestions.forEach { suggestion -> + SuggestionChip( + onClick = { onQuickQuery(suggestion) }, + label = { Text(suggestion, style = MaterialTheme.typography.bodySmall) }, + modifier = Modifier.padding(vertical = 4.dp) + ) + } + } + } +} + +@Composable +private fun ChatInputArea( + isGenerating: Boolean, + isConnected: Boolean, + onSendMessage: (String) -> Unit, + onStopGeneration: () -> Unit +) { + var inputText by remember { mutableStateOf(TextFieldValue("")) } + val focusManager = LocalFocusManager.current + + fun sendMessage() { + if (inputText.text.isNotBlank() && !isGenerating && isConnected) { + onSendMessage(inputText.text) + inputText = TextFieldValue("") + } + } + + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surface + ) { + Row( + modifier = Modifier + .padding(12.dp) + .fillMaxWidth(), + verticalAlignment = Alignment.Bottom, + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + Surface( + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(12.dp), + border = BorderStroke(1.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + BasicTextField( + value = inputText, + onValueChange = { inputText = it }, + modifier = Modifier + .fillMaxWidth() + .padding(12.dp) + .onKeyEvent { event -> + if (event.type == KeyEventType.KeyDown && event.key == Key.Enter) { + if (!event.isShiftPressed) { + sendMessage() + true + } else false + } else false + }, + textStyle = TextStyle( + color = MaterialTheme.colorScheme.onSurface, + fontSize = 14.sp + ), + cursorBrush = SolidColor(MaterialTheme.colorScheme.primary), + enabled = isConnected, + decorationBox = { innerTextField -> + Box { + if (inputText.text.isEmpty()) { + Text( + text = if (isConnected) "Ask a question about your data..." else "Connect to a database first", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f) + ) + } + innerTextField() + } + } + ) + } + + if (isGenerating) { + FilledIconButton( + onClick = onStopGeneration, + colors = IconButtonDefaults.filledIconButtonColors( + containerColor = MaterialTheme.colorScheme.error + ) + ) { + Icon(AutoDevComposeIcons.Stop, contentDescription = "Stop") + } + } else { + FilledIconButton( + onClick = { sendMessage() }, + enabled = inputText.text.isNotBlank() && isConnected + ) { + Icon(AutoDevComposeIcons.Send, contentDescription = "Send") + } + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt new file mode 100644 index 0000000000..dd5a602402 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt @@ -0,0 +1,222 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Dialog for adding/editing data source configuration + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun DataSourceConfigDialog( + existingConfig: DataSourceConfig?, + onDismiss: () -> Unit, + onSave: (DataSourceConfig) -> Unit +) { + var name by remember { mutableStateOf(existingConfig?.name ?: "") } + var dialect by remember { mutableStateOf(existingConfig?.dialect ?: DatabaseDialect.MYSQL) } + var host by remember { mutableStateOf(existingConfig?.host ?: "localhost") } + var port by remember { mutableStateOf(existingConfig?.port?.toString() ?: "3306") } + var database by remember { mutableStateOf(existingConfig?.database ?: "") } + var username by remember { mutableStateOf(existingConfig?.username ?: "") } + var password by remember { mutableStateOf(existingConfig?.password ?: "") } + var description by remember { mutableStateOf(existingConfig?.description ?: "") } + var showPassword by remember { mutableStateOf(false) } + var dialectExpanded by remember { mutableStateOf(false) } + + val isEditing = existingConfig != null + val scrollState = rememberScrollState() + + Dialog(onDismissRequest = onDismiss) { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + shape = RoundedCornerShape(16.dp) + ) { + Column( + modifier = Modifier + .padding(24.dp) + .verticalScroll(scrollState) + ) { + Text( + text = if (isEditing) "Edit Data Source" else "Add Data Source", + style = MaterialTheme.typography.headlineSmall + ) + + Spacer(modifier = Modifier.height(24.dp)) + + // Name + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name *") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Dialect dropdown + ExposedDropdownMenuBox( + expanded = dialectExpanded, + onExpandedChange = { dialectExpanded = it } + ) { + OutlinedTextField( + value = dialect.displayName, + onValueChange = {}, + readOnly = true, + label = { Text("Database Type") }, + trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, + modifier = Modifier.fillMaxWidth().menuAnchor() + ) + ExposedDropdownMenu( + expanded = dialectExpanded, + onDismissRequest = { dialectExpanded = false } + ) { + DatabaseDialect.entries.forEach { option -> + DropdownMenuItem( + text = { Text(option.displayName) }, + onClick = { + dialect = option + port = option.defaultPort.toString() + dialectExpanded = false + } + ) + } + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + // Host and Port + if (dialect != DatabaseDialect.SQLITE) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedTextField( + value = host, + onValueChange = { host = it }, + label = { Text("Host *") }, + modifier = Modifier.weight(2f), + singleLine = true + ) + OutlinedTextField( + value = port, + onValueChange = { port = it.filter { c -> c.isDigit() } }, + label = { Text("Port *") }, + modifier = Modifier.weight(1f), + singleLine = true, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number) + ) + } + Spacer(modifier = Modifier.height(16.dp)) + } + + // Database name + OutlinedTextField( + value = database, + onValueChange = { database = it }, + label = { Text(if (dialect == DatabaseDialect.SQLITE) "File Path *" else "Database *") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + if (dialect != DatabaseDialect.SQLITE) { + Spacer(modifier = Modifier.height(16.dp)) + + // Username + OutlinedTextField( + value = username, + onValueChange = { username = it }, + label = { Text("Username") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Password + OutlinedTextField( + value = password, + onValueChange = { password = it }, + label = { Text("Password") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), + trailingIcon = { + IconButton(onClick = { showPassword = !showPassword }) { + Icon( + if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, + contentDescription = if (showPassword) "Hide" else "Show" + ) + } + } + ) + } + + Spacer(modifier = Modifier.height(16.dp)) + + // Description + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description") }, + modifier = Modifier.fillMaxWidth(), + minLines = 2, + maxLines = 3 + ) + + Spacer(modifier = Modifier.height(24.dp)) + + // Buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.End, + verticalAlignment = Alignment.CenterVertically + ) { + TextButton(onClick = onDismiss) { + Text("Cancel") + } + Spacer(modifier = Modifier.width(8.dp)) + Button( + onClick = { + val config = DataSourceConfig( + id = existingConfig?.id ?: "", + name = name.trim(), + dialect = dialect, + host = host.trim(), + port = port.toIntOrNull() ?: dialect.defaultPort, + database = database.trim(), + username = username.trim(), + password = password, + description = description.trim() + ) + onSave(config) + }, + enabled = name.isNotBlank() && database.isNotBlank() && + (dialect == DatabaseDialect.SQLITE || host.isNotBlank()) + ) { + Text(if (isEditing) "Save" else "Add") + } + } + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt new file mode 100644 index 0000000000..f0c444dd5a --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt @@ -0,0 +1,312 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Data Source Panel - Left side panel for managing database connections + */ +@Composable +fun DataSourcePanel( + dataSources: List, + selectedDataSourceId: String?, + connectionStatus: ConnectionStatus, + filterQuery: String, + onFilterChange: (String) -> Unit, + onSelectDataSource: (String) -> Unit, + onAddClick: () -> Unit, + onEditClick: (DataSourceConfig) -> Unit, + onDeleteClick: (String) -> Unit, + onConnectClick: () -> Unit, + onDisconnectClick: () -> Unit, + modifier: Modifier = Modifier +) { + Column( + modifier = modifier + .fillMaxHeight() + .background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f)) + ) { + // Header with Add button + DataSourceHeader(onAddClick = onAddClick) + + HorizontalDivider() + + // Search/Filter + SearchField( + query = filterQuery, + onQueryChange = onFilterChange, + modifier = Modifier.padding(8.dp) + ) + + // Data source list + LazyColumn( + modifier = Modifier.weight(1f), + contentPadding = PaddingValues(vertical = 4.dp) + ) { + items(dataSources, key = { it.id }) { dataSource -> + DataSourceItem( + dataSource = dataSource, + isSelected = dataSource.id == selectedDataSourceId, + connectionStatus = if (dataSource.id == selectedDataSourceId) connectionStatus else ConnectionStatus.Disconnected, + onClick = { onSelectDataSource(dataSource.id) }, + onEditClick = { onEditClick(dataSource) }, + onDeleteClick = { onDeleteClick(dataSource.id) } + ) + } + } + + // Connection controls + if (selectedDataSourceId != null) { + HorizontalDivider() + ConnectionControls( + connectionStatus = connectionStatus, + onConnect = onConnectClick, + onDisconnect = onDisconnectClick + ) + } + } +} + +@Composable +private fun DataSourceHeader(onAddClick: () -> Unit) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = "Data Sources", + style = MaterialTheme.typography.titleMedium + ) + IconButton( + onClick = onAddClick, + modifier = Modifier.size(32.dp) + ) { + Icon( + AutoDevComposeIcons.Add, + contentDescription = "Add data source", + modifier = Modifier.size(20.dp) + ) + } + } +} + +@Composable +private fun SearchField( + query: String, + onQueryChange: (String) -> Unit, + modifier: Modifier = Modifier +) { + OutlinedTextField( + value = query, + onValueChange = onQueryChange, + modifier = modifier.fillMaxWidth(), + placeholder = { Text("Search...", style = MaterialTheme.typography.bodySmall) }, + leadingIcon = { + Icon( + AutoDevComposeIcons.Search, + contentDescription = null, + modifier = Modifier.size(18.dp) + ) + }, + trailingIcon = { + if (query.isNotEmpty()) { + IconButton(onClick = { onQueryChange("") }, modifier = Modifier.size(24.dp)) { + Icon(AutoDevComposeIcons.Close, contentDescription = "Clear", modifier = Modifier.size(16.dp)) + } + } + }, + singleLine = true, + shape = RoundedCornerShape(8.dp), + textStyle = MaterialTheme.typography.bodySmall, + colors = OutlinedTextFieldDefaults.colors( + focusedContainerColor = MaterialTheme.colorScheme.surface, + unfocusedContainerColor = MaterialTheme.colorScheme.surface + ) + ) +} + +@Composable +private fun DataSourceItem( + dataSource: DataSourceConfig, + isSelected: Boolean, + connectionStatus: ConnectionStatus, + onClick: () -> Unit, + onEditClick: () -> Unit, + onDeleteClick: () -> Unit +) { + var showMenu by remember { mutableStateOf(false) } + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 8.dp, vertical = 2.dp) + .clickable(onClick = onClick), + shape = RoundedCornerShape(8.dp), + color = if (isSelected) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f) + } else { + MaterialTheme.colorScheme.surface + } + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + // Status indicator + Box( + modifier = Modifier + .size(8.dp) + .clip(CircleShape) + .background( + when (connectionStatus) { + is ConnectionStatus.Connected -> MaterialTheme.colorScheme.primary + is ConnectionStatus.Connecting -> MaterialTheme.colorScheme.tertiary + is ConnectionStatus.Error -> MaterialTheme.colorScheme.error + else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.5f) + } + ) + ) + + Spacer(modifier = Modifier.width(12.dp)) + + Column(modifier = Modifier.weight(1f)) { + Text( + text = dataSource.name, + style = MaterialTheme.typography.bodyMedium, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + text = dataSource.getDisplayUrl(), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + + Box { + IconButton( + onClick = { showMenu = true }, + modifier = Modifier.size(24.dp) + ) { + Icon( + AutoDevComposeIcons.MoreVert, + contentDescription = "More options", + modifier = Modifier.size(16.dp) + ) + } + + DropdownMenu( + expanded = showMenu, + onDismissRequest = { showMenu = false } + ) { + DropdownMenuItem( + text = { Text("Edit") }, + onClick = { + showMenu = false + onEditClick() + }, + leadingIcon = { + Icon(AutoDevComposeIcons.Edit, contentDescription = null, modifier = Modifier.size(18.dp)) + } + ) + DropdownMenuItem( + text = { Text("Delete", color = MaterialTheme.colorScheme.error) }, + onClick = { + showMenu = false + onDeleteClick() + }, + leadingIcon = { + Icon( + AutoDevComposeIcons.Delete, + contentDescription = null, + modifier = Modifier.size(18.dp), + tint = MaterialTheme.colorScheme.error + ) + } + ) + } + } + } + } +} + +@Composable +private fun ConnectionControls( + connectionStatus: ConnectionStatus, + onConnect: () -> Unit, + onDisconnect: () -> Unit +) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + when (connectionStatus) { + is ConnectionStatus.Connected -> { + Button( + onClick = onDisconnect, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.error + ), + modifier = Modifier.fillMaxWidth() + ) { + Text("Disconnect") + } + } + is ConnectionStatus.Connecting -> { + Button( + onClick = {}, + enabled = false, + modifier = Modifier.fillMaxWidth() + ) { + CircularProgressIndicator( + modifier = Modifier.size(16.dp), + strokeWidth = 2.dp, + color = MaterialTheme.colorScheme.onPrimary + ) + Spacer(modifier = Modifier.width(8.dp)) + Text("Connecting...") + } + } + else -> { + Button( + onClick = onConnect, + modifier = Modifier.fillMaxWidth() + ) { + Text("Connect") + } + } + } + } + + if (connectionStatus is ConnectionStatus.Error) { + Text( + text = connectionStatus.message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp) + ) + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt new file mode 100644 index 0000000000..bd9df347fd --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt @@ -0,0 +1,168 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.model + +import kotlinx.serialization.Serializable + +/** + * Supported database dialects + */ +enum class DatabaseDialect(val displayName: String, val defaultPort: Int) { + MYSQL("MySQL", 3306), + MARIADB("MariaDB", 3306), + POSTGRESQL("PostgreSQL", 5432), + SQLITE("SQLite", 0), + ORACLE("Oracle", 1521), + SQLSERVER("SQL Server", 1433); + + companion object { + fun fromString(value: String): DatabaseDialect { + return entries.find { + it.name.equals(value, ignoreCase = true) || + it.displayName.equals(value, ignoreCase = true) + } ?: MYSQL + } + } +} + +/** + * Data source configuration for database connections + */ +@Serializable +data class DataSourceConfig( + val id: String, + val name: String, + val dialect: DatabaseDialect = DatabaseDialect.MYSQL, + val host: String = "localhost", + val port: Int = 3306, + val database: String = "", + val username: String = "", + val password: String = "", + val description: String = "", + val isDefault: Boolean = false, + val createdAt: Long = 0L, + val updatedAt: Long = 0L +) { + /** + * Get connection URL for display (without password) + */ + fun getDisplayUrl(): String { + return when (dialect) { + DatabaseDialect.SQLITE -> "sqlite://$database" + else -> "${dialect.name.lowercase()}://$host:$port/$database" + } + } + + /** + * Validate the configuration + */ + fun validate(): ValidationResult { + val errors = mutableListOf() + + if (name.isBlank()) errors.add("Name is required") + if (dialect != DatabaseDialect.SQLITE) { + if (host.isBlank()) errors.add("Host is required") + if (port <= 0 || port > 65535) errors.add("Port must be between 1 and 65535") + } + if (database.isBlank()) errors.add("Database name is required") + + return if (errors.isEmpty()) { + ValidationResult.Valid + } else { + ValidationResult.Invalid(errors) + } + } + + /** + * Convert to cc.unitmesh.agent.database.DatabaseConfig + */ + fun toDatabaseConfig(): cc.unitmesh.agent.database.DatabaseConfig { + return cc.unitmesh.agent.database.DatabaseConfig( + host = host, + port = port, + databaseName = database, + username = username, + password = password, + dialect = dialect.name + ) + } +} + +/** + * Validation result + */ +sealed class ValidationResult { + data object Valid : ValidationResult() + data class Invalid(val errors: List) : ValidationResult() +} + +/** + * Connection status for a data source + */ +sealed class ConnectionStatus { + data object Disconnected : ConnectionStatus() + data object Connecting : ConnectionStatus() + data object Connected : ConnectionStatus() + data class Error(val message: String) : ConnectionStatus() +} + +/** + * Chat message in ChatDB context + */ +@Serializable +data class ChatDBMessage( + val id: String, + val role: MessageRole, + val content: String, + val timestamp: Long, + val sql: String? = null, + val queryResult: QueryResultDisplay? = null +) + +/** + * Message role + */ +@Serializable +enum class MessageRole { + USER, + ASSISTANT, + SYSTEM +} + +/** + * Query result for display + */ +@Serializable +data class QueryResultDisplay( + val columns: List, + val rows: List>, + val rowCount: Int, + val executionTimeMs: Long = 0 +) + +/** + * UI state for ChatDB page + */ +data class ChatDBState( + val dataSources: List = emptyList(), + val selectedDataSourceId: String? = null, + val connectionStatus: ConnectionStatus = ConnectionStatus.Disconnected, + val filterQuery: String = "", + val isLoading: Boolean = false, + val error: String? = null, + val isConfigDialogOpen: Boolean = false, + val editingDataSource: DataSourceConfig? = null +) { + val selectedDataSource: DataSourceConfig? + get() = dataSources.find { it.id == selectedDataSourceId } + + val filteredDataSources: List + get() = if (filterQuery.isBlank()) { + dataSources + } else { + dataSources.filter { ds -> + ds.name.contains(filterQuery, ignoreCase = true) || + ds.database.contains(filterQuery, ignoreCase = true) || + ds.host.contains(filterQuery, ignoreCase = true) + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt index 1ffb24722d..7ebcb345e1 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt @@ -286,6 +286,7 @@ private fun AgentTypeTab( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, tint = contentColor, diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt index 33b1a3d4f5..b9aa2700c6 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt @@ -224,6 +224,7 @@ fun TopBarMenuMobile( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, modifier = Modifier.size(20.dp) @@ -258,6 +259,7 @@ fun TopBarMenuMobile( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, modifier = Modifier.size(20.dp) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt index 12267bc6d3..3328c35d89 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt @@ -113,6 +113,10 @@ object AutoDevComposeIcons { val Functions: ImageVector get() = Icons.Default.Functions val KeyboardArrowRight: ImageVector get() = Icons.Default.KeyboardArrowRight + // Database Icons + val Database: ImageVector get() = Icons.Default.Storage + val Schema: ImageVector get() = Icons.Default.TableChart + /** * Custom icons converted from SVG resources * These icons are converted from ai.svg and mcp.svg to Compose ImageVector format diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt index f512226610..75b0404edf 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt @@ -2,13 +2,19 @@ package cc.unitmesh.devins.ui.compose.state import androidx.compose.runtime.* import cc.unitmesh.agent.AgentType +import cc.unitmesh.config.AutoDevConfigWrapper import cc.unitmesh.devins.ui.state.UIStateManager +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch /** * Desktop UI State ViewModel * ็ฎก็†ๆกŒ้ข็ซฏ UI ็š„ๆ‰€ๆœ‰็Šถๆ€๏ผŒๅŒๆญฅๅ…จๅฑ€ UIStateManager */ class DesktopUiState { + private val scope = CoroutineScope(Dispatchers.Default) + // Agent Type var currentAgentType by mutableStateOf(AgentType.CODING) @@ -40,6 +46,22 @@ class DesktopUiState { // Actions fun updateAgentType(type: AgentType) { currentAgentType = type + // Save to config file for persistence + scope.launch { + try { + val typeString = when (type) { + AgentType.REMOTE -> "Remote" + AgentType.LOCAL_CHAT -> "Local" + AgentType.CODING -> "Coding" + AgentType.CODE_REVIEW -> "CodeReview" + AgentType.KNOWLEDGE -> "Documents" + AgentType.CHAT_DB -> "ChatDB" + } + AutoDevConfigWrapper.saveAgentTypePreference(typeString) + } catch (e: Exception) { + println("โš ๏ธ Failed to save agent type preference: ${e.message}") + } + } } fun toggleSessionSidebar() { From 2bd950bdba54b00812ab73cbed414800746e573f Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 22:55:03 +0800 Subject: [PATCH 02/34] refactor(agent): simplify agent type handling and ChatDB workspace Refactor agent type string conversion to use getDisplayName() and make ChatDB workspace parameter optional. Remove SessionApp and streamline ChatDB routing logic. --- .../idea/toolwindow/IdeaAgentViewModel.kt | 10 +- .../cc/unitmesh/devins/ui/app/SessionApp.kt | 330 ------------------ .../unitmesh/devins/ui/compose/AutoDevApp.kt | 9 +- .../ui/compose/agent/AgentInterfaceRouter.kt | 32 +- .../ui/compose/agent/chatdb/ChatDBPage.kt | 16 +- .../compose/agent/chatdb/ChatDBViewModel.kt | 2 +- .../devins/ui/compose/state/DesktopUiState.kt | 9 +- 7 files changed, 22 insertions(+), 386 deletions(-) delete mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt index 1c2b02edfd..f4fd6e7107 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt @@ -243,15 +243,7 @@ class IdeaAgentViewModel( // Save to config file for persistence coroutineScope.launch { try { - val typeString = when (agentType) { - AgentType.REMOTE -> "Remote" - AgentType.LOCAL_CHAT -> "Local" - AgentType.CODING -> "Coding" - AgentType.CODE_REVIEW -> "CodeReview" - AgentType.KNOWLEDGE -> "Documents" - AgentType.CHAT_DB -> "ChatDB" - } - + val typeString = agentType.getDisplayName() AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { // Silently fail - not critical if we can't save preference diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt deleted file mode 100644 index 683e75ee30..0000000000 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt +++ /dev/null @@ -1,330 +0,0 @@ -package cc.unitmesh.devins.ui.app - -import androidx.compose.foundation.layout.* -import androidx.compose.material3.* -import androidx.compose.runtime.* -import androidx.compose.ui.Modifier -import cc.unitmesh.devins.ui.compose.theme.AutoDevTheme -import cc.unitmesh.devins.ui.compose.theme.ThemeManager -import cc.unitmesh.devins.ui.project.ProjectListScreen -import cc.unitmesh.devins.ui.project.ProjectViewModel -import cc.unitmesh.devins.ui.remote.RemoteAgentClient -import cc.unitmesh.devins.ui.session.* -import cc.unitmesh.devins.ui.task.TaskExecutionScreen -import cc.unitmesh.devins.ui.task.TaskListScreen -import cc.unitmesh.devins.ui.task.TaskViewModel -import kotlinx.coroutines.launch - -/** - * SessionApp - ไผš่ฏ็ฎก็†ๅบ”็”จ - * ๆ”ฏๆŒ Android ๅ’Œ Desktop ๅนณๅฐ - * - * ๅŠŸ่ƒฝ๏ผš - * 1. ็”จๆˆท็™ปๅฝ•/ๆณจๅ†Œ - * 2. ้กน็›ฎ็ฎก็†๏ผˆๅˆ›ๅปบใ€ๅˆ—่กจใ€Git ๆ”ฏๆŒ๏ผ‰ - * 3. ไปปๅŠก็ฎก็†๏ผˆๅŸบไบŽ้กน็›ฎๅˆ›ๅปบไปปๅŠก๏ผ‰ - * 4. AI Agent ๆ‰ง่กŒ๏ผˆ้›†ๆˆ RemoteAgentClient๏ผ‰ - */ -@Composable -fun SessionApp( - serverUrl: String = "http://localhost:8080", - useBottomNavigation: Boolean = false // Android ไฝฟ็”จๅบ•้ƒจๅฏผ่ˆช๏ผŒDesktop ไฝฟ็”จไพง่พนๅฏผ่ˆช -) { - val currentTheme = ThemeManager.currentTheme - - AutoDevTheme(themeMode = currentTheme) { - SessionAppContent( - serverUrl = serverUrl, - useBottomNavigation = useBottomNavigation - ) - } -} - -@Composable -private fun SessionAppContent( - serverUrl: String, - useBottomNavigation: Boolean -) { - // ๅˆๅง‹ๅŒ–ๅฎขๆˆท็ซฏๅ’Œ ViewModel - val sessionClient = remember { SessionClient(serverUrl) } - val remoteAgentClient = remember { RemoteAgentClient(serverUrl) } - - val projectClient = remember { - cc.unitmesh.devins.ui.project.ProjectClient( - serverUrl, - sessionClient.httpClient - ) - } - - val sessionViewModel = remember { SessionViewModel(sessionClient) } - val projectViewModel = remember { ProjectViewModel(projectClient) } - val taskViewModel = remember { TaskViewModel(sessionClient, remoteAgentClient) } - - val isAuthenticated by sessionViewModel.isAuthenticated.collectAsState() - val currentProject by projectViewModel.currentProject.collectAsState() - val currentTask by taskViewModel.currentTask.collectAsState() - - // ็›‘ๅฌ่ฎค่ฏ็Šถๆ€๏ผŒๅŒๆญฅ token - LaunchedEffect(isAuthenticated, sessionClient.authToken) { - if (isAuthenticated && sessionClient.authToken != null) { - projectClient.setAuthToken(sessionClient.authToken!!) - } - } - - // ๅฑๅน•็Šถๆ€็ฎก็† - var currentScreen by remember { mutableStateOf(AppScreen.LOGIN) } - var skipLogin by remember { mutableStateOf(false) } - - // ็›‘ๅฌ่ฎค่ฏ็Šถๆ€ - LaunchedEffect(isAuthenticated) { - if (isAuthenticated && currentScreen == AppScreen.LOGIN) { - currentScreen = AppScreen.PROJECTS - } else if (!isAuthenticated && !skipLogin) { - currentScreen = AppScreen.LOGIN - } - } - - // ็›‘ๅฌ้กน็›ฎ้€‰ๆ‹ฉ - LaunchedEffect(currentProject) { - currentProject?.let { - taskViewModel.setCurrentProject(it) - } - } - - DisposableEffect(Unit) { - onDispose { - sessionViewModel.onCleared() - projectViewModel.onCleared() - taskViewModel.onCleared() - } - } - - when { - !isAuthenticated && !skipLogin -> { - LoginScreen( - viewModel = sessionViewModel, - onLoginSuccess = { - currentScreen = AppScreen.PROJECTS - }, - onSkipLogin = if (useBottomNavigation) { - // Android ๅนณๅฐๅ…่ฎธ่ทณ่ฟ‡็™ปๅฝ• - { - skipLogin = true - currentScreen = AppScreen.PROJECTS - } - } else null - ) - } - currentTask != null -> { - // ไปปๅŠกๆ‰ง่กŒ็•Œ้ข๏ผˆๅ…จๅฑ๏ผ‰ - TaskExecutionScreen( - viewModel = taskViewModel, - onBack = { - currentScreen = AppScreen.TASKS - } - ) - } - useBottomNavigation -> { - // Android ๅบ•้ƒจๅฏผ่ˆช - AndroidNavigationLayout( - currentScreen = currentScreen, - onScreenChange = { currentScreen = it }, - sessionViewModel = sessionViewModel, - projectViewModel = projectViewModel, - taskViewModel = taskViewModel - ) - } - else -> { - // Desktop ไพง่พนๅฏผ่ˆช - DesktopNavigationLayout( - currentScreen = currentScreen, - onScreenChange = { currentScreen = it }, - sessionViewModel = sessionViewModel, - projectViewModel = projectViewModel, - taskViewModel = taskViewModel - ) - } - } -} - -/** - * Android ๅบ•้ƒจๅฏผ่ˆชๅธƒๅฑ€ - * ไฝฟ็”จๆ–ฐ็š„ NavLayout ็ป„ไปถ๏ผŒๆ”ฏๆŒ Drawer + BottomNavigation - */ -@Composable -private fun AndroidNavigationLayout( - currentScreen: AppScreen, - onScreenChange: (AppScreen) -> Unit, - sessionViewModel: SessionViewModel, - projectViewModel: ProjectViewModel, - taskViewModel: TaskViewModel -) { - val scope = rememberCoroutineScope() - - MobileNavLayout( - currentScreen = currentScreen, - onScreenChange = onScreenChange, - sessionViewModel = sessionViewModel - ) { paddingValues -> - Box(modifier = Modifier.padding(paddingValues)) { - when (currentScreen) { - AppScreen.PROJECTS -> { - ProjectListScreen( - viewModel = projectViewModel, - onProjectClick = { project -> - projectViewModel.selectProject(project) - onScreenChange(AppScreen.TASKS) - } - ) - } - AppScreen.TASKS -> { - TaskListScreen( - viewModel = taskViewModel, - onTaskClick = { task -> - // ไปปๅŠก็‚นๅ‡ปไผš่‡ชๅŠจๅˆ‡ๆขๅˆฐ TaskExecutionScreen - } - ) - } - AppScreen.SESSIONS -> { - SessionListScreen( - viewModel = sessionViewModel, - onSessionClick = { session -> - scope.launch { - sessionViewModel.joinSession(session.id) - } - }, - onCreateSession = { - // ๅˆ›ๅปบไผš่ฏ - }, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - AppScreen.PROFILE -> { - ProfileScreen( - viewModel = sessionViewModel, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - else -> {} - } - } - } -} - -/** - * Desktop ไพง่พนๅฏผ่ˆชๅธƒๅฑ€ - * ไฝฟ็”จๆ–ฐ็š„ NavLayout ็ป„ไปถ - */ -@OptIn(ExperimentalMaterial3Api::class) -@Composable -private fun DesktopNavigationLayout( - currentScreen: AppScreen, - onScreenChange: (AppScreen) -> Unit, - sessionViewModel: SessionViewModel, - projectViewModel: ProjectViewModel, - taskViewModel: TaskViewModel -) { - val scope = rememberCoroutineScope() - - DesktopNavLayout( - currentScreen = currentScreen, - onScreenChange = onScreenChange, - sessionViewModel = sessionViewModel - ) { - when (currentScreen) { - AppScreen.PROJECTS -> { - ProjectListScreen( - viewModel = projectViewModel, - onProjectClick = { project -> - projectViewModel.selectProject(project) - onScreenChange(AppScreen.TASKS) - } - ) - } - AppScreen.TASKS -> { - TaskListScreen( - viewModel = taskViewModel, - onTaskClick = { task -> - // ไปปๅŠก็‚นๅ‡ปไผš่‡ชๅŠจๅˆ‡ๆขๅˆฐ TaskExecutionScreen - } - ) - } - AppScreen.SESSIONS -> { - SessionListScreen( - viewModel = sessionViewModel, - onSessionClick = { session -> - scope.launch { - sessionViewModel.joinSession(session.id) - } - }, - onCreateSession = { - // ๅˆ›ๅปบไผš่ฏ - }, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - AppScreen.PROFILE -> { - ProfileScreen( - viewModel = sessionViewModel, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - else -> {} - } - } -} - -/** - * SessionAppContext - Session ๅบ”็”จไธŠไธ‹ๆ–‡ - * - * ็”จไบŽ็ฎก็† Session ๅบ”็”จ็š„ๅ…จๅฑ€็Šถๆ€ๅ’Œ้…็ฝฎ - * ๅ‚่€ƒ SessionAppContext ็š„่ฎพ่ฎก๏ผŒๆไพ›็ปŸไธ€็š„ไธŠไธ‹ๆ–‡็ฎก็† - */ -data class SessionAppContext( - val serverUrl: String, - val useBottomNavigation: Boolean, - val sessionViewModel: SessionViewModel, - val projectViewModel: ProjectViewModel, - val taskViewModel: TaskViewModel, - val currentScreen: AppScreen = AppScreen.LOGIN, - val skipLogin: Boolean = false, - val isAuthenticated: Boolean = false -) { - companion object { - /** - * ๅˆ›ๅปบ้ป˜่ฎค็š„ SessionAppContext - */ - fun create( - serverUrl: String, - useBottomNavigation: Boolean, - sessionClient: SessionClient, - projectClient: cc.unitmesh.devins.ui.project.ProjectClient, - remoteAgentClient: cc.unitmesh.devins.ui.remote.RemoteAgentClient - ): SessionAppContext { - return SessionAppContext( - serverUrl = serverUrl, - useBottomNavigation = useBottomNavigation, - sessionViewModel = SessionViewModel(sessionClient), - projectViewModel = ProjectViewModel(projectClient), - taskViewModel = TaskViewModel(sessionClient, remoteAgentClient) - ) - } - } -} - diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt index ec30fbc1d0..6a02e1bde4 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt @@ -172,14 +172,7 @@ private fun AutoDevContent( scope.launch { try { // Save as string for config compatibility - val typeString = when (type) { - AgentType.REMOTE -> "Remote" - AgentType.LOCAL_CHAT -> "Local" - AgentType.CODING -> "Coding" - AgentType.CODE_REVIEW -> "CodeReview" - AgentType.KNOWLEDGE -> "Documents" - AgentType.CHAT_DB -> "ChatDB" - } + val typeString = type.getDisplayName() AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { println("โš ๏ธ ไฟๅญ˜ Agent ็ฑปๅž‹ๅคฑ่ดฅ: ${e.message}") diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt index 8855769002..cdd5750306 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt @@ -61,28 +61,15 @@ fun AgentInterfaceRouter( ) { when (selectedAgentType) { AgentType.CHAT_DB -> { - if (workspace != null) { - ChatDBPage( - workspace = workspace, - llmService = llmService, - modifier = modifier, - onBack = { - onAgentTypeChange(AgentType.CODING) - }, - onNotification = onNotification - ) - } else { - // Show placeholder when workspace is not available - androidx.compose.foundation.layout.Box( - modifier = modifier.fillMaxSize(), - contentAlignment = androidx.compose.ui.Alignment.Center - ) { - androidx.compose.material3.Text( - text = "Please select a workspace to use ChatDB", - style = androidx.compose.material3.MaterialTheme.typography.bodyLarge - ) - } - } + ChatDBPage( + workspace = workspace, + llmService = llmService, + modifier = modifier, + onBack = { + onAgentTypeChange(AgentType.CODING) + }, + onNotification = onNotification + ) } AgentType.KNOWLEDGE -> { @@ -142,6 +129,7 @@ fun AgentInterfaceRouter( modifier = modifier ) } + AgentType.LOCAL_CHAT, AgentType.CODING -> { AgentChatInterface( diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt index f557da592d..ebdbc566ae 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -15,14 +15,14 @@ import kotlinx.coroutines.flow.collectLatest /** * ChatDB Page - Main page for text-to-SQL agent - * + * * Left side: Data source management panel * Right side: Chat area for natural language to SQL queries */ @OptIn(ExperimentalMaterial3Api::class) @Composable fun ChatDBPage( - workspace: Workspace, + workspace: Workspace? = null, llmService: KoogLLMService?, modifier: Modifier = Modifier, onBack: () -> Unit, @@ -30,21 +30,21 @@ fun ChatDBPage( ) { val viewModel = remember { ChatDBViewModel(workspace) } val state = viewModel.state - + // Collect notifications LaunchedEffect(viewModel) { viewModel.notificationEvent.collectLatest { (title, message) -> onNotification(title, message) } } - + // Cleanup on dispose DisposableEffect(viewModel) { onDispose { viewModel.dispose() } } - + Scaffold( topBar = { TopAppBar( @@ -93,9 +93,9 @@ fun ChatDBPage( onDisconnectClick = viewModel::disconnect, modifier = Modifier.width(280.dp) ) - + VerticalDivider() - + // Right panel - Chat area ChatDBChatPane( renderer = viewModel.renderer, @@ -107,7 +107,7 @@ fun ChatDBPage( modifier = Modifier.weight(1f) ) } - + // Config dialog if (state.isConfigDialogOpen) { DataSourceConfigDialog( diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index c49301d4ad..fd53238234 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -23,7 +23,7 @@ import kotlin.uuid.Uuid * Manages data sources, database connections, and chat interactions for text-to-SQL. */ class ChatDBViewModel( - private val workspace: Workspace + private val workspace: Workspace? = null ) { private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt index 75b0404edf..b38e5b7876 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt @@ -49,14 +49,7 @@ class DesktopUiState { // Save to config file for persistence scope.launch { try { - val typeString = when (type) { - AgentType.REMOTE -> "Remote" - AgentType.LOCAL_CHAT -> "Local" - AgentType.CODING -> "Coding" - AgentType.CODE_REVIEW -> "CodeReview" - AgentType.KNOWLEDGE -> "Documents" - AgentType.CHAT_DB -> "ChatDB" - } + val typeString = type.getDisplayName() AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { println("โš ๏ธ Failed to save agent type preference: ${e.message}") From 277aff6942c403a1dea72a3b88bf10eb2acc9c79 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 22:57:23 +0800 Subject: [PATCH 03/34] feat(ui): conditionally show SessionSidebar by agent type Show SessionSidebar only for agent types that require it, hiding it on pages with their own navigation such as CHAT_DB, KNOWLEDGE, and CODE_REVIEW. --- .../unitmesh/devins/ui/compose/AutoDevApp.kt | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt index 6a02e1bde4..2bb4783a31 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt @@ -445,37 +445,47 @@ private fun AutoDevContent( containerColor = MaterialTheme.colorScheme.background, contentWindowInsets = WindowInsets.systemBars.only(WindowInsetsSides.Horizontal) ) { paddingValues -> + // Determine if SessionSidebar should be shown based on agent type + // Hide for pages that have their own navigation (CHAT_DB, KNOWLEDGE, CODE_REVIEW) + val shouldShowSessionSidebar = selectedAgentType in listOf( + AgentType.CODING, + AgentType.LOCAL_CHAT, + AgentType.REMOTE + ) + Row( modifier = Modifier .fillMaxSize() .padding(paddingValues) ) { - SessionSidebar( - chatHistoryManager = chatHistoryManager, - currentSessionId = chatHistoryManager.getCurrentSession().id, - isExpanded = showSessionSidebar, - onSessionSelected = { sessionId -> - if (agentSessionSelectedHandler != null) { - agentSessionSelectedHandler?.invoke(sessionId) - } else { - chatHistoryManager.switchSession(sessionId) - messages = chatHistoryManager.getMessages() - currentStreamingOutput = "" - } - }, - onNewChat = { - if (agentNewChatHandler != null) { - agentNewChatHandler?.invoke() - } else { - chatHistoryManager.createSession() - messages = emptyList() - currentStreamingOutput = "" + if (shouldShowSessionSidebar) { + SessionSidebar( + chatHistoryManager = chatHistoryManager, + currentSessionId = chatHistoryManager.getCurrentSession().id, + isExpanded = showSessionSidebar, + onSessionSelected = { sessionId -> + if (agentSessionSelectedHandler != null) { + agentSessionSelectedHandler?.invoke(sessionId) + } else { + chatHistoryManager.switchSession(sessionId) + messages = chatHistoryManager.getMessages() + currentStreamingOutput = "" + } + }, + onNewChat = { + if (agentNewChatHandler != null) { + agentNewChatHandler?.invoke() + } else { + chatHistoryManager.createSession() + messages = emptyList() + currentStreamingOutput = "" + } + }, + onRenameSession = { sessionId, newTitle -> + chatHistoryManager.renameSession(sessionId, newTitle) } - }, - onRenameSession = { sessionId, newTitle -> - chatHistoryManager.renameSession(sessionId, newTitle) - } - ) + ) + } Column( modifier = Modifier From 7b25d35e415a1920f224307906d71b0bef5e92dc Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 23:16:49 +0800 Subject: [PATCH 04/34] feat(chatdb): add ChatDB agent and CLI support #508 Introduce ChatDB agent models, executor, schema linker, and CLI entry point. Update build configuration for new modules. --- .../cc/unitmesh/agent/chatdb/ChatDBModels.kt | 144 ++++++++ .../cc/unitmesh/agent/chatdb/SchemaLinker.kt | 169 +++++++++ .../cc/unitmesh/agent/chatdb/ChatDBAgent.kt | 173 +++++++++ .../agent/chatdb/ChatDBAgentExecutor.kt | 334 ++++++++++++++++++ mpp-ui/build.gradle.kts | 41 +++ .../cc/unitmesh/server/cli/ChatDBCli.kt | 187 ++++++++++ 6 files changed, 1048 insertions(+) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt create mode 100644 mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt create mode 100644 mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt create mode 100644 mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt new file mode 100644 index 0000000000..4161e130fa --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt @@ -0,0 +1,144 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.QueryResult +import kotlinx.serialization.Serializable + +/** + * ChatDB Task - Input for the Text2SQL Agent + * + * Contains the natural language query and database context + */ +@Serializable +data class ChatDBTask( + /** + * Natural language query from user + * e.g., "Show me the top 10 customers by total order amount" + */ + val query: String, + + /** + * Database schema for context (optional, will be fetched if not provided) + */ + val schema: DatabaseSchema? = null, + + /** + * Additional context or constraints + * e.g., "Only consider orders from 2024" + */ + val additionalContext: String = "", + + /** + * Maximum number of rows to return + */ + val maxRows: Int = 100, + + /** + * Whether to generate visualization after query + */ + val generateVisualization: Boolean = true +) + +/** + * ChatDB Result - Output from the Text2SQL Agent + */ +@Serializable +data class ChatDBResult( + /** + * Whether the query was successful + */ + val success: Boolean, + + /** + * Human-readable message + */ + val message: String, + + /** + * Generated SQL query + */ + val generatedSql: String? = null, + + /** + * Query execution result + */ + val queryResult: QueryResult? = null, + + /** + * Generated PlotDSL code for visualization (if applicable) + */ + val plotDslCode: String? = null, + + /** + * Number of revision attempts made + */ + val revisionAttempts: Int = 0, + + /** + * Errors encountered during execution + */ + val errors: List = emptyList(), + + /** + * Metadata about the execution + */ + val metadata: Map = emptyMap() +) + +/** + * Schema Linking Result - Tables and columns relevant to the query + */ +@Serializable +data class SchemaLinkingResult( + /** + * Relevant table names + */ + val relevantTables: List, + + /** + * Relevant column names (table.column format) + */ + val relevantColumns: List, + + /** + * Keywords extracted from the query + */ + val keywords: List, + + /** + * Confidence score (0.0 - 1.0) + */ + val confidence: Double = 0.0 +) + +/** + * SQL Revision Context - Context for the Revise Agent + */ +@Serializable +data class SqlRevisionContext( + /** + * Original natural language query + */ + val originalQuery: String, + + /** + * Generated SQL that failed + */ + val failedSql: String, + + /** + * Error message from execution + */ + val errorMessage: String, + + /** + * Schema context + */ + val schemaDescription: String, + + /** + * Previous revision attempts + */ + val previousAttempts: List = emptyList() +) + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt new file mode 100644 index 0000000000..6691f992be --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt @@ -0,0 +1,169 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema + +/** + * Schema Linker - Keyword-based schema linking for Text2SQL + * + * This class finds relevant tables and columns based on natural language queries. + * It uses keyword matching and fuzzy matching to identify relevant schema elements. + * + * Future enhancements: + * - Vector similarity search using embeddings + * - LLM-based schema linking + * - Historical query pattern matching + */ +class SchemaLinker { + + /** + * Common SQL keywords to filter out + */ + private val stopWords = setOf( + "select", "from", "where", "and", "or", "not", "in", "is", "null", + "order", "by", "group", "having", "limit", "offset", "join", "on", + "left", "right", "inner", "outer", "cross", "union", "all", "distinct", + "as", "asc", "desc", "between", "like", "exists", "case", "when", "then", + "else", "end", "count", "sum", "avg", "min", "max", "the", "a", "an", + "show", "me", "get", "find", "list", "display", "give", "what", "which", + "how", "many", "much", "all", "each", "every", "any", "some", "most", + "top", "first", "last", "recent", "latest", "oldest", "highest", "lowest", + "total", "average", "number", "amount", "value", "data", "information" + ) + + /** + * Link natural language query to relevant schema elements + */ + fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + val keywords = extractKeywords(query) + val relevantTables = mutableListOf() + val relevantColumns = mutableListOf() + var totalScore = 0.0 + var matchCount = 0 + + for (table in schema.tables) { + val tableScore = calculateTableRelevance(table, keywords) + if (tableScore > 0) { + relevantTables.add(table.name) + totalScore += tableScore + matchCount++ + + // Find relevant columns in this table + for (column in table.columns) { + val columnScore = calculateColumnRelevance(column.name, column.comment, keywords) + if (columnScore > 0) { + relevantColumns.add("${table.name}.${column.name}") + } + } + } + } + + // If no tables matched, include all tables (fallback) + if (relevantTables.isEmpty()) { + relevantTables.addAll(schema.tables.map { it.name }) + } + + val confidence = if (matchCount > 0) (totalScore / matchCount).coerceIn(0.0, 1.0) else 0.0 + + return SchemaLinkingResult( + relevantTables = relevantTables, + relevantColumns = relevantColumns, + keywords = keywords, + confidence = confidence + ) + } + + /** + * Extract keywords from natural language query + */ + fun extractKeywords(query: String): List { + return query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in stopWords } + .distinct() + } + + /** + * Calculate relevance score for a table + */ + private fun calculateTableRelevance(table: TableSchema, keywords: List): Double { + var score = 0.0 + val tableName = table.name.lowercase() + val tableComment = table.comment?.lowercase() ?: "" + + for (keyword in keywords) { + // Exact match in table name + if (tableName == keyword) { + score += 1.0 + } + // Partial match in table name + else if (tableName.contains(keyword) || keyword.contains(tableName)) { + score += 0.7 + } + // Match in table comment + else if (tableComment.contains(keyword)) { + score += 0.5 + } + // Fuzzy match (Levenshtein distance) + else if (fuzzyMatch(tableName, keyword)) { + score += 0.3 + } + + // Check column names + for (column in table.columns) { + val colName = column.name.lowercase() + if (colName == keyword || colName.contains(keyword)) { + score += 0.4 + } + } + } + + return score + } + + /** + * Calculate relevance score for a column + */ + private fun calculateColumnRelevance(columnName: String, comment: String?, keywords: List): Double { + var score = 0.0 + val colName = columnName.lowercase() + val colComment = comment?.lowercase() ?: "" + + for (keyword in keywords) { + if (colName == keyword) score += 1.0 + else if (colName.contains(keyword)) score += 0.7 + else if (colComment.contains(keyword)) score += 0.5 + else if (fuzzyMatch(colName, keyword)) score += 0.3 + } + + return score + } + + /** + * Simple fuzzy matching using edit distance threshold + */ + private fun fuzzyMatch(s1: String, s2: String): Boolean { + if (kotlin.math.abs(s1.length - s2.length) > 3) return false + val distance = levenshteinDistance(s1, s2) + val threshold = kotlin.math.min(s1.length, s2.length) / 3 + return distance <= threshold + } + + /** + * Calculate Levenshtein distance between two strings + */ + private fun levenshteinDistance(s1: String, s2: String): Int { + val dp = Array(s1.length + 1) { IntArray(s2.length + 1) } + for (i in 0..s1.length) dp[i][0] = i + for (j in 0..s2.length) dp[0][j] = j + for (i in 1..s1.length) { + for (j in 1..s2.length) { + val cost = if (s1[i - 1] == s2[j - 1]) 0 else 1 + dp[i][j] = minOf(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) + } + } + return dp[s1.length][s2.length] + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt new file mode 100644 index 0000000000..7be0155fab --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt @@ -0,0 +1,173 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.core.MainAgent +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.policy.DefaultPolicyEngine +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.render.DefaultCodingAgentRenderer +import cc.unitmesh.agent.tool.shell.DefaultShellExecutor +import cc.unitmesh.agent.tool.shell.ShellExecutor +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.agent.tool.filesystem.ToolFileSystem +import cc.unitmesh.agent.tool.registry.ToolRegistry +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig + +/** + * ChatDB Agent - Text2SQL Agent for natural language database queries + * + * This agent converts natural language queries to SQL, executes them, + * and optionally generates visualizations of the results. + * + * Features: + * - Schema Linking: Keyword-based search to find relevant tables/columns + * - SQL Generation: LLM generates SQL from natural language + * - Revise Agent: Self-correction loop using JSqlParser for SQL validation + * - Query Execution: Execute validated SQL and return results + * - Visualization: Optional PlotDSL generation for data visualization + * + * Based on GitHub Issue #508: https://github.com/phodal/auto-dev/issues/508 + */ +class ChatDBAgent( + private val projectPath: String, + private val llmService: KoogLLMService, + private val databaseConfig: DatabaseConfig, + override val maxIterations: Int = 10, + private val renderer: CodingAgentRenderer = DefaultCodingAgentRenderer(), + private val fileSystem: ToolFileSystem? = null, + private val shellExecutor: ShellExecutor? = null, + private val mcpToolConfigService: McpToolConfigService, + private val enableLLMStreaming: Boolean = true +) : MainAgent( + AgentDefinition( + name = "ChatDBAgent", + displayName = "ChatDB Agent", + description = "Text2SQL Agent that converts natural language to SQL queries with schema linking and self-correction", + promptConfig = PromptConfig( + systemPrompt = SYSTEM_PROMPT + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 10, maxTimeMinutes = 5) + ) +) { + private val logger = getLogger("ChatDBAgent") + + private val actualFileSystem = fileSystem ?: DefaultToolFileSystem(projectPath = projectPath) + + private val toolRegistry = ToolRegistry( + fileSystem = actualFileSystem, + shellExecutor = shellExecutor ?: DefaultShellExecutor(), + configService = mcpToolConfigService, + llmService = llmService + ) + + private val policyEngine = DefaultPolicyEngine() + + private val toolOrchestrator = ToolOrchestrator( + registry = toolRegistry, + policyEngine = policyEngine, + renderer = renderer, + mcpConfigService = mcpToolConfigService + ) + + private var databaseConnection: DatabaseConnection? = null + + private val executor: ChatDBAgentExecutor by lazy { + val connection = databaseConnection ?: createDatabaseConnection(databaseConfig) + databaseConnection = connection + + ChatDBAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + databaseConnection = connection, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming + ) + } + + override fun validateInput(input: Map): ChatDBTask { + val query = input["query"] as? String + ?: throw IllegalArgumentException("Missing required parameter: query") + + return ChatDBTask( + query = query, + additionalContext = input["additionalContext"] as? String ?: "", + maxRows = (input["maxRows"] as? Number)?.toInt() ?: 100, + generateVisualization = input["generateVisualization"] as? Boolean ?: true + ) + } + + override suspend fun execute( + input: ChatDBTask, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + logger.info { "Starting ChatDB Agent for query: ${input.query}" } + + val systemPrompt = buildSystemPrompt() + val result = executor.execute(input, systemPrompt, onProgress) + + return ToolResult.AgentResult( + success = result.success, + content = result.message, + metadata = mapOf( + "generatedSql" to (result.generatedSql ?: ""), + "rowCount" to (result.queryResult?.rowCount?.toString() ?: "0"), + "revisionAttempts" to result.revisionAttempts.toString(), + "hasVisualization" to (result.plotDslCode != null).toString() + ) + ) + } + + private fun buildSystemPrompt(): String { + return SYSTEM_PROMPT + } + + override fun formatOutput(output: ToolResult.AgentResult): String { + return output.content + } + + override fun getParameterClass(): String = "ChatDBTask" + + /** + * Close database connection when done + */ + suspend fun close() { + databaseConnection?.close() + databaseConnection = null + } + + companion object { + const val SYSTEM_PROMPT = """You are an expert SQL developer and data analyst. Your task is to: + +1. Understand the user's natural language query +2. Generate accurate, safe, and efficient SQL queries +3. Only generate SELECT queries (read-only operations) +4. Use proper SQL syntax for the target database +5. Consider performance implications (use indexes, avoid SELECT *) +6. Handle edge cases and NULL values appropriately + +When generating SQL: +- Always wrap SQL in ```sql code blocks +- Use meaningful aliases for tables and columns +- Add comments for complex queries +- Limit results appropriately (use LIMIT clause) +- Prefer explicit column names over SELECT * + +For visualization: +- When asked to visualize data, generate PlotDSL code +- Choose appropriate chart types based on data characteristics +- Wrap PlotDSL in ```plotdsl code blocks""" + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt new file mode 100644 index 0000000000..04a0c96cc8 --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -0,0 +1,334 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.conversation.ConversationManager +import cc.unitmesh.agent.database.* +import cc.unitmesh.agent.executor.BaseAgentExecutor +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.orchestrator.ToolExecutionContext +import cc.unitmesh.agent.orchestrator.ToolExecutionResult +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.state.ToolCall +import cc.unitmesh.agent.state.ToolExecutionState +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.schema.ToolResultFormatter +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.yield +import kotlinx.datetime.Clock + +/** + * ChatDB Agent Executor - Text2SQL Agent with Schema Linking and Revise Agent + * + * Features: + * 1. Schema Linking - Keyword-based search to find relevant tables/columns + * 2. SQL Generation - LLM generates SQL from natural language + * 3. SQL Validation - JSqlParser validates SQL syntax and safety + * 4. Revise Agent - Self-correction loop for fixing SQL errors + * 5. Query Execution - Execute validated SQL and return results + * 6. Visualization - Optional PlotDSL generation for data visualization + */ +class ChatDBAgentExecutor( + projectPath: String, + llmService: KoogLLMService, + toolOrchestrator: ToolOrchestrator, + renderer: CodingAgentRenderer, + private val databaseConnection: DatabaseConnection, + maxIterations: Int = 10, + enableLLMStreaming: Boolean = true +) : BaseAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming +) { + private val logger = getLogger("ChatDBAgentExecutor") + private val schemaLinker = SchemaLinker() + private val sqlValidator = SqlValidator() + private val maxRevisionAttempts = 3 + + suspend fun execute( + task: ChatDBTask, + systemPrompt: String, + onProgress: (String) -> Unit = {} + ): ChatDBResult { + resetExecution() + conversationManager = ConversationManager(llmService, systemPrompt) + + val errors = mutableListOf() + var generatedSql: String? = null + var queryResult: QueryResult? = null + var plotDslCode: String? = null + var revisionAttempts = 0 + + try { + // Step 1: Get database schema + onProgress("๐Ÿ“Š Fetching database schema...") + val schema = task.schema ?: databaseConnection.getSchema() + + // Step 2: Schema Linking + onProgress("๐Ÿ”— Performing schema linking...") + val linkingResult = schemaLinker.link(task.query, schema) + logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables" } + + // Step 3: Build context with relevant schema + val relevantSchema = buildRelevantSchemaDescription(schema, linkingResult) + val initialMessage = buildInitialUserMessage(task, relevantSchema, linkingResult) + + // Step 4: Generate SQL with LLM + onProgress("๐Ÿค– Generating SQL query...") + val llmResponse = StringBuilder() + val response = getLLMResponse(initialMessage, compileDevIns = false) { chunk -> + onProgress(chunk) + } + llmResponse.append(response) + + // Step 5: Extract SQL from response + generatedSql = extractSqlFromResponse(llmResponse.toString()) + if (generatedSql == null) { + errors.add("Failed to extract SQL from LLM response") + return buildResult(false, errors, null, null, null, 0) + } + + // Step 6: Validate and revise SQL + var validatedSql = generatedSql + while (revisionAttempts < maxRevisionAttempts) { + val validation = sqlValidator.validate(validatedSql!!) + if (validation.isSafe) { + break + } + + revisionAttempts++ + onProgress("๐Ÿ”„ Revising SQL (attempt $revisionAttempts)...") + + val revisionContext = SqlRevisionContext( + originalQuery = task.query, + failedSql = validatedSql, + errorMessage = validation.errors.joinToString("; "), + schemaDescription = relevantSchema + ) + + validatedSql = reviseSql(revisionContext, onProgress) + if (validatedSql == null) { + errors.add("SQL revision failed after $revisionAttempts attempts") + break + } + } + + generatedSql = validatedSql + + // Step 7: Execute SQL + if (generatedSql != null) { + onProgress("โšก Executing SQL query...") + try { + queryResult = databaseConnection.executeQuery(generatedSql) + onProgress("โœ… Query returned ${queryResult.rowCount} rows") + } catch (e: Exception) { + errors.add("Query execution failed: ${e.message}") + logger.error(e) { "Query execution failed" } + } + } + + // Step 8: Generate visualization if requested + if (task.generateVisualization && queryResult != null && !queryResult.isEmpty()) { + onProgress("๐Ÿ“ˆ Generating visualization...") + plotDslCode = generateVisualization(task.query, queryResult, onProgress) + } + + } catch (e: Exception) { + logger.error(e) { "ChatDB execution failed" } + errors.add("Execution failed: ${e.message}") + } + + return buildResult( + success = errors.isEmpty() && queryResult != null, + errors = errors, + generatedSql = generatedSql, + queryResult = queryResult, + plotDslCode = plotDslCode, + revisionAttempts = revisionAttempts + ) + } + + private fun resetExecution() { + currentIteration = 0 + } + + private fun buildRelevantSchemaDescription( + schema: DatabaseSchema, + linkingResult: SchemaLinkingResult + ): String { + val relevantTables = schema.tables.filter { it.name in linkingResult.relevantTables } + return buildString { + appendLine("## Relevant Database Schema") + appendLine() + relevantTables.forEach { table -> + append(table.getDescription()) + appendLine() + } + } + } + + private fun buildInitialUserMessage( + task: ChatDBTask, + schemaDescription: String, + linkingResult: SchemaLinkingResult + ): String { + return buildString { + appendLine("Please generate a SQL query for the following request:") + appendLine() + appendLine("**User Query**: ${task.query}") + appendLine() + if (task.additionalContext.isNotBlank()) { + appendLine("**Additional Context**: ${task.additionalContext}") + appendLine() + } + appendLine("**Maximum Rows**: ${task.maxRows}") + appendLine() + appendLine(schemaDescription) + appendLine() + appendLine("**Schema Linking Keywords**: ${linkingResult.keywords.joinToString(", ")}") + appendLine("**Confidence**: ${String.format("%.2f", linkingResult.confidence)}") + appendLine() + appendLine("Please generate a safe, read-only SQL query. Wrap the SQL in a ```sql code block.") + } + } + + private fun extractSqlFromResponse(response: String): String? { + val codeFence = CodeFence.parse(response) + if (codeFence.languageId.lowercase() == "sql" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } + + // Try to find SQL block manually + val sqlPattern = Regex("```sql\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = sqlPattern.find(response) + if (match != null) { + return match.groupValues[1].trim() + } + + // Last resort: look for SELECT statement + val selectPattern = Regex("(SELECT[\\s\\S]*?;)", RegexOption.IGNORE_CASE) + val selectMatch = selectPattern.find(response) + return selectMatch?.groupValues?.get(1)?.trim() + } + + private suspend fun reviseSql( + context: SqlRevisionContext, + onProgress: (String) -> Unit + ): String? { + val revisionPrompt = buildString { + appendLine("The following SQL query has errors and needs to be fixed:") + appendLine() + appendLine("**Original User Query**: ${context.originalQuery}") + appendLine() + appendLine("**Failed SQL**:") + appendLine("```sql") + appendLine(context.failedSql) + appendLine("```") + appendLine() + appendLine("**Error**: ${context.errorMessage}") + appendLine() + appendLine("**Schema**:") + appendLine(context.schemaDescription) + appendLine() + appendLine("Please fix the SQL and provide a corrected version. Wrap the SQL in a ```sql code block.") + } + + try { + val response = getLLMResponse(revisionPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + return extractSqlFromResponse(response) + } catch (e: Exception) { + logger.error(e) { "SQL revision failed" } + return null + } + } + + private suspend fun generateVisualization( + query: String, + result: QueryResult, + onProgress: (String) -> Unit + ): String? { + val visualizationPrompt = buildString { + appendLine("Based on the following query result, generate a PlotDSL visualization:") + appendLine() + appendLine("**Original Query**: $query") + appendLine() + appendLine("**Query Result** (${result.rowCount} rows):") + appendLine("```csv") + appendLine(result.toCsvString()) + appendLine("```") + appendLine() + appendLine("Generate a PlotDSL chart that best visualizes this data.") + appendLine("Choose an appropriate chart type (bar, line, scatter, etc.) based on the data.") + appendLine("Wrap the PlotDSL code in a ```plotdsl code block.") + } + + try { + val response = getLLMResponse(visualizationPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + + val codeFence = CodeFence.parse(response) + if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } + + // Try to find plotdsl block manually + val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = plotPattern.find(response) + return match?.groupValues?.get(1)?.trim() + } catch (e: Exception) { + logger.error(e) { "Visualization generation failed" } + return null + } + } + + private fun buildResult( + success: Boolean, + errors: List, + generatedSql: String?, + queryResult: QueryResult?, + plotDslCode: String?, + revisionAttempts: Int + ): ChatDBResult { + val message = if (success) { + buildString { + appendLine("Query executed successfully!") + if (queryResult != null) { + appendLine() + appendLine("**Results** (${queryResult.rowCount} rows):") + appendLine(queryResult.toTableString()) + } + if (plotDslCode != null) { + appendLine() + appendLine("**Visualization**:") + appendLine("```plotdsl") + appendLine(plotDslCode) + appendLine("```") + } + } + } else { + "Query failed: ${errors.joinToString("; ")}" + } + + return ChatDBResult( + success = success, + message = message, + generatedSql = generatedSql, + queryResult = queryResult, + plotDslCode = plotDslCode, + revisionAttempts = revisionAttempts, + errors = errors + ) + } + + override fun buildContinuationMessage(): String { + return "Please continue with the database query based on the results above." + } +} + diff --git a/mpp-ui/build.gradle.kts b/mpp-ui/build.gradle.kts index c311f0eee5..e44332fd82 100644 --- a/mpp-ui/build.gradle.kts +++ b/mpp-ui/build.gradle.kts @@ -624,6 +624,47 @@ tasks.register("runDomainDictCli") { standardInput = System.`in` } +// Task to run ChatDB CLI (Text2SQL Agent) +tasks.register("runChatDBCli") { + group = "application" + description = "Run ChatDB CLI (Text2SQL Agent with Schema Linking and Revise Agent)" + + val jvmCompilation = kotlin.jvm().compilations.getByName("main") + classpath(jvmCompilation.output, configurations["jvmRuntimeClasspath"]) + mainClass.set("cc.unitmesh.server.cli.ChatDBCli") + + // Pass database connection properties + if (project.hasProperty("dbHost")) { + systemProperty("dbHost", project.property("dbHost") as String) + } + if (project.hasProperty("dbPort")) { + systemProperty("dbPort", project.property("dbPort") as String) + } + if (project.hasProperty("dbName")) { + systemProperty("dbName", project.property("dbName") as String) + } + if (project.hasProperty("dbUser")) { + systemProperty("dbUser", project.property("dbUser") as String) + } + if (project.hasProperty("dbPassword")) { + systemProperty("dbPassword", project.property("dbPassword") as String) + } + if (project.hasProperty("dbDialect")) { + systemProperty("dbDialect", project.property("dbDialect") as String) + } + if (project.hasProperty("dbQuery")) { + systemProperty("dbQuery", project.property("dbQuery") as String) + } + if (project.hasProperty("generateVisualization")) { + systemProperty("generateVisualization", project.property("generateVisualization") as String) + } + if (project.hasProperty("maxRows")) { + systemProperty("maxRows", project.property("maxRows") as String) + } + + standardInput = System.`in` +} + // Ktlint configuration configure { version.set("1.0.1") diff --git a/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt new file mode 100644 index 0000000000..d16927e65f --- /dev/null +++ b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt @@ -0,0 +1,187 @@ +package cc.unitmesh.server.cli + +import cc.unitmesh.agent.chatdb.ChatDBAgent +import cc.unitmesh.agent.chatdb.ChatDBTask +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.config.ToolConfigFile +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.LLMProviderType +import cc.unitmesh.llm.ModelConfig +import com.charleskorn.kaml.Yaml +import com.charleskorn.kaml.YamlConfiguration +import kotlinx.coroutines.runBlocking +import java.io.File + +/** + * JVM CLI for testing ChatDBAgent - Text2SQL Agent + * + * Usage: + * ```bash + * ./gradlew :mpp-ui:runChatDBCli \ + * -PdbHost=localhost \ + * -PdbPort=3306 \ + * -PdbName=testdb \ + * -PdbUser=root \ + * -PdbPassword=prisma \ + * -PdbQuery="Show me the top 10 customers by order amount" + * ``` + */ +object ChatDBCli { + + @JvmStatic + fun main(args: Array) { + println("=".repeat(80)) + println("AutoDev ChatDB Agent CLI (Text2SQL)") + println("=".repeat(80)) + + // Parse database connection arguments + val dbHost = System.getProperty("dbHost") ?: args.getOrNull(0) ?: "localhost" + val dbPort = System.getProperty("dbPort")?.toIntOrNull() ?: args.getOrNull(1)?.toIntOrNull() ?: 3306 + val dbName = System.getProperty("dbName") ?: args.getOrNull(2) ?: run { + System.err.println("Usage: -PdbName= -PdbQuery= [-PdbHost=localhost] [-PdbPort=3306] [-PdbUser=root] [-PdbPassword=]") + return + } + val dbUser = System.getProperty("dbUser") ?: args.getOrNull(3) ?: "root" + val dbPassword = System.getProperty("dbPassword") ?: args.getOrNull(4) ?: "" + val dbDialect = System.getProperty("dbDialect") ?: args.getOrNull(5) ?: "MariaDB" + + val query = System.getProperty("dbQuery") ?: args.getOrNull(6) ?: run { + System.err.println("Usage: -PdbName= -PdbQuery= [-PdbHost=localhost] [-PdbPort=3306] [-PdbUser=root] [-PdbPassword=]") + return + } + + val generateVisualization = System.getProperty("generateVisualization")?.toBoolean() ?: true + val maxRows = System.getProperty("maxRows")?.toIntOrNull() ?: 100 + + println("๐Ÿ—„๏ธ Database: $dbDialect://$dbHost:$dbPort/$dbName") + println("๐Ÿ‘ค User: $dbUser") + println("๐Ÿ“ Query: $query") + println("๐Ÿ“Š Generate Visualization: $generateVisualization") + println("๐Ÿ“ Max Rows: $maxRows") + println() + + runBlocking { + var agent: ChatDBAgent? = null + try { + val startTime = System.currentTimeMillis() + + // Load LLM configuration from ~/.autodev/config.yaml + val configFile = File(System.getProperty("user.home"), ".autodev/config.yaml") + if (!configFile.exists()) { + System.err.println("โŒ Configuration file not found: ${configFile.absolutePath}") + System.err.println(" Please create ~/.autodev/config.yaml with your LLM configuration") + return@runBlocking + } + + val yamlContent = configFile.readText() + val yaml = Yaml(configuration = YamlConfiguration(strictMode = false)) + val config = yaml.decodeFromString(AutoDevConfig.serializer(), yamlContent) + + val activeName = config.active + val activeConfig = config.configs.find { it.name == activeName } + + if (activeConfig == null) { + System.err.println("โŒ Active configuration '$activeName' not found in config.yaml") + System.err.println(" Available configs: ${config.configs.map { it.name }.joinToString(", ")}") + return@runBlocking + } + + println("๐Ÿ“ Using LLM config: ${activeConfig.name} (${activeConfig.provider}/${activeConfig.model})") + + // Convert provider string to LLMProviderType + val providerType = when (activeConfig.provider.lowercase()) { + "openai" -> LLMProviderType.OPENAI + "anthropic" -> LLMProviderType.ANTHROPIC + "google" -> LLMProviderType.GOOGLE + "deepseek" -> LLMProviderType.DEEPSEEK + "ollama" -> LLMProviderType.OLLAMA + "openrouter" -> LLMProviderType.OPENROUTER + "glm" -> LLMProviderType.GLM + "qwen" -> LLMProviderType.QWEN + "kimi" -> LLMProviderType.KIMI + else -> LLMProviderType.CUSTOM_OPENAI_BASE + } + + val llmService = KoogLLMService( + ModelConfig( + provider = providerType, + modelName = activeConfig.model, + apiKey = activeConfig.apiKey, + temperature = activeConfig.temperature ?: 0.7, + maxTokens = activeConfig.maxTokens ?: 4096, + baseUrl = activeConfig.baseUrl ?: "" + ) + ) + + // Create database configuration + val databaseConfig = DatabaseConfig( + host = dbHost, + port = dbPort, + databaseName = dbName, + username = dbUser, + password = dbPassword, + dialect = dbDialect + ) + + val renderer = CodingCliRenderer() + val mcpConfigService = McpToolConfigService(ToolConfigFile()) + val projectPath = System.getProperty("user.dir") + + println("๐Ÿง  Creating ChatDBAgent...") + agent = ChatDBAgent( + projectPath = projectPath, + llmService = llmService, + databaseConfig = databaseConfig, + maxIterations = 10, + renderer = renderer, + fileSystem = DefaultToolFileSystem(projectPath), + mcpToolConfigService = mcpConfigService, + enableLLMStreaming = true + ) + + println("โœ… Agent created") + println() + println("๐Ÿš€ Executing query...") + println() + + val task = ChatDBTask( + query = query, + maxRows = maxRows, + generateVisualization = generateVisualization + ) + + val result = agent.execute(task) { progress -> + println(" $progress") + } + + val totalTime = System.currentTimeMillis() - startTime + + println() + println("=".repeat(80)) + println("๐Ÿ“Š Result:") + println("=".repeat(80)) + println(result.content) + println() + + if (result.success) { + println("โœ… Query completed successfully") + } else { + println("โŒ Query failed") + } + println("โฑ๏ธ Total time: ${totalTime}ms") + println("๐Ÿ”„ Revision attempts: ${result.metadata["revisionAttempts"] ?: "0"}") + println("๐Ÿ“ˆ Has visualization: ${result.metadata["hasVisualization"] ?: "false"}") + + } catch (e: Exception) { + System.err.println("โŒ Error: ${e.message}") + e.printStackTrace() + } finally { + // Close database connection + agent?.close() + } + } + } +} + From 16c9cc0783ef5552cb8f4ef1e8fa659b26f49a67 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 23:35:53 +0800 Subject: [PATCH 05/34] feat(chatdb): add SqlReviseAgent SubAgent and tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add SqlReviseAgent as a standalone SubAgent for SQL self-correction - Add JSqlParserValidator for JVM-specific SQL validation using JSqlParser - Add comprehensive tests for SqlReviseAgent, JSqlParserValidator, and SchemaLinker - Update ChatDBAgentExecutor to use the new SqlReviseAgent - Fix CodeFence property reference (languageId instead of language) Implements the Revise Agent (่‡ชๆˆ‘ไฟฎๆญฃ้—ญ็Žฏ) feature from issue #508 --- .../unitmesh/agent/subagent/SqlReviseAgent.kt | 294 ++++++++++++++++++ .../unitmesh/agent/chatdb/SchemaLinkerTest.kt | 205 ++++++++++++ .../agent/subagent/SqlReviseAgentTest.kt | 161 ++++++++++ .../agent/chatdb/ChatDBAgentExecutor.kt | 86 ++--- .../agent/subagent/JSqlParserValidator.kt | 124 ++++++++ .../agent/subagent/JSqlParserValidatorTest.kt | 185 +++++++++++ 6 files changed, 995 insertions(+), 60 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt create mode 100644 mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt create mode 100644 mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt create mode 100644 mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt create mode 100644 mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt new file mode 100644 index 0000000000..72f237da38 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt @@ -0,0 +1,294 @@ +package cc.unitmesh.agent.subagent + +import cc.unitmesh.agent.core.SubAgent +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.schema.DeclarativeToolSchema +import cc.unitmesh.agent.tool.schema.SchemaPropertyBuilder.integer +import cc.unitmesh.agent.tool.schema.SchemaPropertyBuilder.string +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig +import kotlinx.serialization.Serializable + +/** + * SQL Revise Agent Schema - Tool definition for SQL revision + */ +object SqlReviseAgentSchema : DeclarativeToolSchema( + description = "Revise and fix SQL queries based on validation errors or execution failures", + properties = mapOf( + "originalQuery" to string( + description = "The original natural language query from user", + required = true + ), + "failedSql" to string( + description = "The SQL query that failed validation or execution", + required = true + ), + "errorMessage" to string( + description = "The error message from validation or execution", + required = true + ), + "schemaDescription" to string( + description = "Database schema description for context", + required = true + ), + "maxAttempts" to integer( + description = "Maximum number of revision attempts", + required = false + ) + ) +) { + override fun getExampleUsage(toolName: String): String { + return "/$toolName originalQuery=\"Show top customers\" failedSql=\"SELECT * FROM customer\" errorMessage=\"Table 'customer' doesn't exist\"" + } +} + +/** + * SQL Revise Agent - Self-correction loop for SQL queries + * + * This SubAgent is responsible for: + * 1. Analyzing SQL validation/execution errors + * 2. Understanding the original user intent + * 3. Generating corrected SQL based on schema and error context + * 4. Iterating until a valid SQL is produced or max attempts reached + * + * Based on GitHub Issue #508: https://github.com/phodal/auto-dev/issues/508 + * Implements the "Revise Agent (่‡ชๆˆ‘ไฟฎๆญฃ้—ญ็Žฏ)" feature + */ +class SqlReviseAgent( + private val llmService: KoogLLMService, + private val sqlValidator: SqlValidatorInterface? = null +) : SubAgent( + definition = createDefinition() +) { + private val logger = getLogger("SqlReviseAgent") + + override fun validateInput(input: Map): SqlRevisionInput { + return SqlRevisionInput( + originalQuery = input["originalQuery"] as? String + ?: throw IllegalArgumentException("originalQuery is required"), + failedSql = input["failedSql"] as? String + ?: throw IllegalArgumentException("failedSql is required"), + errorMessage = input["errorMessage"] as? String + ?: throw IllegalArgumentException("errorMessage is required"), + schemaDescription = input["schemaDescription"] as? String ?: "", + previousAttempts = (input["previousAttempts"] as? List<*>)?.filterIsInstance() ?: emptyList(), + maxAttempts = (input["maxAttempts"] as? Number)?.toInt() ?: 3 + ) + } + + override fun getParameterClass(): String = SqlRevisionInput::class.simpleName ?: "SqlRevisionInput" + + override suspend fun execute( + input: SqlRevisionInput, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + onProgress("๐Ÿ”„ SQL Revise Agent - Starting revision") + onProgress("Original query: ${input.originalQuery.take(50)}...") + onProgress("Error: ${input.errorMessage.take(80)}...") + + var currentSql = input.failedSql + var currentError = input.errorMessage + val attempts = input.previousAttempts.toMutableList() + var attemptCount = attempts.size + + while (attemptCount < input.maxAttempts) { + attemptCount++ + onProgress("๐Ÿ“ Revision attempt $attemptCount/${input.maxAttempts}") + + // Build revision context + val context = buildRevisionContext(input, currentSql, currentError, attempts) + + // Ask LLM for revised SQL + val revisedSql = askLLMForRevision(context, onProgress) + + if (revisedSql == null) { + onProgress("โŒ Failed to generate revised SQL") + return ToolResult.AgentResult( + success = false, + content = "Failed to generate revised SQL after $attemptCount attempts", + metadata = mapOf( + "attempts" to attemptCount.toString(), + "lastError" to currentError + ) + ) + } + + attempts.add(revisedSql) + + // Validate the revised SQL if validator is available + if (sqlValidator != null) { + val validation = sqlValidator.validate(revisedSql) + if (validation.isValid) { + onProgress("โœ… SQL validated successfully") + return ToolResult.AgentResult( + success = true, + content = revisedSql, + metadata = mapOf( + "attempts" to attemptCount.toString(), + "validated" to "true" + ) + ) + } else { + currentSql = revisedSql + currentError = validation.errors.joinToString("; ") + onProgress("โš ๏ธ Validation failed: ${currentError.take(50)}...") + } + } else { + // No validator, return the revised SQL + onProgress("โœ… SQL revised (no validation)") + return ToolResult.AgentResult( + success = true, + content = revisedSql, + metadata = mapOf( + "attempts" to attemptCount.toString(), + "validated" to "false" + ) + ) + } + } + + onProgress("โŒ Max revision attempts reached") + return ToolResult.AgentResult( + success = false, + content = "Max revision attempts ($attemptCount) reached. Last SQL: $currentSql", + metadata = mapOf( + "attempts" to attemptCount.toString(), + "lastError" to currentError, + "lastSql" to currentSql + ) + ) + } + + override fun formatOutput(output: ToolResult.AgentResult): String = output.content + + private fun buildRevisionContext( + input: SqlRevisionInput, + currentSql: String, + currentError: String, + previousAttempts: List + ): String = buildString { + appendLine("# SQL Revision Context") + appendLine() + appendLine("## Original User Query") + appendLine(input.originalQuery) + appendLine() + appendLine("## Database Schema") + appendLine("```") + appendLine(input.schemaDescription.take(2000)) + appendLine("```") + appendLine() + appendLine("## Failed SQL") + appendLine("```sql") + appendLine(currentSql) + appendLine("```") + appendLine() + appendLine("## Error Message") + appendLine("```") + appendLine(currentError) + appendLine("```") + if (previousAttempts.isNotEmpty()) { + appendLine() + appendLine("## Previous Attempts (avoid repeating)") + previousAttempts.forEachIndexed { i, sql -> + appendLine("Attempt ${i + 1}: $sql") + } + } + } + + private suspend fun askLLMForRevision(context: String, onProgress: (String) -> Unit): String? { + val systemPrompt = """ +You are a SQL Revision Agent. Your task is to fix SQL queries that failed validation or execution. + +## Guidelines: +1. Analyze the error message carefully +2. Consider the original user intent +3. Use the provided schema to ensure correct table/column names +4. Generate a corrected SQL query +5. Avoid repeating previous failed attempts + +## Response Format: +Return ONLY the corrected SQL query wrapped in ```sql code block. +Do not include explanations outside the code block. + +Example: +```sql +SELECT * FROM customers WHERE id = 1 +``` +""".trimIndent() + + val userPrompt = """ +$context + +**Task:** Generate a corrected SQL query that fixes the error while preserving the original intent. +""".trimIndent() + + return try { + val response = llmService.sendPrompt("$systemPrompt\n\n$userPrompt") + extractSqlFromResponse(response) + } catch (e: Exception) { + logger.error(e) { "LLM revision failed: ${e.message}" } + null + } + } + + private fun extractSqlFromResponse(response: String): String? { + val codeFences = CodeFence.parseAll(response) + val sqlFence = codeFences.find { it.languageId.lowercase() == "sql" } + if (sqlFence != null) { + return sqlFence.text.trim() + } + // Fallback: try to extract from markdown + val sqlMatch = Regex("```sql\\s*([\\s\\S]*?)\\s*```", RegexOption.IGNORE_CASE) + .find(response)?.groupValues?.get(1) + return sqlMatch?.trim() + } + + companion object { + private fun createDefinition() = AgentDefinition( + name = "SqlReviseAgent", + displayName = "SQL Revise Agent", + description = "Revises and fixes SQL queries based on validation errors or execution failures", + promptConfig = PromptConfig( + systemPrompt = "You are a SQL Revision Agent specialized in fixing SQL queries." + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 5, maxTimeMinutes = 2) + ) + } +} + +/** + * Input for SQL revision + */ +@Serializable +data class SqlRevisionInput( + val originalQuery: String, + val failedSql: String, + val errorMessage: String, + val schemaDescription: String = "", + val previousAttempts: List = emptyList(), + val maxAttempts: Int = 3 +) + +/** + * Interface for SQL validation - platform-specific implementations + */ +interface SqlValidatorInterface { + fun validate(sql: String): SqlValidationResult +} + +/** + * Result of SQL validation + */ +@Serializable +data class SqlValidationResult( + val isValid: Boolean, + val errors: List = emptyList(), + val warnings: List = emptyList() +) + diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt new file mode 100644 index 0000000000..adce0ee17a --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt @@ -0,0 +1,205 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.ColumnSchema +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertNotNull + +/** + * Tests for SchemaLinker - keyword-based schema linking for Text2SQL + */ +class SchemaLinkerTest { + + private val schemaLinker = SchemaLinker() + + // ============= Keyword Extraction Tests ============= + + @Test + fun testExtractKeywordsBasic() { + val keywords = schemaLinker.extractKeywords("Show me all users") + + // "show" and "me" are stop words, "all" is also filtered + assertTrue(keywords.contains("users")) + } + + @Test + fun testExtractKeywordsFiltersStopWords() { + val keywords = schemaLinker.extractKeywords("Show me the top 10 users with the most orders") + + // Stop words should be filtered (show, me, the, top, most are in stopWords) + assertTrue(keywords.none { it in listOf("show", "me", "the", "top", "most") }) + // Important words should remain + assertTrue(keywords.contains("users")) + assertTrue(keywords.contains("orders")) + } + + @Test + fun testExtractKeywordsFiltersShortWords() { + val keywords = schemaLinker.extractKeywords("Get a list of all items") + + // Short words (length <= 2) should be filtered + assertTrue(keywords.none { it.length <= 2 }) + } + + @Test + fun testExtractKeywordsLowercase() { + val keywords = schemaLinker.extractKeywords("Show USERS and ORDERS") + + assertTrue(keywords.contains("users")) + assertTrue(keywords.contains("orders")) + assertTrue(keywords.none { word -> word.any { c -> c.isUpperCase() } }) + } + + // ============= Schema Linking Tests ============= + + private fun createTestSchema(): DatabaseSchema { + return DatabaseSchema( + tables = listOf( + TableSchema( + name = "users", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("name", "VARCHAR", false, null), + ColumnSchema("email", "VARCHAR", false, null), + ColumnSchema("created_at", "DATETIME", false, null) + ) + ), + TableSchema( + name = "orders", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("user_id", "INT", false, null), + ColumnSchema("total", "DECIMAL", false, null), + ColumnSchema("status", "VARCHAR", false, null) + ) + ), + TableSchema( + name = "products", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("name", "VARCHAR", false, null), + ColumnSchema("price", "DECIMAL", false, null), + ColumnSchema("category", "VARCHAR", false, null) + ) + ) + ) + ) + } + + @Test + fun testLinkSchemaFindsRelevantTables() { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me all users with their orders", schema) + + assertTrue(result.relevantTables.contains("users")) + assertTrue(result.relevantTables.contains("orders")) + } + + @Test + fun testLinkSchemaFindsRelevantColumns() { + val schema = createTestSchema() + + val result = schemaLinker.link("Show user names and order totals", schema) + + assertTrue(result.relevantColumns.any { it.contains("name") }) + } + + @Test + fun testLinkSchemaWithNoMatches() { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me the weather forecast", schema) + + // Should still return a result - fallback includes all tables + assertNotNull(result) + assertTrue(result.relevantTables.isNotEmpty()) + } + + @Test + fun testLinkSchemaWithPartialMatch() { + val schema = DatabaseSchema( + tables = listOf( + TableSchema( + name = "user_accounts", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("username", "VARCHAR", false, null), + ColumnSchema("email", "VARCHAR", false, null) + ) + ), + TableSchema( + name = "customer_orders", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("customer_id", "INT", false, null), + ColumnSchema("amount", "DECIMAL", false, null) + ) + ) + ) + ) + + val result = schemaLinker.link("Show me all users", schema) + + // Should find user_accounts due to partial match + assertTrue(result.relevantTables.any { it.contains("user") }) + } + + // ============= Edge Cases ============= + + @Test + fun testExtractKeywordsEmptyQuery() { + val keywords = schemaLinker.extractKeywords("") + + assertTrue(keywords.isEmpty()) + } + + @Test + fun testLinkSchemaEmptySchema() { + val schema = DatabaseSchema(tables = emptyList()) + val result = schemaLinker.link("Show me all users", schema) + + assertTrue(result.relevantTables.isEmpty()) + assertTrue(result.relevantColumns.isEmpty()) + } + + @Test + fun testLinkSchemaWithSpecialCharacters() { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me users' emails!", schema) + + assertTrue(result.relevantTables.contains("users")) + } + + @Test + fun testLinkSchemaWithNumbers() { + val schema = createTestSchema() + + val result = schemaLinker.link("Show top 10 users with orders over 100", schema) + + assertTrue(result.relevantTables.contains("users")) + assertTrue(result.relevantTables.contains("orders")) + } + + // ============= Schema Linking Result Tests ============= + + @Test + fun testSchemaLinkingResultDescription() { + val result = SchemaLinkingResult( + relevantTables = listOf("users", "orders"), + relevantColumns = listOf("users.name", "orders.total"), + keywords = listOf("users", "orders"), + confidence = 0.9 + ) + + assertEquals(2, result.relevantTables.size) + assertEquals(2, result.relevantColumns.size) + assertEquals(2, result.keywords.size) + assertEquals(0.9, result.confidence) + } +} + diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt new file mode 100644 index 0000000000..f7881c35a9 --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt @@ -0,0 +1,161 @@ +package cc.unitmesh.agent.subagent + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Tests for SqlReviseAgent data classes and schema + */ +class SqlReviseAgentTest { + + // ============= Input Validation Tests ============= + + @Test + fun testSqlRevisionInputCreation() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM user", + errorMessage = "Table 'user' doesn't exist", + schemaDescription = "Tables: users (id, name, email)", + maxAttempts = 3 + ) + + assertEquals("Show me all users", input.originalQuery) + assertEquals("SELECT * FROM user", input.failedSql) + assertEquals("Table 'user' doesn't exist", input.errorMessage) + assertEquals("Tables: users (id, name, email)", input.schemaDescription) + assertEquals(3, input.maxAttempts) + } + + @Test + fun testSqlRevisionInputDefaultMaxAttempts() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM user", + errorMessage = "Table 'user' doesn't exist", + schemaDescription = "Tables: users" + ) + + assertEquals(3, input.maxAttempts) + } + + // ============= Schema Tests ============= + + @Test + fun testSqlReviseAgentSchemaExampleUsage() { + val example = SqlReviseAgentSchema.getExampleUsage("sql-revise") + + assertTrue(example.contains("/sql-revise")) + assertTrue(example.contains("originalQuery=")) + assertTrue(example.contains("failedSql=")) + assertTrue(example.contains("errorMessage=")) + } + + @Test + fun testSqlReviseAgentSchemaToJsonSchema() { + val jsonSchema = SqlReviseAgentSchema.toJsonSchema() + + assertNotNull(jsonSchema) + val schemaString = jsonSchema.toString() + assertTrue(schemaString.contains("originalQuery")) + assertTrue(schemaString.contains("failedSql")) + assertTrue(schemaString.contains("errorMessage")) + assertTrue(schemaString.contains("schemaDescription")) + assertTrue(schemaString.contains("maxAttempts")) + } + + @Test + fun testSqlReviseAgentSchemaDescription() { + val description = SqlReviseAgentSchema.description + + assertTrue(description.contains("Revise")) + assertTrue(description.contains("SQL")) + } + + // ============= Validation Result Tests ============= + + @Test + fun testSqlValidationResultSuccess() { + val result = SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = listOf("Using SELECT *") + ) + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + assertEquals(1, result.warnings.size) + assertEquals("Using SELECT *", result.warnings[0]) + } + + @Test + fun testSqlValidationResultFailure() { + val result = SqlValidationResult( + isValid = false, + errors = listOf("Syntax error near 'FORM'"), + warnings = emptyList() + ) + + assertEquals(false, result.isValid) + assertEquals("Syntax error near 'FORM'", result.errors[0]) + assertTrue(result.warnings.isEmpty()) + } + + // ============= Edge Cases ============= + + @Test + fun testSqlRevisionInputWithEmptySchema() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM users", + errorMessage = "Unknown error", + schemaDescription = "" + ) + + assertEquals("", input.schemaDescription) + } + + @Test + fun testSqlRevisionInputWithComplexQuery() { + val complexSql = """ + SELECT u.name, COUNT(o.id) as order_count + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.created_at > '2023-01-01' + GROUP BY u.name + HAVING COUNT(o.id) > 5 + ORDER BY order_count DESC + LIMIT 10 + """.trimIndent() + + val input = SqlRevisionInput( + originalQuery = "Show top 10 users with most orders since 2023", + failedSql = complexSql, + errorMessage = "Column 'created_at' not found", + schemaDescription = "Tables: users (id, name, created_date), orders (id, user_id, total)" + ) + + assertTrue(input.failedSql.contains("LEFT JOIN")) + assertTrue(input.failedSql.contains("GROUP BY")) + assertTrue(input.failedSql.contains("HAVING")) + } + + @Test + fun testSqlValidationResultWithMultipleWarnings() { + val result = SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = listOf( + "Using SELECT *", + "No LIMIT clause", + "Consider adding index on user_id" + ) + ) + + assertTrue(result.isValid) + assertEquals(3, result.warnings.size) + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 04a0c96cc8..362a8c019c 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -4,18 +4,13 @@ import cc.unitmesh.agent.conversation.ConversationManager import cc.unitmesh.agent.database.* import cc.unitmesh.agent.executor.BaseAgentExecutor import cc.unitmesh.agent.logging.getLogger -import cc.unitmesh.agent.orchestrator.ToolExecutionContext -import cc.unitmesh.agent.orchestrator.ToolExecutionResult import cc.unitmesh.agent.orchestrator.ToolOrchestrator import cc.unitmesh.agent.render.CodingAgentRenderer -import cc.unitmesh.agent.state.ToolCall -import cc.unitmesh.agent.state.ToolExecutionState -import cc.unitmesh.agent.tool.ToolResult -import cc.unitmesh.agent.tool.schema.ToolResultFormatter +import cc.unitmesh.agent.subagent.JSqlParserValidator +import cc.unitmesh.agent.subagent.SqlReviseAgent +import cc.unitmesh.agent.subagent.SqlRevisionInput import cc.unitmesh.devins.parser.CodeFence import cc.unitmesh.llm.KoogLLMService -import kotlinx.coroutines.yield -import kotlinx.datetime.Clock /** * ChatDB Agent Executor - Text2SQL Agent with Schema Linking and Revise Agent @@ -46,7 +41,8 @@ class ChatDBAgentExecutor( ) { private val logger = getLogger("ChatDBAgentExecutor") private val schemaLinker = SchemaLinker() - private val sqlValidator = SqlValidator() + private val jsqlValidator = JSqlParserValidator() + private val sqlReviseAgent = SqlReviseAgent(llmService, jsqlValidator) private val maxRevisionAttempts = 3 suspend fun execute( @@ -92,31 +88,34 @@ class ChatDBAgentExecutor( return buildResult(false, errors, null, null, null, 0) } - // Step 6: Validate and revise SQL + // Step 6: Validate and revise SQL using SqlReviseAgent var validatedSql = generatedSql - while (revisionAttempts < maxRevisionAttempts) { - val validation = sqlValidator.validate(validatedSql!!) - if (validation.isSafe) { - break - } - - revisionAttempts++ - onProgress("๐Ÿ”„ Revising SQL (attempt $revisionAttempts)...") - - val revisionContext = SqlRevisionContext( + val validation = jsqlValidator.validate(validatedSql!!) + if (!validation.isValid) { + onProgress("๐Ÿ”„ SQL validation failed, invoking SqlReviseAgent...") + + val revisionInput = SqlRevisionInput( originalQuery = task.query, failedSql = validatedSql, errorMessage = validation.errors.joinToString("; "), - schemaDescription = relevantSchema + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts ) - - validatedSql = reviseSql(revisionContext, onProgress) - if (validatedSql == null) { - errors.add("SQL revision failed after $revisionAttempts attempts") - break + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + revisionAttempts = revisionResult.metadata["attempts"]?.toIntOrNull() ?: 0 + + if (revisionResult.success) { + validatedSql = revisionResult.content + onProgress("โœ… SQL revised successfully after $revisionAttempts attempts") + } else { + errors.add("SQL revision failed: ${revisionResult.content}") } } - + generatedSql = validatedSql // Step 7: Execute SQL @@ -215,39 +214,6 @@ class ChatDBAgentExecutor( return selectMatch?.groupValues?.get(1)?.trim() } - private suspend fun reviseSql( - context: SqlRevisionContext, - onProgress: (String) -> Unit - ): String? { - val revisionPrompt = buildString { - appendLine("The following SQL query has errors and needs to be fixed:") - appendLine() - appendLine("**Original User Query**: ${context.originalQuery}") - appendLine() - appendLine("**Failed SQL**:") - appendLine("```sql") - appendLine(context.failedSql) - appendLine("```") - appendLine() - appendLine("**Error**: ${context.errorMessage}") - appendLine() - appendLine("**Schema**:") - appendLine(context.schemaDescription) - appendLine() - appendLine("Please fix the SQL and provide a corrected version. Wrap the SQL in a ```sql code block.") - } - - try { - val response = getLLMResponse(revisionPrompt, compileDevIns = false) { chunk -> - onProgress(chunk) - } - return extractSqlFromResponse(response) - } catch (e: Exception) { - logger.error(e) { "SQL revision failed" } - return null - } - } - private suspend fun generateVisualization( query: String, result: QueryResult, diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt new file mode 100644 index 0000000000..8668014f56 --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt @@ -0,0 +1,124 @@ +package cc.unitmesh.agent.subagent + +import net.sf.jsqlparser.parser.CCJSqlParserUtil +import net.sf.jsqlparser.statement.Statement + +/** + * JVM implementation of SqlValidatorInterface using JSqlParser + * + * This validator uses JSqlParser to validate SQL syntax. + * It can detect: + * - Syntax errors + * - Malformed SQL statements + * - Unsupported SQL constructs + */ +class JSqlParserValidator : SqlValidatorInterface { + + override fun validate(sql: String): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + /** + * Validate SQL and return the parsed statement if valid + */ + fun validateAndParse(sql: String): Pair { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + Pair( + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ), + statement + ) + } catch (e: Exception) { + Pair( + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ), + null + ) + } + } + + /** + * Extract a clean error message from the exception + */ + private fun extractErrorMessage(e: Exception): String { + val message = e.message ?: "Unknown SQL parsing error" + // Clean up JSqlParser error messages + return when { + message.contains("Encountered") -> { + // Parse error with position info + val match = Regex("Encountered \"(.+?)\" at line (\\d+), column (\\d+)").find(message) + if (match != null) { + val (token, line, column) = match.destructured + "Syntax error at line $line, column $column: unexpected token '$token'" + } else { + message + } + } + message.contains("Was expecting") -> { + val match = Regex("Was expecting.*?:\\s*(.+)").find(message) + if (match != null) { + "Expected: ${match.groupValues[1].take(100)}" + } else { + message + } + } + else -> message.take(200) + } + } + + /** + * Collect warnings from parsed statement (e.g., deprecated syntax) + */ + private fun collectWarnings(statement: Statement): List { + val warnings = mutableListOf() + + // Check for common issues that aren't errors but might be problematic + val sql = statement.toString() + + if (sql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!sql.contains("WHERE", ignoreCase = true) && + (sql.contains("UPDATE", ignoreCase = true) || sql.contains("DELETE", ignoreCase = true))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return warnings + } + + companion object { + /** + * Quick validation check - returns true if SQL is syntactically valid + */ + fun isValidSql(sql: String): Boolean { + return try { + CCJSqlParserUtil.parse(sql) + true + } catch (e: Exception) { + false + } + } + } +} + diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt new file mode 100644 index 0000000000..9b2ba5c381 --- /dev/null +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt @@ -0,0 +1,185 @@ +package cc.unitmesh.agent.subagent + +import org.junit.Test +import kotlin.test.* + +/** + * Tests for JSqlParserValidator - JVM-specific SQL validation using JSqlParser + */ +class JSqlParserValidatorTest { + + private val validator = JSqlParserValidator() + + // ============= Basic Validation Tests ============= + + @Test + fun testValidSelectQuery() { + val result = validator.validate("SELECT * FROM users WHERE age > 18") + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + } + + @Test + fun testValidSelectWithJoin() { + val sql = """ + SELECT u.name, o.total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE o.total > 100 + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + } + + @Test + fun testInvalidSyntax() { + val result = validator.validate("SELECT * FORM users") // Typo: FORM instead of FROM + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + @Test + fun testEmptySql() { + val result = validator.validate("") + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + @Test + fun testBlankSql() { + val result = validator.validate(" ") + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + // ============= Warning Tests ============= + + @Test + fun testSelectStarWarning() { + val result = validator.validate("SELECT * FROM users") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("SELECT *") }) + } + + @Test + fun testUpdateWithoutWhereWarning() { + val result = validator.validate("UPDATE users SET status = 'inactive'") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("WHERE") }) + } + + @Test + fun testDeleteWithoutWhereWarning() { + val result = validator.validate("DELETE FROM users") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("WHERE") }) + } + + @Test + fun testSafeUpdateNoWarning() { + val result = validator.validate("UPDATE users SET status = 'active' WHERE id = 123") + + assertTrue(result.isValid) + assertFalse(result.warnings.any { it.contains("UPDATE/DELETE without WHERE") }) + } + + @Test + fun testSafeDeleteNoWarning() { + val result = validator.validate("DELETE FROM users WHERE id = 123") + + assertTrue(result.isValid) + assertFalse(result.warnings.any { it.contains("UPDATE/DELETE without WHERE") }) + } + + // ============= Complex Query Tests ============= + + @Test + fun testComplexSelectWithSubquery() { + val sql = """ + SELECT * FROM users + WHERE id IN (SELECT user_id FROM orders WHERE total > 1000) + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + @Test + fun testSelectWithGroupByAndHaving() { + val sql = """ + SELECT u.name, COUNT(o.id) as order_count + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + GROUP BY u.name + HAVING COUNT(o.id) > 5 + ORDER BY order_count DESC + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + @Test + fun testInsertStatement() { + val sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')" + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + // ============= ValidateAndParse Tests ============= + + @Test + fun testValidateAndParseValid() { + val (result, statement) = validator.validateAndParse("SELECT * FROM users") + + assertTrue(result.isValid) + assertNotNull(statement) + } + + @Test + fun testValidateAndParseInvalid() { + val (result, statement) = validator.validateAndParse("SELECT * FORM users") + + assertFalse(result.isValid) + assertNull(statement) + } + + // ============= Edge Cases ============= + + @Test + fun testMultipleStatements() { + // JSqlParser may handle this differently + val sql = "SELECT * FROM users; SELECT * FROM orders" + + val result = validator.validate(sql) + // Just verify it doesn't crash - behavior may vary + assertNotNull(result) + } + + @Test + fun testSqlWithComments() { + val sql = """ + -- This is a comment + SELECT * FROM users WHERE id = 1 + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } +} + From 56c19259fedb6455294aae2dec8b24d48500e31e Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Tue, 9 Dec 2025 23:54:17 +0800 Subject: [PATCH 06/34] fix: resolve CI build failures for mpp-core, mpp-ui, and mpp-idea - Fix ExecutorFactory.kt runBlocking issue for WASM by using expect/actual pattern - Move IdeaDatabaseConnection.kt from mpp-idea-core to ext-database module - Add AssertJ dependency to mpp-idea for test assertions - Replace AssertJ assertions with kotlin.test assertions in some test files - Update test files to use standard Kotlin test assertions --- .../kotlin/cc/unitmesh/llm/ExecutorFactory.kt | 28 +++++--- .../cc/unitmesh/llm/ExecutorFactory.js.kt | 10 +++ .../cc/unitmesh/llm/ExecutorFactory.jvm.kt | 14 ++++ .../cc/unitmesh/llm/ExecutorFactory.wasmJs.kt | 10 +++ mpp-idea/build.gradle.kts | 1 + .../devti/sketch/run/ShellSafetyCheckTest.kt | 64 ++++++++++--------- .../devti/util/parser/MarkdownTest.kt | 12 ++-- .../parser/MarkdownToHtmlConverterTest.kt | 10 +-- .../connection}/IdeaDatabaseConnection.kt | 10 ++- 9 files changed, 103 insertions(+), 56 deletions(-) rename mpp-idea/{mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database => mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection}/IdeaDatabaseConnection.kt (96%) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt index 0985b0e653..2b1e40cf45 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt @@ -10,13 +10,25 @@ import cc.unitmesh.llm.provider.LLMClientRegistry /** * Try to auto-register GitHub Copilot provider. * Returns the provider if successfully registered, null otherwise. - * + * * Implementation is platform-specific: * - JVM: Creates and registers GithubCopilotClientProvider * - Other platforms: Returns null */ internal expect fun tryAutoRegisterGithubCopilot(): LLMClientProvider? +/** + * Platform-specific blocking executor creation for registered providers. + * + * Implementation is platform-specific: + * - JVM: Uses runBlocking to create executor synchronously + * - JS/WASM: Returns null (registered providers not supported in sync mode) + */ +internal expect fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? + /** * Executor ๅทฅๅŽ‚ - ่ดŸ่ดฃๆ นๆฎ้…็ฝฎๅˆ›ๅปบๅˆ้€‚็š„ LLM Executor * ่Œ่ดฃ๏ผš @@ -49,14 +61,14 @@ object ExecutorFactory { } if (registryProvider != null) { - // Use blocking call for registered providers + // Use platform-specific blocking call for registered providers // This works because on JVM, the provider caches the API token - return kotlinx.coroutines.runBlocking { - registryProvider.createExecutor(config) - } ?: throw IllegalStateException( - "Failed to create executor for ${config.provider.displayName}. " + - "Provider is registered but returned null." - ) + return createExecutorBlocking(registryProvider, config) + ?: throw IllegalStateException( + "Failed to create executor for ${config.provider.displayName}. " + + "Provider is registered but returned null. " + + "On non-JVM platforms, use createAsync() instead." + ) } return when (config.provider) { diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt index 5601d81f97..edf022ecba 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * JS implementation: Blocking executor creation not supported + * Use createAsync() instead for async initialization + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt index a1ce2f923a..6158c6c838 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt @@ -1,8 +1,10 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.GithubCopilotClientProvider import cc.unitmesh.llm.provider.LLMClientProvider import cc.unitmesh.llm.provider.LLMClientRegistry +import kotlinx.coroutines.runBlocking /** * JVM implementation: Creates and registers GithubCopilotClientProvider @@ -21,3 +23,15 @@ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? { } } +/** + * JVM implementation: Uses runBlocking to create executor synchronously + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? { + return runBlocking { + provider.createExecutor(config) + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt index 2aea30b56d..49617f73b7 100644 --- a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * WASM implementation: Blocking executor creation not supported + * Use createAsync() instead for async initialization + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-idea/build.gradle.kts b/mpp-idea/build.gradle.kts index 7bd25f588a..38d584e2ab 100644 --- a/mpp-idea/build.gradle.kts +++ b/mpp-idea/build.gradle.kts @@ -172,6 +172,7 @@ configure(subprojects) { testImplementation("junit:junit:4.13.2") testImplementation("org.opentest4j:opentest4j:1.3.0") testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.9.3") + testImplementation("org.assertj:assertj-core:3.24.2") testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-debug:1.7.0") { exclude(group = "net.java.dev.jna", module = "jna-platform") exclude(group = "net.java.dev.jna", module = "jna") diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt index a78e7acb41..d5620f9154 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt @@ -1,7 +1,9 @@ package cc.unitmesh.devti.sketch.run -import org.assertj.core.api.Assertions.assertThat import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertFalse class ShellSafetyCheckTest { @Test @@ -9,8 +11,8 @@ class ShellSafetyCheckTest { val command = "rm -rf /some/path" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of -rf flags - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -18,8 +20,8 @@ class ShellSafetyCheckTest { val command = "rm /some/file" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous due to generic rm command check - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -27,8 +29,8 @@ class ShellSafetyCheckTest { val command = "rm -i somefile.txt" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect safe-ish command as interactive flag is present but still rm is detected - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -36,8 +38,8 @@ class ShellSafetyCheckTest { val command = "rmdir /" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous as it touches root directory - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Removing directories from root") + assertTrue(result.first) + assertEquals("Removing directories from root", result.second) } @Test @@ -45,8 +47,8 @@ class ShellSafetyCheckTest { val command = "mkfs /dev/sda1" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of filesystem formatting command - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Filesystem formatting command") + assertTrue(result.first) + assertEquals("Filesystem formatting command", result.second) } @Test @@ -54,8 +56,8 @@ class ShellSafetyCheckTest { val command = "dd if=/dev/zero of=/dev/sda1" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of low-level disk operation - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Low-level disk operation") + assertTrue(result.first) + assertEquals("Low-level disk operation", result.second) } @Test @@ -63,8 +65,8 @@ class ShellSafetyCheckTest { val command = ":(){ :|:& };:" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of potential fork bomb pattern - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Potential fork bomb") + assertTrue(result.first) + assertEquals("Potential fork bomb", result.second) } @Test @@ -72,8 +74,8 @@ class ShellSafetyCheckTest { val command = "chmod -R 777 /some/directory" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous as recursive chmod with insecure permissions is detected - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Recursive chmod with insecure permissions") + assertTrue(result.first) + assertEquals("Recursive chmod with insecure permissions", result.second) } @Test @@ -81,8 +83,8 @@ class ShellSafetyCheckTest { val command = "sudo rm -rf /some/path" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous due to sudo rm pattern - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Dangerous rm command with recursive or force flags") + assertTrue(result.first) + assertEquals("Dangerous rm command with recursive or force flags", result.second) } @Test @@ -90,47 +92,47 @@ class ShellSafetyCheckTest { val command = "ls -la" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect no dangerous patterns detected - assertThat(result.first).isFalse() - assertThat(result.second).isEmpty() + assertFalse(result.first) + assertTrue(result.second.isEmpty()) } @Test fun testDangerousCurlPipeToShell() { val command = "curl https://some-site.com/script.sh | bash" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Downloading and executing scripts directly") + assertTrue(result.first) + assertEquals("Downloading and executing scripts directly", result.second) } @Test fun testDangerousKillAllProcesses() { val command = "kill -9 -1" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Killing all user processes") + assertTrue(result.first) + assertEquals("Killing all user processes", result.second) } @Test fun testDangerousOverwriteSystemConfig() { val command = "echo 'something' > /etc/passwd" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Overwriting system configuration files") + assertTrue(result.first) + assertEquals("Overwriting system configuration files", result.second) } @Test fun testDangerousSystemUserDeletion() { val command = "userdel root" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Removing critical system users") + assertTrue(result.first) + assertEquals("Removing critical system users", result.second) } @Test fun testDangerousRecursiveChown() { val command = "chown -R nobody:nobody /var" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Recursive ownership change") + assertTrue(result.first) + assertEquals("Recursive ownership change", result.second) } } diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt index c5621e50b0..3dfb62c5d8 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt @@ -1,11 +1,11 @@ package cc.unitmesh.devti.util.parser -import org.assertj.core.api.Assertions.assertThat import org.intellij.markdown.MarkdownElementTypes import org.intellij.markdown.flavours.gfm.GFMFlavourDescriptor import org.intellij.markdown.parser.MarkdownParser import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue class MarkdownHelperTest { @Test @@ -17,7 +17,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.parseCodeFromString(markdown) // Then - assertThat(result).containsExactly(markdown) + assertEquals(listOf(markdown), result) } @Test @@ -40,7 +40,7 @@ class MarkdownHelperTest { // you can skip this part of the code. ``` """.trimIndent() - assertThat(result).isEqualTo(expected) + assertEquals(expected, result) } @Test @@ -52,7 +52,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.removeAllMarkdownCode(markdown) // Then - assertThat(result).isEqualTo(markdown) + assertEquals(markdown, result) } @Test @@ -67,7 +67,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.extractCodeFenceLanguage(codeFenceNode, markdown) // Then - assertThat(result).isEqualTo("kotlin") + assertEquals("kotlin", result) } @Test @@ -82,7 +82,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.extractCodeFenceLanguage(codeFenceNode, markdown) // Then - assertThat(result).isEmpty() + assertTrue(result.isEmpty()) } @Test diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt index ebe71f5b63..ca73503407 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt @@ -1,8 +1,8 @@ package cc.unitmesh.devti.util.parser -import org.assertj.core.api.Assertions.assertThat import kotlin.test.Ignore import kotlin.test.Test +import kotlin.test.assertEquals class MarkdownConverterTest { @@ -16,17 +16,17 @@ class MarkdownConverterTest { - ๅฐ†BlogPostๅฎžไฝ“ๅˆๅนถๅˆฐBlog่šๅˆๆ น๏ผŒๅปบ็ซ‹ๅฎŒๆ•ด็š„้ข†ๅŸŸๅฏน่ฑก - ๆทปๅŠ ้ข†ๅŸŸ่กŒไธบๆ–นๆณ•๏ผˆๅ‘ๅธƒใ€ๅฎกๆ ธใ€่ฏ„่ฎบ็ญ‰๏ผ‰ - ๅผ•ๅ…ฅๅ€ผๅฏน่ฑก๏ผˆBlogIdใ€Content็ญ‰๏ผ‰ - + 2. ๅˆ†ๅฑ‚็ป“ๆž„่ฐƒๆ•ด๏ผš - ๆธ…็†entityๅฑ‚ๅ†—ไฝ™ๅฏน่ฑก๏ผŒๅปบ็ซ‹ๆธ…ๆ™ฐ็š„domainๅฑ‚ - ๅฎž็Žฐ้ข†ๅŸŸๆœๅŠกไธŽๅŸบ็ก€่ฎพๆ–ฝๅฑ‚ๅˆ†็ฆป - ้‡ๆž„ๆ•ฐๆฎๆŒไน…ๅŒ–ๆŽฅๅฃ - + 3. ๆˆ˜ๆœฏๆจกๅผๅฎž็Žฐ๏ผš - ไฝฟ็”จๅทฅๅŽ‚ๆจกๅผๅค„็†ๅคๆ‚ๅฏน่ฑกๅˆ›ๅปบ - ๅฎž็Žฐไป“ๅ‚จๆŽฅๅฃไธŽ้ข†ๅŸŸๅฑ‚็š„ไพ่ต–ๅ€’็ฝฎ - ๆทปๅŠ ้ข†ๅŸŸไบ‹ไปถๆœบๅˆถ - + 4. ๆต‹่ฏ•ไฟ้šœ๏ผš - ้‡ๆž„ๅ•ๅ…ƒๆต‹่ฏ•๏ผŒ้ชŒ่ฏ้ข†ๅŸŸๆจกๅž‹่กŒไธบ - ๆทปๅŠ ่šๅˆๆ นไธๅ˜ๆ€ง็บฆๆŸๆต‹่ฏ• @@ -56,6 +56,6 @@ class MarkdownConverterTest { val resultHtml = convertMarkdownToHtml(markdownText) // Then - assertThat(resultHtml).isEqualTo(expectedHtml) + assertEquals(expectedHtml, resultHtml) } } diff --git a/mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt similarity index 96% rename from mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt rename to mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt index 328d719cfc..c49ad26ff2 100644 --- a/mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt +++ b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt @@ -1,8 +1,6 @@ -package cc.unitmesh.devti.database +package cc.unitmesh.database.connection import cc.unitmesh.agent.database.* -import com.intellij.database.dataSource.DatabaseConnection -import com.intellij.database.dataSource.LocalDataSource import com.intellij.database.psi.DbDataSource import com.intellij.openapi.project.Project import kotlinx.coroutines.Dispatchers @@ -11,7 +9,7 @@ import java.sql.Connection /** * IDEA platform database connection implementation - * + * * Leverages IDEA's built-in Database tools and connection configuration. * Can directly use data sources configured in IDEA Database tool window. */ @@ -148,9 +146,9 @@ class IdeaDatabaseConnection( companion object { /** * Create database connection from IDEA data source - * + * * @param project IDEA project - * @param dataSourceName Data source name (configured in IDEA Database) + * @param dataSource Data source (configured in IDEA Database) * @return Database connection */ fun createFromIdea(project: Project, dataSource: DbDataSource): IdeaDatabaseConnection { From b383402941226b1cbc888dba9cd23a32f53c6d36 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 07:48:58 +0800 Subject: [PATCH 07/34] fix: add missing CHAT_DB branch and iOS actual implementation - Add CHAT_DB branch to IdeaAgentApp.kt when expression - Add createExecutorBlocking actual implementation for iOS platform - Import SingleLLMPromptExecutor in iOS ExecutorFactory --- .../kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt | 10 ++++++++++ .../cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt index a36b3140dd..020373bcbb 100644 --- a/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * iOS implementation: Blocking executor creation not supported + * iOS doesn't support runBlocking, so we return null + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt index 7d8517d2f0..062d0163db 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt @@ -352,6 +352,13 @@ fun IdeaAgentApp( } ?: IdeaEmptyStateMessage("Loading Knowledge Agent...") } } + AgentType.CHAT_DB -> { + // ChatDB mode - Text2SQL agent for database queries + // TODO: Implement ChatDB UI when ready + Box(modifier = Modifier.fillMaxWidth().weight(1f)) { + IdeaEmptyStateMessage("ChatDB Agent coming soon...") + } + } } // Tool loading status bar From d295e092ca7b74afc528474b0fff63f09a1d11b5 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 08:13:57 +0800 Subject: [PATCH 08/34] feat(chatdb): add LLM-based schema linker and SQL validation #508 Introduce LlmSchemaLinker for LLM-powered schema linking with fallback, add table whitelist validation to SQL execution, and refactor SchemaLinker to support multiple strategies. Also improve SQL execution error handling and retries. --- .../unitmesh/agent/chatdb/LlmSchemaLinker.kt | 165 ++++++++++++++++++ .../cc/unitmesh/agent/chatdb/SchemaLinker.kt | 151 +++++++++------- .../unitmesh/agent/chatdb/SchemaLinkerTest.kt | 29 +-- .../agent/chatdb/ChatDBAgentExecutor.kt | 92 ++++++++-- .../agent/subagent/JSqlParserValidator.kt | 66 ++++++- 5 files changed, 403 insertions(+), 100 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt new file mode 100644 index 0000000000..0d74094a6e --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt @@ -0,0 +1,165 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema +import cc.unitmesh.llm.KoogLLMService +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * LLM-based Schema Linker - Uses LLM to extract keywords and link schema + * + * This implementation uses the LLM to: + * 1. Extract semantic keywords from the natural language query + * 2. Map natural language terms to database schema elements + * + * Falls back to KeywordSchemaLinker if LLM fails. + */ +class LlmSchemaLinker( + private val llmService: KoogLLMService, + private val fallbackLinker: KeywordSchemaLinker = KeywordSchemaLinker() +) : SchemaLinker() { + + private val json = Json { ignoreUnknownKeys = true } + + companion object { + private const val KEYWORD_EXTRACTION_PROMPT = """You are a database schema expert. Extract keywords from the user's natural language query that are relevant for finding database tables and columns. + +Given the user query, extract: +1. Entity names (nouns that might be table names) +2. Attribute names (properties that might be column names) +3. Semantic synonyms (alternative terms for the same concept) + +Respond ONLY with a JSON object in this exact format: +{"keywords": ["keyword1", "keyword2", ...], "entities": ["entity1", ...], "attributes": ["attr1", ...]} + +User Query: """ + + private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database schema, identify the most relevant tables and columns. + +IMPORTANT: You MUST only use table and column names that exist in the provided schema. Do NOT invent or hallucinate table/column names. + +Database Schema: +{{SCHEMA}} + +User Query: {{QUERY}} + +Respond ONLY with a JSON object in this exact format: +{"tables": ["table1", "table2"], "columns": ["table1.column1", "table2.column2"], "confidence": 0.8} + +Only include tables and columns that are directly relevant to answering the query.""" + } + + @Serializable + private data class KeywordExtractionResult( + val keywords: List = emptyList(), + val entities: List = emptyList(), + val attributes: List = emptyList() + ) + + @Serializable + private data class SchemaLinkingLlmResult( + val tables: List = emptyList(), + val columns: List = emptyList(), + val confidence: Double = 0.0 + ) + + /** + * Link natural language query to relevant schema elements using LLM + */ + override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + return try { + // Build schema description for LLM + val schemaDescription = buildSchemaDescription(schema) + + // Ask LLM to identify relevant tables and columns + val prompt = SCHEMA_LINKING_PROMPT + .replace("{{SCHEMA}}", schemaDescription) + .replace("{{QUERY}}", query) + val response = llmService.sendPrompt(prompt) + + // Parse LLM response + val llmResult = parseSchemaLinkingResponse(response) + + // Validate that tables/columns exist in schema + val validTables = llmResult.tables.filter { tableName -> + schema.tables.any { it.name.equals(tableName, ignoreCase = true) } + } + val validColumns = llmResult.columns.filter { colRef -> + val parts = colRef.split(".") + if (parts.size == 2) { + val table = schema.tables.find { it.name.equals(parts[0], ignoreCase = true) } + table?.columns?.any { it.name.equals(parts[1], ignoreCase = true) } == true + } else false + } + + // Extract keywords for the result + val keywords = extractKeywords(query) + + // If LLM didn't find valid tables, fall back to keyword linker + if (validTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + SchemaLinkingResult( + relevantTables = validTables, + relevantColumns = validColumns, + keywords = keywords, + confidence = llmResult.confidence + ) + } catch (e: Exception) { + // Fall back to keyword-based linking on any error + fallbackLinker.link(query, schema) + } + } + + /** + * Extract keywords from natural language query using LLM + */ + override suspend fun extractKeywords(query: String): List { + return try { + val prompt = KEYWORD_EXTRACTION_PROMPT + query + val response = llmService.sendPrompt(prompt) + + val result = parseKeywordExtractionResponse(response) + (result.keywords + result.entities + result.attributes).distinct() + } catch (e: Exception) { + // Fall back to simple keyword extraction + fallbackLinker.extractKeywords(query) + } + } + + private fun buildSchemaDescription(schema: DatabaseSchema): String { + return schema.tables.joinToString("\n\n") { table -> + val columns = table.columns.joinToString(", ") { col -> + "${col.name} (${col.type}${if (col.isPrimaryKey) ", PK" else ""})" + } + "Table: ${table.name}\nColumns: $columns" + } + } + + private fun parseKeywordExtractionResponse(response: String): KeywordExtractionResult { + val jsonStr = extractJsonFromResponse(response) + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + KeywordExtractionResult() + } + } + + private fun parseSchemaLinkingResponse(response: String): SchemaLinkingLlmResult { + val jsonStr = extractJsonFromResponse(response) + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + SchemaLinkingLlmResult() + } + } + + private fun extractJsonFromResponse(response: String): String { + // Try to find JSON object in the response + val jsonPattern = Regex("""\{[^{}]*\}""") + return jsonPattern.find(response)?.value ?: "{}" + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt index 6691f992be..ec03690550 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt @@ -4,50 +4,95 @@ import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.TableSchema /** - * Schema Linker - Keyword-based schema linking for Text2SQL - * + * Schema Linker - Abstract base class for Text2SQL schema linking + * * This class finds relevant tables and columns based on natural language queries. - * It uses keyword matching and fuzzy matching to identify relevant schema elements. - * - * Future enhancements: - * - Vector similarity search using embeddings - * - LLM-based schema linking - * - Historical query pattern matching + * Different implementations can use different strategies: + * - KeywordSchemaLinker: Keyword matching and fuzzy matching + * - LlmSchemaLinker: LLM-based keyword extraction and schema linking + * - VectorSchemaLinker: Vector similarity search using embeddings (future) */ -class SchemaLinker { - +abstract class SchemaLinker { + /** - * Common SQL keywords to filter out + * Link natural language query to relevant schema elements */ - private val stopWords = setOf( - "select", "from", "where", "and", "or", "not", "in", "is", "null", - "order", "by", "group", "having", "limit", "offset", "join", "on", - "left", "right", "inner", "outer", "cross", "union", "all", "distinct", - "as", "asc", "desc", "between", "like", "exists", "case", "when", "then", - "else", "end", "count", "sum", "avg", "min", "max", "the", "a", "an", - "show", "me", "get", "find", "list", "display", "give", "what", "which", - "how", "many", "much", "all", "each", "every", "any", "some", "most", - "top", "first", "last", "recent", "latest", "oldest", "highest", "lowest", - "total", "average", "number", "amount", "value", "data", "information" - ) - + abstract suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult + + /** + * Extract keywords from natural language query + */ + abstract suspend fun extractKeywords(query: String): List + + companion object { + /** + * Common SQL keywords to filter out + */ + val STOP_WORDS = setOf( + "select", "from", "where", "and", "or", "not", "in", "is", "null", + "order", "by", "group", "having", "limit", "offset", "join", "on", + "left", "right", "inner", "outer", "cross", "union", "all", "distinct", + "as", "asc", "desc", "between", "like", "exists", "case", "when", "then", + "else", "end", "count", "sum", "avg", "min", "max", "the", "a", "an", + "show", "me", "get", "find", "list", "display", "give", "what", "which", + "how", "many", "much", "all", "each", "every", "any", "some", "most", + "top", "first", "last", "recent", "latest", "oldest", "highest", "lowest", + "total", "average", "number", "amount", "value", "data", "information" + ) + + /** + * Calculate Levenshtein distance between two strings + */ + fun levenshteinDistance(s1: String, s2: String): Int { + val dp = Array(s1.length + 1) { IntArray(s2.length + 1) } + for (i in 0..s1.length) dp[i][0] = i + for (j in 0..s2.length) dp[0][j] = j + for (i in 1..s1.length) { + for (j in 1..s2.length) { + val cost = if (s1[i - 1] == s2[j - 1]) 0 else 1 + dp[i][j] = minOf(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) + } + } + return dp[s1.length][s2.length] + } + + /** + * Simple fuzzy matching using edit distance threshold + */ + fun fuzzyMatch(s1: String, s2: String): Boolean { + if (kotlin.math.abs(s1.length - s2.length) > 3) return false + val distance = levenshteinDistance(s1, s2) + val threshold = kotlin.math.min(s1.length, s2.length) / 3 + return distance <= threshold + } + } +} + +/** + * Keyword-based Schema Linker - Uses keyword matching and fuzzy matching + * + * This is the default/fallback implementation that doesn't require LLM calls. + * It extracts keywords from the query and matches them against table/column names. + */ +class KeywordSchemaLinker : SchemaLinker() { + /** * Link natural language query to relevant schema elements */ - fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { val keywords = extractKeywords(query) val relevantTables = mutableListOf() val relevantColumns = mutableListOf() var totalScore = 0.0 var matchCount = 0 - + for (table in schema.tables) { val tableScore = calculateTableRelevance(table, keywords) if (tableScore > 0) { relevantTables.add(table.name) totalScore += tableScore matchCount++ - + // Find relevant columns in this table for (column in table.columns) { val columnScore = calculateColumnRelevance(column.name, column.comment, keywords) @@ -57,14 +102,14 @@ class SchemaLinker { } } } - + // If no tables matched, include all tables (fallback) if (relevantTables.isEmpty()) { relevantTables.addAll(schema.tables.map { it.name }) } - + val confidence = if (matchCount > 0) (totalScore / matchCount).coerceIn(0.0, 1.0) else 0.0 - + return SchemaLinkingResult( relevantTables = relevantTables, relevantColumns = relevantColumns, @@ -72,18 +117,18 @@ class SchemaLinker { confidence = confidence ) } - + /** - * Extract keywords from natural language query + * Extract keywords from natural language query using simple tokenization */ - fun extractKeywords(query: String): List { + override suspend fun extractKeywords(query: String): List { return query.lowercase() .replace(Regex("[^a-z0-9\\s_]"), " ") .split(Regex("\\s+")) - .filter { it.length > 2 && it !in stopWords } + .filter { it.length > 2 && it !in STOP_WORDS } .distinct() } - + /** * Calculate relevance score for a table */ @@ -91,7 +136,7 @@ class SchemaLinker { var score = 0.0 val tableName = table.name.lowercase() val tableComment = table.comment?.lowercase() ?: "" - + for (keyword in keywords) { // Exact match in table name if (tableName == keyword) { @@ -109,7 +154,7 @@ class SchemaLinker { else if (fuzzyMatch(tableName, keyword)) { score += 0.3 } - + // Check column names for (column in table.columns) { val colName = column.name.lowercase() @@ -118,10 +163,10 @@ class SchemaLinker { } } } - + return score } - + /** * Calculate relevance score for a column */ @@ -129,41 +174,15 @@ class SchemaLinker { var score = 0.0 val colName = columnName.lowercase() val colComment = comment?.lowercase() ?: "" - + for (keyword in keywords) { if (colName == keyword) score += 1.0 else if (colName.contains(keyword)) score += 0.7 else if (colComment.contains(keyword)) score += 0.5 else if (fuzzyMatch(colName, keyword)) score += 0.3 } - + return score } - - /** - * Simple fuzzy matching using edit distance threshold - */ - private fun fuzzyMatch(s1: String, s2: String): Boolean { - if (kotlin.math.abs(s1.length - s2.length) > 3) return false - val distance = levenshteinDistance(s1, s2) - val threshold = kotlin.math.min(s1.length, s2.length) / 3 - return distance <= threshold - } - - /** - * Calculate Levenshtein distance between two strings - */ - private fun levenshteinDistance(s1: String, s2: String): Int { - val dp = Array(s1.length + 1) { IntArray(s2.length + 1) } - for (i in 0..s1.length) dp[i][0] = i - for (j in 0..s2.length) dp[0][j] = j - for (i in 1..s1.length) { - for (j in 1..s2.length) { - val cost = if (s1[i - 1] == s2[j - 1]) 0 else 1 - dp[i][j] = minOf(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) - } - } - return dp[s1.length][s2.length] - } } diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt index adce0ee17a..cf733f565c 100644 --- a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt @@ -3,22 +3,23 @@ package cc.unitmesh.agent.chatdb import cc.unitmesh.agent.database.ColumnSchema import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.TableSchema +import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue import kotlin.test.assertNotNull /** - * Tests for SchemaLinker - keyword-based schema linking for Text2SQL + * Tests for KeywordSchemaLinker - keyword-based schema linking for Text2SQL */ class SchemaLinkerTest { - private val schemaLinker = SchemaLinker() + private val schemaLinker = KeywordSchemaLinker() // ============= Keyword Extraction Tests ============= @Test - fun testExtractKeywordsBasic() { + fun testExtractKeywordsBasic() = runTest { val keywords = schemaLinker.extractKeywords("Show me all users") // "show" and "me" are stop words, "all" is also filtered @@ -26,7 +27,7 @@ class SchemaLinkerTest { } @Test - fun testExtractKeywordsFiltersStopWords() { + fun testExtractKeywordsFiltersStopWords() = runTest { val keywords = schemaLinker.extractKeywords("Show me the top 10 users with the most orders") // Stop words should be filtered (show, me, the, top, most are in stopWords) @@ -37,7 +38,7 @@ class SchemaLinkerTest { } @Test - fun testExtractKeywordsFiltersShortWords() { + fun testExtractKeywordsFiltersShortWords() = runTest { val keywords = schemaLinker.extractKeywords("Get a list of all items") // Short words (length <= 2) should be filtered @@ -45,7 +46,7 @@ class SchemaLinkerTest { } @Test - fun testExtractKeywordsLowercase() { + fun testExtractKeywordsLowercase() = runTest { val keywords = schemaLinker.extractKeywords("Show USERS and ORDERS") assertTrue(keywords.contains("users")) @@ -90,7 +91,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaFindsRelevantTables() { + fun testLinkSchemaFindsRelevantTables() = runTest { val schema = createTestSchema() val result = schemaLinker.link("Show me all users with their orders", schema) @@ -100,7 +101,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaFindsRelevantColumns() { + fun testLinkSchemaFindsRelevantColumns() = runTest { val schema = createTestSchema() val result = schemaLinker.link("Show user names and order totals", schema) @@ -109,7 +110,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaWithNoMatches() { + fun testLinkSchemaWithNoMatches() = runTest { val schema = createTestSchema() val result = schemaLinker.link("Show me the weather forecast", schema) @@ -120,7 +121,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaWithPartialMatch() { + fun testLinkSchemaWithPartialMatch() = runTest { val schema = DatabaseSchema( tables = listOf( TableSchema( @@ -151,14 +152,14 @@ class SchemaLinkerTest { // ============= Edge Cases ============= @Test - fun testExtractKeywordsEmptyQuery() { + fun testExtractKeywordsEmptyQuery() = runTest { val keywords = schemaLinker.extractKeywords("") assertTrue(keywords.isEmpty()) } @Test - fun testLinkSchemaEmptySchema() { + fun testLinkSchemaEmptySchema() = runTest { val schema = DatabaseSchema(tables = emptyList()) val result = schemaLinker.link("Show me all users", schema) @@ -167,7 +168,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaWithSpecialCharacters() { + fun testLinkSchemaWithSpecialCharacters() = runTest { val schema = createTestSchema() val result = schemaLinker.link("Show me users' emails!", schema) @@ -176,7 +177,7 @@ class SchemaLinkerTest { } @Test - fun testLinkSchemaWithNumbers() { + fun testLinkSchemaWithNumbers() = runTest { val schema = createTestSchema() val result = schemaLinker.link("Show top 10 users with orders over 100", schema) diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 362a8c019c..996c8d5549 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -30,7 +30,8 @@ class ChatDBAgentExecutor( renderer: CodingAgentRenderer, private val databaseConnection: DatabaseConnection, maxIterations: Int = 10, - enableLLMStreaming: Boolean = true + enableLLMStreaming: Boolean = true, + useLlmSchemaLinker: Boolean = true ) : BaseAgentExecutor( projectPath = projectPath, llmService = llmService, @@ -40,10 +41,16 @@ class ChatDBAgentExecutor( enableLLMStreaming = enableLLMStreaming ) { private val logger = getLogger("ChatDBAgentExecutor") - private val schemaLinker = SchemaLinker() + private val keywordSchemaLinker = KeywordSchemaLinker() + private val schemaLinker: SchemaLinker = if (useLlmSchemaLinker) { + LlmSchemaLinker(llmService, keywordSchemaLinker) + } else { + keywordSchemaLinker + } private val jsqlValidator = JSqlParserValidator() private val sqlReviseAgent = SqlReviseAgent(llmService, jsqlValidator) private val maxRevisionAttempts = 3 + private val maxExecutionRetries = 3 suspend fun execute( task: ChatDBTask, @@ -87,17 +94,29 @@ class ChatDBAgentExecutor( errors.add("Failed to extract SQL from LLM response") return buildResult(false, errors, null, null, null, 0) } - - // Step 6: Validate and revise SQL using SqlReviseAgent + + // Step 6: Validate SQL syntax and table names using SqlReviseAgent var validatedSql = generatedSql - val validation = jsqlValidator.validate(validatedSql!!) - if (!validation.isValid) { - onProgress("๐Ÿ”„ SQL validation failed, invoking SqlReviseAgent...") + + // Get all table names from schema for whitelist validation + val allTableNames = schema.tables.map { it.name }.toSet() + + // First validate syntax, then validate table names + val syntaxValidation = jsqlValidator.validate(validatedSql!!) + val tableValidation = if (syntaxValidation.isValid) { + jsqlValidator.validateWithTableWhitelist(validatedSql, allTableNames) + } else { + syntaxValidation + } + + if (!tableValidation.isValid) { + val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") val revisionInput = SqlRevisionInput( originalQuery = task.query, failedSql = validatedSql, - errorMessage = validation.errors.joinToString("; "), + errorMessage = tableValidation.errors.joinToString("; "), schemaDescription = relevantSchema, maxAttempts = maxRevisionAttempts ) @@ -117,19 +136,56 @@ class ChatDBAgentExecutor( } generatedSql = validatedSql - - // Step 7: Execute SQL + + // Step 7: Execute SQL with retry loop for execution errors if (generatedSql != null) { - onProgress("โšก Executing SQL query...") - try { - queryResult = databaseConnection.executeQuery(generatedSql) - onProgress("โœ… Query returned ${queryResult.rowCount} rows") - } catch (e: Exception) { - errors.add("Query execution failed: ${e.message}") - logger.error(e) { "Query execution failed" } + var executionRetries = 0 + var lastExecutionError: String? = null + + while (executionRetries < maxExecutionRetries && queryResult == null) { + onProgress("โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") + try { + queryResult = databaseConnection.executeQuery(generatedSql!!) + onProgress("โœ… Query returned ${queryResult.rowCount} rows") + } catch (e: Exception) { + lastExecutionError = e.message ?: "Unknown execution error" + logger.warn { "Query execution failed (attempt ${executionRetries + 1}): $lastExecutionError" } + + // Try to revise SQL based on execution error + if (executionRetries < maxExecutionRetries - 1) { + onProgress("๐Ÿ”„ Execution failed, attempting to fix SQL...") + + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = generatedSql!!, + errorMessage = "Execution error: $lastExecutionError", + schemaDescription = relevantSchema, + maxAttempts = 1 + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success && revisionResult.content != generatedSql) { + generatedSql = revisionResult.content + revisionAttempts++ + onProgress("๐Ÿ”ง SQL revised, retrying execution...") + } else { + // Revision didn't help, break the loop + break + } + } + executionRetries++ + } + } + + if (queryResult == null && lastExecutionError != null) { + errors.add("Query execution failed after $executionRetries retries: $lastExecutionError") + logger.error { "Query execution failed after $executionRetries retries" } } } - + // Step 8: Generate visualization if requested if (task.generateVisualization && queryResult != null && !queryResult.isEmpty()) { onProgress("๐Ÿ“ˆ Generating visualization...") diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt index 8668014f56..219c48a591 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt @@ -2,18 +2,21 @@ package cc.unitmesh.agent.subagent import net.sf.jsqlparser.parser.CCJSqlParserUtil import net.sf.jsqlparser.statement.Statement +import net.sf.jsqlparser.statement.select.Select +import net.sf.jsqlparser.util.TablesNamesFinder /** * JVM implementation of SqlValidatorInterface using JSqlParser - * + * * This validator uses JSqlParser to validate SQL syntax. * It can detect: * - Syntax errors * - Malformed SQL statements * - Unsupported SQL constructs + * - Table names not in whitelist (schema validation) */ class JSqlParserValidator : SqlValidatorInterface { - + override fun validate(sql: String): SqlValidationResult { return try { val statement: Statement = CCJSqlParserUtil.parse(sql) @@ -30,6 +33,65 @@ class JSqlParserValidator : SqlValidatorInterface { ) } } + + /** + * Validate SQL with table whitelist - ensures only allowed tables are used + * + * @param sql The SQL query to validate + * @param allowedTables Set of table names that are allowed in the query + * @return SqlValidationResult with errors if invalid tables are used + */ + fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + + // Extract table names from the SQL + val tablesNamesFinder = TablesNamesFinder() + val usedTables = tablesNamesFinder.getTableList(statement) + + // Check if all used tables are in the whitelist (case-insensitive) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = collectWarnings(statement) + ) + } else { + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + /** + * Extract table names from SQL query + */ + fun extractTableNames(sql: String): List { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + val tablesNamesFinder = TablesNamesFinder() + tablesNamesFinder.getTableList(statement) + } catch (e: Exception) { + emptyList() + } + } /** * Validate SQL and return the parsed statement if valid From bb4b2e323ae7c0101f6fdd30bb6057bd28604c2e Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 08:23:45 +0800 Subject: [PATCH 09/34] refactor(agent): clarify SQL prompt instructions and output #508 Update prompts and context for SQL generation and revision agents to enforce stricter schema usage and concise SQL-only responses. --- .../unitmesh/agent/subagent/SqlReviseAgent.kt | 34 ++++++++----------- .../cc/unitmesh/agent/chatdb/ChatDBAgent.kt | 33 +++++++++--------- .../agent/chatdb/ChatDBAgentExecutor.kt | 18 +++++----- 3 files changed, 39 insertions(+), 46 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt index 72f237da38..dd280f05c7 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt @@ -172,12 +172,12 @@ class SqlReviseAgent( currentError: String, previousAttempts: List ): String = buildString { - appendLine("# SQL Revision Context") + appendLine("# SQL Revision Task") appendLine() - appendLine("## Original User Query") + appendLine("## User Query") appendLine(input.originalQuery) appendLine() - appendLine("## Database Schema") + appendLine("## Available Schema (USE ONLY THESE TABLES AND COLUMNS)") appendLine("```") appendLine(input.schemaDescription.take(2000)) appendLine("```") @@ -187,37 +187,33 @@ class SqlReviseAgent( appendLine(currentSql) appendLine("```") appendLine() - appendLine("## Error Message") - appendLine("```") + appendLine("## Error") appendLine(currentError) - appendLine("```") if (previousAttempts.isNotEmpty()) { appendLine() - appendLine("## Previous Attempts (avoid repeating)") + appendLine("## Previous Failed Attempts (do not repeat)") previousAttempts.forEachIndexed { i, sql -> - appendLine("Attempt ${i + 1}: $sql") + appendLine("${i + 1}: $sql") } } } private suspend fun askLLMForRevision(context: String, onProgress: (String) -> Unit): String? { val systemPrompt = """ -You are a SQL Revision Agent. Your task is to fix SQL queries that failed validation or execution. +You are a SQL Revision Agent. Fix SQL queries that failed validation or execution. -## Guidelines: -1. Analyze the error message carefully -2. Consider the original user intent -3. Use the provided schema to ensure correct table/column names -4. Generate a corrected SQL query +CRITICAL RULES: +1. ONLY use table names from the provided schema - NEVER invent table names +2. ONLY use column names from the provided schema - NEVER invent column names +3. If the error says a table doesn't exist, find the correct table name from the schema +4. Analyze the error message carefully and fix the specific issue 5. Avoid repeating previous failed attempts -## Response Format: -Return ONLY the corrected SQL query wrapped in ```sql code block. -Do not include explanations outside the code block. +OUTPUT FORMAT: +Return ONLY the corrected SQL in ```sql code block. No explanations. -Example: ```sql -SELECT * FROM customers WHERE id = 1 +SELECT column FROM table WHERE condition LIMIT 100; ``` """.trimIndent() diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt index 7be0155fab..8feac9d5db 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt @@ -148,26 +148,25 @@ class ChatDBAgent( } companion object { - const val SYSTEM_PROMPT = """You are an expert SQL developer and data analyst. Your task is to: + const val SYSTEM_PROMPT = """You are an expert SQL developer. Generate SQL queries from natural language. -1. Understand the user's natural language query -2. Generate accurate, safe, and efficient SQL queries -3. Only generate SELECT queries (read-only operations) -4. Use proper SQL syntax for the target database -5. Consider performance implications (use indexes, avoid SELECT *) -6. Handle edge cases and NULL values appropriately +CRITICAL RULES - YOU MUST FOLLOW THESE: +1. ONLY use table names provided in the schema - NEVER invent or guess table names +2. ONLY use column names provided in the schema - NEVER invent or guess column names +3. If a table or column doesn't exist in the schema, DO NOT use it +4. Only generate SELECT queries (read-only operations) +5. Always add LIMIT clause to prevent large result sets -When generating SQL: -- Always wrap SQL in ```sql code blocks -- Use meaningful aliases for tables and columns -- Add comments for complex queries -- Limit results appropriately (use LIMIT clause) -- Prefer explicit column names over SELECT * +OUTPUT FORMAT: +- Return ONLY the SQL query wrapped in ```sql code block +- Do NOT include explanations, alternatives, or reasoning +- Do NOT add comments outside the code block +- Keep response concise - just the SQL -For visualization: -- When asked to visualize data, generate PlotDSL code -- Choose appropriate chart types based on data characteristics -- Wrap PlotDSL in ```plotdsl code blocks""" +Example response: +```sql +SELECT id, name FROM users WHERE status = 'active' LIMIT 100; +```""" } } diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 996c8d5549..d7df5dea43 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -217,10 +217,11 @@ class ChatDBAgentExecutor( ): String { val relevantTables = schema.tables.filter { it.name in linkingResult.relevantTables } return buildString { - appendLine("## Relevant Database Schema") + appendLine("## Database Schema (USE ONLY THESE TABLES)") appendLine() relevantTables.forEach { table -> - append(table.getDescription()) + appendLine("Table: ${table.name}") + appendLine("Columns: ${table.columns.joinToString(", ") { "${it.name} (${it.type})" }}") appendLine() } } @@ -232,22 +233,19 @@ class ChatDBAgentExecutor( linkingResult: SchemaLinkingResult ): String { return buildString { - appendLine("Please generate a SQL query for the following request:") - appendLine() - appendLine("**User Query**: ${task.query}") + appendLine("Generate a SQL query for: ${task.query}") appendLine() if (task.additionalContext.isNotBlank()) { - appendLine("**Additional Context**: ${task.additionalContext}") + appendLine("Context: ${task.additionalContext}") appendLine() } - appendLine("**Maximum Rows**: ${task.maxRows}") + appendLine("ALLOWED TABLES (use ONLY these): ${linkingResult.relevantTables.joinToString(", ")}") appendLine() appendLine(schemaDescription) appendLine() - appendLine("**Schema Linking Keywords**: ${linkingResult.keywords.joinToString(", ")}") - appendLine("**Confidence**: ${String.format("%.2f", linkingResult.confidence)}") + appendLine("Max rows: ${task.maxRows}") appendLine() - appendLine("Please generate a safe, read-only SQL query. Wrap the SQL in a ```sql code block.") + appendLine("Return ONLY the SQL in a ```sql code block. No explanations.") } } From 92def0e3fa1f732fbe13bdb6af7123d93bdf1aef Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 09:02:03 +0800 Subject: [PATCH 10/34] feat(chatdb): improve schema linking with sample data #508 Introduce DatabaseContentSchemaLinker to enhance schema linking accuracy using table sample data and system table filtering. Add support for English and Chinese keyword extraction and new database methods for retrieving sample rows and distinct values. --- .../chatdb/DatabaseContentSchemaLinker.kt | 166 ++++++++++++++++++ .../unitmesh/agent/chatdb/LlmSchemaLinker.kt | 70 +++++++- .../cc/unitmesh/agent/chatdb/SchemaLinker.kt | 22 ++- .../agent/database/DatabaseConnection.kt | 32 ++++ .../agent/chatdb/ChatDBAgentExecutor.kt | 44 ++++- 5 files changed, 317 insertions(+), 17 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt new file mode 100644 index 0000000000..7ed81549f2 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt @@ -0,0 +1,166 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema +import cc.unitmesh.llm.KoogLLMService +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * Database Content Schema Linker - Uses database content to improve Schema Linking accuracy + * + * Based on RSL-SQL research, this linker: + * 1. Filters out system tables (sys_*, x$*, etc.) + * 2. Retrieves sample data from each table + * 3. Uses LLM with enriched context to identify relevant tables + * + * This approach is more accurate than pure keyword matching because it uses + * actual database content to understand table semantics. + */ +class DatabaseContentSchemaLinker( + private val llmService: KoogLLMService, + private val databaseConnection: DatabaseConnection, + private val fallbackLinker: KeywordSchemaLinker = KeywordSchemaLinker() +) : SchemaLinker() { + + private val json = Json { ignoreUnknownKeys = true } + + companion object { + // System table prefixes to filter out + private val SYSTEM_TABLE_PREFIXES = listOf( + "sys_", "x$", "innodb_", "io_", "memory_", "schema_", "statement", + "user_summary", "host_summary", "wait", "process", "session", + "metrics", "privileges", "ps_", "flyway_", "hibernate_" + ) + + // System table exact names + private val SYSTEM_TABLE_NAMES = setOf( + "version", "latest_file_io", "session_ssl_status" + ) + + private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database tables with sample data, identify the most relevant tables. + +CRITICAL RULES: +1. ONLY select tables from the provided list - do NOT invent table names +2. Look at sample data to understand what each table actually contains +3. Match the user's semantic intent, not just keywords +4. For "ๆ–‡็ซ /article/post" queries, look for tables containing blog/post content +5. For "ไฝœ่€…/author/creator" queries, look for tables with author/creator information + +Available Tables with Sample Data: +{{TABLES_WITH_SAMPLES}} + +User Query: {{QUERY}} + +Respond ONLY with a JSON object: +{"tables": ["table1", "table2"], "reason": "brief explanation", "confidence": 0.9} + +Select ONLY tables that are directly needed to answer the query.""" + } + + @Serializable + private data class SchemaLinkingResult( + val tables: List = emptyList(), + val reason: String = "", + val confidence: Double = 0.0 + ) + + /** + * Filter out system tables that are not relevant for user queries + */ + private fun filterUserTables(schema: DatabaseSchema): List { + return schema.tables.filter { table -> + val lowerName = table.name.lowercase() + // Filter out system tables + val isSystemTable = SYSTEM_TABLE_PREFIXES.any { prefix -> + lowerName.startsWith(prefix) + } || SYSTEM_TABLE_NAMES.contains(lowerName) + !isSystemTable + } + } + + /** + * Build table description with sample data for Schema Linking + */ + private suspend fun buildTableWithSamples(table: TableSchema): String { + return buildString { + appendLine("Table: ${table.name}") + appendLine(" Columns: ${table.columns.joinToString(", ") { "${it.name}(${it.type})" }}") + + // Get sample data to help understand table content + try { + val samples = databaseConnection.getSampleRows(table.name, 2) + if (!samples.isEmpty()) { + appendLine(" Sample Data:") + appendLine(" ${samples.columns.joinToString(" | ")}") + samples.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(30) }}") + } + } + } catch (e: Exception) { + // Table might not exist in current database, skip it + appendLine(" (No sample data available)") + } + } + } + + override suspend fun link(query: String, schema: DatabaseSchema): cc.unitmesh.agent.chatdb.SchemaLinkingResult { + return try { + // Step 1: Filter out system tables + val userTables = filterUserTables(schema) + if (userTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + // Step 2: Build table descriptions with sample data + val tablesWithSamples = userTables.map { table -> + buildTableWithSamples(table) + }.joinToString("\n") + + // Step 3: Ask LLM to identify relevant tables + val prompt = SCHEMA_LINKING_PROMPT + .replace("{{TABLES_WITH_SAMPLES}}", tablesWithSamples) + .replace("{{QUERY}}", query) + + val response = llmService.sendPrompt(prompt) + val result = parseResponse(response) + + // Step 4: Validate tables exist + val validTables = result.tables.filter { tableName -> + userTables.any { it.name.equals(tableName, ignoreCase = true) } + } + + if (validTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + // Extract keywords for additional context + val keywords = fallbackLinker.extractKeywords(query) + + cc.unitmesh.agent.chatdb.SchemaLinkingResult( + relevantTables = validTables, + relevantColumns = emptyList(), + keywords = keywords, + confidence = result.confidence + ) + } catch (e: Exception) { + fallbackLinker.link(query, schema) + } + } + + override suspend fun extractKeywords(query: String): List { + return fallbackLinker.extractKeywords(query) + } + + private fun parseResponse(response: String): SchemaLinkingResult { + val jsonPattern = Regex("""\{[^{}]*\}""") + val jsonStr = jsonPattern.find(response)?.value ?: "{}" + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + SchemaLinkingResult() + } + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt index 0d74094a6e..c901aad100 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt @@ -1,5 +1,6 @@ package cc.unitmesh.agent.chatdb +import cc.unitmesh.agent.database.DatabaseConnection import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.TableSchema import cc.unitmesh.llm.KoogLLMService @@ -8,20 +9,26 @@ import kotlinx.serialization.json.Json /** * LLM-based Schema Linker - Uses LLM to extract keywords and link schema - * + * * This implementation uses the LLM to: * 1. Extract semantic keywords from the natural language query * 2. Map natural language terms to database schema elements - * + * 3. Use sample data for better value matching (RSL-SQL approach) + * * Falls back to KeywordSchemaLinker if LLM fails. + * + * Based on research from: + * - RSL-SQL: Robust Schema Linking in Text-to-SQL Generation + * - A Survey of NL2SQL with Large Language Models */ class LlmSchemaLinker( private val llmService: KoogLLMService, + private val databaseConnection: DatabaseConnection? = null, private val fallbackLinker: KeywordSchemaLinker = KeywordSchemaLinker() ) : SchemaLinker() { - + private val json = Json { ignoreUnknownKeys = true } - + companion object { private const val KEYWORD_EXTRACTION_PROMPT = """You are a database schema expert. Extract keywords from the user's natural language query that are relevant for finding database tables and columns. @@ -35,11 +42,16 @@ Respond ONLY with a JSON object in this exact format: User Query: """ - private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database schema, identify the most relevant tables and columns. + // Enhanced Schema Linking prompt based on RSL-SQL research + private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database schema with sample data, identify the most relevant tables and columns. -IMPORTANT: You MUST only use table and column names that exist in the provided schema. Do NOT invent or hallucinate table/column names. +CRITICAL RULES: +1. You MUST only use table and column names that EXACTLY exist in the provided schema +2. Do NOT invent or hallucinate table/column names +3. Look at the sample data to understand what each table contains +4. Match user's intent with the actual table/column names, not assumed names -Database Schema: +Database Schema with Sample Data: {{SCHEMA}} User Query: {{QUERY}} @@ -129,7 +141,49 @@ Only include tables and columns that are directly relevant to answering the quer } } - private fun buildSchemaDescription(schema: DatabaseSchema): String { + /** + * Build schema description with sample data for better Schema Linking + * Based on RSL-SQL research: sample data helps LLM understand table semantics + */ + private suspend fun buildSchemaDescription(schema: DatabaseSchema): String { + val tableDescriptions = mutableListOf() + + for (table in schema.tables) { + val description = buildString { + appendLine("Table: ${table.name}") + + // Column information + val columns = table.columns.joinToString(", ") { col -> + "${col.name} (${col.type}${if (col.isPrimaryKey) ", PK" else ""})" + } + appendLine("Columns: $columns") + + // Add sample data if database connection is available + if (databaseConnection != null) { + try { + val sampleRows = databaseConnection.getSampleRows(table.name, 2) + if (!sampleRows.isEmpty()) { + appendLine("Sample Data:") + appendLine(" ${sampleRows.columns.joinToString(" | ")}") + sampleRows.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(30) }}") + } + } + } catch (e: Exception) { + // Ignore sample data errors + } + } + }.trim() + tableDescriptions.add(description) + } + + return tableDescriptions.joinToString("\n\n") + } + + /** + * Build schema description without sample data (synchronous version) + */ + private fun buildSchemaDescriptionSync(schema: DatabaseSchema): String { return schema.tables.joinToString("\n\n") { table -> val columns = table.columns.joinToString(", ") { col -> "${col.name} (${col.type}${if (col.isPrimaryKey) ", PK" else ""})" diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt index ec03690550..09ea05b891 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt @@ -120,13 +120,31 @@ class KeywordSchemaLinker : SchemaLinker() { /** * Extract keywords from natural language query using simple tokenization + * Supports both English and Chinese text */ override suspend fun extractKeywords(query: String): List { - return query.lowercase() + val keywords = mutableListOf() + + // Extract English words + val englishWords = query.lowercase() .replace(Regex("[^a-z0-9\\s_]"), " ") .split(Regex("\\s+")) .filter { it.length > 2 && it !in STOP_WORDS } - .distinct() + keywords.addAll(englishWords) + + // Extract Chinese characters/words (each Chinese character or common word) + val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") + val chineseMatches = chinesePattern.findAll(query) + for (match in chineseMatches) { + val word = match.value + keywords.add(word) + // Also add individual characters for better matching + if (word.length > 1) { + word.forEach { char -> keywords.add(char.toString()) } + } + } + + return keywords.distinct() } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index e504ea3f4e..3ada6fefa4 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -52,6 +52,38 @@ interface DatabaseConnection { return (result.rows[0][0] as? Number)?.toLong() ?: 0L } + /** + * Get sample rows from a table (for Schema Linking context) + * + * @param tableName Table name + * @param limit Maximum number of rows to return (default 3) + * @return Sample rows as QueryResult + */ + suspend fun getSampleRows(tableName: String, limit: Int = 3): QueryResult { + return try { + executeQuery("SELECT * FROM `$tableName` LIMIT $limit") + } catch (e: Exception) { + QueryResult(emptyList(), emptyList(), 0) + } + } + + /** + * Get distinct values for a column (for Value Matching in Schema Linking) + * + * @param tableName Table name + * @param columnName Column name + * @param limit Maximum number of distinct values to return (default 10) + * @return List of distinct values as strings + */ + suspend fun getDistinctValues(tableName: String, columnName: String, limit: Int = 10): List { + return try { + val result = executeQuery("SELECT DISTINCT `$columnName` FROM `$tableName` LIMIT $limit") + result.rows.map { it.firstOrNull() ?: "" }.filter { it.isNotEmpty() } + } catch (e: Exception) { + emptyList() + } + } + /** * Close database connection */ diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index d7df5dea43..559b29c3f5 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -43,7 +43,9 @@ class ChatDBAgentExecutor( private val logger = getLogger("ChatDBAgentExecutor") private val keywordSchemaLinker = KeywordSchemaLinker() private val schemaLinker: SchemaLinker = if (useLlmSchemaLinker) { - LlmSchemaLinker(llmService, keywordSchemaLinker) + // Use DatabaseContentSchemaLinker for better accuracy (RSL-SQL approach) + // It filters system tables and uses sample data for semantic matching + DatabaseContentSchemaLinker(llmService, databaseConnection, keywordSchemaLinker) } else { keywordSchemaLinker } @@ -70,15 +72,25 @@ class ChatDBAgentExecutor( // Step 1: Get database schema onProgress("๐Ÿ“Š Fetching database schema...") val schema = task.schema ?: databaseConnection.getSchema() - + logger.info { "Database has ${schema.tables.size} tables: ${schema.tables.map { it.name }}" } + // Step 2: Schema Linking onProgress("๐Ÿ”— Performing schema linking...") val linkingResult = schemaLinker.link(task.query, schema) - logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables" } - + logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables: ${linkingResult.relevantTables}" } + logger.info { "Schema linking keywords: ${linkingResult.keywords}" } + // Step 3: Build context with relevant schema - val relevantSchema = buildRelevantSchemaDescription(schema, linkingResult) - val initialMessage = buildInitialUserMessage(task, relevantSchema, linkingResult) + // If schema linking found too few tables, use all tables to avoid missing important ones + val effectiveLinkingResult = if (linkingResult.relevantTables.size < 2 && schema.tables.size <= 10) { + logger.info { "Schema linking found few tables, using all ${schema.tables.size} tables" } + linkingResult.copy(relevantTables = schema.tables.map { it.name }) + } else { + linkingResult + } + + val relevantSchema = buildRelevantSchemaDescription(schema, effectiveLinkingResult) + val initialMessage = buildInitialUserMessage(task, relevantSchema, effectiveLinkingResult) // Step 4: Generate SQL with LLM onProgress("๐Ÿค– Generating SQL query...") @@ -211,7 +223,11 @@ class ChatDBAgentExecutor( currentIteration = 0 } - private fun buildRelevantSchemaDescription( + /** + * Build schema description with sample data for SQL generation + * Based on RSL-SQL research: sample data helps LLM understand table semantics + */ + private suspend fun buildRelevantSchemaDescription( schema: DatabaseSchema, linkingResult: SchemaLinkingResult ): String { @@ -222,6 +238,20 @@ class ChatDBAgentExecutor( relevantTables.forEach { table -> appendLine("Table: ${table.name}") appendLine("Columns: ${table.columns.joinToString(", ") { "${it.name} (${it.type})" }}") + + // Add sample data to help LLM understand table content + try { + val sampleRows = databaseConnection.getSampleRows(table.name, 2) + if (!sampleRows.isEmpty()) { + appendLine("Sample Data:") + appendLine(" ${sampleRows.columns.joinToString(" | ")}") + sampleRows.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(50) }}") + } + } + } catch (e: Exception) { + // Ignore sample data errors + } appendLine() } } From 61ecd4ffa8eb9452e993d5a075b07aca4cbdabce Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 09:09:41 +0800 Subject: [PATCH 11/34] fix(chatdb): ensure single SQL statement extraction #508 Enforce stricter rules for LLM SQL output and extract only the first SQL statement to prevent multiple statement errors. Also, filter table metadata by current database/catalog. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 33 +++++++++++++++++-- .../database/ExposedDatabaseConnection.kt | 15 +++++---- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 559b29c3f5..eaecd23eb9 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -275,21 +275,25 @@ class ChatDBAgentExecutor( appendLine() appendLine("Max rows: ${task.maxRows}") appendLine() - appendLine("Return ONLY the SQL in a ```sql code block. No explanations.") + appendLine("CRITICAL RULES:") + appendLine("1. Return ONLY ONE SQL statement - never multiple statements") + appendLine("2. Choose the BEST matching table if multiple similar tables exist") + appendLine("3. Wrap the SQL in a ```sql code block") + appendLine("4. No comments, no explanations, just the SQL") } } private fun extractSqlFromResponse(response: String): String? { val codeFence = CodeFence.parse(response) if (codeFence.languageId.lowercase() == "sql" && codeFence.text.isNotBlank()) { - return codeFence.text.trim() + return extractFirstStatement(codeFence.text.trim()) } // Try to find SQL block manually val sqlPattern = Regex("```sql\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) val match = sqlPattern.find(response) if (match != null) { - return match.groupValues[1].trim() + return extractFirstStatement(match.groupValues[1].trim()) } // Last resort: look for SELECT statement @@ -298,6 +302,29 @@ class ChatDBAgentExecutor( return selectMatch?.groupValues?.get(1)?.trim() } + /** + * Extract only the first SQL statement if LLM returns multiple statements. + * This prevents "multiple statements" errors. + */ + private fun extractFirstStatement(sql: String): String { + // Remove SQL comments (-- style) + val withoutComments = sql.lines() + .filterNot { it.trim().startsWith("--") } + .joinToString("\n") + .trim() + + // If there are multiple statements separated by semicolons, take only the first + val statements = withoutComments.split(";") + .map { it.trim() } + .filter { it.isNotBlank() } + + return if (statements.isNotEmpty()) { + statements.first() + ";" + } else { + withoutComments + } + } + private suspend fun generateVisualization( query: String, result: QueryResult, diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt index 288658e523..cb12cb6ed1 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt @@ -60,9 +60,12 @@ class ExposedDatabaseConnection( val metadata = connection.metaData val tables = mutableListOf() - // Get all tables + // Get current database/catalog name to filter tables + val currentCatalog = connection.catalog + + // Get tables only from current database (catalog) val tableTypes = arrayOf("TABLE", "VIEW") - val tableRs = metadata.getTables(null, null, "%", tableTypes) + val tableRs = metadata.getTables(currentCatalog, null, "%", tableTypes) while (tableRs.next()) { val tableName = tableRs.getString("TABLE_NAME") @@ -72,10 +75,10 @@ class ExposedDatabaseConnection( null } - // Get primary keys + // Get primary keys for current catalog val primaryKeys = mutableSetOf() try { - val pkRs = metadata.getPrimaryKeys(null, null, tableName) + val pkRs = metadata.getPrimaryKeys(currentCatalog, null, tableName) while (pkRs.next()) { primaryKeys.add(pkRs.getString("COLUMN_NAME")) } @@ -84,8 +87,8 @@ class ExposedDatabaseConnection( // Ignore if primary keys cannot be retrieved } - // Get columns - val columnRs = metadata.getColumns(null, null, tableName, null) + // Get columns for current catalog + val columnRs = metadata.getColumns(currentCatalog, null, tableName, null) val columns = mutableListOf() while (columnRs.next()) { From 681dd27fb25938b5fe0b4e36a87467c93452aaea Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 09:22:39 +0800 Subject: [PATCH 12/34] refactor(chatdb): move KeywordSchemaLinker to separate file #508 KeywordSchemaLinker was extracted from SchemaLinker.kt into its own file for better modularity and code organization. Updated references to use SchemaLinker type for fallback linker. --- .../chatdb/DatabaseContentSchemaLinker.kt | 2 +- .../agent/chatdb/KeywordSchemaLinker.kt | 140 ++++++++++++++++++ .../unitmesh/agent/chatdb/LlmSchemaLinker.kt | 1 - .../cc/unitmesh/agent/chatdb/SchemaLinker.kt | 137 ----------------- .../agent/chatdb/ChatDBAgentExecutor.kt | 6 +- 5 files changed, 144 insertions(+), 142 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt index 7ed81549f2..13263c272b 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt @@ -21,7 +21,7 @@ import kotlinx.serialization.json.Json class DatabaseContentSchemaLinker( private val llmService: KoogLLMService, private val databaseConnection: DatabaseConnection, - private val fallbackLinker: KeywordSchemaLinker = KeywordSchemaLinker() + private val fallbackLinker: SchemaLinker = KeywordSchemaLinker() ) : SchemaLinker() { private val json = Json { ignoreUnknownKeys = true } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt new file mode 100644 index 0000000000..62599f3496 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt @@ -0,0 +1,140 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema + +/** + * Keyword-based Schema Linker - Uses keyword matching and fuzzy matching + * + * This is the default/fallback implementation that doesn't require LLM calls. + * It extracts keywords from the query and matches them against table/column names. + */ +class KeywordSchemaLinker : SchemaLinker() { + + /** + * Link natural language query to relevant schema elements + */ + override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + val keywords = extractKeywords(query) + val relevantTables = mutableListOf() + val relevantColumns = mutableListOf() + var totalScore = 0.0 + var matchCount = 0 + + for (table in schema.tables) { + val tableScore = calculateTableRelevance(table, keywords) + if (tableScore > 0) { + relevantTables.add(table.name) + totalScore += tableScore + matchCount++ + + // Find relevant columns in this table + for (column in table.columns) { + val columnScore = calculateColumnRelevance(column.name, column.comment, keywords) + if (columnScore > 0) { + relevantColumns.add("${table.name}.${column.name}") + } + } + } + } + + // If no tables matched, include all tables (fallback) + if (relevantTables.isEmpty()) { + relevantTables.addAll(schema.tables.map { it.name }) + } + + val confidence = if (matchCount > 0) (totalScore / matchCount).coerceIn(0.0, 1.0) else 0.0 + + return SchemaLinkingResult( + relevantTables = relevantTables, + relevantColumns = relevantColumns, + keywords = keywords, + confidence = confidence + ) + } + + /** + * Extract keywords from natural language query using simple tokenization + * Supports both English and Chinese text + */ + override suspend fun extractKeywords(query: String): List { + val keywords = mutableListOf() + + // Extract English words + val englishWords = query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in STOP_WORDS } + keywords.addAll(englishWords) + + // Extract Chinese characters/words (each Chinese character or common word) + val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") + val chineseMatches = chinesePattern.findAll(query) + for (match in chineseMatches) { + val word = match.value + keywords.add(word) + // Also add individual characters for better matching + if (word.length > 1) { + word.forEach { char -> keywords.add(char.toString()) } + } + } + + return keywords.distinct() + } + + /** + * Calculate relevance score for a table + */ + private fun calculateTableRelevance(table: TableSchema, keywords: List): Double { + var score = 0.0 + val tableName = table.name.lowercase() + val tableComment = table.comment?.lowercase() ?: "" + + for (keyword in keywords) { + // Exact match in table name + if (tableName == keyword) { + score += 1.0 + } + // Partial match in table name + else if (tableName.contains(keyword) || keyword.contains(tableName)) { + score += 0.7 + } + // Match in table comment + else if (tableComment.contains(keyword)) { + score += 0.5 + } + // Fuzzy match (Levenshtein distance) + else if (fuzzyMatch(tableName, keyword)) { + score += 0.3 + } + + // Check column names + for (column in table.columns) { + val colName = column.name.lowercase() + if (colName == keyword || colName.contains(keyword)) { + score += 0.4 + } + } + } + + return score + } + + /** + * Calculate relevance score for a column + */ + private fun calculateColumnRelevance(columnName: String, comment: String?, keywords: List): Double { + var score = 0.0 + val colName = columnName.lowercase() + val colComment = comment?.lowercase() ?: "" + + for (keyword in keywords) { + if (colName == keyword) score += 1.0 + else if (colName.contains(keyword)) score += 0.7 + else if (colComment.contains(keyword)) score += 0.5 + else if (fuzzyMatch(colName, keyword)) score += 0.3 + } + + return score + } +} \ No newline at end of file diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt index c901aad100..2ade26623f 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt @@ -2,7 +2,6 @@ package cc.unitmesh.agent.chatdb import cc.unitmesh.agent.database.DatabaseConnection import cc.unitmesh.agent.database.DatabaseSchema -import cc.unitmesh.agent.database.TableSchema import cc.unitmesh.llm.KoogLLMService import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt index 09ea05b891..e282724709 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt @@ -1,7 +1,6 @@ package cc.unitmesh.agent.chatdb import cc.unitmesh.agent.database.DatabaseSchema -import cc.unitmesh.agent.database.TableSchema /** * Schema Linker - Abstract base class for Text2SQL schema linking @@ -68,139 +67,3 @@ abstract class SchemaLinker { } } -/** - * Keyword-based Schema Linker - Uses keyword matching and fuzzy matching - * - * This is the default/fallback implementation that doesn't require LLM calls. - * It extracts keywords from the query and matches them against table/column names. - */ -class KeywordSchemaLinker : SchemaLinker() { - - /** - * Link natural language query to relevant schema elements - */ - override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { - val keywords = extractKeywords(query) - val relevantTables = mutableListOf() - val relevantColumns = mutableListOf() - var totalScore = 0.0 - var matchCount = 0 - - for (table in schema.tables) { - val tableScore = calculateTableRelevance(table, keywords) - if (tableScore > 0) { - relevantTables.add(table.name) - totalScore += tableScore - matchCount++ - - // Find relevant columns in this table - for (column in table.columns) { - val columnScore = calculateColumnRelevance(column.name, column.comment, keywords) - if (columnScore > 0) { - relevantColumns.add("${table.name}.${column.name}") - } - } - } - } - - // If no tables matched, include all tables (fallback) - if (relevantTables.isEmpty()) { - relevantTables.addAll(schema.tables.map { it.name }) - } - - val confidence = if (matchCount > 0) (totalScore / matchCount).coerceIn(0.0, 1.0) else 0.0 - - return SchemaLinkingResult( - relevantTables = relevantTables, - relevantColumns = relevantColumns, - keywords = keywords, - confidence = confidence - ) - } - - /** - * Extract keywords from natural language query using simple tokenization - * Supports both English and Chinese text - */ - override suspend fun extractKeywords(query: String): List { - val keywords = mutableListOf() - - // Extract English words - val englishWords = query.lowercase() - .replace(Regex("[^a-z0-9\\s_]"), " ") - .split(Regex("\\s+")) - .filter { it.length > 2 && it !in STOP_WORDS } - keywords.addAll(englishWords) - - // Extract Chinese characters/words (each Chinese character or common word) - val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") - val chineseMatches = chinesePattern.findAll(query) - for (match in chineseMatches) { - val word = match.value - keywords.add(word) - // Also add individual characters for better matching - if (word.length > 1) { - word.forEach { char -> keywords.add(char.toString()) } - } - } - - return keywords.distinct() - } - - /** - * Calculate relevance score for a table - */ - private fun calculateTableRelevance(table: TableSchema, keywords: List): Double { - var score = 0.0 - val tableName = table.name.lowercase() - val tableComment = table.comment?.lowercase() ?: "" - - for (keyword in keywords) { - // Exact match in table name - if (tableName == keyword) { - score += 1.0 - } - // Partial match in table name - else if (tableName.contains(keyword) || keyword.contains(tableName)) { - score += 0.7 - } - // Match in table comment - else if (tableComment.contains(keyword)) { - score += 0.5 - } - // Fuzzy match (Levenshtein distance) - else if (fuzzyMatch(tableName, keyword)) { - score += 0.3 - } - - // Check column names - for (column in table.columns) { - val colName = column.name.lowercase() - if (colName == keyword || colName.contains(keyword)) { - score += 0.4 - } - } - } - - return score - } - - /** - * Calculate relevance score for a column - */ - private fun calculateColumnRelevance(columnName: String, comment: String?, keywords: List): Double { - var score = 0.0 - val colName = columnName.lowercase() - val colComment = comment?.lowercase() ?: "" - - for (keyword in keywords) { - if (colName == keyword) score += 1.0 - else if (colName.contains(keyword)) score += 0.7 - else if (colComment.contains(keyword)) score += 0.5 - else if (fuzzyMatch(colName, keyword)) score += 0.3 - } - - return score - } -} - diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index eaecd23eb9..097313cb4e 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -43,12 +43,12 @@ class ChatDBAgentExecutor( private val logger = getLogger("ChatDBAgentExecutor") private val keywordSchemaLinker = KeywordSchemaLinker() private val schemaLinker: SchemaLinker = if (useLlmSchemaLinker) { - // Use DatabaseContentSchemaLinker for better accuracy (RSL-SQL approach) - // It filters system tables and uses sample data for semantic matching - DatabaseContentSchemaLinker(llmService, databaseConnection, keywordSchemaLinker) + val fallbackLinker = LlmSchemaLinker(llmService, databaseConnection, keywordSchemaLinker) + DatabaseContentSchemaLinker(llmService, databaseConnection, fallbackLinker) } else { keywordSchemaLinker } + private val jsqlValidator = JSqlParserValidator() private val sqlReviseAgent = SqlReviseAgent(llmService, jsqlValidator) private val maxRevisionAttempts = 3 From 03d891b99336b0a1fe0882205c1b77875aeb555c Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 09:41:09 +0800 Subject: [PATCH 13/34] feat(chatdb): add platform-specific NLP keyword extraction #508 Introduce NlpTokenizer with MyNLP support on JVM for improved Chinese segmentation and fallback regex-based tokenization on other platforms. Updates KeywordSchemaLinker to use platform-specific extraction. --- mpp-core/build.gradle.kts | 3 + .../agent/chatdb/NlpTokenizer.android.kt | 23 +++++ .../agent/chatdb/KeywordSchemaLinker.kt | 29 ++----- .../cc/unitmesh/agent/chatdb/NlpTokenizer.kt | 55 ++++++++++++ .../unitmesh/agent/chatdb/NlpTokenizer.ios.kt | 22 +++++ .../unitmesh/agent/chatdb/NlpTokenizer.js.kt | 20 +++++ .../unitmesh/agent/chatdb/NlpTokenizer.jvm.kt | 85 +++++++++++++++++++ .../agent/chatdb/NlpTokenizer.wasmJs.kt | 20 +++++ 8 files changed, 233 insertions(+), 24 deletions(-) create mode 100644 mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt create mode 100644 mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt create mode 100644 mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt create mode 100644 mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt create mode 100644 mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt diff --git a/mpp-core/build.gradle.kts b/mpp-core/build.gradle.kts index 5369891311..547a344946 100644 --- a/mpp-core/build.gradle.kts +++ b/mpp-core/build.gradle.kts @@ -212,6 +212,9 @@ kotlin { // JSQLParser for SQL validation and parsing implementation("com.github.jsqlparser:jsqlparser:4.9") + + // MyNLP for Chinese NLP tokenization + implementation("com.mayabot.mynlp:mynlp-all:4.0.0") } } diff --git a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt new file mode 100644 index 0000000000..537d3877ca --- /dev/null +++ b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt @@ -0,0 +1,23 @@ +package cc.unitmesh.agent.chatdb + +/** + * Android implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only + * and may have compatibility issues on Android. + * + * TODO: Consider using Android's BreakIterator or a lightweight NLP library for better tokenization. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt index 62599f3496..dede9b34e4 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt @@ -54,32 +54,13 @@ class KeywordSchemaLinker : SchemaLinker() { } /** - * Extract keywords from natural language query using simple tokenization - * Supports both English and Chinese text + * Extract keywords from natural language query using platform-specific NLP tokenization. + * + * On JVM, this uses MyNLP for proper Chinese word segmentation. + * On other platforms, this falls back to simple regex-based tokenization. */ override suspend fun extractKeywords(query: String): List { - val keywords = mutableListOf() - - // Extract English words - val englishWords = query.lowercase() - .replace(Regex("[^a-z0-9\\s_]"), " ") - .split(Regex("\\s+")) - .filter { it.length > 2 && it !in STOP_WORDS } - keywords.addAll(englishWords) - - // Extract Chinese characters/words (each Chinese character or common word) - val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") - val chineseMatches = chinesePattern.findAll(query) - for (match in chineseMatches) { - val word = match.value - keywords.add(word) - // Also add individual characters for better matching - if (word.length > 1) { - word.forEach { char -> keywords.add(char.toString()) } - } - } - - return keywords.distinct() + return NlpTokenizer.extractKeywords(query, STOP_WORDS) } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt new file mode 100644 index 0000000000..67da490c28 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt @@ -0,0 +1,55 @@ +package cc.unitmesh.agent.chatdb + +/** + * Platform-specific NLP tokenizer for keyword extraction. + * + * On JVM, this uses MyNLP (https://github.com/jimichan/mynlp) for Chinese tokenization. + * On other platforms (JS, WASM, iOS, Android), this falls back to simple regex-based tokenization. + */ +expect object NlpTokenizer { + /** + * Extract keywords from natural language query using NLP tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + fun extractKeywords(query: String, stopWords: Set): List +} + +/** + * Fallback implementation for keyword extraction using simple regex-based tokenization. + * This is used on platforms where NLP libraries are not available. + */ +object FallbackTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + */ + fun extractKeywords(query: String, stopWords: Set): List { + val keywords = mutableListOf() + + // Extract English words + val englishWords = query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in stopWords } + keywords.addAll(englishWords) + + // Extract Chinese characters/words (each Chinese character or common word) + val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") + val chineseMatches = chinesePattern.findAll(query) + for (match in chineseMatches) { + val word = match.value + keywords.add(word) + // Also add individual characters for better matching + if (word.length > 1) { + word.forEach { char -> keywords.add(char.toString()) } + } + } + + return keywords.distinct() + } +} + diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt new file mode 100644 index 0000000000..2e26ee788f --- /dev/null +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt @@ -0,0 +1,22 @@ +package cc.unitmesh.agent.chatdb + +/** + * iOS implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + * + * TODO: Consider using iOS NaturalLanguage framework for better Chinese tokenization. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt new file mode 100644 index 0000000000..50545f23b4 --- /dev/null +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt @@ -0,0 +1,20 @@ +package cc.unitmesh.agent.chatdb + +/** + * JavaScript implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt new file mode 100644 index 0000000000..30a133c062 --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt @@ -0,0 +1,85 @@ +package cc.unitmesh.agent.chatdb + +import com.mayabot.nlp.segment.Lexers +import io.github.oshai.kotlinlogging.KotlinLogging + +private val logger = KotlinLogging.logger {} + +/** + * JVM implementation of NlpTokenizer using MyNLP for Chinese tokenization. + * + * MyNLP (https://github.com/jimichan/mynlp) provides high-quality Chinese word segmentation + * which is essential for accurate keyword extraction from Chinese natural language queries. + */ +actual object NlpTokenizer { + + // Lazy initialization of the lexer to avoid startup overhead + private val lexer by lazy { + try { + Lexers.core() + } catch (e: Exception) { + logger.warn(e) { "Failed to initialize MyNLP lexer, falling back to simple tokenization" } + null + } + } + + /** + * Extract keywords from natural language query using MyNLP tokenization. + * Supports both English and Chinese text. + * + * For Chinese text, MyNLP provides proper word segmentation instead of + * character-by-character splitting, which significantly improves matching accuracy. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + val currentLexer = lexer + if (currentLexer == null) { + // Fallback to simple tokenization if MyNLP initialization failed + return FallbackTokenizer.extractKeywords(query, stopWords) + } + + return try { + extractKeywordsWithMyNlp(query, stopWords, currentLexer) + } catch (e: Exception) { + logger.warn(e) { "MyNLP tokenization failed, falling back to simple tokenization" } + FallbackTokenizer.extractKeywords(query, stopWords) + } + } + + private fun extractKeywordsWithMyNlp( + query: String, + stopWords: Set, + lexer: com.mayabot.nlp.segment.Lexer + ): List { + val keywords = mutableListOf() + + // Use MyNLP to tokenize the query + val sentence = lexer.scan(query) + + for (term in sentence) { + val word = term.word.lowercase() + + // Skip short words and stop words + if (word.length <= 1) continue + if (word in stopWords) continue + + // Skip pure punctuation + if (word.all { !it.isLetterOrDigit() }) continue + + keywords.add(word) + } + + // Also extract English words that might be missed by Chinese tokenizer + val englishWords = query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in stopWords && it !in keywords } + keywords.addAll(englishWords) + + return keywords.distinct() + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt new file mode 100644 index 0000000000..75971869d2 --- /dev/null +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt @@ -0,0 +1,20 @@ +package cc.unitmesh.agent.chatdb + +/** + * WebAssembly implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackTokenizer.extractKeywords(query, stopWords) + } +} + From 143421afffdbc3b8bb94f79ec9f193ce33e246aa Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 09:53:44 +0800 Subject: [PATCH 14/34] refactor(nlp): replace FallbackTokenizer with FallbackNlpTokenizer Moves fallback keyword extraction logic to a new FallbackNlpTokenizer object and updates all usages. Adds JVM tests for NLP tokenization and a test script for comparison. --- .../agent/chatdb/FallbackNlpTokenizer.kt | 36 +++++++ .../cc/unitmesh/agent/chatdb/NlpTokenizer.kt | 36 ------- .../unitmesh/agent/chatdb/NlpTokenizer.ios.kt | 2 +- .../unitmesh/agent/chatdb/NlpTokenizer.js.kt | 2 +- .../unitmesh/agent/chatdb/NlpTokenizer.jvm.kt | 4 +- .../unitmesh/agent/chatdb/NlpTokenizerTest.kt | 100 ++++++++++++++++++ .../agent/chatdb/NlpTokenizer.wasmJs.kt | 2 +- 7 files changed, 141 insertions(+), 41 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt create mode 100644 mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt new file mode 100644 index 0000000000..714762ab1c --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt @@ -0,0 +1,36 @@ +package cc.unitmesh.agent.chatdb + +/** + * Fallback implementation for keyword extraction using simple regex-based tokenization. + * This is used on platforms where NLP libraries are not available. + */ +object FallbackNlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + */ + fun extractKeywords(query: String, stopWords: Set): List { + val keywords = mutableListOf() + + // Extract English words + val englishWords = query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in stopWords } + keywords.addAll(englishWords) + + // Extract Chinese characters/words (each Chinese character or common word) + val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") + val chineseMatches = chinesePattern.findAll(query) + for (match in chineseMatches) { + val word = match.value + keywords.add(word) + // Also add individual characters for better matching + if (word.length > 1) { + word.forEach { char -> keywords.add(char.toString()) } + } + } + + return keywords.distinct() + } +} \ No newline at end of file diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt index 67da490c28..cc7453401e 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt @@ -17,39 +17,3 @@ expect object NlpTokenizer { */ fun extractKeywords(query: String, stopWords: Set): List } - -/** - * Fallback implementation for keyword extraction using simple regex-based tokenization. - * This is used on platforms where NLP libraries are not available. - */ -object FallbackTokenizer { - /** - * Extract keywords from natural language query using simple tokenization. - * Supports both English and Chinese text. - */ - fun extractKeywords(query: String, stopWords: Set): List { - val keywords = mutableListOf() - - // Extract English words - val englishWords = query.lowercase() - .replace(Regex("[^a-z0-9\\s_]"), " ") - .split(Regex("\\s+")) - .filter { it.length > 2 && it !in stopWords } - keywords.addAll(englishWords) - - // Extract Chinese characters/words (each Chinese character or common word) - val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") - val chineseMatches = chinesePattern.findAll(query) - for (match in chineseMatches) { - val word = match.value - keywords.add(word) - // Also add individual characters for better matching - if (word.length > 1) { - word.forEach { char -> keywords.add(char.toString()) } - } - } - - return keywords.distinct() - } -} - diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt index 2e26ee788f..167ed5d19b 100644 --- a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt @@ -16,7 +16,7 @@ actual object NlpTokenizer { * @return List of extracted keywords */ actual fun extractKeywords(query: String, stopWords: Set): List { - return FallbackTokenizer.extractKeywords(query, stopWords) + return FallbackNlpTokenizer.extractKeywords(query, stopWords) } } diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt index 50545f23b4..1c03f28567 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt @@ -14,7 +14,7 @@ actual object NlpTokenizer { * @return List of extracted keywords */ actual fun extractKeywords(query: String, stopWords: Set): List { - return FallbackTokenizer.extractKeywords(query, stopWords) + return FallbackNlpTokenizer.extractKeywords(query, stopWords) } } diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt index 30a133c062..5ec36e8c1a 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt @@ -38,14 +38,14 @@ actual object NlpTokenizer { val currentLexer = lexer if (currentLexer == null) { // Fallback to simple tokenization if MyNLP initialization failed - return FallbackTokenizer.extractKeywords(query, stopWords) + return FallbackNlpTokenizer.extractKeywords(query, stopWords) } return try { extractKeywordsWithMyNlp(query, stopWords, currentLexer) } catch (e: Exception) { logger.warn(e) { "MyNLP tokenization failed, falling back to simple tokenization" } - FallbackTokenizer.extractKeywords(query, stopWords) + FallbackNlpTokenizer.extractKeywords(query, stopWords) } } diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt new file mode 100644 index 0000000000..6492d27b79 --- /dev/null +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt @@ -0,0 +1,100 @@ +package cc.unitmesh.agent.chatdb + +import org.junit.Test +import kotlin.test.assertTrue + +/** + * Test NLP tokenizer functionality on JVM using MyNLP. + */ +class NlpTokenizerTest { + + private val stopWords = SchemaLinker.STOP_WORDS + + @Test + fun `test Chinese tokenization with MyNLP`() { + val query = "ๆŸฅ่ฏขๆ‰€ๆœ‰็”จๆˆท็š„่ฎขๅ•้‡‘้ข" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // MyNLP should extract meaningful words like "ๆŸฅ่ฏข", "็”จๆˆท", "่ฎขๅ•", "้‡‘้ข" + // instead of just individual characters + assertTrue(keywords.isNotEmpty(), "Should extract keywords from Chinese text") + + // Check that we get proper word segmentation (not just single characters) + val multiCharWords = keywords.filter { it.length > 1 } + assertTrue(multiCharWords.isNotEmpty(), "Should have multi-character words from Chinese segmentation") + } + + @Test + fun `test mixed Chinese and English tokenization`() { + val query = "ๆŸฅ่ฏขuser่กจไธญ็š„orderๆ•ฐๆฎ" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // Should extract both Chinese words and English words + assertTrue(keywords.isNotEmpty(), "Should extract keywords from mixed text") + assertTrue(keywords.any { it.matches(Regex("[a-z]+")) }, "Should contain English words") + } + + @Test + fun `test English only tokenization`() { + val query = "Show me the top 10 customers by order amount" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // Should extract English words, filtering out stop words + assertTrue(keywords.isNotEmpty(), "Should extract keywords from English text") + assertTrue(keywords.contains("customers") || keywords.contains("amount"), + "Should contain meaningful English words") + } + + @Test + fun `compare NLP vs Fallback tokenization for Chinese`() { + val query = "็ปŸ่ฎกๆฏไธช้ƒจ้—จ็š„ๅ‘˜ๅทฅไบบๆ•ฐ" + + val nlpKeywords = NlpTokenizer.extractKeywords(query, stopWords) + val fallbackKeywords = FallbackNlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("NLP Keywords: ${nlpKeywords.joinToString(", ")}") + println("Fallback Keywords: ${fallbackKeywords.joinToString(", ")}") + + // NLP should produce better segmentation than fallback + // Fallback will include single characters, NLP should have proper words + val nlpMultiCharWords = nlpKeywords.filter { it.length > 1 } + val fallbackMultiCharWords = fallbackKeywords.filter { it.length > 1 } + + println("NLP multi-char words: ${nlpMultiCharWords.size}") + println("Fallback multi-char words: ${fallbackMultiCharWords.size}") + + // NLP should have more meaningful multi-character words + // (fallback just adds the whole string plus individual chars) + assertTrue(nlpMultiCharWords.isNotEmpty(), "NLP should produce multi-character words") + } + + @Test + fun `test database related Chinese queries`() { + val queries = listOf( + "ๆŸฅ่ฏข็”จๆˆท่กจไธญๅนด้พ„ๅคงไบŽ30็š„็”จๆˆท", + "็ปŸ่ฎก2024ๅนดๆฏๆœˆ็š„้”€ๅ”ฎ้ข", + "ๆ˜พ็คบๆ‰€ๆœ‰ๆœชๆ”ฏไป˜็š„่ฎขๅ•", + "ๆ‰พๅ‡บ่ดญไนฐ้‡‘้ขๆœ€้ซ˜็š„ๅ‰10ไธชๅฎขๆˆท" + ) + + for (query in queries) { + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + println() + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from: $query") + } + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt index 75971869d2..9afcbb4f31 100644 --- a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt @@ -14,7 +14,7 @@ actual object NlpTokenizer { * @return List of extracted keywords */ actual fun extractKeywords(query: String, stopWords: Set): List { - return FallbackTokenizer.extractKeywords(query, stopWords) + return FallbackNlpTokenizer.extractKeywords(query, stopWords) } } From 82684761f71e3663fb3681367f8ffe13c6258443 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 10:03:15 +0800 Subject: [PATCH 15/34] refactor(chatdb): remove SqlRevisionContext data class #508 Delete unused SqlRevisionContext to simplify ChatDBModels and reduce code clutter. --- .../agent/chatdb/NlpTokenizer.android.kt | 2 +- .../cc/unitmesh/agent/chatdb/ChatDBModels.kt | 32 ------------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt index 537d3877ca..b6b0b0e2a2 100644 --- a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt +++ b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt @@ -17,7 +17,7 @@ actual object NlpTokenizer { * @return List of extracted keywords */ actual fun extractKeywords(query: String, stopWords: Set): List { - return FallbackTokenizer.extractKeywords(query, stopWords) + return FallbackNlpTokenizer.extractKeywords(query, stopWords) } } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt index 4161e130fa..394c9d5de8 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt @@ -110,35 +110,3 @@ data class SchemaLinkingResult( */ val confidence: Double = 0.0 ) - -/** - * SQL Revision Context - Context for the Revise Agent - */ -@Serializable -data class SqlRevisionContext( - /** - * Original natural language query - */ - val originalQuery: String, - - /** - * Generated SQL that failed - */ - val failedSql: String, - - /** - * Error message from execution - */ - val errorMessage: String, - - /** - * Schema context - */ - val schemaDescription: String, - - /** - * Previous revision attempts - */ - val previousAttempts: List = emptyList() -) - From 038bd9344a50cabad59ccccff5eb4deb9a741b3e Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 10:47:59 +0800 Subject: [PATCH 16/34] test(chatdb): add tests for FallbackNlpTokenizer #508 Introduce unit tests to verify FallbackNlpTokenizer behavior and update implementation as needed. --- .../agent/chatdb/FallbackNlpTokenizer.kt | 528 +++++++++++++++++- .../agent/chatdb/FallbackNlpTokenizerTest.kt | 376 +++++++++++++ 2 files changed, 884 insertions(+), 20 deletions(-) create mode 100644 mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt index 714762ab1c..14e7ffbcd4 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt @@ -1,36 +1,524 @@ package cc.unitmesh.agent.chatdb /** - * Fallback implementation for keyword extraction using simple regex-based tokenization. - * This is used on platforms where NLP libraries are not available. + * Enhanced Fallback NLP Tokenizer for Kotlin Multiplatform. + * + * Capabilities: + * 1. English: Porter Stemmer for morphological normalization. + * 2. Chinese: Bi-directional Maximum Matching (BiMM) for segmentation. + * 3. Code: SemVer and CamelCase handling. + * 4. Keywords: RAKE (Rapid Automatic Keyword Extraction) algorithm. + * + * References: + * - Porter Stemmer: https://tartarus.org/martin/PorterStemmer/ + * - BiMM: Bi-directional Maximum Matching for Chinese word segmentation + * - RAKE: Rapid Automatic Keyword Extraction */ object FallbackNlpTokenizer { + /** + * Main entry point: Extract weighted keywords from a query. + * This is the primary method for keyword extraction with intelligent weighting. + */ + fun extractKeywords(query: String, maxKeywords: Int = 10): List { + // 1. Pre-process and Tokenize + val tokens = tokenize(query) + + // 2. RAKE Algorithm implementation + // Filter out stop words and delimiters to get content words + val stopWords = StopWords.ENGLISH + StopWords.CHINESE + val contentTokens = tokens.filter { it.text !in stopWords && it.text.length > 1 } + + if (contentTokens.isEmpty()) return emptyList() + + // Build Co-occurrence Graph + val frequency = mutableMapOf() + val degree = mutableMapOf() + + // RAKE window size (words appear together within this distance) + val windowSize = 3 + + for (i in contentTokens.indices) { + val token = contentTokens[i].text + frequency[token] = (frequency[token] ?: 0) + 1 + + val windowStart = maxOf(0, i - windowSize) + val windowEnd = minOf(contentTokens.size - 1, i + windowSize) + + for (j in windowStart..windowEnd) { + if (i == j) continue + degree[token] = (degree[token] ?: 0) + 1 + } + // Degree includes self-occurrence in RAKE definitions + degree[token] = (degree[token] ?: 0) + 1 + } + + // Calculate Scores: Score(w) = deg(w) / freq(w) + val scores = contentTokens.map { it.text }.distinct().associateWith { word -> + val deg = degree[word]?.toDouble() ?: 0.0 + val freq = frequency[word]?.toDouble() ?: 1.0 + deg / freq + } + + return scores.entries + .sortedByDescending { it.value } + .take(maxKeywords) + .map { it.key } + } + + /** + * Legacy method for backward compatibility. * Extract keywords from natural language query using simple tokenization. * Supports both English and Chinese text. */ fun extractKeywords(query: String, stopWords: Set): List { - val keywords = mutableListOf() + val tokens = tokenize(query) + + return tokens + .filter { it.text !in stopWords && it.text.length > 1 } + .map { it.text } + .distinct() + } + + /** + * Unified tokenization pipeline. + */ + fun tokenize(text: String): List { + val tokens = mutableListOf() + + // Step 1: Extract Semantic Versions (v1.2.3, 1.0.0-alpha, 2.1.0+build.123) to prevent splitting + val versionPattern = Regex("""v?(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?""") + + // Use placeholder strategy to preserve SemVer tokens + var processedText = text + val versionTokens = mutableListOf() + versionPattern.findAll(text).forEach { match -> + versionTokens.add(match.value) + processedText = processedText.replace(match.value, " __SEMVER_${versionTokens.size - 1}__ ") + } + + // Step 2: Split by non-alphanumeric (but keep Chinese chars together for now) + val rawSegments = processedText.split(Regex("[^a-zA-Z0-9\\u4e00-\\u9fa5._-]+")) + + for (segment in rawSegments) { + if (segment.isBlank()) continue + + // Check for SemVer placeholder + val semverMatch = Regex("__SEMVER_(\\d+)__").matchEntire(segment) + if (semverMatch != null) { + val index = semverMatch.groupValues[1].toInt() + tokens.add(Token(versionTokens[index], TokenType.CODE)) + continue + } + + if (segment.matches(Regex("[\\u4e00-\\u9fa5]+"))) { + // Pure Chinese Segment + tokens.addAll(ChineseSegmenter.segment(segment)) + } else if (segment.matches(Regex("[a-zA-Z0-9._-]+"))) { + // Pure English/Code segment + val splitCamel = splitCamelCase(segment) + splitCamel.forEach { word -> + val stem = PorterStemmer.stem(word.lowercase()) + tokens.add(Token(stem, TokenType.ENGLISH)) + } + } else { + // Mixed English and Chinese - split at character type boundaries + tokens.addAll(splitMixedScript(segment)) + } + } + + return tokens + } + + /** + * Split mixed script text (English/Chinese) at character type boundaries. + * E.g., "Helloไธ–็•ŒTest" -> ["hello", "ไธ–", "็•Œ", "test"] with appropriate types + */ + private fun splitMixedScript(text: String): List { + val tokens = mutableListOf() + val currentSegment = StringBuilder() + var currentType: Char? = null // 'E' for English/Code, 'C' for Chinese + + for (char in text) { + val charType = when { + char in '\u4e00'..'\u9fa5' -> 'C' + char.isLetterOrDigit() || char in "._-" -> 'E' + else -> null + } + + if (charType == null) { + // Delimiter - flush current segment + if (currentSegment.isNotEmpty()) { + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + currentSegment.clear() + } + currentType = null + continue + } + + if (currentType != null && currentType != charType) { + // Type changed - flush current segment + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + currentSegment.clear() + } + + currentSegment.append(char) + currentType = charType + } + + // Flush remaining + if (currentSegment.isNotEmpty()) { + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + } + + return tokens + } + + private fun processSegment(segment: String, type: Char?): List { + return when (type) { + 'C' -> ChineseSegmenter.segment(segment) + 'E' -> { + splitCamelCase(segment).map { word -> + Token(PorterStemmer.stem(word.lowercase()), TokenType.ENGLISH) + } + } + else -> emptyList() + } + } + + private fun splitCamelCase(s: String): List { + // Regex to look for switch from lower to upper, or number boundaries + return s.replace(Regex("([a-z])([A-Z])"), "$1 $2") + .replace(Regex("([A-Z])([A-Z][a-z])"), "$1 $2") // Handle HTMLParser -> HTML Parser + .replace(Regex("([a-zA-Z])([0-9])"), "$1 $2") + .replace(Regex("([0-9])([a-zA-Z])"), "$1 $2") + .replace('_', ' ') + .replace('.', ' ') + .replace('-', ' ') + .split(' ') + .filter { it.isNotBlank() } + } - // Extract English words - val englishWords = query.lowercase() - .replace(Regex("[^a-z0-9\\s_]"), " ") - .split(Regex("\\s+")) - .filter { it.length > 2 && it !in stopWords } - keywords.addAll(englishWords) + data class Token(val text: String, val type: TokenType) + enum class TokenType { ENGLISH, CHINESE, CODE } - // Extract Chinese characters/words (each Chinese character or common word) - val chinesePattern = Regex("[\\u4e00-\\u9fa5]+") - val chineseMatches = chinesePattern.findAll(query) - for (match in chineseMatches) { - val word = match.value - keywords.add(word) - // Also add individual characters for better matching - if (word.length > 1) { - word.forEach { char -> keywords.add(char.toString()) } + // --- INTERNAL HELPER CLASSES --- + + /** + * Pure Kotlin implementation of the Porter Stemming Algorithm. + * Reference: https://tartarus.org/martin/PorterStemmer/ + */ + internal object PorterStemmer { + fun stem(word: String): String { + if (word.length < 3) return word + val stemmer = StemmerState(word.toCharArray()) + stemmer.step1() + stemmer.step2() + stemmer.step3() + stemmer.step4() + stemmer.step5() + return stemmer.toString() + } + + private class StemmerState(var b: CharArray) { + var k: Int = b.size - 1 // offset to end of stemmed part + var j: Int = 0 // general offset into string + + override fun toString() = b.concatToString(0, 0 + (k + 1)) + + private fun cons(i: Int): Boolean { + return when (b[i]) { + 'a', 'e', 'i', 'o', 'u' -> false + 'y' -> if (i == 0) true else !cons(i - 1) + else -> true + } + } + + // m() measures the number of consonant sequences between 0 and j. + private fun m(): Int { + var n = 0 + var i = 0 + while (true) { + if (i > j) return n + if (!cons(i)) break + i++ + } + i++ + while (true) { + while (true) { + if (i > j) return n + if (cons(i)) break + i++ + } + i++ + n++ + while (true) { + if (i > j) return n + if (!cons(i)) break + i++ + } + i++ + } + } + + // vowelInStem() is true if 0,...,j contains a vowel + private fun vowelInStem(): Boolean { + for (i in 0..j) if (!cons(i)) return true + return false + } + + // doublec(j) is true if j,(j-1) contain a double consonant + private fun doublec(j: Int): Boolean { + if (j < 1) return false + if (b[j] != b[j - 1]) return false + return cons(j) + } + + // cvc(i) is true if i-2,i-1,i has the form consonant - vowel - consonant + // and also if the second c is not w, x or y. + private fun cvc(i: Int): Boolean { + if (i < 2 || !cons(i) || cons(i - 1) || !cons(i - 2)) return false + val ch = b[i] + return ch != 'w' && ch != 'x' && ch != 'y' + } + + private fun ends(s: String): Boolean { + val l = s.length + val o = k - l + 1 + if (o < 0) return false + for (i in 0 until l) if (b[o + i] != s[i]) return false + j = k - l + return true + } + + private fun setto(s: String) { + val l = s.length + val o = j + 1 + val newB = if (o + l > b.size) b.copyOf(o + l) else b + for (i in 0 until l) newB[o + i] = s[i] + b = newB + k = j + l + } + + private fun r(s: String) { + if (m() > 0) setto(s) + } + + fun step1() { + if (b[k] == 's') { + if (ends("sses")) k -= 2 + else if (ends("ies")) setto("i") + else if (b[k - 1] != 's') k-- + } + if (ends("eed")) { + if (m() > 0) k-- + } else if ((ends("ed") || ends("ing")) && vowelInStem()) { + k = j + if (ends("at")) setto("ate") + else if (ends("bl")) setto("ble") + else if (ends("iz")) setto("ize") + else if (doublec(k)) { + k-- + val ch = b[k] + if (ch == 'l' || ch == 's' || ch == 'z') k++ + } else if (m() == 1 && cvc(k)) setto("e") + } } + + fun step2() { + if (ends("y") && vowelInStem()) b[k] = 'i' + if (k == 0) return + // Optimization: switch on penultimate char + when (b[k - 1]) { + 'a' -> { + if (ends("ational")) r("ate") + else if (ends("tional")) r("tion") + } + 'c' -> { + if (ends("enci")) r("ence") + else if (ends("anci")) r("ance") + } + 'e' -> if (ends("izer")) r("ize") + 'l' -> { + if (ends("bli")) r("ble") + else if (ends("alli")) r("al") + else if (ends("entli")) r("ent") + else if (ends("eli")) r("e") + else if (ends("ousli")) r("ous") + } + 'o' -> { + if (ends("ization")) r("ize") + else if (ends("ation")) r("ate") + else if (ends("ator")) r("ate") + } + 's' -> { + if (ends("alism")) r("al") + else if (ends("iveness")) r("ive") + else if (ends("fulness")) r("ful") + else if (ends("ousness")) r("ous") + } + 't' -> { + if (ends("aliti")) r("al") + else if (ends("iviti")) r("ive") + else if (ends("biliti")) r("ble") + } + } + } + + fun step3() { + when (b[k]) { + 'e' -> { + if (ends("icate")) r("ic") + else if (ends("ative")) r("") + else if (ends("alize")) r("al") + } + 'i' -> if (ends("iciti")) r("ic") + 'l' -> { + if (ends("ical")) r("ic") + else if (ends("ful")) r("") + } + 's' -> if (ends("ness")) r("") + } + } + + fun step4() { + if (k < 2) return + when (b[k - 1]) { + 'a' -> if (ends("al")) { if (m() > 1) k = j } + 'c' -> { + if (ends("ance") || ends("ence")) { + if (m() > 1) k = j + } + } + 'e' -> if (ends("er")) { if (m() > 1) k = j } + 'i' -> if (ends("ic")) { if (m() > 1) k = j } + 'l' -> { + if (ends("able") || ends("ible")) { + if (m() > 1) k = j + } + } + 'n' -> { + if (ends("ant") || ends("ement") || ends("ment") || ends("ent")) { + if (m() > 1) k = j + } + } + 'o' -> { + if (ends("ion") && j >= 0 && (b[j] == 's' || b[j] == 't')) { + if (m() > 1) k = j + } else if (ends("ou")) { + if (m() > 1) k = j + } + } + 's' -> if (ends("ism")) { if (m() > 1) k = j } + 't' -> { + if (ends("ate") || ends("iti")) { + if (m() > 1) k = j + } + } + 'u' -> if (ends("ous")) { if (m() > 1) k = j } + 'v' -> if (ends("ive")) { if (m() > 1) k = j } + 'z' -> if (ends("ize")) { if (m() > 1) k = j } + } + } + + fun step5() { + j = k + if (b[k] == 'e') { + val a = m() + if (a > 1 || (a == 1 && !cvc(k - 1))) k-- + } + if (b[k] == 'l' && doublec(k) && m() > 1) k-- + } + } + } + + /** + * Bi-directional Maximum Matching for Chinese Segmentation. + * Uses a built-in dictionary for common words. + */ + internal object ChineseSegmenter { + // Common Chinese words dictionary for segmentation + // In a real app, this can be loaded from external resources + val dictionary = setOf( + // Common verbs + "ๆˆ‘ไปฌ", "ไป–ไปฌ", "ไฝ ไปฌ", "่ฟ™ไธช", "้‚ฃไธช", "ไป€ไนˆ", "ๆ€Žไนˆ", "ไธบไป€ไนˆ", "ๅ› ไธบ", "ๆ‰€ไปฅ", + "ๅฏไปฅ", "่ƒฝๅคŸ", "ๅบ”่ฏฅ", "ๅฟ…้กป", "้œ€่ฆ", "ๅธŒๆœ›", "ๆƒณ่ฆ", "็Ÿฅ้“", "่ฎคไธบ", "่ง‰ๅพ—", + // Tech terms + "่ถ…ๅธ‚", "่‡ช็„ถ", "่ฏญ่จ€", "ๅค„็†", "ๆต‹่ฏ•", "ๆ•ฐๆฎ", "ๅผ€ๅ‘", "ๅทฅ็จ‹ๅธˆ", "็จ‹ๅบๅ‘˜", + "ไบบๅทฅๆ™บ่ƒฝ", "ๆœบๅ™จๅญฆไน ", "ๆทฑๅบฆๅญฆไน ", "็ฅž็ป็ฝ‘็ปœ", "ๆจกๅž‹", "ไปฃ็ ", "้€ป่พ‘", + "็”จๆˆท", "็•Œ้ข", "่ฎพ่ฎก", "็ณป็ปŸ", "ๅˆ†ๆž", "ๆ•ฐๆฎๅบ“", "ๆœๅŠกๅ™จ", "ๅฎขๆˆท็ซฏ", + "่ฝฏไปถ", "็กฌไปถ", "็ฎ—ๆณ•", "ๅ‡ฝๆ•ฐ", "ๅ˜้‡", "ๅฏน่ฑก", "็ฑปๅž‹", "ๆŽฅๅฃ", "ๅฎž็Žฐ", + // Place names + "ๅŒ—ไบฌ", "ไธŠๆตท", "ๅนฟๅทž", "ๆทฑๅœณ", "ๆญๅทž", "ๅ—ไบฌ", "ๅคฉๆดฅ", "้‡ๅบ†", "ๆˆ้ƒฝ", "ๆญฆๆฑ‰", + "ๅ—ไบฌๅธ‚", "้•ฟๆฑŸ", "ๅคงๆกฅ", "้ป„ๆฒณ", "้•ฟๅŸŽ", "ๆ•…ๅฎซ", + // Common nouns + "ๅ…ฌๅธ", "ๅญฆๆ ก", "ๅŒป้™ข", "้“ถ่กŒ", "ๅ•†ๅบ—", "้คๅŽ…", "้…’ๅบ—", "ๆœบๅœบ", "็ซ่ฝฆ็ซ™", + "็”ต่„‘", "ๆ‰‹ๆœบ", "็ฝ‘็ปœ", "ไบ’่”็ฝ‘", "็”ตๅญ", "็ง‘ๆŠ€", "ๆŠ€ๆœฏ", "ไบงๅ“", "้กน็›ฎ", + // Time words + "ไปŠๅคฉ", "ๆ˜Žๅคฉ", "ๆ˜จๅคฉ", "็Žฐๅœจ", "ไปฅๅŽ", "ไปฅๅ‰", "ๅฐ†ๆฅ", "่ฟ‡ๅŽป", "ๅนดๆœˆ", "ๆ—ฅๆœŸ" + ) + private const val MAX_LEN = 5 + + fun segment(text: String): List { + val fmm = forwardMaxMatch(text) + val rmm = reverseMaxMatch(text) + + // Heuristic: Prefer fewer tokens (implies longer words matched) + return if (fmm.size <= rmm.size) fmm else rmm } - return keywords.distinct() + private fun forwardMaxMatch(text: String): List { + val tokens = mutableListOf() + var i = 0 + while (i < text.length) { + var matched = false + // Try window sizes from MAX_LEN down to 1 + for (len in minOf(MAX_LEN, text.length - i) downTo 1) { + val sub = text.substring(i, i + len) + if (len == 1 || dictionary.contains(sub)) { + tokens.add(Token(sub, TokenType.CHINESE)) + i += len + matched = true + break + } + } + if (!matched) i++ // Should not happen due to len=1 fallback + } + return tokens + } + + private fun reverseMaxMatch(text: String): List { + val tokens = mutableListOf() + var i = text.length + while (i > 0) { + var matched = false + for (len in minOf(MAX_LEN, i) downTo 1) { + val sub = text.substring(i - len, i) + if (len == 1 || dictionary.contains(sub)) { + tokens.add(0, Token(sub, TokenType.CHINESE)) + i -= len + matched = true + break + } + } + if (!matched) i-- + } + return tokens + } + } + + internal object StopWords { + val ENGLISH = setOf( + "the", "is", "at", "which", "on", "and", "a", "an", "in", "to", "of", "for", + "it", "that", "this", "by", "from", "be", "or", "as", "with", "are", "was", + "were", "been", "being", "have", "has", "had", "do", "does", "did", "will", + "would", "could", "should", "may", "might", "must", "shall", "can", "need", + "get", "all", "me", "my", "show", "most", "top", "up", "down", "out" + ) + // Minimal set for Chinese + val CHINESE = setOf( + "็š„", "ไบ†", "ๅ’Œ", "ๆ˜ฏ", "ๅฐฑ", "้ƒฝ", "่€Œ", "ๅŠ", "ไธŽ", "็€", "ๆˆ–", "ไธ€ไธช", "ๆฒกๆœ‰", "ๆˆ‘ไปฌ", + "ไธ", "ไนŸ", "ๅพˆ", "ๅœจ", "ๆœ‰", "่ฟ™", "้‚ฃ", "ไป–", "ๅฅน", "ๅฎƒ", "ไปฌ", "ๅ—", "ๅ‘ข", "ๅง" + ) } -} \ No newline at end of file +} diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt new file mode 100644 index 0000000000..6f1f73464c --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt @@ -0,0 +1,376 @@ +package cc.unitmesh.agent.chatdb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertNotNull + +/** + * Comprehensive tests for FallbackNlpTokenizer. + * Tests cover: + * 1. Porter Stemmer for English morphological normalization + * 2. Bi-directional Maximum Matching (BiMM) for Chinese segmentation + * 3. RAKE keyword extraction algorithm + * 4. SemVer and CamelCase handling + */ +class FallbackNlpTokenizerTest { + + // ==================== Porter Stemmer Tests ==================== + + @Test + fun `Porter Stemmer should reduce processing to process`() { + val tokens = FallbackNlpTokenizer.tokenize("processing") + assertEquals(1, tokens.size) + assertEquals("process", tokens[0].text) + } + + @Test + fun `Porter Stemmer should reduce processed to process`() { + val tokens = FallbackNlpTokenizer.tokenize("processed") + assertEquals(1, tokens.size) + assertEquals("process", tokens[0].text) + } + + @Test + fun `Porter Stemmer should handle various English word forms`() { + // Test various word forms that should stem to similar roots + val testCases = mapOf( + "running" to "run", + "runs" to "run", + "runner" to "runner", // noun, different handling + "cats" to "cat", + "playing" to "plai", // Porter stemmer result + "played" to "plai", + "happily" to "happili", + "happiness" to "happi" + ) + + testCases.forEach { (input, expectedStem) -> + val tokens = FallbackNlpTokenizer.tokenize(input) + assertEquals(1, tokens.size, "Expected 1 token for '$input'") + assertEquals(expectedStem, tokens[0].text, "Expected '$input' to stem to '$expectedStem'") + } + } + + @Test + fun `Porter Stemmer should unify related words for better recall`() { + // Words like "processing", "processed", "process" should all map to "process" + val words = listOf("processing", "processed", "process") + val stems = words.map { FallbackNlpTokenizer.tokenize(it)[0].text } + + // All should be "process" + assertTrue(stems.all { it == "process" }, "All variations should stem to 'process', got: $stems") + } + + @Test + fun `Porter Stemmer should preserve short words`() { + // Words shorter than 3 characters should be preserved + val tokens = FallbackNlpTokenizer.tokenize("go do be") + assertEquals(3, tokens.size) + assertEquals("go", tokens[0].text) + assertEquals("do", tokens[1].text) + assertEquals("be", tokens[2].text) + } + + // ==================== Chinese BiMM Segmentation Tests ==================== + + @Test + fun `BiMM should segment Chinese text with dictionary words`() { + val tokens = FallbackNlpTokenizer.tokenize("ไบบๅทฅๆ™บ่ƒฝ") + // Should recognize "ไบบๅทฅๆ™บ่ƒฝ" as a single word + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + assertTrue(chineseTokens.any { it.text == "ไบบๅทฅๆ™บ่ƒฝ" }, "Should recognize 'ไบบๅทฅๆ™บ่ƒฝ' as a word") + } + + @Test + fun `BiMM should segment common Chinese words`() { + val tokens = FallbackNlpTokenizer.tokenize("ๆ•ฐๆฎๅบ“็ณป็ปŸ") + val texts = tokens.map { it.text } + + // Should segment as "ๆ•ฐๆฎๅบ“" + "็ณป็ปŸ" (both in dictionary) + assertTrue(texts.contains("ๆ•ฐๆฎๅบ“") || texts.contains("็ณป็ปŸ"), + "Should recognize common tech terms, got: $texts") + } + + @Test + fun `BiMM should handle ambiguous segmentation with forward and reverse matching`() { + // "ๅ—ไบฌๅธ‚้•ฟๆฑŸๅคงๆกฅ" is a classic ambiguity example + // FMM might give: ๅ—ไบฌๅธ‚/้•ฟๆฑŸ/ๅคงๆกฅ + // RMM might give: ๅ—ไบฌ/ๅธ‚้•ฟ/ๆฑŸๅคงๆกฅ (incorrect) + // BiMM chooses based on fewer tokens (longer matches) + val tokens = FallbackNlpTokenizer.tokenize("ๅ—ไบฌๅธ‚้•ฟๆฑŸๅคงๆกฅ") + val texts = tokens.map { it.text } + + // Should prefer segmentation with "ๅ—ไบฌๅธ‚", "้•ฟๆฑŸ", "ๅคงๆกฅ" (all in dictionary) + assertTrue( + texts.contains("ๅ—ไบฌๅธ‚") || texts.contains("้•ฟๆฑŸ") || texts.contains("ๅคงๆกฅ"), + "Should handle ambiguous segmentation correctly, got: $texts" + ) + } + + @Test + fun `BiMM should fall back to character-by-character for unknown words`() { + // Text with characters not in dictionary + val tokens = FallbackNlpTokenizer.tokenize("ๅ•Šๅงๅ”ง") + assertEquals(3, tokens.size, "Should segment unknown characters individually") + } + + // ==================== Mixed Language Tests ==================== + + @Test + fun `should handle mixed English and Chinese text`() { + val tokens = FallbackNlpTokenizer.tokenize("Helloไธ–็•ŒTestๆต‹่ฏ•") + val englishTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.ENGLISH } + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + + assertTrue(englishTokens.isNotEmpty(), "Should have English tokens") + assertTrue(chineseTokens.isNotEmpty(), "Should have Chinese tokens") + } + + // ==================== CamelCase Handling Tests ==================== + + @Test + fun `should split CamelCase words`() { + val tokens = FallbackNlpTokenizer.tokenize("UserDao") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("user"), "Should extract 'user' from 'UserDao'") + assertTrue(texts.contains("dao"), "Should extract 'dao' from 'UserDao'") + } + + @Test + fun `should split complex CamelCase with multiple words`() { + val tokens = FallbackNlpTokenizer.tokenize("getUserNameById") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("get"), "Should extract 'get'") + assertTrue(texts.contains("user"), "Should extract 'user'") + assertTrue(texts.contains("name"), "Should extract 'name'") + // "by" is short, "id" is short - they should still be present (as "by", "id") + assertTrue(texts.contains("by"), "Should extract 'by'") + assertTrue(texts.contains("id"), "Should extract 'id'") + } + + @Test + fun `should handle HTML style uppercase sequences`() { + val tokens = FallbackNlpTokenizer.tokenize("HTMLParser") + val texts = tokens.map { it.text } + + // Should split as "HTML" + "Parser" -> "html", "parser" + assertTrue(texts.contains("html") || texts.any { it.contains("html") }, + "Should handle HTMLParser correctly, got: $texts") + } + + @Test + fun `should handle snake_case`() { + val tokens = FallbackNlpTokenizer.tokenize("user_name_service") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("user"), "Should extract 'user' from snake_case") + assertTrue(texts.contains("name"), "Should extract 'name' from snake_case") + assertTrue(texts.contains("servic"), "Should extract 'servic' (stemmed) from snake_case") + } + + // ==================== SemVer Handling Tests ==================== + + @Test + fun `should preserve semantic version numbers`() { + val tokens = FallbackNlpTokenizer.tokenize("v1.2.3") + assertEquals(1, tokens.size, "SemVer should be kept as single token") + assertEquals("v1.2.3", tokens[0].text) + assertEquals(FallbackNlpTokenizer.TokenType.CODE, tokens[0].type) + } + + @Test + fun `should preserve complex SemVer with prerelease`() { + val tokens = FallbackNlpTokenizer.tokenize("1.0.0-alpha") + assertEquals(1, tokens.size, "SemVer with prerelease should be single token") + assertEquals("1.0.0-alpha", tokens[0].text) + } + + @Test + fun `should preserve SemVer with build metadata`() { + val tokens = FallbackNlpTokenizer.tokenize("2.1.0+build.123") + assertEquals(1, tokens.size, "SemVer with build metadata should be single token") + assertEquals("2.1.0+build.123", tokens[0].text) + } + + // ==================== RAKE Keyword Extraction Tests ==================== + + @Test + fun `RAKE should extract meaningful keywords from text`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "The quick brown fox jumps over the lazy dog", + maxKeywords = 5 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords") + assertTrue(keywords.none { it in FallbackNlpTokenizer.StopWords.ENGLISH }, + "Keywords should not contain stop words") + } + + @Test + fun `RAKE should filter stop words`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "The is at which on and in to of for", + maxKeywords = 10 + ) + + assertTrue(keywords.isEmpty(), "Stop words only text should return empty keywords") + } + + @Test + fun `RAKE should rank keywords by co-occurrence`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "data processing pipeline data analysis data visualization", + maxKeywords = 5 + ) + + // "data" appears most frequently and co-occurs with many words + assertTrue(keywords.contains("data"), "Frequent co-occurring word 'data' should be top keyword") + } + + @Test + fun `RAKE should handle Chinese text`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ไบบๅทฅๆ™บ่ƒฝๆŠ€ๆœฏๅœจๆ•ฐๆฎๅˆ†ๆž้ข†ๅŸŸ็š„ๅบ”็”จ", + maxKeywords = 5 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract Chinese keywords") + assertTrue(keywords.none { it in FallbackNlpTokenizer.StopWords.CHINESE }, + "Chinese keywords should not contain stop words") + } + + @Test + fun `RAKE should respect maxKeywords limit`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "apple banana cherry date elderberry fig grape honeydew kiwi lemon mango nectarine", + maxKeywords = 3 + ) + + assertTrue(keywords.size <= 3, "Should respect maxKeywords limit") + } + + // ==================== Legacy Method Tests ==================== + + @Test + fun `legacy extractKeywords should work with custom stop words`() { + val customStopWords = setOf("quick", "lazy") + val keywords = FallbackNlpTokenizer.extractKeywords( + "The quick brown fox jumps over the lazy dog", + customStopWords + ) + + assertTrue(keywords.none { it == "quick" }, "Should filter custom stop word 'quick'") + assertTrue(keywords.none { it == "lazy" }, "Should filter custom stop word 'lazy'") + } + + @Test + fun `legacy extractKeywords should return distinct tokens`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "test test test data data", + emptySet() + ) + + assertEquals(keywords.distinct().size, keywords.size, "Keywords should be distinct") + } + + // ==================== Edge Cases ==================== + + @Test + fun `should handle empty string`() { + val tokens = FallbackNlpTokenizer.tokenize("") + assertTrue(tokens.isEmpty(), "Empty string should return empty tokens") + + val keywords = FallbackNlpTokenizer.extractKeywords("", maxKeywords = 5) + assertTrue(keywords.isEmpty(), "Empty string should return empty keywords") + } + + @Test + fun `should handle whitespace only`() { + val tokens = FallbackNlpTokenizer.tokenize(" \t\n ") + assertTrue(tokens.isEmpty(), "Whitespace only should return empty tokens") + } + + @Test + fun `should handle special characters`() { + val tokens = FallbackNlpTokenizer.tokenize("hello! @world# \$test%") + assertTrue(tokens.isNotEmpty(), "Should extract tokens from text with special chars") + } + + @Test + fun `should handle numbers mixed with text`() { + val tokens = FallbackNlpTokenizer.tokenize("user123 test456") + val texts = tokens.map { it.text } + + // Numbers should be separated from text + assertTrue(texts.any { it.contains("user") || it == "user" }, + "Should handle alphanumeric text") + } + + // ==================== Integration Tests ==================== + + @Test + fun `should handle realistic code query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "How to implement UserService with database connection pooling?", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from code query") + // Should contain stemmed versions of key terms + assertTrue( + keywords.any { it.contains("user") || it.contains("servic") || it.contains("databas") }, + "Should extract relevant programming terms, got: $keywords" + ) + } + + @Test + fun `should handle realistic Chinese tech query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ๅฆ‚ไฝ•ไฝฟ็”จๆ•ฐๆฎๅบ“่ฟžๆŽฅๆฑ ไผ˜ๅŒ–็ณป็ปŸๆ€ง่ƒฝ๏ผŸ", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from Chinese query") + } + + @Test + fun `should handle mixed language tech query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ๅฆ‚ไฝ•ๅฎž็ŽฐUserService็š„ๆ•ฐๆฎๅบ“connection pooling?", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from mixed language query") + } + + // ==================== Token Type Tests ==================== + + @Test + fun `should correctly identify token types`() { + val tokens = FallbackNlpTokenizer.tokenize("Helloไธ–็•Œv1.2.3") + + val englishTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.ENGLISH } + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + val codeTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CODE } + + assertTrue(englishTokens.isNotEmpty(), "Should have English tokens") + assertTrue(chineseTokens.isNotEmpty(), "Should have Chinese tokens") + assertTrue(codeTokens.isNotEmpty(), "Should have CODE tokens (SemVer)") + } + + // ==================== Performance Considerations ==================== + + @Test + fun `should handle long text efficiently`() { + val longText = "data processing ".repeat(100) + "ไบบๅทฅๆ™บ่ƒฝ".repeat(50) + + // Simply verify it completes without timeout (test framework handles timeout) + val keywords = FallbackNlpTokenizer.extractKeywords(longText, maxKeywords = 10) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from long text") + } +} + From b32681bbe948dd74ab8defad64c7a2c8e3ae101d Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 11:08:05 +0800 Subject: [PATCH 17/34] feat(chatdb): add inline data source config pane #508 Introduce an inline configuration pane for adding and editing data sources, replacing the previous dialog-based approach. The pane appears on the right side and supports saving and connecting actions. --- .../ui/compose/agent/chatdb/ChatDBPage.kt | 135 ++++------ .../compose/agent/chatdb/ChatDBViewModel.kt | 66 +++++ .../chatdb/components/DataSourceConfigPane.kt | 243 ++++++++++++++++++ .../agent/chatdb/model/DataSourceModels.kt | 16 +- 4 files changed, 375 insertions(+), 85 deletions(-) create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt index ebdbc566ae..6cc633717b 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -3,12 +3,10 @@ package cc.unitmesh.devins.ui.compose.agent.chatdb import androidx.compose.foundation.layout.* import androidx.compose.material3.* import androidx.compose.runtime.* -import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.base.ResizableSplitPane import cc.unitmesh.devins.ui.compose.agent.chatdb.components.* -import cc.unitmesh.devins.ui.compose.agent.chatdb.model.ConnectionStatus -import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons import cc.unitmesh.devins.workspace.Workspace import cc.unitmesh.llm.KoogLLMService import kotlinx.coroutines.flow.collectLatest @@ -16,8 +14,8 @@ import kotlinx.coroutines.flow.collectLatest /** * ChatDB Page - Main page for text-to-SQL agent * - * Left side: Data source management panel - * Right side: Chat area for natural language to SQL queries + * Left side: Data source management panel (resizable) + * Right side: Config pane (when adding/editing) or Chat area for natural language to SQL queries */ @OptIn(ExperimentalMaterial3Api::class) @Composable @@ -45,83 +43,62 @@ fun ChatDBPage( } } - Scaffold( - topBar = { - TopAppBar( - title = { Text("ChatDB") }, - navigationIcon = { - IconButton(onClick = onBack) { - Icon(AutoDevComposeIcons.ArrowBack, contentDescription = "Back") - } - }, - actions = { - // Schema info button when connected - if (state.connectionStatus is ConnectionStatus.Connected) { - var showSchemaDialog by remember { mutableStateOf(false) } - IconButton(onClick = { showSchemaDialog = true }) { - Icon(AutoDevComposeIcons.Schema, contentDescription = "View Schema") - } - if (showSchemaDialog) { - SchemaInfoDialog( - schema = viewModel.getSchema(), - onDismiss = { showSchemaDialog = false } - ) - } - } - } - ) - }, - modifier = modifier - ) { paddingValues -> - Row( + Scaffold(modifier = modifier) { paddingValues -> + ResizableSplitPane( modifier = Modifier .fillMaxSize() - .padding(paddingValues) - ) { - // Left panel - Data source management - DataSourcePanel( - dataSources = state.filteredDataSources, - selectedDataSourceId = state.selectedDataSourceId, - connectionStatus = state.connectionStatus, - filterQuery = state.filterQuery, - onFilterChange = viewModel::setFilterQuery, - onSelectDataSource = viewModel::selectDataSource, - onAddClick = viewModel::openAddDialog, - onEditClick = viewModel::openEditDialog, - onDeleteClick = viewModel::deleteDataSource, - onConnectClick = viewModel::connect, - onDisconnectClick = viewModel::disconnect, - modifier = Modifier.width(280.dp) - ) - - VerticalDivider() - - // Right panel - Chat area - ChatDBChatPane( - renderer = viewModel.renderer, - connectionStatus = state.connectionStatus, - schema = viewModel.getSchema(), - isGenerating = viewModel.isGenerating, - onSendMessage = viewModel::sendMessage, - onStopGeneration = viewModel::stopGeneration, - modifier = Modifier.weight(1f) - ) - } - - // Config dialog - if (state.isConfigDialogOpen) { - DataSourceConfigDialog( - existingConfig = state.editingDataSource, - onDismiss = viewModel::closeConfigDialog, - onSave = { config -> - if (state.editingDataSource != null) { - viewModel.updateDataSource(config) - } else { - viewModel.addDataSource(config) - } + .padding(paddingValues), + initialSplitRatio = 0.22f, + minRatio = 0.15f, + maxRatio = 0.4f, + saveKey = "chatdb_split_ratio", + first = { + // Left panel - Data source management + DataSourcePanel( + dataSources = state.filteredDataSources, + selectedDataSourceId = state.selectedDataSourceId, + connectionStatus = state.connectionStatus, + filterQuery = state.filterQuery, + onFilterChange = viewModel::setFilterQuery, + onSelectDataSource = { id -> + viewModel.selectDataSource(id) + // When selecting a different data source, show its config in the pane + val selected = state.dataSources.find { it.id == id } + if (selected != null && state.isConfigPaneOpen) { + viewModel.openConfigPane(selected) + } + }, + onAddClick = { viewModel.openConfigPane(null) }, + onEditClick = { config -> viewModel.openConfigPane(config) }, + onDeleteClick = viewModel::deleteDataSource, + onConnectClick = viewModel::connect, + onDisconnectClick = viewModel::disconnect, + modifier = Modifier.fillMaxSize() + ) + }, + second = { + // Right panel - Config pane or Chat area + if (state.isConfigPaneOpen) { + DataSourceConfigPane( + existingConfig = state.configuringDataSource, + onCancel = viewModel::closeConfigPane, + onSave = viewModel::saveFromPane, + onSaveAndConnect = viewModel::saveAndConnectFromPane, + modifier = Modifier.fillMaxSize() + ) + } else { + ChatDBChatPane( + renderer = viewModel.renderer, + connectionStatus = state.connectionStatus, + schema = viewModel.getSchema(), + isGenerating = viewModel.isGenerating, + onSendMessage = viewModel::sendMessage, + onStopGeneration = viewModel::stopGeneration, + modifier = Modifier.fillMaxSize() + ) } - ) - } + } + ) } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index fd53238234..3e2dc45212 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -138,6 +138,72 @@ class ChatDBViewModel( state = state.copy(isConfigDialogOpen = false, editingDataSource = null) } + // --- Config Pane methods (inline panel mode) --- + + fun openConfigPane(config: DataSourceConfig? = null) { + state = state.copy(isConfigPaneOpen = true, configuringDataSource = config) + } + + fun closeConfigPane() { + state = state.copy(isConfigPaneOpen = false, configuringDataSource = null) + } + + @OptIn(ExperimentalUuidApi::class) + fun saveFromPane(config: DataSourceConfig) { + val isNew = config.id.isBlank() + val savedConfig = config.copy( + id = if (isNew) Uuid.random().toString() else config.id, + createdAt = if (isNew) kotlinx.datetime.Clock.System.now().toEpochMilliseconds() else state.configuringDataSource?.createdAt ?: 0L, + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + + state = if (isNew) { + state.copy( + dataSources = state.dataSources + savedConfig, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceId = savedConfig.id + ) + } else { + state.copy( + dataSources = state.dataSources.map { if (it.id == savedConfig.id) savedConfig else it }, + isConfigPaneOpen = false, + configuringDataSource = null + ) + } + saveDataSources() + } + + @OptIn(ExperimentalUuidApi::class) + fun saveAndConnectFromPane(config: DataSourceConfig) { + val isNew = config.id.isBlank() + val savedConfig = config.copy( + id = if (isNew) Uuid.random().toString() else config.id, + createdAt = if (isNew) kotlinx.datetime.Clock.System.now().toEpochMilliseconds() else state.configuringDataSource?.createdAt ?: 0L, + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + + state = if (isNew) { + state.copy( + dataSources = state.dataSources + savedConfig, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceId = savedConfig.id + ) + } else { + state.copy( + dataSources = state.dataSources.map { if (it.id == savedConfig.id) savedConfig else it }, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceId = savedConfig.id + ) + } + saveDataSources() + + // Trigger connection + connect() + } + fun connect() { val dataSource = state.selectedDataSource ?: return diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt new file mode 100644 index 0000000000..e6a4951924 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt @@ -0,0 +1,243 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Inline configuration pane for adding/editing data source + * Displayed on the right side instead of a dialog + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun DataSourceConfigPane( + existingConfig: DataSourceConfig?, + onCancel: () -> Unit, + onSave: (DataSourceConfig) -> Unit, + onSaveAndConnect: (DataSourceConfig) -> Unit, + modifier: Modifier = Modifier +) { + var name by remember(existingConfig) { mutableStateOf(existingConfig?.name ?: "") } + var dialect by remember(existingConfig) { mutableStateOf(existingConfig?.dialect ?: DatabaseDialect.MYSQL) } + var host by remember(existingConfig) { mutableStateOf(existingConfig?.host ?: "localhost") } + var port by remember(existingConfig) { mutableStateOf(existingConfig?.port?.toString() ?: "3306") } + var database by remember(existingConfig) { mutableStateOf(existingConfig?.database ?: "") } + var username by remember(existingConfig) { mutableStateOf(existingConfig?.username ?: "") } + var password by remember(existingConfig) { mutableStateOf(existingConfig?.password ?: "") } + var description by remember(existingConfig) { mutableStateOf(existingConfig?.description ?: "") } + var showPassword by remember { mutableStateOf(false) } + var dialectExpanded by remember { mutableStateOf(false) } + + val isEditing = existingConfig != null + val scrollState = rememberScrollState() + + val isValid = name.isNotBlank() && database.isNotBlank() && + (dialect == DatabaseDialect.SQLITE || host.isNotBlank()) + + fun buildConfig(): DataSourceConfig = DataSourceConfig( + id = existingConfig?.id ?: "", + name = name.trim(), + dialect = dialect, + host = host.trim(), + port = port.toIntOrNull() ?: dialect.defaultPort, + database = database.trim(), + username = username.trim(), + password = password, + description = description.trim() + ) + + Column( + modifier = modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface) + ) { + // Header + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = if (isEditing) "Edit Data Source" else "Add Data Source", + style = MaterialTheme.typography.titleMedium + ) + IconButton(onClick = onCancel) { + Icon(AutoDevComposeIcons.Close, contentDescription = "Close") + } + } + } + + HorizontalDivider() + + // Form content + Column( + modifier = Modifier + .weight(1f) + .verticalScroll(scrollState) + .padding(16.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Name + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name *") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + // Dialect dropdown + ExposedDropdownMenuBox( + expanded = dialectExpanded, + onExpandedChange = { dialectExpanded = it } + ) { + OutlinedTextField( + value = dialect.displayName, + onValueChange = {}, + readOnly = true, + label = { Text("Database Type") }, + trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, + modifier = Modifier.fillMaxWidth().menuAnchor() + ) + ExposedDropdownMenu( + expanded = dialectExpanded, + onDismissRequest = { dialectExpanded = false } + ) { + DatabaseDialect.entries.forEach { option -> + DropdownMenuItem( + text = { Text(option.displayName) }, + onClick = { + dialect = option + port = option.defaultPort.toString() + dialectExpanded = false + } + ) + } + } + } + + // Host and Port (not for SQLite) + if (dialect != DatabaseDialect.SQLITE) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedTextField( + value = host, + onValueChange = { host = it }, + label = { Text("Host *") }, + modifier = Modifier.weight(2f), + singleLine = true + ) + OutlinedTextField( + value = port, + onValueChange = { port = it.filter { c -> c.isDigit() } }, + label = { Text("Port *") }, + modifier = Modifier.weight(1f), + singleLine = true, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number) + ) + } + } + + // Database name + OutlinedTextField( + value = database, + onValueChange = { database = it }, + label = { Text(if (dialect == DatabaseDialect.SQLITE) "File Path *" else "Database *") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + // Username and Password (not for SQLite) + if (dialect != DatabaseDialect.SQLITE) { + OutlinedTextField( + value = username, + onValueChange = { username = it }, + label = { Text("Username") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + OutlinedTextField( + value = password, + onValueChange = { password = it }, + label = { Text("Password") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), + trailingIcon = { + IconButton(onClick = { showPassword = !showPassword }) { + Icon( + if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, + contentDescription = if (showPassword) "Hide" else "Show" + ) + } + } + ) + } + + // Description + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description") }, + modifier = Modifier.fillMaxWidth(), + minLines = 2, + maxLines = 3 + ) + } + + HorizontalDivider() + + // Action buttons + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp, Alignment.End), + verticalAlignment = Alignment.CenterVertically + ) { + TextButton(onClick = onCancel) { + Text("Cancel") + } + OutlinedButton( + onClick = { onSave(buildConfig()) }, + enabled = isValid + ) { + Text("Save") + } + Button( + onClick = { onSaveAndConnect(buildConfig()) }, + enabled = isValid + ) { + Text("Save & Connect") + } + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt index bd9df347fd..d06e3249c3 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt @@ -15,9 +15,9 @@ enum class DatabaseDialect(val displayName: String, val defaultPort: Int) { companion object { fun fromString(value: String): DatabaseDialect { - return entries.find { - it.name.equals(value, ignoreCase = true) || - it.displayName.equals(value, ignoreCase = true) + return entries.find { + it.name.equals(value, ignoreCase = true) || + it.displayName.equals(value, ignoreCase = true) } ?: MYSQL } } @@ -56,14 +56,14 @@ data class DataSourceConfig( */ fun validate(): ValidationResult { val errors = mutableListOf() - + if (name.isBlank()) errors.add("Name is required") if (dialect != DatabaseDialect.SQLITE) { if (host.isBlank()) errors.add("Host is required") if (port <= 0 || port > 65535) errors.add("Port must be between 1 and 65535") } if (database.isBlank()) errors.add("Database name is required") - + return if (errors.isEmpty()) { ValidationResult.Valid } else { @@ -149,7 +149,11 @@ data class ChatDBState( val isLoading: Boolean = false, val error: String? = null, val isConfigDialogOpen: Boolean = false, - val editingDataSource: DataSourceConfig? = null + val editingDataSource: DataSourceConfig? = null, + /** Whether the config pane is shown (inline panel mode) */ + val isConfigPaneOpen: Boolean = false, + /** The data source being configured in the pane */ + val configuringDataSource: DataSourceConfig? = null ) { val selectedDataSource: DataSourceConfig? get() = dataSources.find { it.id == selectedDataSourceId } From a4eacbd23dead4c958f49b6ecee7719db70ca5b5 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 12:52:32 +0800 Subject: [PATCH 18/34] refactor(agent): migrate SQL validator to multiplatform #508 Replaces JSqlParserValidator with platform-specific SqlValidator implementations and moves ChatDBAgent classes to commonMain for improved multiplatform support. --- .../agent/subagent/SqlValidator.android.kt | 125 +++++++++++++++ .../cc/unitmesh/agent/chatdb/ChatDBAgent.kt | 0 .../agent/chatdb/ChatDBAgentExecutor.kt | 10 +- .../unitmesh/agent/subagent/SqlReviseAgent.kt | 2 + .../unitmesh/agent/subagent/SqlValidator.kt | 52 ++++++ .../agent/subagent/SqlValidator.ios.kt | 140 ++++++++++++++++ .../agent/subagent/SqlValidator.js.kt | 150 ++++++++++++++++++ ...ParserValidator.kt => SqlValidator.jvm.kt} | 11 +- .../agent/subagent/JSqlParserValidatorTest.kt | 12 +- .../agent/subagent/SqlValidator.wasmJs.kt | 140 ++++++++++++++++ .../compose/agent/chatdb/ChatDBViewModel.kt | 89 +++++------ 11 files changed, 665 insertions(+), 66 deletions(-) create mode 100644 mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt rename mpp-core/src/{jvmMain => commonMain}/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt (100%) rename mpp-core/src/{jvmMain => commonMain}/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt (98%) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt create mode 100644 mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt create mode 100644 mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt rename mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/{JSqlParserValidator.kt => SqlValidator.jvm.kt} (94%) create mode 100644 mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt diff --git a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt new file mode 100644 index 0000000000..f3e1e12170 --- /dev/null +++ b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt @@ -0,0 +1,125 @@ +package cc.unitmesh.agent.subagent + +import net.sf.jsqlparser.parser.CCJSqlParserUtil +import net.sf.jsqlparser.statement.Statement +import net.sf.jsqlparser.util.TablesNamesFinder + +/** + * Android implementation of SqlValidator using JSqlParser. + * + * This validator uses JSqlParser to validate SQL syntax. + * It can detect: + * - Syntax errors + * - Malformed SQL statements + * - Unsupported SQL constructs + * - Table names not in whitelist (schema validation) + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + + // Extract table names from the SQL + val tablesNamesFinder = TablesNamesFinder() + val usedTables = tablesNamesFinder.getTableList(statement) + + // Check if all used tables are in the whitelist (case-insensitive) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = collectWarnings(statement) + ) + } else { + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + actual override fun extractTableNames(sql: String): List { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + val tablesNamesFinder = TablesNamesFinder() + tablesNamesFinder.getTableList(statement) + } catch (e: Exception) { + emptyList() + } + } + + private fun extractErrorMessage(e: Exception): String { + val message = e.message ?: "Unknown SQL parsing error" + return when { + message.contains("Encountered") -> { + val match = Regex("Encountered \"(.+?)\" at line (\\d+), column (\\d+)").find(message) + if (match != null) { + val (token, line, column) = match.destructured + "Syntax error at line $line, column $column: unexpected token '$token'" + } else { + message + } + } + message.contains("Was expecting") -> { + val match = Regex("Was expecting.*?:\\s*(.+)").find(message) + if (match != null) { + "Expected: ${match.groupValues[1].take(100)}" + } else { + message + } + } + else -> message.take(200) + } + } + + private fun collectWarnings(statement: Statement): List { + val warnings = mutableListOf() + val sql = statement.toString() + + if (sql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!sql.contains("WHERE", ignoreCase = true) && + (sql.contains("UPDATE", ignoreCase = true) || sql.contains("DELETE", ignoreCase = true))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return warnings + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt similarity index 100% rename from mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt rename to mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt similarity index 98% rename from mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt rename to mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 097313cb4e..788d56d98d 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -6,7 +6,7 @@ import cc.unitmesh.agent.executor.BaseAgentExecutor import cc.unitmesh.agent.logging.getLogger import cc.unitmesh.agent.orchestrator.ToolOrchestrator import cc.unitmesh.agent.render.CodingAgentRenderer -import cc.unitmesh.agent.subagent.JSqlParserValidator +import cc.unitmesh.agent.subagent.SqlValidator import cc.unitmesh.agent.subagent.SqlReviseAgent import cc.unitmesh.agent.subagent.SqlRevisionInput import cc.unitmesh.devins.parser.CodeFence @@ -49,8 +49,8 @@ class ChatDBAgentExecutor( keywordSchemaLinker } - private val jsqlValidator = JSqlParserValidator() - private val sqlReviseAgent = SqlReviseAgent(llmService, jsqlValidator) + private val sqlValidator = SqlValidator() + private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) private val maxRevisionAttempts = 3 private val maxExecutionRetries = 3 @@ -114,9 +114,9 @@ class ChatDBAgentExecutor( val allTableNames = schema.tables.map { it.name }.toSet() // First validate syntax, then validate table names - val syntaxValidation = jsqlValidator.validate(validatedSql!!) + val syntaxValidation = sqlValidator.validate(validatedSql!!) val tableValidation = if (syntaxValidation.isValid) { - jsqlValidator.validateWithTableWhitelist(validatedSql, allTableNames) + sqlValidator.validateWithTableWhitelist(validatedSql, allTableNames) } else { syntaxValidation } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt index dd280f05c7..1fd2c92950 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt @@ -276,6 +276,8 @@ data class SqlRevisionInput( */ interface SqlValidatorInterface { fun validate(sql: String): SqlValidationResult + fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult + fun extractTableNames(sql: String): List } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt new file mode 100644 index 0000000000..cbb2a5f5d6 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt @@ -0,0 +1,52 @@ +package cc.unitmesh.agent.subagent + +/** + * Platform-agnostic SQL validator. + * + * JVM platforms use JSqlParser for full SQL parsing and validation. + * Non-JVM platforms perform basic syntax checks. + * + * Usage: + * ```kotlin + * val validator = SqlValidator() + * val result = validator.validate("SELECT * FROM users") + * if (result.isValid) { + * // SQL is valid + * } else { + * // Handle errors: result.errors + * } + * ``` + */ +expect class SqlValidator() : SqlValidatorInterface { + /** + * Validate SQL syntax. + * + * @param sql The SQL query to validate + * @return Validation result with errors and warnings + */ + override fun validate(sql: String): SqlValidationResult + + /** + * Validate SQL with table whitelist - ensures only allowed tables are used. + * + * On JVM platforms, this uses JSqlParser to extract table names and validate. + * On non-JVM platforms, this performs basic regex-based table name extraction. + * + * @param sql The SQL query to validate + * @param allowedTables Set of table names that are allowed in the query + * @return SqlValidationResult with errors if invalid tables are used + */ + override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult + + /** + * Extract table names from SQL query. + * + * On JVM platforms, this uses JSqlParser for accurate extraction. + * On non-JVM platforms, this uses regex-based extraction which may be less accurate. + * + * @param sql The SQL query to extract table names from + * @return List of table names found in the query + */ + override fun extractTableNames(sql: String): List +} + diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt new file mode 100644 index 0000000000..e1c0db3f3a --- /dev/null +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt @@ -0,0 +1,140 @@ +package cc.unitmesh.agent.subagent + +/** + * iOS implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause + val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match JOIN clause + val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match UPDATE clause + val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match INSERT INTO clause + val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match DELETE FROM clause + val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + return tables.distinct() + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt new file mode 100644 index 0000000000..702a298ee9 --- /dev/null +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt @@ -0,0 +1,150 @@ +package cc.unitmesh.agent.subagent + +/** + * JavaScript implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause: FROM table_name or FROM schema.table_name + val fromPattern = Regex("""FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match JOIN clause: JOIN table_name or LEFT/RIGHT/INNER/OUTER JOIN table_name + val joinPattern = Regex("""(?:LEFT|RIGHT|INNER|OUTER|CROSS|FULL)?\s*JOIN\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match UPDATE clause: UPDATE table_name + val updatePattern = Regex("""UPDATE\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match INSERT INTO clause: INSERT INTO table_name + val insertPattern = Regex("""INSERT\s+INTO\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match DELETE FROM clause: DELETE FROM table_name + val deletePattern = Regex("""DELETE\s+FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + return tables.distinct() + } + + private fun cleanTableName(name: String): String { + // Remove quotes and brackets, extract table name from schema.table format + val cleaned = name.replace(Regex("""[`"\[\]]"""), "") + return if (cleaned.contains(".")) { + cleaned.substringAfterLast(".") + } else { + cleaned + } + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt similarity index 94% rename from mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt rename to mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt index 219c48a591..cba40aa96c 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidator.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt @@ -2,11 +2,10 @@ package cc.unitmesh.agent.subagent import net.sf.jsqlparser.parser.CCJSqlParserUtil import net.sf.jsqlparser.statement.Statement -import net.sf.jsqlparser.statement.select.Select import net.sf.jsqlparser.util.TablesNamesFinder /** - * JVM implementation of SqlValidatorInterface using JSqlParser + * JVM implementation of SqlValidator using JSqlParser. * * This validator uses JSqlParser to validate SQL syntax. * It can detect: @@ -15,9 +14,9 @@ import net.sf.jsqlparser.util.TablesNamesFinder * - Unsupported SQL constructs * - Table names not in whitelist (schema validation) */ -class JSqlParserValidator : SqlValidatorInterface { +actual class SqlValidator actual constructor() : SqlValidatorInterface { - override fun validate(sql: String): SqlValidationResult { + actual override fun validate(sql: String): SqlValidationResult { return try { val statement: Statement = CCJSqlParserUtil.parse(sql) SqlValidationResult( @@ -41,7 +40,7 @@ class JSqlParserValidator : SqlValidatorInterface { * @param allowedTables Set of table names that are allowed in the query * @return SqlValidationResult with errors if invalid tables are used */ - fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { return try { val statement: Statement = CCJSqlParserUtil.parse(sql) @@ -83,7 +82,7 @@ class JSqlParserValidator : SqlValidatorInterface { /** * Extract table names from SQL query */ - fun extractTableNames(sql: String): List { + actual override fun extractTableNames(sql: String): List { return try { val statement: Statement = CCJSqlParserUtil.parse(sql) val tablesNamesFinder = TablesNamesFinder() diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt index 9b2ba5c381..b8878e87d8 100644 --- a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt @@ -4,11 +4,11 @@ import org.junit.Test import kotlin.test.* /** - * Tests for JSqlParserValidator - JVM-specific SQL validation using JSqlParser + * Tests for SqlValidator - JVM-specific SQL validation using JSqlParser */ class JSqlParserValidatorTest { - private val validator = JSqlParserValidator() + private val validator = SqlValidator() // ============= Basic Validation Tests ============= @@ -144,7 +144,9 @@ class JSqlParserValidatorTest { @Test fun testValidateAndParseValid() { - val (result, statement) = validator.validateAndParse("SELECT * FROM users") + val pair = validator.validateAndParse("SELECT * FROM users") + val result = pair.first + val statement = pair.second assertTrue(result.isValid) assertNotNull(statement) @@ -152,7 +154,9 @@ class JSqlParserValidatorTest { @Test fun testValidateAndParseInvalid() { - val (result, statement) = validator.validateAndParse("SELECT * FORM users") + val pair = validator.validateAndParse("SELECT * FORM users") + val result = pair.first + val statement = pair.second assertFalse(result.isValid) assertNull(statement) diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt new file mode 100644 index 0000000000..d1106660aa --- /dev/null +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt @@ -0,0 +1,140 @@ +package cc.unitmesh.agent.subagent + +/** + * WASM JavaScript implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause + val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match JOIN clause + val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match UPDATE clause + val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match INSERT INTO clause + val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match DELETE FROM clause + val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + return tables.distinct() + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 3e2dc45212..73b5d14a93 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -3,6 +3,10 @@ package cc.unitmesh.devins.ui.compose.agent.chatdb import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue +import cc.unitmesh.agent.chatdb.ChatDBAgent +import cc.unitmesh.agent.chatdb.ChatDBTask +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.config.ToolConfigFile import cc.unitmesh.agent.database.DatabaseConnection import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.createDatabaseConnection @@ -255,24 +259,44 @@ class ChatDBViewModel( return@launch } - val schemaContext = currentSchema?.getDescription() ?: "No database connected" - val systemPrompt = buildSystemPrompt(schemaContext) - - val response = StringBuilder() - renderer.renderLLMResponseStart() + val dataSource = state.selectedDataSource + if (dataSource == null) { + renderer.renderError("No database selected. Please select a data source first.") + isGenerating = false + return@launch + } - service.streamPrompt( - userPrompt = "$systemPrompt\n\nUser: $text", - compileDevIns = false - ).collect { chunk -> - response.append(chunk) - renderer.renderLLMResponseChunk(chunk) + val databaseConfig = dataSource.toDatabaseConfig() + val projectPath = workspace?.rootPath ?: "." + val mcpConfigService = McpToolConfigService(ToolConfigFile()) + + val agent = ChatDBAgent( + projectPath = projectPath, + llmService = service, + databaseConfig = databaseConfig, + maxIterations = 10, + renderer = renderer, + mcpToolConfigService = mcpConfigService, + enableLLMStreaming = true + ) + + val task = ChatDBTask( + query = text, + maxRows = 100, + generateVisualization = false + ) + + val result = agent.execute(task) { progress -> + // Progress callback - can be used for UI updates + println("[ChatDB] Progress: $progress") } - renderer.renderLLMResponseEnd() + if (!result.success) { + renderer.renderError("Query failed: ${result.content}") + } - // Try to extract and execute SQL if present - extractAndExecuteSQL(response.toString()) + // Close the agent connection + agent.close() } catch (e: CancellationException) { renderer.forceStop() @@ -286,43 +310,6 @@ class ChatDBViewModel( } } - private fun buildSystemPrompt(schemaContext: String): String { - return """You are a helpful SQL assistant. You help users write SQL queries based on their natural language questions. - -## Database Schema -$schemaContext - -## Instructions -1. Analyze the user's question and understand what data they need -2. Generate the appropriate SQL query -3. Wrap SQL queries in ```sql code blocks -4. Explain your query briefly - -## Rules -- Only generate SELECT queries for safety -- Always use proper table and column names from the schema -- If you're unsure about the schema, ask for clarification -""" - } - - private suspend fun extractAndExecuteSQL(response: String) { - val sqlPattern = Regex("```sql\\n([\\s\\S]*?)```", RegexOption.IGNORE_CASE) - val match = sqlPattern.find(response) - - if (match != null && currentConnection != null) { - val sql = match.groupValues[1].trim() - try { - val result = currentConnection!!.executeQuery(sql) - // Display query result as a new message - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk("\n\n**Query Result:**\n```\n${result.toTableString()}\n```") - renderer.renderLLMResponseEnd() - } catch (e: Exception) { - renderer.renderError("Query Error: ${e.message}") - } - } - } - fun stopGeneration() { currentExecutionJob?.cancel() isGenerating = false From 668ab073d285b5f2244f68694ef63f569ba9cba8 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 13:08:02 +0800 Subject: [PATCH 19/34] refactor(chatdb): remove DataSourceConfigDialog component #508 Deleted unused DataSourceConfigDialog to clean up codebase. --- .../components/DataSourceConfigDialog.kt | 222 ------------------ 1 file changed, 222 deletions(-) delete mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt deleted file mode 100644 index dd5a602402..0000000000 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigDialog.kt +++ /dev/null @@ -1,222 +0,0 @@ -package cc.unitmesh.devins.ui.compose.agent.chatdb.components - -import androidx.compose.foundation.layout.* -import androidx.compose.foundation.rememberScrollState -import androidx.compose.foundation.shape.RoundedCornerShape -import androidx.compose.foundation.text.KeyboardOptions -import androidx.compose.foundation.verticalScroll -import androidx.compose.material3.* -import androidx.compose.runtime.* -import androidx.compose.ui.Alignment -import androidx.compose.ui.Modifier -import androidx.compose.ui.text.input.KeyboardType -import androidx.compose.ui.text.input.PasswordVisualTransformation -import androidx.compose.ui.text.input.VisualTransformation -import androidx.compose.ui.unit.dp -import androidx.compose.ui.window.Dialog -import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* -import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons - -/** - * Dialog for adding/editing data source configuration - */ -@OptIn(ExperimentalMaterial3Api::class) -@Composable -fun DataSourceConfigDialog( - existingConfig: DataSourceConfig?, - onDismiss: () -> Unit, - onSave: (DataSourceConfig) -> Unit -) { - var name by remember { mutableStateOf(existingConfig?.name ?: "") } - var dialect by remember { mutableStateOf(existingConfig?.dialect ?: DatabaseDialect.MYSQL) } - var host by remember { mutableStateOf(existingConfig?.host ?: "localhost") } - var port by remember { mutableStateOf(existingConfig?.port?.toString() ?: "3306") } - var database by remember { mutableStateOf(existingConfig?.database ?: "") } - var username by remember { mutableStateOf(existingConfig?.username ?: "") } - var password by remember { mutableStateOf(existingConfig?.password ?: "") } - var description by remember { mutableStateOf(existingConfig?.description ?: "") } - var showPassword by remember { mutableStateOf(false) } - var dialectExpanded by remember { mutableStateOf(false) } - - val isEditing = existingConfig != null - val scrollState = rememberScrollState() - - Dialog(onDismissRequest = onDismiss) { - Card( - modifier = Modifier - .fillMaxWidth() - .padding(16.dp), - shape = RoundedCornerShape(16.dp) - ) { - Column( - modifier = Modifier - .padding(24.dp) - .verticalScroll(scrollState) - ) { - Text( - text = if (isEditing) "Edit Data Source" else "Add Data Source", - style = MaterialTheme.typography.headlineSmall - ) - - Spacer(modifier = Modifier.height(24.dp)) - - // Name - OutlinedTextField( - value = name, - onValueChange = { name = it }, - label = { Text("Name *") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true - ) - - Spacer(modifier = Modifier.height(16.dp)) - - // Dialect dropdown - ExposedDropdownMenuBox( - expanded = dialectExpanded, - onExpandedChange = { dialectExpanded = it } - ) { - OutlinedTextField( - value = dialect.displayName, - onValueChange = {}, - readOnly = true, - label = { Text("Database Type") }, - trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, - modifier = Modifier.fillMaxWidth().menuAnchor() - ) - ExposedDropdownMenu( - expanded = dialectExpanded, - onDismissRequest = { dialectExpanded = false } - ) { - DatabaseDialect.entries.forEach { option -> - DropdownMenuItem( - text = { Text(option.displayName) }, - onClick = { - dialect = option - port = option.defaultPort.toString() - dialectExpanded = false - } - ) - } - } - } - - Spacer(modifier = Modifier.height(16.dp)) - - // Host and Port - if (dialect != DatabaseDialect.SQLITE) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.spacedBy(12.dp) - ) { - OutlinedTextField( - value = host, - onValueChange = { host = it }, - label = { Text("Host *") }, - modifier = Modifier.weight(2f), - singleLine = true - ) - OutlinedTextField( - value = port, - onValueChange = { port = it.filter { c -> c.isDigit() } }, - label = { Text("Port *") }, - modifier = Modifier.weight(1f), - singleLine = true, - keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number) - ) - } - Spacer(modifier = Modifier.height(16.dp)) - } - - // Database name - OutlinedTextField( - value = database, - onValueChange = { database = it }, - label = { Text(if (dialect == DatabaseDialect.SQLITE) "File Path *" else "Database *") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true - ) - - if (dialect != DatabaseDialect.SQLITE) { - Spacer(modifier = Modifier.height(16.dp)) - - // Username - OutlinedTextField( - value = username, - onValueChange = { username = it }, - label = { Text("Username") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true - ) - - Spacer(modifier = Modifier.height(16.dp)) - - // Password - OutlinedTextField( - value = password, - onValueChange = { password = it }, - label = { Text("Password") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true, - visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), - trailingIcon = { - IconButton(onClick = { showPassword = !showPassword }) { - Icon( - if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, - contentDescription = if (showPassword) "Hide" else "Show" - ) - } - } - ) - } - - Spacer(modifier = Modifier.height(16.dp)) - - // Description - OutlinedTextField( - value = description, - onValueChange = { description = it }, - label = { Text("Description") }, - modifier = Modifier.fillMaxWidth(), - minLines = 2, - maxLines = 3 - ) - - Spacer(modifier = Modifier.height(24.dp)) - - // Buttons - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.End, - verticalAlignment = Alignment.CenterVertically - ) { - TextButton(onClick = onDismiss) { - Text("Cancel") - } - Spacer(modifier = Modifier.width(8.dp)) - Button( - onClick = { - val config = DataSourceConfig( - id = existingConfig?.id ?: "", - name = name.trim(), - dialect = dialect, - host = host.trim(), - port = port.toIntOrNull() ?: dialect.defaultPort, - database = database.trim(), - username = username.trim(), - password = password, - description = description.trim() - ) - onSave(config) - }, - enabled = name.isNotBlank() && database.isNotBlank() && - (dialect == DatabaseDialect.SQLITE || host.isNotBlank()) - ) { - Text(if (isEditing) "Save" else "Add") - } - } - } - } - } -} - From fe33fe5dd6fabd0274d8a81f33688ed1c1e41306 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 13:21:58 +0800 Subject: [PATCH 20/34] feat(chatdb): add persistent data source repository #508 Implement cross-platform DataSourceRepository for saving, loading, and managing database connection configurations. Refactor ChatDBViewModel to use repository for persistence. Update DataSourceConfigPane for improved UI. --- .../devins/db/DataSourceRepository.kt | 53 ++++ .../compose/agent/chatdb/ChatDBViewModel.kt | 67 ++++- .../chatdb/components/DataSourceConfigPane.kt | 250 +++++++++++------- .../cc/unitmesh/devins/db/DataSource.sq | 40 +++ .../devins/db/DataSourceRepository.ios.kt | 83 ++++++ .../devins/db/DataSourceRepository.js.kt | 47 ++++ .../devins/db/DataSourceRepository.jvm.kt | 85 ++++++ .../devins/db/DataSourceRepository.kt | 188 +++++++++++++ 8 files changed, 706 insertions(+), 107 deletions(-) create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt create mode 100644 mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq create mode 100644 mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt create mode 100644 mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt create mode 100644 mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt create mode 100644 mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt new file mode 100644 index 0000000000..de1d57c955 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt @@ -0,0 +1,53 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - Data access layer for database connection configurations + * Uses expect/actual pattern for cross-platform support + */ +expect class DataSourceRepository { + /** + * Get all data source configurations + */ + fun getAll(): List + + /** + * Get a data source by ID + */ + fun getById(id: String): DataSourceConfig? + + /** + * Get the default data source + */ + fun getDefault(): DataSourceConfig? + + /** + * Save a data source configuration (insert or update) + */ + fun save(config: DataSourceConfig) + + /** + * Delete a data source by ID + */ + fun delete(id: String) + + /** + * Delete all data sources + */ + fun deleteAll() + + /** + * Set a data source as default + */ + fun setDefault(id: String) + + companion object { + /** + * Get singleton instance + */ + fun getInstance(): DataSourceRepository + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 73b5d14a93..7f271c41aa 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -11,6 +11,7 @@ import cc.unitmesh.agent.database.DatabaseConnection import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.createDatabaseConnection import cc.unitmesh.config.ConfigManager +import cc.unitmesh.devins.db.DataSourceRepository import cc.unitmesh.devins.ui.compose.agent.ComposeRenderer import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* import cc.unitmesh.devins.workspace.Workspace @@ -41,6 +42,11 @@ class ChatDBViewModel( private var currentConnection: DatabaseConnection? = null private var currentSchema: DatabaseSchema? = null + // Data source repository for persistence + private val dataSourceRepository: DataSourceRepository by lazy { + DataSourceRepository.getInstance() + } + // UI State var state by mutableStateOf(ChatDBState()) private set @@ -72,9 +78,20 @@ class ChatDBViewModel( } private fun loadDataSources() { - // TODO: Load from persistent storage - // For now, use empty list - state = state.copy(dataSources = emptyList()) + scope.launch { + try { + val dataSources = dataSourceRepository.getAll() + val defaultDataSource = dataSourceRepository.getDefault() + state = state.copy( + dataSources = dataSources, + selectedDataSourceId = defaultDataSource?.id + ) + println("[ChatDB] Loaded ${dataSources.size} data sources") + } catch (e: Exception) { + println("[ChatDB] Failed to load data sources: ${e.message}") + state = state.copy(dataSources = emptyList()) + } + } } @OptIn(ExperimentalUuidApi::class) @@ -89,7 +106,7 @@ class ChatDBViewModel( isConfigDialogOpen = false, editingDataSource = null ) - saveDataSources() + saveDataSource(newConfig) } fun updateDataSource(config: DataSourceConfig) { @@ -101,7 +118,7 @@ class ChatDBViewModel( isConfigDialogOpen = false, editingDataSource = null ) - saveDataSources() + saveDataSource(updated) } fun deleteDataSource(id: String) { @@ -112,11 +129,43 @@ class ChatDBViewModel( if (state.selectedDataSourceId == id) { disconnect() } - saveDataSources() + deleteDataSourceFromRepository(id) } private fun saveDataSources() { - // TODO: Persist to storage + scope.launch { + try { + // Save all data sources to repository + state.dataSources.forEach { config -> + dataSourceRepository.save(config) + } + println("[ChatDB] Saved ${state.dataSources.size} data sources") + } catch (e: Exception) { + println("[ChatDB] Failed to save data sources: ${e.message}") + } + } + } + + private fun saveDataSource(config: DataSourceConfig) { + scope.launch { + try { + dataSourceRepository.save(config) + println("[ChatDB] Saved data source: ${config.id}") + } catch (e: Exception) { + println("[ChatDB] Failed to save data source: ${e.message}") + } + } + } + + private fun deleteDataSourceFromRepository(id: String) { + scope.launch { + try { + dataSourceRepository.delete(id) + println("[ChatDB] Deleted data source: $id") + } catch (e: Exception) { + println("[ChatDB] Failed to delete data source: ${e.message}") + } + } } fun selectDataSource(id: String) { @@ -175,7 +224,7 @@ class ChatDBViewModel( configuringDataSource = null ) } - saveDataSources() + saveDataSource(savedConfig) } @OptIn(ExperimentalUuidApi::class) @@ -202,7 +251,7 @@ class ChatDBViewModel( selectedDataSourceId = savedConfig.id ) } - saveDataSources() + saveDataSource(savedConfig) // Trigger connection connect() diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt index e6a4951924..d3a4699d5f 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt @@ -1,9 +1,12 @@ package cc.unitmesh.devins.ui.compose.agent.chatdb.components +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.shrinkVertically import androidx.compose.foundation.background +import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.* import androidx.compose.foundation.rememberScrollState -import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.foundation.verticalScroll import androidx.compose.material3.* @@ -40,6 +43,7 @@ fun DataSourceConfigPane( var description by remember(existingConfig) { mutableStateOf(existingConfig?.description ?: "") } var showPassword by remember { mutableStateOf(false) } var dialectExpanded by remember { mutableStateOf(false) } + var advancedExpanded by remember { mutableStateOf(existingConfig?.description?.isNotBlank() == true) } val isEditing = existingConfig != null val scrollState = rememberScrollState() @@ -72,65 +76,72 @@ fun DataSourceConfigPane( Row( modifier = Modifier .fillMaxWidth() - .padding(16.dp), + .padding(horizontal = 12.dp, vertical = 8.dp), horizontalArrangement = Arrangement.SpaceBetween, verticalAlignment = Alignment.CenterVertically ) { Text( text = if (isEditing) "Edit Data Source" else "Add Data Source", - style = MaterialTheme.typography.titleMedium + style = MaterialTheme.typography.titleSmall ) - IconButton(onClick = onCancel) { - Icon(AutoDevComposeIcons.Close, contentDescription = "Close") + IconButton(onClick = onCancel, modifier = Modifier.size(32.dp)) { + Icon(AutoDevComposeIcons.Close, contentDescription = "Close", modifier = Modifier.size(18.dp)) } } } HorizontalDivider() - // Form content + // Form content - more compact Column( modifier = Modifier .weight(1f) .verticalScroll(scrollState) - .padding(16.dp), - verticalArrangement = Arrangement.spacedBy(16.dp) + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) ) { - // Name - OutlinedTextField( - value = name, - onValueChange = { name = it }, - label = { Text("Name *") }, + // Name and Database Type in one row + Row( modifier = Modifier.fillMaxWidth(), - singleLine = true - ) - - // Dialect dropdown - ExposedDropdownMenuBox( - expanded = dialectExpanded, - onExpandedChange = { dialectExpanded = it } + horizontalArrangement = Arrangement.spacedBy(8.dp) ) { OutlinedTextField( - value = dialect.displayName, - onValueChange = {}, - readOnly = true, - label = { Text("Database Type") }, - trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, - modifier = Modifier.fillMaxWidth().menuAnchor() + value = name, + onValueChange = { name = it }, + label = { Text("Name *") }, + modifier = Modifier.weight(1f), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall ) - ExposedDropdownMenu( + ExposedDropdownMenuBox( expanded = dialectExpanded, - onDismissRequest = { dialectExpanded = false } + onExpandedChange = { dialectExpanded = it }, + modifier = Modifier.weight(1f) ) { - DatabaseDialect.entries.forEach { option -> - DropdownMenuItem( - text = { Text(option.displayName) }, - onClick = { - dialect = option - port = option.defaultPort.toString() - dialectExpanded = false - } - ) + OutlinedTextField( + value = dialect.displayName, + onValueChange = {}, + readOnly = true, + label = { Text("Type") }, + trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, + modifier = Modifier.fillMaxWidth().menuAnchor(), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + ExposedDropdownMenu( + expanded = dialectExpanded, + onDismissRequest = { dialectExpanded = false } + ) { + DatabaseDialect.entries.forEach { option -> + DropdownMenuItem( + text = { Text(option.displayName, style = MaterialTheme.typography.bodySmall) }, + onClick = { + dialect = option + port = option.defaultPort.toString() + dialectExpanded = false + } + ) + } } } } @@ -139,22 +150,24 @@ fun DataSourceConfigPane( if (dialect != DatabaseDialect.SQLITE) { Row( modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.spacedBy(12.dp) + horizontalArrangement = Arrangement.spacedBy(8.dp) ) { OutlinedTextField( value = host, onValueChange = { host = it }, label = { Text("Host *") }, modifier = Modifier.weight(2f), - singleLine = true + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall ) OutlinedTextField( value = port, onValueChange = { port = it.filter { c -> c.isDigit() } }, - label = { Text("Port *") }, + label = { Text("Port") }, modifier = Modifier.weight(1f), singleLine = true, - keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number) + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number), + textStyle = MaterialTheme.typography.bodySmall ) } } @@ -165,77 +178,118 @@ fun DataSourceConfigPane( onValueChange = { database = it }, label = { Text(if (dialect == DatabaseDialect.SQLITE) "File Path *" else "Database *") }, modifier = Modifier.fillMaxWidth(), - singleLine = true + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall ) - // Username and Password (not for SQLite) + // Username and Password in one row (not for SQLite) if (dialect != DatabaseDialect.SQLITE) { - OutlinedTextField( - value = username, - onValueChange = { username = it }, - label = { Text("Username") }, + Row( modifier = Modifier.fillMaxWidth(), - singleLine = true - ) + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = username, + onValueChange = { username = it }, + label = { Text("Username") }, + modifier = Modifier.weight(1f), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + OutlinedTextField( + value = password, + onValueChange = { password = it }, + label = { Text("Password") }, + modifier = Modifier.weight(1f), + singleLine = true, + visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), + trailingIcon = { + IconButton(onClick = { showPassword = !showPassword }, modifier = Modifier.size(24.dp)) { + Icon( + if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, + contentDescription = if (showPassword) "Hide" else "Show", + modifier = Modifier.size(16.dp) + ) + } + }, + textStyle = MaterialTheme.typography.bodySmall + ) + } + } - OutlinedTextField( - value = password, - onValueChange = { password = it }, - label = { Text("Password") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true, - visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), - trailingIcon = { - IconButton(onClick = { showPassword = !showPassword }) { - Icon( - if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, - contentDescription = if (showPassword) "Hide" else "Show" - ) - } - } + Spacer(modifier = Modifier.height(4.dp)) + + // Advanced Settings - collapsible + Row( + modifier = Modifier + .fillMaxWidth() + .clickable { advancedExpanded = !advancedExpanded } + .padding(vertical = 4.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + if (advancedExpanded) AutoDevComposeIcons.ExpandLess else AutoDevComposeIcons.ExpandMore, + contentDescription = if (advancedExpanded) "Collapse" else "Expand", + modifier = Modifier.size(18.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "Advanced Settings", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant ) } - // Description - OutlinedTextField( - value = description, - onValueChange = { description = it }, - label = { Text("Description") }, - modifier = Modifier.fillMaxWidth(), - minLines = 2, - maxLines = 3 - ) + AnimatedVisibility( + visible = advancedExpanded, + enter = expandVertically(), + exit = shrinkVertically() + ) { + Column( + modifier = Modifier.padding(start = 4.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description") }, + modifier = Modifier.fillMaxWidth(), + minLines = 2, + maxLines = 3, + textStyle = MaterialTheme.typography.bodySmall + ) + } + } } HorizontalDivider() - // Action buttons - Surface( - modifier = Modifier.fillMaxWidth(), - color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + // Action buttons - more compact + Row( + modifier = Modifier + .fillMaxWidth() + .background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f)) + .padding(horizontal = 12.dp, vertical = 8.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp, Alignment.End), + verticalAlignment = Alignment.CenterVertically ) { - Row( - modifier = Modifier - .fillMaxWidth() - .padding(16.dp), - horizontalArrangement = Arrangement.spacedBy(8.dp, Alignment.End), - verticalAlignment = Alignment.CenterVertically + TextButton(onClick = onCancel, modifier = Modifier.height(36.dp)) { + Text("Cancel", style = MaterialTheme.typography.labelMedium) + } + OutlinedButton( + onClick = { onSave(buildConfig()) }, + enabled = isValid, + modifier = Modifier.height(36.dp) ) { - TextButton(onClick = onCancel) { - Text("Cancel") - } - OutlinedButton( - onClick = { onSave(buildConfig()) }, - enabled = isValid - ) { - Text("Save") - } - Button( - onClick = { onSaveAndConnect(buildConfig()) }, - enabled = isValid - ) { - Text("Save & Connect") - } + Text("Save", style = MaterialTheme.typography.labelMedium) + } + Button( + onClick = { onSaveAndConnect(buildConfig()) }, + enabled = isValid, + modifier = Modifier.height(36.dp) + ) { + Text("Save & Connect", style = MaterialTheme.typography.labelMedium) } } } diff --git a/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq b/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq new file mode 100644 index 0000000000..2f6f86ccc7 --- /dev/null +++ b/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS DataSource ( + id TEXT NOT NULL PRIMARY KEY, + name TEXT NOT NULL, + dialect TEXT NOT NULL, + host TEXT NOT NULL, + port INTEGER NOT NULL, + databaseName TEXT NOT NULL, + username TEXT NOT NULL, + password TEXT NOT NULL, + description TEXT NOT NULL, + isDefault INTEGER NOT NULL DEFAULT 0, + createdAt INTEGER NOT NULL, + updatedAt INTEGER NOT NULL +); + +selectById: +SELECT * FROM DataSource WHERE id = ?; + +selectAll: +SELECT * FROM DataSource ORDER BY updatedAt DESC; + +selectDefault: +SELECT * FROM DataSource WHERE isDefault = 1 LIMIT 1; + +insertOrReplace: +INSERT OR REPLACE INTO DataSource(id, name, dialect, host, port, databaseName, username, password, description, isDefault, createdAt, updatedAt) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + +deleteById: +DELETE FROM DataSource WHERE id = ?; + +deleteAll: +DELETE FROM DataSource; + +clearDefault: +UPDATE DataSource SET isDefault = 0 WHERE isDefault = 1; + +setDefault: +UPDATE DataSource SET isDefault = 1 WHERE id = ?; + diff --git a/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt b/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt new file mode 100644 index 0000000000..afd4df8fed --- /dev/null +++ b/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt @@ -0,0 +1,83 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - iOS implementation + */ +actual class DataSourceRepository(private val database: DevInsDatabase) { + private val queries = database.dataSourceQueries + + actual fun getAll(): List { + return queries.selectAll().executeAsList().map { it.toDataSourceConfig() } + } + + actual fun getById(id: String): DataSourceConfig? { + return queries.selectById(id).executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun getDefault(): DataSourceConfig? { + return queries.selectDefault().executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun save(config: DataSourceConfig) { + queries.insertOrReplace( + id = config.id, + name = config.name, + dialect = config.dialect.name, + host = config.host, + port = config.port.toLong(), + databaseName = config.database, + username = config.username, + password = config.password, + description = config.description, + isDefault = if (config.isDefault) 1L else 0L, + createdAt = config.createdAt, + updatedAt = config.updatedAt + ) + } + + actual fun delete(id: String) { + queries.deleteById(id) + } + + actual fun deleteAll() { + queries.deleteAll() + } + + actual fun setDefault(id: String) { + queries.clearDefault() + queries.setDefault(id) + } + + private fun DataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port.toInt(), + database = this.databaseName, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault == 1L, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: run { + val driverFactory = DatabaseDriverFactory() + val database = createDatabase(driverFactory) + DataSourceRepository(database).also { instance = it } + } + } + } +} + diff --git a/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt b/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt new file mode 100644 index 0000000000..5f4ca56cbe --- /dev/null +++ b/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt @@ -0,0 +1,47 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig + +/** + * DataSource Repository - JS implementation + * Currently provides stub implementation, can be extended with localStorage or IndexedDB + */ +actual class DataSourceRepository { + actual fun getAll(): List { + console.warn("DataSourceRepository not implemented for JS platform") + return emptyList() + } + + actual fun getById(id: String): DataSourceConfig? { + return null + } + + actual fun getDefault(): DataSourceConfig? { + return null + } + + actual fun save(config: DataSourceConfig) { + // No-op + } + + actual fun delete(id: String) { + // No-op + } + + actual fun deleteAll() { + // No-op + } + + actual fun setDefault(id: String) { + // No-op + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: DataSourceRepository().also { instance = it } + } + } +} + diff --git a/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt new file mode 100644 index 0000000000..b517ca6cd1 --- /dev/null +++ b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt @@ -0,0 +1,85 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - JVM implementation + */ +actual class DataSourceRepository(private val database: DevInsDatabase) { + private val queries = database.dataSourceQueries + + actual fun getAll(): List { + return queries.selectAll().executeAsList().map { it.toDataSourceConfig() } + } + + actual fun getById(id: String): DataSourceConfig? { + return queries.selectById(id).executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun getDefault(): DataSourceConfig? { + return queries.selectDefault().executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun save(config: DataSourceConfig) { + queries.insertOrReplace( + id = config.id, + name = config.name, + dialect = config.dialect.name, + host = config.host, + port = config.port.toLong(), + databaseName = config.database, + username = config.username, + password = config.password, + description = config.description, + isDefault = if (config.isDefault) 1L else 0L, + createdAt = config.createdAt, + updatedAt = config.updatedAt + ) + } + + actual fun delete(id: String) { + queries.deleteById(id) + } + + actual fun deleteAll() { + queries.deleteAll() + } + + actual fun setDefault(id: String) { + queries.clearDefault() + queries.setDefault(id) + } + + private fun DataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port.toInt(), + database = this.databaseName, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault == 1L, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: synchronized(this) { + instance ?: run { + val driverFactory = DatabaseDriverFactory() + val database = createDatabase(driverFactory) + DataSourceRepository(database).also { instance = it } + } + } + } + } +} + diff --git a/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt b/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt new file mode 100644 index 0000000000..2f007ebd38 --- /dev/null +++ b/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt @@ -0,0 +1,188 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect +import cc.unitmesh.devins.ui.platform.BrowserStorage +import cc.unitmesh.devins.ui.platform.console +import kotlinx.serialization.Serializable +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json + +/** + * DataSource Repository for WASM platform + * Uses browser localStorage to store data source configurations + */ +actual class DataSourceRepository { + private val json = Json { + prettyPrint = true + ignoreUnknownKeys = true + } + + @Serializable + private data class StoredDataSource( + val id: String, + val name: String, + val dialect: String, + val host: String, + val port: Int, + val database: String, + val username: String, + val password: String, + val description: String, + val isDefault: Boolean, + val createdAt: Long, + val updatedAt: Long + ) + + @Serializable + private data class DataSourceStorage( + val dataSources: List + ) + + actual fun getAll(): List { + return try { + val storage = loadStorage() + storage.dataSources.map { it.toDataSourceConfig() } + } catch (e: Exception) { + console.error("WASM: Error loading data sources: ${e.message}") + emptyList() + } + } + + actual fun getById(id: String): DataSourceConfig? { + return try { + val storage = loadStorage() + storage.dataSources.firstOrNull { it.id == id }?.toDataSourceConfig() + } catch (e: Exception) { + console.error("WASM: Error getting data source by id: ${e.message}") + null + } + } + + actual fun getDefault(): DataSourceConfig? { + return try { + val storage = loadStorage() + storage.dataSources.firstOrNull { it.isDefault }?.toDataSourceConfig() + } catch (e: Exception) { + console.error("WASM: Error getting default data source: ${e.message}") + null + } + } + + actual fun save(config: DataSourceConfig) { + try { + val storage = loadStorage() + val stored = config.toStoredDataSource() + val existing = storage.dataSources.indexOfFirst { it.id == config.id } + val updatedList = if (existing >= 0) { + storage.dataSources.toMutableList().apply { set(existing, stored) } + } else { + storage.dataSources + stored + } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Data source saved: ${config.id}") + } catch (e: Exception) { + console.error("WASM: Error saving data source: ${e.message}") + } + } + + actual fun delete(id: String) { + try { + val storage = loadStorage() + val updatedList = storage.dataSources.filter { it.id != id } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Data source deleted: $id") + } catch (e: Exception) { + console.error("WASM: Error deleting data source: ${e.message}") + } + } + + actual fun deleteAll() { + try { + saveStorage(DataSourceStorage(emptyList())) + console.log("WASM: All data sources deleted") + } catch (e: Exception) { + console.error("WASM: Error deleting all data sources: ${e.message}") + } + } + + actual fun setDefault(id: String) { + try { + val storage = loadStorage() + val now = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + val updatedList = storage.dataSources.map { + it.copy( + isDefault = it.id == id, + updatedAt = if (it.id == id) now else it.updatedAt + ) + } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Default data source set to: $id") + } catch (e: Exception) { + console.error("WASM: Error setting default data source: ${e.message}") + } + } + + private fun loadStorage(): DataSourceStorage { + val content = BrowserStorage.getItem(STORAGE_KEY) + return if (content != null) { + try { + json.decodeFromString(content) + } catch (e: Exception) { + console.warn("WASM: Failed to parse data source storage: ${e.message}") + DataSourceStorage(emptyList()) + } + } else { + DataSourceStorage(emptyList()) + } + } + + private fun saveStorage(storage: DataSourceStorage) { + val content = json.encodeToString(storage) + BrowserStorage.setItem(STORAGE_KEY, content) + } + + private fun StoredDataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port, + database = this.database, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + private fun DataSourceConfig.toStoredDataSource(): StoredDataSource { + return StoredDataSource( + id = this.id, + name = this.name, + dialect = this.dialect.name, + host = this.host, + port = this.port, + database = this.database, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private const val STORAGE_KEY = "autodev-datasources" + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: DataSourceRepository().also { instance = it } + } + } +} + From d7a8f04b486267af68e3ba23acbecb6ee4df50c7 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 13:22:56 +0800 Subject: [PATCH 21/34] refactor(chatdb): remove unused SchemaInfoDialog component #508 Deleted the SchemaInfoDialog composable as it is no longer used in ChatDBPage. --- .../ui/compose/agent/chatdb/ChatDBPage.kt | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt index 6cc633717b..ab2cd2b245 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -4,7 +4,6 @@ import androidx.compose.foundation.layout.* import androidx.compose.material3.* import androidx.compose.runtime.* import androidx.compose.ui.Modifier -import androidx.compose.ui.unit.dp import cc.unitmesh.devins.ui.base.ResizableSplitPane import cc.unitmesh.devins.ui.compose.agent.chatdb.components.* import cc.unitmesh.devins.workspace.Workspace @@ -101,46 +100,3 @@ fun ChatDBPage( ) } } - -@Composable -private fun SchemaInfoDialog( - schema: cc.unitmesh.agent.database.DatabaseSchema?, - onDismiss: () -> Unit -) { - AlertDialog( - onDismissRequest = onDismiss, - title = { Text("Database Schema") }, - text = { - if (schema != null) { - Column { - Text( - text = "${schema.tables.size} tables", - style = MaterialTheme.typography.labelMedium - ) - Spacer(modifier = Modifier.height(8.dp)) - schema.tables.take(10).forEach { table -> - Text( - text = "โ€ข ${table.name} (${table.columns.size} columns)", - style = MaterialTheme.typography.bodySmall - ) - } - if (schema.tables.size > 10) { - Text( - text = "... and ${schema.tables.size - 10} more", - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) - } - } - } else { - Text("No schema available") - } - }, - confirmButton = { - TextButton(onClick = onDismiss) { - Text("Close") - } - } - ) -} - From a9ad4f4a8cc773a7c6caf4746bb9806b1bfd919e Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 13:29:13 +0800 Subject: [PATCH 22/34] feat(chatdb): render successful query results as messages #508 Render successful query results as assistant messages to support proper markdown table rendering in the timeline. --- .../devins/ui/compose/agent/chatdb/ChatDBViewModel.kt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 7f271c41aa..456f4bcdb5 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -340,7 +340,13 @@ class ChatDBViewModel( println("[ChatDB] Progress: $progress") } - if (!result.success) { + // Render the result to the timeline + if (result.success) { + // Add the successful result as an assistant message to properly render markdown tables + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(result.content) + renderer.renderLLMResponseEnd() + } else { renderer.renderError("Query failed: ${result.content}") } From b8b3ce2e2e3179ece2a64922828e9380a622e361 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 15:32:26 +0800 Subject: [PATCH 23/34] feat(chatdb): render agent progress and results in timeline #508 Render all agent progress updates and final results directly to the chat timeline via the renderer, providing clearer feedback to users during query execution. Removes redundant result rendering from the view model. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 82 ++++++++++++++++--- .../compose/agent/chatdb/ChatDBViewModel.kt | 13 +-- .../agent/chatdb/components/ChatDBChatPane.kt | 4 +- 3 files changed, 73 insertions(+), 26 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 788d56d98d..81814026f5 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -70,12 +70,22 @@ class ChatDBAgentExecutor( try { // Step 1: Get database schema - onProgress("๐Ÿ“Š Fetching database schema...") + val step1Message = "๐Ÿ“Š Fetching database schema..." + onProgress(step1Message) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(step1Message) + renderer.renderLLMResponseEnd() + val schema = task.schema ?: databaseConnection.getSchema() logger.info { "Database has ${schema.tables.size} tables: ${schema.tables.map { it.name }}" } // Step 2: Schema Linking - onProgress("๐Ÿ”— Performing schema linking...") + val step2Message = "๐Ÿ”— Performing schema linking..." + onProgress(step2Message) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(step2Message) + renderer.renderLLMResponseEnd() + val linkingResult = schemaLinker.link(task.query, schema) logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables: ${linkingResult.relevantTables}" } logger.info { "Schema linking keywords: ${linkingResult.keywords}" } @@ -91,9 +101,14 @@ class ChatDBAgentExecutor( val relevantSchema = buildRelevantSchemaDescription(schema, effectiveLinkingResult) val initialMessage = buildInitialUserMessage(task, relevantSchema, effectiveLinkingResult) - + // Step 4: Generate SQL with LLM - onProgress("๐Ÿค– Generating SQL query...") + val step4Message = "๐Ÿค– Generating SQL query..." + onProgress(step4Message) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(step4Message + "\n\n") + renderer.renderLLMResponseEnd() + val llmResponse = StringBuilder() val response = getLLMResponse(initialMessage, compileDevIns = false) { chunk -> onProgress(chunk) @@ -123,7 +138,11 @@ class ChatDBAgentExecutor( if (!tableValidation.isValid) { val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" - onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") + val revisionMessage = "๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent..." + onProgress(revisionMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(revisionMessage + "\n\n") + renderer.renderLLMResponseEnd() val revisionInput = SqlRevisionInput( originalQuery = task.query, @@ -141,7 +160,11 @@ class ChatDBAgentExecutor( if (revisionResult.success) { validatedSql = revisionResult.content - onProgress("โœ… SQL revised successfully after $revisionAttempts attempts") + val successMessage = "โœ… SQL revised successfully after $revisionAttempts attempts" + onProgress(successMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(successMessage + "\n\n") + renderer.renderLLMResponseEnd() } else { errors.add("SQL revision failed: ${revisionResult.content}") } @@ -155,17 +178,30 @@ class ChatDBAgentExecutor( var lastExecutionError: String? = null while (executionRetries < maxExecutionRetries && queryResult == null) { - onProgress("โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") + val executionMessage = "โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}..." + onProgress(executionMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(executionMessage + "\n\n") + renderer.renderLLMResponseEnd() + try { queryResult = databaseConnection.executeQuery(generatedSql!!) - onProgress("โœ… Query returned ${queryResult.rowCount} rows") + val successMessage = "โœ… Query returned ${queryResult.rowCount} rows" + onProgress(successMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(successMessage + "\n\n") + renderer.renderLLMResponseEnd() } catch (e: Exception) { lastExecutionError = e.message ?: "Unknown execution error" logger.warn { "Query execution failed (attempt ${executionRetries + 1}): $lastExecutionError" } // Try to revise SQL based on execution error if (executionRetries < maxExecutionRetries - 1) { - onProgress("๐Ÿ”„ Execution failed, attempting to fix SQL...") + val retryMessage = "๐Ÿ”„ Execution failed, attempting to fix SQL..." + onProgress(retryMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(retryMessage + "\n\n") + renderer.renderLLMResponseEnd() val revisionInput = SqlRevisionInput( originalQuery = task.query, @@ -182,7 +218,11 @@ class ChatDBAgentExecutor( if (revisionResult.success && revisionResult.content != generatedSql) { generatedSql = revisionResult.content revisionAttempts++ - onProgress("๐Ÿ”ง SQL revised, retrying execution...") + val revisedMessage = "๐Ÿ”ง SQL revised, retrying execution..." + onProgress(revisedMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(revisedMessage + "\n\n") + renderer.renderLLMResponseEnd() } else { // Revision didn't help, break the loop break @@ -200,7 +240,12 @@ class ChatDBAgentExecutor( // Step 8: Generate visualization if requested if (task.generateVisualization && queryResult != null && !queryResult.isEmpty()) { - onProgress("๐Ÿ“ˆ Generating visualization...") + val vizMessage = "๐Ÿ“ˆ Generating visualization..." + onProgress(vizMessage) + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(vizMessage + "\n\n") + renderer.renderLLMResponseEnd() + plotDslCode = generateVisualization(task.query, queryResult, onProgress) } @@ -208,8 +253,8 @@ class ChatDBAgentExecutor( logger.error(e) { "ChatDB execution failed" } errors.add("Execution failed: ${e.message}") } - - return buildResult( + + val result = buildResult( success = errors.isEmpty() && queryResult != null, errors = errors, generatedSql = generatedSql, @@ -217,6 +262,17 @@ class ChatDBAgentExecutor( plotDslCode = plotDslCode, revisionAttempts = revisionAttempts ) + + // Render the final result to the timeline + if (result.success) { + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(result.message) + renderer.renderLLMResponseEnd() + } else { + renderer.renderError(result.message) + } + + return result } private fun resetExecution() { diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 456f4bcdb5..bd4e59b6ef 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -335,21 +335,12 @@ class ChatDBViewModel( generateVisualization = false ) - val result = agent.execute(task) { progress -> + // Execute the agent - it will render results to the timeline via the renderer + agent.execute(task) { progress -> // Progress callback - can be used for UI updates println("[ChatDB] Progress: $progress") } - // Render the result to the timeline - if (result.success) { - // Add the successful result as an assistant message to properly render markdown tables - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(result.content) - renderer.renderLLMResponseEnd() - } else { - renderer.renderError("Query failed: ${result.content}") - } - // Close the agent connection agent.close() diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt index 4bcfbf6492..dc2e22ec7b 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt @@ -54,8 +54,8 @@ fun ChatDBChatPane( modifier = Modifier.fillMaxSize() ) - // Welcome message when no messages - if (renderer.timeline.isEmpty()) { + // Welcome message when no messages and not streaming + if (renderer.timeline.isEmpty() && renderer.currentStreamingOutput.isEmpty() && !renderer.isProcessing) { WelcomeMessage( isConnected = connectionStatus is ConnectionStatus.Connected, schema = schema, From 9515c4971a9ec56f6ed02955d67893920c604ace Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 15:42:34 +0800 Subject: [PATCH 24/34] feat(chatdb): enhance LLM response rendering with details #508 Add detailed rendering for schema, linking, SQL generation, validation, execution, and errors in LLM responses for improved clarity and debugging. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 137 +++++++++++++++--- 1 file changed, 119 insertions(+), 18 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 81814026f5..1d8f1c7037 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -79,6 +79,16 @@ class ChatDBAgentExecutor( val schema = task.schema ?: databaseConnection.getSchema() logger.info { "Database has ${schema.tables.size} tables: ${schema.tables.map { it.name }}" } + // Render schema details + val schemaDetails = buildString { + append("**Database Schema:**\n") + append("- Total tables: ${schema.tables.size}\n") + append("- Tables: ${schema.tables.joinToString(", ") { it.name }}") + } + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(schemaDetails) + renderer.renderLLMResponseEnd() + // Step 2: Schema Linking val step2Message = "๐Ÿ”— Performing schema linking..." onProgress(step2Message) @@ -90,6 +100,16 @@ class ChatDBAgentExecutor( logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables: ${linkingResult.relevantTables}" } logger.info { "Schema linking keywords: ${linkingResult.keywords}" } + // Render schema linking details + val linkingDetails = buildString { + append("**Schema Linking Results:**\n") + append("- Relevant tables: ${linkingResult.relevantTables.joinToString(", ")}\n") + append("- Keywords: ${linkingResult.keywords.joinToString(", ")}") + } + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(linkingDetails) + renderer.renderLLMResponseEnd() + // Step 3: Build context with relevant schema // If schema linking found too few tables, use all tables to avoid missing important ones val effectiveLinkingResult = if (linkingResult.relevantTables.size < 2 && schema.tables.size <= 10) { @@ -122,6 +142,17 @@ class ChatDBAgentExecutor( return buildResult(false, errors, null, null, null, 0) } + // Render extracted SQL + val extractedSqlMessage = buildString { + append("**Generated SQL:**\n") + append("```sql\n") + append(generatedSql) + append("\n```") + } + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(extractedSqlMessage) + renderer.renderLLMResponseEnd() + // Step 6: Validate SQL syntax and table names using SqlReviseAgent var validatedSql = generatedSql @@ -138,10 +169,18 @@ class ChatDBAgentExecutor( if (!tableValidation.isValid) { val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" - val revisionMessage = "๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent..." - onProgress(revisionMessage) + val revisionMessage = buildString { + append("๐Ÿ”„ **SQL Validation Failed**\n\n") + append("**Error Type:** $errorType\n") + append("**Errors:**\n") + tableValidation.errors.forEach { error -> + append("- $error\n") + } + append("\nInvoking SqlReviseAgent to fix the SQL...") + } + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(revisionMessage + "\n\n") + renderer.renderLLMResponseChunk(revisionMessage) renderer.renderLLMResponseEnd() val revisionInput = SqlRevisionInput( @@ -160,10 +199,17 @@ class ChatDBAgentExecutor( if (revisionResult.success) { validatedSql = revisionResult.content - val successMessage = "โœ… SQL revised successfully after $revisionAttempts attempts" - onProgress(successMessage) + val successMessage = buildString { + append("โœ… **SQL Revised Successfully**\n\n") + append("**Attempts:** $revisionAttempts\n") + append("**Revised SQL:**\n") + append("```sql\n") + append(validatedSql) + append("\n```") + } + onProgress("โœ… SQL revised successfully after $revisionAttempts attempts") renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(successMessage + "\n\n") + renderer.renderLLMResponseChunk(successMessage) renderer.renderLLMResponseEnd() } else { errors.add("SQL revision failed: ${revisionResult.content}") @@ -186,21 +232,35 @@ class ChatDBAgentExecutor( try { queryResult = databaseConnection.executeQuery(generatedSql!!) - val successMessage = "โœ… Query returned ${queryResult.rowCount} rows" - onProgress(successMessage) + val successMessage = buildString { + append("โœ… **Query Executed Successfully**\n\n") + append("**Rows returned:** ${queryResult.rowCount}\n") + append("**Columns:** ${queryResult.columns.joinToString(", ")}") + } + onProgress("โœ… Query returned ${queryResult.rowCount} rows") renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(successMessage + "\n\n") + renderer.renderLLMResponseChunk(successMessage) renderer.renderLLMResponseEnd() } catch (e: Exception) { lastExecutionError = e.message ?: "Unknown execution error" logger.warn { "Query execution failed (attempt ${executionRetries + 1}): $lastExecutionError" } + // Render execution error details + val errorMessage = buildString { + append("โŒ **Query Execution Failed**\n\n") + append("**Attempt:** ${executionRetries + 1}/$maxExecutionRetries\n") + append("**Error:** $lastExecutionError") + } + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(errorMessage) + renderer.renderLLMResponseEnd() + // Try to revise SQL based on execution error if (executionRetries < maxExecutionRetries - 1) { - val retryMessage = "๐Ÿ”„ Execution failed, attempting to fix SQL..." + val retryMessage = "๐Ÿ”„ Attempting to fix SQL based on execution error..." onProgress(retryMessage) renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(retryMessage + "\n\n") + renderer.renderLLMResponseChunk(retryMessage) renderer.renderLLMResponseEnd() val revisionInput = SqlRevisionInput( @@ -218,10 +278,17 @@ class ChatDBAgentExecutor( if (revisionResult.success && revisionResult.content != generatedSql) { generatedSql = revisionResult.content revisionAttempts++ - val revisedMessage = "๐Ÿ”ง SQL revised, retrying execution..." - onProgress(revisedMessage) + val revisedMessage = buildString { + append("๐Ÿ”ง **SQL Revised Based on Execution Error**\n\n") + append("**Revised SQL:**\n") + append("```sql\n") + append(generatedSql) + append("\n```\n\n") + append("Retrying execution...") + } + onProgress("๐Ÿ”ง SQL revised, retrying execution...") renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(revisedMessage + "\n\n") + renderer.renderLLMResponseChunk(revisedMessage) renderer.renderLLMResponseEnd() } else { // Revision didn't help, break the loop @@ -431,22 +498,56 @@ class ChatDBAgentExecutor( ): ChatDBResult { val message = if (success) { buildString { - appendLine("Query executed successfully!") + appendLine("## โœ… Query Executed Successfully") + appendLine() + + // Show SQL that was executed + if (generatedSql != null) { + appendLine("**Executed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + appendLine() + } + + // Show revision info if applicable + if (revisionAttempts > 0) { + appendLine("*Note: SQL was revised $revisionAttempts time(s) to fix validation/execution errors*") + appendLine() + } + + // Show query results if (queryResult != null) { + appendLine("**Results** (${queryResult.rowCount} row${if (queryResult.rowCount != 1) "s" else ""}):") appendLine() - appendLine("**Results** (${queryResult.rowCount} rows):") appendLine(queryResult.toTableString()) } + + // Show visualization if generated if (plotDslCode != null) { appendLine() - appendLine("**Visualization**:") + appendLine("**Visualization:**") appendLine("```plotdsl") appendLine(plotDslCode) appendLine("```") } } } else { - "Query failed: ${errors.joinToString("; ")}" + buildString { + appendLine("## โŒ Query Failed") + appendLine() + appendLine("**Errors:**") + errors.forEach { error -> + appendLine("- $error") + } + if (generatedSql != null) { + appendLine() + appendLine("**Failed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + } + } } return ChatDBResult( From 40671d56e3ef528d6ea65cd82a4a28982bd9260a Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 16:11:25 +0800 Subject: [PATCH 25/34] feat(chatdb): add interactive execution step cards to UI #508 Introduce ChatDBStepItem and ChatDBStepCard for displaying database query execution steps with expandable details and status badges in the timeline. Refactor agent executor and renderer to support step-wise progress updates. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 311 ++++++++++++------ .../agent/render/CodingAgentRenderer.kt | 21 ++ .../unitmesh/agent/render/RendererModels.kt | 39 +++ .../ui/compose/agent/AgentMessageList.kt | 6 + .../ui/compose/agent/ComposeRenderer.kt | 45 +++ .../agent/chatdb/components/ChatDBStepCard.kt | 277 ++++++++++++++++ 6 files changed, 590 insertions(+), 109 deletions(-) create mode 100644 mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 1d8f1c7037..a808793eeb 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -5,6 +5,8 @@ import cc.unitmesh.agent.database.* import cc.unitmesh.agent.executor.BaseAgentExecutor import cc.unitmesh.agent.logging.getLogger import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.CodingAgentRenderer import cc.unitmesh.agent.subagent.SqlValidator import cc.unitmesh.agent.subagent.SqlReviseAgent @@ -70,45 +72,47 @@ class ChatDBAgentExecutor( try { // Step 1: Get database schema - val step1Message = "๐Ÿ“Š Fetching database schema..." - onProgress(step1Message) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(step1Message) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Fetching database schema..." + ) + onProgress("๐Ÿ“Š Fetching database schema...") val schema = task.schema ?: databaseConnection.getSchema() logger.info { "Database has ${schema.tables.size} tables: ${schema.tables.map { it.name }}" } - // Render schema details - val schemaDetails = buildString { - append("**Database Schema:**\n") - append("- Total tables: ${schema.tables.size}\n") - append("- Tables: ${schema.tables.joinToString(", ") { it.name }}") - } - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(schemaDetails) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.SUCCESS, + title = "Database schema fetched", + details = mapOf( + "totalTables" to schema.tables.size, + "tables" to schema.tables.map { it.name } + ) + ) // Step 2: Schema Linking - val step2Message = "๐Ÿ”— Performing schema linking..." - onProgress(step2Message) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(step2Message) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Performing schema linking..." + ) + onProgress("๐Ÿ”— Performing schema linking...") val linkingResult = schemaLinker.link(task.query, schema) logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables: ${linkingResult.relevantTables}" } logger.info { "Schema linking keywords: ${linkingResult.keywords}" } - // Render schema linking details - val linkingDetails = buildString { - append("**Schema Linking Results:**\n") - append("- Relevant tables: ${linkingResult.relevantTables.joinToString(", ")}\n") - append("- Keywords: ${linkingResult.keywords.joinToString(", ")}") - } - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(linkingDetails) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.SUCCESS, + title = "Schema linking completed", + details = mapOf( + "relevantTables" to linkingResult.relevantTables, + "keywords" to linkingResult.keywords + ) + ) // Step 3: Build context with relevant schema // If schema linking found too few tables, use all tables to avoid missing important ones @@ -123,11 +127,12 @@ class ChatDBAgentExecutor( val initialMessage = buildInitialUserMessage(task, relevantSchema, effectiveLinkingResult) // Step 4: Generate SQL with LLM - val step4Message = "๐Ÿค– Generating SQL query..." - onProgress(step4Message) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(step4Message + "\n\n") - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating SQL query..." + ) + onProgress("๐Ÿค– Generating SQL query...") val llmResponse = StringBuilder() val response = getLLMResponse(initialMessage, compileDevIns = false) { chunk -> @@ -139,21 +144,29 @@ class ChatDBAgentExecutor( generatedSql = extractSqlFromResponse(llmResponse.toString()) if (generatedSql == null) { errors.add("Failed to extract SQL from LLM response") + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Failed to extract SQL", + error = "Could not find SQL code block in LLM response" + ) return buildResult(false, errors, null, null, null, 0) } - // Render extracted SQL - val extractedSqlMessage = buildString { - append("**Generated SQL:**\n") - append("```sql\n") - append(generatedSql) - append("\n```") - } - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(extractedSqlMessage) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL query generated", + details = mapOf("sql" to generatedSql) + ) // Step 6: Validate SQL syntax and table names using SqlReviseAgent + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating SQL..." + ) + var validatedSql = generatedSql // Get all table names from schema for whitelist validation @@ -169,19 +182,25 @@ class ChatDBAgentExecutor( if (!tableValidation.isValid) { val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" - val revisionMessage = buildString { - append("๐Ÿ”„ **SQL Validation Failed**\n\n") - append("**Error Type:** $errorType\n") - append("**Errors:**\n") - tableValidation.errors.forEach { error -> - append("- $error\n") - } - append("\nInvoking SqlReviseAgent to fix the SQL...") - } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL validation failed", + details = mapOf( + "errorType" to errorType, + "errors" to tableValidation.errors + ), + error = tableValidation.errors.joinToString("; ") + ) + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(revisionMessage) - renderer.renderLLMResponseEnd() + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL..." + ) val revisionInput = SqlRevisionInput( originalQuery = task.query, @@ -199,21 +218,33 @@ class ChatDBAgentExecutor( if (revisionResult.success) { validatedSql = revisionResult.content - val successMessage = buildString { - append("โœ… **SQL Revised Successfully**\n\n") - append("**Attempts:** $revisionAttempts\n") - append("**Revised SQL:**\n") - append("```sql\n") - append(validatedSql) - append("\n```") - } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully", + details = mapOf( + "attempts" to revisionAttempts, + "sql" to validatedSql + ) + ) + onProgress("โœ… SQL revised successfully after $revisionAttempts attempts") - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(successMessage) - renderer.renderLLMResponseEnd() } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed", + error = revisionResult.content + ) errors.add("SQL revision failed: ${revisionResult.content}") } + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL validation passed" + ) } generatedSql = validatedSql @@ -224,44 +255,56 @@ class ChatDBAgentExecutor( var lastExecutionError: String? = null while (executionRetries < maxExecutionRetries && queryResult == null) { - val executionMessage = "โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}..." - onProgress(executionMessage) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(executionMessage + "\n\n") - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...", + details = mapOf( + "attempt" to (executionRetries + 1), + "sql" to generatedSql!! + ) + ) + + onProgress("โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") try { queryResult = databaseConnection.executeQuery(generatedSql!!) - val successMessage = buildString { - append("โœ… **Query Executed Successfully**\n\n") - append("**Rows returned:** ${queryResult.rowCount}\n") - append("**Columns:** ${queryResult.columns.joinToString(", ")}") - } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "Query executed successfully", + details = mapOf( + "rowCount" to queryResult.rowCount, + "columns" to queryResult.columns + ) + ) + onProgress("โœ… Query returned ${queryResult.rowCount} rows") - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(successMessage) - renderer.renderLLMResponseEnd() } catch (e: Exception) { lastExecutionError = e.message ?: "Unknown execution error" logger.warn { "Query execution failed (attempt ${executionRetries + 1}): $lastExecutionError" } - // Render execution error details - val errorMessage = buildString { - append("โŒ **Query Execution Failed**\n\n") - append("**Attempt:** ${executionRetries + 1}/$maxExecutionRetries\n") - append("**Error:** $lastExecutionError") - } - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(errorMessage) - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Query execution failed", + details = mapOf( + "attempt" to (executionRetries + 1), + "maxAttempts" to maxExecutionRetries + ), + error = lastExecutionError + ) // Try to revise SQL based on execution error if (executionRetries < maxExecutionRetries - 1) { - val retryMessage = "๐Ÿ”„ Attempting to fix SQL based on execution error..." - onProgress(retryMessage) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(retryMessage) - renderer.renderLLMResponseEnd() + onProgress("๐Ÿ”„ Attempting to fix SQL based on execution error...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on execution error..." + ) val revisionInput = SqlRevisionInput( originalQuery = task.query, @@ -278,19 +321,24 @@ class ChatDBAgentExecutor( if (revisionResult.success && revisionResult.content != generatedSql) { generatedSql = revisionResult.content revisionAttempts++ - val revisedMessage = buildString { - append("๐Ÿ”ง **SQL Revised Based on Execution Error**\n\n") - append("**Revised SQL:**\n") - append("```sql\n") - append(generatedSql) - append("\n```\n\n") - append("Retrying execution...") - } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised based on execution error", + details = mapOf( + "sql" to generatedSql + ) + ) + onProgress("๐Ÿ”ง SQL revised, retrying execution...") - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(revisedMessage) - renderer.renderLLMResponseEnd() } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.WARNING, + title = "SQL revision did not help", + error = "Revision did not produce a different SQL" + ) // Revision didn't help, break the loop break } @@ -307,13 +355,32 @@ class ChatDBAgentExecutor( // Step 8: Generate visualization if requested if (task.generateVisualization && queryResult != null && !queryResult.isEmpty()) { - val vizMessage = "๐Ÿ“ˆ Generating visualization..." - onProgress(vizMessage) - renderer.renderLLMResponseStart() - renderer.renderLLMResponseChunk(vizMessage + "\n\n") - renderer.renderLLMResponseEnd() + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating visualization..." + ) + + onProgress("๐Ÿ“ˆ Generating visualization...") plotDslCode = generateVisualization(task.query, queryResult, onProgress) + + if (plotDslCode != null) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.SUCCESS, + title = "Visualization generated", + details = mapOf( + "code" to plotDslCode + ) + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.WARNING, + title = "Visualization not generated" + ) + } } } catch (e: Exception) { @@ -332,10 +399,36 @@ class ChatDBAgentExecutor( // Render the final result to the timeline if (result.success) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.SUCCESS, + title = "Query completed successfully", + details = buildMap { + queryResult?.let { + put("rowCount", it.rowCount) + put("columns", it.columns) + } + if (revisionAttempts > 0) { + put("revisionAttempts", revisionAttempts) + } + plotDslCode?.let { put("visualization", it) } + } + ) + renderer.renderLLMResponseStart() renderer.renderLLMResponseChunk(result.message) renderer.renderLLMResponseEnd() } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.ERROR, + title = "Query failed", + details = buildMap { + generatedSql?.let { put("sql", it) } + }, + error = errors.joinToString("; ") + ) + renderer.renderError(result.message) } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt index 134e57876b..813f69e577 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt @@ -66,6 +66,27 @@ interface CodingAgentRenderer { // Default: no-op for renderers that don't display task progress } + /** + * Render a ChatDB execution step. + * This is an optional method primarily used by UI renderers (ComposeRenderer, JewelRenderer). + * Console renderers can ignore this or provide simple text output. + * + * @param stepType The type of step being executed + * @param status The current status of the step + * @param title The display title for the step (defaults to stepType.displayName) + * @param details Additional details about the step (e.g., table names, row counts, SQL) + * @param error Error message if the step failed + */ + fun renderChatDBStep( + stepType: ChatDBStepType, + status: ChatDBStepStatus, + title: String = stepType.displayName, + details: Map = emptyMap(), + error: String? = null + ) { + // Default: no-op for renderers that don't support ChatDB steps + } + /** * Render a compact plan summary bar. * Called when plan is created or updated to show progress in a compact format. diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt index f1bcd9e973..afd02a15e3 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt @@ -203,6 +203,20 @@ sealed class TimelineItem( override val id: String = generateId() ) : TimelineItem(timestamp, id) + /** + * ChatDB execution step item for displaying database query execution steps. + * Each step can be expanded/collapsed and shows detailed information. + */ + data class ChatDBStepItem( + val stepType: ChatDBStepType, + val status: ChatDBStepStatus, + val title: String, + val details: Map = emptyMap(), + val error: String? = null, + override val timestamp: Long = Platform.getCurrentTimestamp(), + override val id: String = generateId() + ) : TimelineItem(timestamp, id) + companion object { /** * Thread-safe counter for generating unique IDs. @@ -218,3 +232,28 @@ sealed class TimelineItem( } } +/** + * ChatDB execution step types + */ +enum class ChatDBStepType(val displayName: String, val icon: String) { + FETCH_SCHEMA("Fetch Database Schema", "๐Ÿ“Š"), + SCHEMA_LINKING("Schema Linking", "๐Ÿ”—"), + GENERATE_SQL("Generate SQL Query", "๐Ÿค–"), + VALIDATE_SQL("Validate SQL", "โœ“"), + REVISE_SQL("Revise SQL", "๐Ÿ”„"), + EXECUTE_SQL("Execute SQL Query", "โšก"), + GENERATE_VISUALIZATION("Generate Visualization", "๐Ÿ“ˆ"), + FINAL_RESULT("Query Result", "โœ…") +} + +/** + * ChatDB execution step status + */ +enum class ChatDBStepStatus(val displayName: String) { + PENDING("Pending"), + IN_PROGRESS("In Progress"), + SUCCESS("Success"), + WARNING("Warning"), + ERROR("Error") +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt index b157323522..e1d872d857 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt @@ -236,6 +236,12 @@ fun RenderMessageItem( metadata = timelineItem.metadata ) } + + is TimelineItem.ChatDBStepItem -> { + cc.unitmesh.devins.ui.compose.agent.chatdb.components.ChatDBStepCard( + step = timelineItem + ) + } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt index ea298f844b..e4fe9ce970 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt @@ -4,6 +4,8 @@ import androidx.compose.runtime.* import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.RendererUtils import cc.unitmesh.agent.render.TaskInfo import cc.unitmesh.agent.render.TaskStatus @@ -782,6 +784,39 @@ class ComposeRenderer : BaseRenderer() { ) } + /** + * Render a ChatDB execution step. + * Adds or updates a step in the timeline for interactive display. + */ + override fun renderChatDBStep( + stepType: ChatDBStepType, + status: ChatDBStepStatus, + title: String, + details: Map, + error: String? + ) { + // Check if this step already exists in the timeline + val existingIndex = _timeline.indexOfLast { + it is ChatDBStepItem && it.stepType == stepType + } + + val stepItem = ChatDBStepItem( + stepType = stepType, + status = status, + title = title, + details = details, + error = error + ) + + if (existingIndex >= 0) { + // Update existing step + _timeline[existingIndex] = stepItem + } else { + // Add new step + _timeline.add(stepItem) + } + } + /** * Render an Agent-generated sketch block (chart, nanodsl, mermaid, etc.) * Adds the sketch block to the timeline for interactive rendering. @@ -884,6 +919,11 @@ class ComposeRenderer : BaseRenderer() { sketchCode = item.code ) } + + is ChatDBStepItem -> { + // ChatDB steps are not persisted (they're runtime-only for UI display) + null + } } } @@ -1109,6 +1149,11 @@ class ComposeRenderer : BaseRenderer() { metadata = toMessageMetadata(item) ) } + + is ChatDBStepItem -> { + // ChatDB steps are not persisted as messages + null + } } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt new file mode 100644 index 0000000000..25ac49a80a --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -0,0 +1,277 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.shrinkVertically +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.KeyboardArrowDown +import androidx.compose.material.icons.filled.KeyboardArrowRight +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType +import cc.unitmesh.agent.render.TimelineItem + +/** + * Composable for rendering a ChatDB execution step card. + * Displays step status, title, and expandable details. + */ +@Composable +fun ChatDBStepCard( + step: TimelineItem.ChatDBStepItem, + modifier: Modifier = Modifier +) { + var isExpanded by remember { mutableStateOf(step.status == ChatDBStepStatus.ERROR) } + + Card( + modifier = modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(12.dp) + ) { + // Header row with icon, title, and status + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Expand/collapse icon + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown else Icons.Default.KeyboardArrowRight, + contentDescription = if (isExpanded) "Collapse" else "Expand", + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.width(8.dp)) + + // Step icon + Text( + text = step.stepType.icon, + fontSize = 16.sp, + modifier = Modifier.padding(end = 8.dp) + ) + + // Step title + Text( + text = step.title, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + } + + // Status indicator + StepStatusBadge(status = step.status) + } + + // Expandable details section + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 12.dp) + ) { + // Error message if present + step.error?.let { error -> + Text( + text = "Error: $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 8.dp) + ) + } + + // Details + if (step.details.isNotEmpty()) { + StepDetails(details = step.details, stepType = step.stepType) + } + } + } + } + } +} + +@Composable +private fun StepStatusBadge(status: ChatDBStepStatus) { + val (color, text) = when (status) { + ChatDBStepStatus.PENDING -> MaterialTheme.colorScheme.outline to "Pending" + ChatDBStepStatus.IN_PROGRESS -> MaterialTheme.colorScheme.primary to "In Progress" + ChatDBStepStatus.SUCCESS -> MaterialTheme.colorScheme.tertiary to "Success" + ChatDBStepStatus.WARNING -> MaterialTheme.colorScheme.secondary to "Warning" + ChatDBStepStatus.ERROR -> MaterialTheme.colorScheme.error to "Error" + } + + Surface( + shape = RoundedCornerShape(12.dp), + color = color.copy(alpha = 0.2f), + modifier = Modifier.padding(start = 8.dp) + ) { + Text( + text = text, + style = MaterialTheme.typography.labelSmall, + color = color, + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp) + ) + } +} + +@Composable +private fun StepDetails(details: Map, stepType: ChatDBStepType) { + Column( + modifier = Modifier + .fillMaxWidth() + .clip(RoundedCornerShape(4.dp)) + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.5f)) + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + when (stepType) { + ChatDBStepType.FETCH_SCHEMA -> { + details["totalTables"]?.let { + DetailRow("Total Tables", it.toString()) + } + details["tables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Tables", tables.joinToString(", ")) + } + } + } + ChatDBStepType.SCHEMA_LINKING -> { + details["relevantTables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Relevant Tables", tables.joinToString(", ")) + } + } + details["keywords"]?.let { keywords -> + if (keywords is List<*>) { + DetailRow("Keywords", keywords.joinToString(", ")) + } + } + } + ChatDBStepType.GENERATE_SQL, ChatDBStepType.REVISE_SQL -> { + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + } + ChatDBStepType.VALIDATE_SQL -> { + details["errorType"]?.let { + DetailRow("Error Type", it.toString()) + } + details["errors"]?.let { errors -> + if (errors is List<*>) { + Column { + Text( + text = "Errors:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(start = 8.dp) + ) + } + } + } + } + } + ChatDBStepType.EXECUTE_SQL -> { + details["rowCount"]?.let { + DetailRow("Rows Returned", it.toString()) + } + details["columns"]?.let { columns -> + if (columns is List<*>) { + DetailRow("Columns", columns.joinToString(", ")) + } + } + } + else -> { + // Generic detail rendering + details.forEach { (key, value) -> + DetailRow(key, value.toString()) + } + } + } + } +} + +@Composable +private fun DetailRow(label: String, value: String) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween + ) { + Text( + text = "$label:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(end = 8.dp) + ) + Text( + text = value, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface + ) + } +} + +@Composable +private fun CodeBlock(code: String, language: String) { + Column(modifier = Modifier.fillMaxWidth()) { + Text( + text = language.uppercase(), + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.primary, + modifier = Modifier.padding(bottom = 4.dp) + ) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) { + Text( + text = code, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(12.dp) + ) + } + } +} + From e60a1ed10a1235e97c0ee54fc1471bdc0872788c Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 16:11:36 +0800 Subject: [PATCH 26/34] feat(chatdb): add new session button to chat pane Introduce a "New Session" button that clears the current chat timeline. Also include generated SQL in query result details. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 1 + .../ui/compose/agent/chatdb/ChatDBPage.kt | 1 + .../compose/agent/chatdb/ChatDBViewModel.kt | 10 ++++++++ .../agent/chatdb/components/ChatDBChatPane.kt | 23 +++++++++++++++++-- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index a808793eeb..9368052328 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -404,6 +404,7 @@ class ChatDBAgentExecutor( status = ChatDBStepStatus.SUCCESS, title = "Query completed successfully", details = buildMap { + generatedSql?.let { put("sql", it) } queryResult?.let { put("rowCount", it.rowCount) put("columns", it.columns) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt index ab2cd2b245..a69f734154 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -93,6 +93,7 @@ fun ChatDBPage( isGenerating = viewModel.isGenerating, onSendMessage = viewModel::sendMessage, onStopGeneration = viewModel::stopGeneration, + onNewSession = viewModel::newSession, modifier = Modifier.fillMaxSize() ) } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index bd4e59b6ef..68a364bf17 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -361,6 +361,16 @@ class ChatDBViewModel( isGenerating = false } + /** + * Create a new session - clears the current chat timeline + */ + fun newSession() { + if (isGenerating) { + stopGeneration() + } + renderer.clearMessages() + } + fun getSchema(): DatabaseSchema? = currentSchema fun dispose() { diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt index dc2e22ec7b..38248994c5 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt @@ -38,13 +38,16 @@ fun ChatDBChatPane( isGenerating: Boolean, onSendMessage: (String) -> Unit, onStopGeneration: () -> Unit, + onNewSession: () -> Unit = {}, modifier: Modifier = Modifier ) { Column(modifier = modifier.fillMaxSize()) { // Connection status banner ConnectionStatusBanner( connectionStatus = connectionStatus, - schema = schema + schema = schema, + onNewSession = onNewSession, + hasMessages = renderer.timeline.isNotEmpty() ) // Message list @@ -80,7 +83,9 @@ fun ChatDBChatPane( @Composable private fun ConnectionStatusBanner( connectionStatus: ConnectionStatus, - schema: DatabaseSchema? + schema: DatabaseSchema?, + onNewSession: () -> Unit, + hasMessages: Boolean ) { when (connectionStatus) { is ConnectionStatus.Connected -> { @@ -117,6 +122,20 @@ private fun ConnectionStatusBanner( ) } } + + // New session button (only show if there are messages) + if (hasMessages) { + FilledTonalIconButton( + onClick = onNewSession, + modifier = Modifier.size(36.dp) + ) { + Icon( + AutoDevComposeIcons.Add, + contentDescription = "New Session", + modifier = Modifier.size(20.dp) + ) + } + } } } } From 5f74492a93564c48162d39dbaa9663f847a2ef59 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 16:22:55 +0800 Subject: [PATCH 27/34] feat(chatdb): add table schema and query preview UI #508 Display detailed table schemas and query result previews in ChatDB steps, including expandable table cards, column info, keyword chips, and data preview tables. --- .../agent/chatdb/ChatDBAgentExecutor.kt | 68 ++- .../agent/chatdb/components/ChatDBStepCard.kt | 442 +++++++++++++++++- 2 files changed, 494 insertions(+), 16 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt index 9368052328..4477f5cd63 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -87,8 +87,10 @@ class ChatDBAgentExecutor( status = ChatDBStepStatus.SUCCESS, title = "Database schema fetched", details = mapOf( + "databaseName" to (schema.databaseName ?: "Database"), "totalTables" to schema.tables.size, - "tables" to schema.tables.map { it.name } + "tables" to schema.tables.map { it.name }, + "tableSchemas" to schemaToTableInfoList(schema) ) ) @@ -110,7 +112,8 @@ class ChatDBAgentExecutor( title = "Schema linking completed", details = mapOf( "relevantTables" to linkingResult.relevantTables, - "keywords" to linkingResult.keywords + "keywords" to linkingResult.keywords, + "relevantTableSchemas" to getTableInfoForNames(schema, linkingResult.relevantTables) ) ) @@ -275,8 +278,10 @@ class ChatDBAgentExecutor( status = ChatDBStepStatus.SUCCESS, title = "Query executed successfully", details = mapOf( + "sql" to generatedSql!!, "rowCount" to queryResult.rowCount, - "columns" to queryResult.columns + "columns" to queryResult.columns, + "previewRows" to getPreviewRows(queryResult, 5) ) ) @@ -408,6 +413,7 @@ class ChatDBAgentExecutor( queryResult?.let { put("rowCount", it.rowCount) put("columns", it.columns) + put("previewRows", getPreviewRows(it, 10)) } if (revisionAttempts > 0) { put("revisionAttempts", revisionAttempts) @@ -658,5 +664,59 @@ class ChatDBAgentExecutor( override fun buildContinuationMessage(): String { return "Please continue with the database query based on the results above." } -} + /** + * Convert DatabaseSchema to a list of table info maps for UI rendering + */ + private fun schemaToTableInfoList(schema: DatabaseSchema): List> { + return schema.tables.map { table -> + mapOf( + "name" to table.name, + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: ""), + "defaultValue" to (col.defaultValue ?: "") + ) + } + ) + } + } + + /** + * Get table info maps for specific table names + */ + private fun getTableInfoForNames(schema: DatabaseSchema, tableNames: List): List> { + return tableNames.mapNotNull { tableName -> + schema.getTable(tableName)?.let { table -> + mapOf( + "name" to table.name, + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: ""), + "defaultValue" to (col.defaultValue ?: "") + ) + } + ) + } + } + } + + /** + * Get preview rows from query result (limited to first N rows) + */ + private fun getPreviewRows(queryResult: QueryResult, maxRows: Int = 5): List> { + return queryResult.rows.take(maxRows) + } +} diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt index 25ac49a80a..c5ffdcfd3a 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -6,8 +6,11 @@ import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeOut import androidx.compose.animation.shrinkVertically import androidx.compose.foundation.background +import androidx.compose.foundation.border import androidx.compose.foundation.clickable +import androidx.compose.foundation.horizontalScroll import androidx.compose.foundation.layout.* +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.KeyboardArrowDown @@ -19,6 +22,7 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip import androidx.compose.ui.text.font.FontFamily import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import cc.unitmesh.agent.render.ChatDBStepStatus @@ -160,32 +164,88 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { ) { when (stepType) { ChatDBStepType.FETCH_SCHEMA -> { + // Database overview with table cards + details["databaseName"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } details["totalTables"]?.let { DetailRow("Total Tables", it.toString()) } - details["tables"]?.let { tables -> - if (tables is List<*>) { - DetailRow("Tables", tables.joinToString(", ")) + + // Show table schema cards + @Suppress("UNCHECKED_CAST") + val tableSchemas = details["tableSchemas"] as? List> + if (tableSchemas != null && tableSchemas.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Tables", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + tableSchemas.forEach { tableInfo -> + TableSchemaCard(tableInfo) + } + } else { + // Fallback to simple table list + details["tables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Tables", tables.joinToString(", ")) + } } } } + ChatDBStepType.SCHEMA_LINKING -> { - details["relevantTables"]?.let { tables -> - if (tables is List<*>) { - DetailRow("Relevant Tables", tables.joinToString(", ")) + // Show relevant tables with their columns + details["keywords"]?.let { keywords -> + if (keywords is List<*> && keywords.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Keywords:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + keywords.forEach { keyword -> + KeywordChip(keyword.toString()) + } + } } } - details["keywords"]?.let { keywords -> - if (keywords is List<*>) { - DetailRow("Keywords", keywords.joinToString(", ")) + + // Show relevant table schemas + @Suppress("UNCHECKED_CAST") + val relevantTableSchemas = details["relevantTableSchemas"] as? List> + if (relevantTableSchemas != null && relevantTableSchemas.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Relevant Tables", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + relevantTableSchemas.forEach { tableInfo -> + TableSchemaCard(tableInfo, highlightRelevant = true) + } + } else { + // Fallback + details["relevantTables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Relevant Tables", tables.joinToString(", ")) + } } } } + ChatDBStepType.GENERATE_SQL, ChatDBStepType.REVISE_SQL -> { details["sql"]?.let { sql -> CodeBlock(code = sql.toString(), language = "sql") } } + ChatDBStepType.VALIDATE_SQL -> { details["errorType"]?.let { DetailRow("Error Type", it.toString()) @@ -210,16 +270,82 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } } } + ChatDBStepType.EXECUTE_SQL -> { + // Show SQL that was executed + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show result summary details["rowCount"]?.let { DetailRow("Rows Returned", it.toString()) } - details["columns"]?.let { columns -> - if (columns is List<*>) { - DetailRow("Columns", columns.joinToString(", ")) + + // Show data preview table + @Suppress("UNCHECKED_CAST") + val columns = details["columns"] as? List + @Suppress("UNCHECKED_CAST") + val previewRows = details["previewRows"] as? List> + + if (columns != null && previewRows != null && previewRows.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Data Preview", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + DataPreviewTable(columns = columns, rows = previewRows) + + val totalRows = (details["rowCount"] as? Int) ?: previewRows.size + if (totalRows > previewRows.size) { + Text( + text = "Showing ${previewRows.size} of $totalRows rows", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(top = 4.dp) + ) + } + } else if (columns != null) { + DetailRow("Columns", columns.joinToString(", ")) + } + } + + ChatDBStepType.FINAL_RESULT -> { + // Show final SQL + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show result summary + details["rowCount"]?.let { + DetailRow("Total Rows", it.toString()) + } + details["revisionAttempts"]?.let { attempts -> + if ((attempts as? Int ?: 0) > 0) { + DetailRow("Revision Attempts", attempts.toString()) } } + + // Show data preview + @Suppress("UNCHECKED_CAST") + val columns = details["columns"] as? List + @Suppress("UNCHECKED_CAST") + val previewRows = details["previewRows"] as? List> + + if (columns != null && previewRows != null && previewRows.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Query Results", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + DataPreviewTable(columns = columns, rows = previewRows) + } } + else -> { // Generic detail rendering details.forEach { (key, value) -> @@ -275,3 +401,295 @@ private fun CodeBlock(code: String, language: String) { } } +/** + * Card displaying table schema information + */ +@Composable +private fun TableSchemaCard( + tableInfo: Map, + highlightRelevant: Boolean = false +) { + var isExpanded by remember { mutableStateOf(highlightRelevant) } + val tableName = tableInfo["name"]?.toString() ?: "Unknown" + val comment = tableInfo["comment"]?.toString() + + @Suppress("UNCHECKED_CAST") + val columns = tableInfo["columns"] as? List> ?: emptyList() + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(6.dp), + color = if (highlightRelevant) + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + else + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(8.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown + else Icons.Default.KeyboardArrowRight, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = "๐Ÿ“‹", + fontSize = 12.sp + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = tableName, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + } + Text( + text = "${columns.size} columns", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + comment?.let { + Text( + text = it, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(start = 24.dp, top = 2.dp) + ) + } + + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + if (columns.isNotEmpty()) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp, start = 8.dp) + ) { + columns.forEach { col -> + ColumnInfoRow(col) + } + } + } + } + } + } +} + +/** + * Row displaying column information + */ +@Composable +private fun ColumnInfoRow(columnInfo: Map) { + val name = columnInfo["name"]?.toString() ?: "" + val type = columnInfo["type"]?.toString() ?: "" + val isPrimaryKey = columnInfo["isPrimaryKey"] as? Boolean ?: false + val isForeignKey = columnInfo["isForeignKey"] as? Boolean ?: false + val nullable = columnInfo["nullable"] as? Boolean ?: true + + Row( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 2.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Column name with key indicators + if (isPrimaryKey) { + Text(text = "๐Ÿ”‘", fontSize = 10.sp) + Spacer(modifier = Modifier.width(2.dp)) + } else if (isForeignKey) { + Text(text = "๐Ÿ”—", fontSize = 10.sp) + Spacer(modifier = Modifier.width(2.dp)) + } + + Text( + text = name, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + fontWeight = if (isPrimaryKey) FontWeight.Bold else FontWeight.Normal, + color = if (isPrimaryKey) MaterialTheme.colorScheme.primary + else MaterialTheme.colorScheme.onSurface + ) + } + + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + verticalAlignment = Alignment.CenterVertically + ) { + // Type badge + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.5f) + ) { + Text( + text = type, + style = MaterialTheme.typography.labelSmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 10.sp + ), + color = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier.padding(horizontal = 4.dp, vertical = 1.dp) + ) + } + + // Nullable indicator + if (!nullable) { + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + ) { + Text( + text = "NOT NULL", + style = MaterialTheme.typography.labelSmall.copy(fontSize = 8.sp), + color = MaterialTheme.colorScheme.onErrorContainer, + modifier = Modifier.padding(horizontal = 3.dp, vertical = 1.dp) + ) + } + } + } + } +} + +/** + * Keyword chip for schema linking + */ +@Composable +private fun KeywordChip(keyword: String) { + Surface( + shape = RoundedCornerShape(12.dp), + color = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.7f) + ) { + Text( + text = keyword, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp) + ) + } +} + +/** + * Data preview table for query results + */ +@Composable +private fun DataPreviewTable( + columns: List, + rows: List>, + maxDisplayRows: Int = 10 +) { + val displayRows = rows.take(maxDisplayRows) + val borderColor = MaterialTheme.colorScheme.outlineVariant + + Column( + modifier = Modifier + .fillMaxWidth() + .horizontalScroll(rememberScrollState()) + ) { + Surface( + shape = RoundedCornerShape(6.dp), + color = MaterialTheme.colorScheme.surface, + modifier = Modifier.border( + width = 1.dp, + color = borderColor, + shape = RoundedCornerShape(6.dp) + ) + ) { + Column { + // Header row + Row( + modifier = Modifier + .background(MaterialTheme.colorScheme.surfaceVariant) + .padding(1.dp) + ) { + columns.forEach { column -> + Box( + modifier = Modifier + .widthIn(min = 80.dp, max = 200.dp) + .border( + width = 0.5.dp, + color = borderColor + ) + .padding(horizontal = 8.dp, vertical = 6.dp) + ) { + Text( + text = column, + style = MaterialTheme.typography.labelSmall.copy( + fontFamily = FontFamily.Monospace + ), + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + } + + // Data rows + displayRows.forEachIndexed { rowIndex, row -> + Row( + modifier = Modifier + .background( + if (rowIndex % 2 == 0) + MaterialTheme.colorScheme.surface + else + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) + .padding(1.dp) + ) { + row.forEachIndexed { colIndex, cell -> + Box( + modifier = Modifier + .widthIn(min = 80.dp, max = 200.dp) + .border( + width = 0.5.dp, + color = borderColor.copy(alpha = 0.5f) + ) + .padding(horizontal = 8.dp, vertical = 4.dp) + ) { + Text( + text = cell.ifEmpty { "NULL" }, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 11.sp + ), + color = if (cell.isEmpty()) + MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f) + else + MaterialTheme.colorScheme.onSurface, + maxLines = 2, + overflow = TextOverflow.Ellipsis + ) + } + } + } + } + } + } + } +} From 755db21d6ea122e10fed7106ad726e588ba916da Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 17:09:01 +0800 Subject: [PATCH 28/34] feat(chatdb): add multi-database chat agent support #508 Introduce MultiDatabaseChatDBAgent and executor, update UI and models to handle multiple database sources in chatdb agent workflow. --- .../agent/chatdb/MultiDatabaseChatDBAgent.kt | 191 ++++++++ .../chatdb/MultiDatabaseChatDBExecutor.kt | 417 ++++++++++++++++++ .../agent/render/CodingAgentRenderer.kt | 9 + .../unitmesh/agent/render/RendererModels.kt | 10 + .../ui/compose/agent/AgentMessageList.kt | 37 ++ .../ui/compose/agent/ComposeRenderer.kt | 14 + .../ui/compose/agent/chatdb/ChatDBPage.kt | 20 +- .../compose/agent/chatdb/ChatDBViewModel.kt | 227 ++++++++-- .../agent/chatdb/components/ChatDBChatPane.kt | 91 +++- .../agent/chatdb/components/ChatDBStepCard.kt | 258 ++++++++++- .../chatdb/components/DataSourcePanel.kt | 265 +++++++---- .../agent/chatdb/model/DataSourceModels.kt | 41 +- 12 files changed, 1411 insertions(+), 169 deletions(-) create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt create mode 100644 mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt new file mode 100644 index 0000000000..76c4abe7f2 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt @@ -0,0 +1,191 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.core.MainAgent +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.policy.DefaultPolicyEngine +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.render.DefaultCodingAgentRenderer +import cc.unitmesh.agent.tool.shell.DefaultShellExecutor +import cc.unitmesh.agent.tool.shell.ShellExecutor +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.agent.tool.filesystem.ToolFileSystem +import cc.unitmesh.agent.tool.registry.ToolRegistry +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig + +/** + * Multi-Database ChatDB Agent - Text2SQL Agent supporting multiple database connections + * + * This agent converts natural language queries to SQL across multiple databases. + * It merges schemas from all connected databases and lets the LLM decide which + * database(s) to query based on the user's question. + * + * Features: + * - Multi-Database Schema Linking: Merges schemas from all databases with database prefixes + * - Intelligent Database Selection: LLM determines which database to query + * - Parallel Execution: Can execute queries on multiple databases simultaneously + * - Unified Results: Combines results from multiple databases + */ +class MultiDatabaseChatDBAgent( + private val projectPath: String, + private val llmService: KoogLLMService, + private val databaseConfigs: Map, + override val maxIterations: Int = 10, + private val renderer: CodingAgentRenderer = DefaultCodingAgentRenderer(), + private val fileSystem: ToolFileSystem? = null, + private val shellExecutor: ShellExecutor? = null, + private val mcpToolConfigService: McpToolConfigService, + private val enableLLMStreaming: Boolean = true +) : MainAgent( + AgentDefinition( + name = "MultiDatabaseChatDBAgent", + displayName = "Multi-Database ChatDB Agent", + description = "Text2SQL Agent that queries across multiple databases with intelligent database selection", + promptConfig = PromptConfig( + systemPrompt = MULTI_DB_SYSTEM_PROMPT + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 10, maxTimeMinutes = 5) + ) +) { + private val logger = getLogger("MultiDatabaseChatDBAgent") + + private val actualFileSystem = fileSystem ?: DefaultToolFileSystem(projectPath = projectPath) + + private val toolRegistry = ToolRegistry( + fileSystem = actualFileSystem, + shellExecutor = shellExecutor ?: DefaultShellExecutor(), + configService = mcpToolConfigService, + llmService = llmService + ) + + private val policyEngine = DefaultPolicyEngine() + + private val toolOrchestrator = ToolOrchestrator( + registry = toolRegistry, + policyEngine = policyEngine, + renderer = renderer, + mcpConfigService = mcpToolConfigService + ) + + // Database connections keyed by database name/id + private val databaseConnections: MutableMap = mutableMapOf() + + private val executor: MultiDatabaseChatDBExecutor by lazy { + // Create connections for all configured databases + databaseConfigs.forEach { (id, config) -> + if (!databaseConnections.containsKey(id)) { + databaseConnections[id] = createDatabaseConnection(config) + } + } + + MultiDatabaseChatDBExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + databaseConnections = databaseConnections, + databaseConfigs = databaseConfigs, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming + ) + } + + override fun validateInput(input: Map): ChatDBTask { + val query = input["query"] as? String + ?: throw IllegalArgumentException("Missing required parameter: query") + + return ChatDBTask( + query = query, + additionalContext = input["additionalContext"] as? String ?: "", + maxRows = (input["maxRows"] as? Number)?.toInt() ?: 100, + generateVisualization = input["generateVisualization"] as? Boolean ?: true + ) + } + + override suspend fun execute( + input: ChatDBTask, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + logger.info { "Starting Multi-Database ChatDB Agent for query: ${input.query}" } + logger.info { "Connected databases: ${databaseConfigs.keys}" } + + val systemPrompt = buildSystemPrompt() + val result = executor.execute(input, systemPrompt, onProgress) + + return ToolResult.AgentResult( + success = result.success, + content = result.message, + metadata = mapOf( + "generatedSql" to (result.generatedSql ?: ""), + "rowCount" to (result.queryResult?.rowCount?.toString() ?: "0"), + "revisionAttempts" to result.revisionAttempts.toString(), + "hasVisualization" to (result.plotDslCode != null).toString(), + "targetDatabases" to (result.targetDatabases?.joinToString(",") ?: "") + ) + ) + } + + private fun buildSystemPrompt(): String { + return MULTI_DB_SYSTEM_PROMPT + } + + override fun formatOutput(output: ToolResult.AgentResult): String { + return output.content + } + + override fun getParameterClass(): String = "ChatDBTask" + + /** + * Close all database connections when done + */ + suspend fun close() { + databaseConnections.values.forEach { it.close() } + databaseConnections.clear() + } + + companion object { + const val MULTI_DB_SYSTEM_PROMPT = """You are an expert SQL developer working with MULTIPLE databases. + +IMPORTANT: You are connected to multiple databases. Each table in the schema is prefixed with its database name. +Format: [database_name].table_name + +CRITICAL RULES: +1. ONLY use table names provided in the schema - NEVER invent or guess table names +2. ONLY use column names provided in the schema - NEVER invent or guess column names +3. When generating SQL, use the table name WITHOUT the database prefix (the system will route to the correct database) +4. If the user's question relates to tables in multiple databases, generate SEPARATE SQL queries for each database +5. Only generate SELECT queries (read-only operations) +6. Always add LIMIT clause to prevent large result sets + +OUTPUT FORMAT: +- For single database query: +```sql +-- database: +SELECT id, name FROM users WHERE status = 'active' LIMIT 100; +``` + +- For multiple database queries: +```sql +-- database: db1 +SELECT * FROM users LIMIT 100; +``` +```sql +-- database: db2 +SELECT * FROM customers LIMIT 100; +``` + +The "-- database: " comment is REQUIRED to specify which database to execute the query on.""" + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt new file mode 100644 index 0000000000..0200813425 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -0,0 +1,417 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.conversation.ConversationManager +import cc.unitmesh.agent.database.* +import cc.unitmesh.agent.executor.BaseAgentExecutor +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.SqlValidator +import cc.unitmesh.agent.subagent.SqlReviseAgent +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService + +/** + * Multi-Database ChatDB Executor - Executes Text2SQL across multiple databases + * + * Key differences from single-database executor: + * 1. Merges schemas from all databases with database name prefixes + * 2. Parses SQL comments to determine target database + * 3. Can execute queries on multiple databases in parallel + * 4. Combines results from multiple databases + */ +class MultiDatabaseChatDBExecutor( + projectPath: String, + llmService: KoogLLMService, + toolOrchestrator: ToolOrchestrator, + renderer: CodingAgentRenderer, + private val databaseConnections: Map, + private val databaseConfigs: Map, + maxIterations: Int = 10, + enableLLMStreaming: Boolean = true +) : BaseAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming +) { + private val logger = getLogger("MultiDatabaseChatDBExecutor") + private val sqlValidator = SqlValidator() + private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) + private val maxRevisionAttempts = 3 + private val maxExecutionRetries = 3 + + // Cache for merged schema + private var mergedSchema: MergedDatabaseSchema? = null + + suspend fun execute( + task: ChatDBTask, + systemPrompt: String, + onProgress: (String) -> Unit = {} + ): MultiDatabaseChatDBResult { + currentIteration = 0 + conversationManager = ConversationManager(llmService, systemPrompt) + + val errors = mutableListOf() + var generatedSql: String? = null + val queryResults = mutableMapOf() + var revisionAttempts = 0 + val targetDatabases = mutableListOf() + + try { + // Step 1: Fetch and merge schemas from all databases + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Fetching schemas from ${databaseConnections.size} databases..." + ) + onProgress("๐Ÿ“Š Fetching schemas from ${databaseConnections.size} databases...") + + val merged = fetchAndMergeSchemas() + mergedSchema = merged + + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.SUCCESS, + title = "Schemas fetched from ${merged.databases.size} databases", + details = mapOf( + "databases" to merged.databases.map { (name, schema) -> + mapOf( + "name" to name, + "displayName" to (databaseConfigs[name]?.databaseName ?: name), + "tableCount" to schema.tables.size, + "tables" to schema.tables.map { it.name } + ) + }, + "totalTables" to merged.totalTableCount + ) + ) + + // Step 2: Schema Linking with merged schema + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Analyzing query across ${merged.databases.size} databases..." + ) + onProgress("๐Ÿ”— Analyzing query across databases...") + + val schemaContext = buildMultiDatabaseSchemaContext(merged, task.query) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.SUCCESS, + title = "Schema analysis complete", + details = mapOf( + "databasesAnalyzed" to merged.databases.keys.toList(), + "schemaContext" to schemaContext.take(500) + "..." + ) + ) + + // Step 3: Generate SQL with multi-database context + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating SQL..." + ) + onProgress("๐Ÿค– Generating SQL...") + + val sqlPrompt = buildMultiDatabaseSqlPrompt(task.query, schemaContext, task.maxRows) + val sqlResponse = getLLMResponse(sqlPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + + // Parse SQL blocks with database targets + val sqlBlocks = parseSqlBlocksWithTargets(sqlResponse) + + if (sqlBlocks.isEmpty()) { + throw DatabaseException("No valid SQL generated") + } + + generatedSql = sqlBlocks.map { "${it.database}: ${it.sql}" }.joinToString("\n") + targetDatabases.addAll(sqlBlocks.map { it.database }.distinct()) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL generated for ${targetDatabases.size} database(s)", + details = mapOf( + "targetDatabases" to targetDatabases, + "sqlBlocks" to sqlBlocks.map { mapOf("database" to it.database, "sql" to it.sql) } + ) + ) + + // Step 4: Execute SQL on target databases + for (sqlBlock in sqlBlocks) { + val dbName = sqlBlock.database + val sql = sqlBlock.sql + val connection = databaseConnections[dbName] + + if (connection == null) { + errors.add("Database '$dbName' not connected") + continue + } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Executing on $dbName...", + details = mapOf("database" to dbName, "sql" to sql) + ) + onProgress("โšก Executing SQL on $dbName...") + + try { + val result = connection.executeQuery(sql) + queryResults[dbName] = result + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "Query executed on $dbName", + details = mapOf( + "database" to dbName, + "sql" to sql, + "rowCount" to result.rowCount, + "columns" to result.columns, + "previewRows" to result.rows.take(5) + ) + ) + } catch (e: Exception) { + errors.add("[$dbName] ${e.message}") + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Query failed on $dbName", + details = mapOf("database" to dbName, "sql" to sql, "error" to (e.message ?: "Unknown error")) + ) + } + } + + // Step 5: Final result + val success = queryResults.isNotEmpty() + val combinedResult = combineResults(queryResults) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = if (success) ChatDBStepStatus.SUCCESS else ChatDBStepStatus.ERROR, + title = if (success) "Query completed on ${queryResults.size} database(s)" else "Query failed", + details = mapOf( + "databases" to queryResults.keys.toList(), + "totalRows" to combinedResult.rowCount, + "columns" to combinedResult.columns, + "previewRows" to combinedResult.rows.take(10), + "errors" to errors + ) + ) + + return MultiDatabaseChatDBResult( + success = success, + message = if (success) "Query executed successfully on ${queryResults.size} database(s)" else errors.joinToString("\n"), + generatedSql = generatedSql, + queryResult = combinedResult, + queryResultsByDatabase = queryResults, + targetDatabases = targetDatabases, + revisionAttempts = revisionAttempts, + errors = errors + ) + + } catch (e: Exception) { + logger.error(e) { "Multi-database query failed: ${e.message}" } + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.ERROR, + title = "Query failed", + details = mapOf("error" to (e.message ?: "Unknown error")) + ) + return MultiDatabaseChatDBResult( + success = false, + message = "Error: ${e.message}", + errors = listOf(e.message ?: "Unknown error") + ) + } + } + + /** + * Fetch schemas from all databases and merge them + */ + private suspend fun fetchAndMergeSchemas(): MergedDatabaseSchema { + val schemas = mutableMapOf() + + for ((dbId, connection) in databaseConnections) { + try { + val schema = connection.getSchema() + schemas[dbId] = schema + logger.info { "Fetched schema for $dbId: ${schema.tables.size} tables" } + } catch (e: Exception) { + logger.error(e) { "Failed to fetch schema for $dbId" } + } + } + + return MergedDatabaseSchema(schemas) + } + + /** + * Build schema context for multi-database prompt + */ + private fun buildMultiDatabaseSchemaContext(merged: MergedDatabaseSchema, query: String): String { + val sb = StringBuilder() + sb.append("=== AVAILABLE DATABASES AND TABLES ===\n\n") + + for ((dbId, schema) in merged.databases) { + val displayName = databaseConfigs[dbId]?.databaseName ?: dbId + sb.append("DATABASE: $dbId ($displayName)\n") + sb.append("-".repeat(40)).append("\n") + + for (table in schema.tables) { + sb.append(" Table: ${table.name}\n") + if (table.comment != null) { + sb.append(" Comment: ${table.comment}\n") + } + sb.append(" Columns:\n") + for (col in table.columns) { + val flags = mutableListOf() + if (col.isPrimaryKey) flags.add("PK") + if (col.isForeignKey) flags.add("FK") + if (!col.nullable) flags.add("NOT NULL") + val flagStr = if (flags.isNotEmpty()) " [${flags.joinToString(", ")}]" else "" + sb.append(" - ${col.name}: ${col.type}$flagStr\n") + } + sb.append("\n") + } + sb.append("\n") + } + + return sb.toString() + } + + /** + * Build SQL generation prompt for multi-database context + */ + private fun buildMultiDatabaseSqlPrompt(query: String, schemaContext: String, maxRows: Int): String { + return """ +$schemaContext + +USER QUERY: $query + +INSTRUCTIONS: +1. Analyze which database(s) contain the relevant tables for this query +2. Generate SQL for the appropriate database(s) +3. Each SQL block MUST start with a comment specifying the target database: -- database: +4. Use LIMIT $maxRows to restrict results +5. Only generate SELECT queries + +Generate the SQL: +""".trimIndent() + } + + /** + * Parse SQL response to extract SQL blocks with their target databases + */ + private fun parseSqlBlocksWithTargets(response: String): List { + val blocks = mutableListOf() + val codeFences = CodeFence.parseAll(response) + + for (fence in codeFences) { + if (fence.languageId.lowercase() == "sql") { + val sql = fence.text.trim() + val database = extractDatabaseFromSql(sql) + val cleanSql = removeDatabaseComment(sql) + + if (database != null && cleanSql.isNotBlank()) { + blocks.add(SqlBlock(database, cleanSql)) + } else if (databaseConnections.size == 1) { + // If only one database, use it as default + val defaultDb = databaseConnections.keys.first() + blocks.add(SqlBlock(defaultDb, cleanSql.ifBlank { sql })) + } + } + } + + return blocks + } + + /** + * Extract database name from SQL comment + */ + private fun extractDatabaseFromSql(sql: String): String? { + val regex = Regex("--\\s*database:\\s*(\\S+)", RegexOption.IGNORE_CASE) + val match = regex.find(sql) + return match?.groupValues?.get(1) + } + + /** + * Remove database comment from SQL + */ + private fun removeDatabaseComment(sql: String): String { + return sql.replace(Regex("--\\s*database:\\s*\\S+\\s*\n?", RegexOption.IGNORE_CASE), "").trim() + } + + /** + * Combine results from multiple databases + */ + private fun combineResults(results: Map): QueryResult { + if (results.isEmpty()) { + return QueryResult(emptyList(), emptyList(), 0) + } + + if (results.size == 1) { + return results.values.first() + } + + // For multiple results, add a "database" column and combine + val allColumns = mutableListOf("_database") + val allRows = mutableListOf>() + + for ((dbName, result) in results) { + if (allColumns.size == 1) { + allColumns.addAll(result.columns) + } + for (row in result.rows) { + allRows.add(listOf(dbName) + row) + } + } + + return QueryResult(allColumns, allRows, allRows.size) + } +} + +/** + * SQL block with target database + */ +data class SqlBlock( + val database: String, + val sql: String +) + +/** + * Merged schema from multiple databases + */ +data class MergedDatabaseSchema( + val databases: Map +) { + val totalTableCount: Int + get() = databases.values.sumOf { it.tables.size } + + fun getTablesForDatabase(dbId: String): List { + return databases[dbId]?.tables ?: emptyList() + } +} + +/** + * Result from multi-database query + */ +data class MultiDatabaseChatDBResult( + val success: Boolean, + val message: String, + val generatedSql: String? = null, + val queryResult: QueryResult? = null, + val queryResultsByDatabase: Map = emptyMap(), + val targetDatabases: List? = null, + val plotDslCode: String? = null, + val revisionAttempts: Int = 0, + val errors: List = emptyList() +) + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt index 813f69e577..c539f64588 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt @@ -44,6 +44,15 @@ interface CodingAgentRenderer { fun renderTaskComplete(executionTimeMs: Long = 0L, toolsUsedCount: Int = 0) fun renderFinalResult(success: Boolean, message: String, iterations: Int) fun renderError(message: String) + + /** + * Render an informational message (non-error, non-warning) + * Used for status updates, progress information, etc. + */ + fun renderInfo(message: String) { + // Default: no-op, renderers can override to display info messages + } + fun renderRepeatWarning(toolName: String, count: Int) fun renderRecoveryAdvice(recoveryAdvice: String) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt index afd02a15e3..7715ba4054 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt @@ -138,6 +138,16 @@ sealed class TimelineItem( override val id: String = generateId() ) : TimelineItem(timestamp, id) + /** + * Info item for displaying informational messages (non-error, non-warning). + * Used for status updates, progress information, database context, etc. + */ + data class InfoItem( + val message: String, + override val timestamp: Long = Platform.getCurrentTimestamp(), + override val id: String = generateId() + ) : TimelineItem(timestamp, id) + /** * Task completion item. */ diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt index e1d872d857..7afbccf3c0 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt @@ -242,6 +242,43 @@ fun RenderMessageItem( step = timelineItem ) } + + is TimelineItem.InfoItem -> { + InfoMessageItem(message = timelineItem.message) + } + } +} + +/** + * Simple info message item for displaying informational messages + */ +@Composable +private fun InfoMessageItem(message: String) { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp, vertical = 4.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.Info, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + Text( + text = message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt index e4fe9ce970..8069d38056 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt @@ -527,6 +527,10 @@ class ComposeRenderer : BaseRenderer() { _isProcessing = false } + override fun renderInfo(message: String) { + _timeline.add(TimelineItem.InfoItem(message = message)) + } + override fun renderRepeatWarning( toolName: String, count: Int @@ -924,6 +928,11 @@ class ComposeRenderer : BaseRenderer() { // ChatDB steps are not persisted (they're runtime-only for UI display) null } + + is TimelineItem.InfoItem -> { + // Info items are not persisted (they're runtime-only for UI display) + null + } } } @@ -1154,6 +1163,11 @@ class ComposeRenderer : BaseRenderer() { // ChatDB steps are not persisted as messages null } + + is TimelineItem.InfoItem -> { + // Info items are not persisted as messages + null + } } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt index a69f734154..aa30e71ef6 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -52,16 +52,16 @@ fun ChatDBPage( maxRatio = 0.4f, saveKey = "chatdb_split_ratio", first = { - // Left panel - Data source management + // Left panel - Data source management (multi-selection mode) DataSourcePanel( dataSources = state.filteredDataSources, - selectedDataSourceId = state.selectedDataSourceId, - connectionStatus = state.connectionStatus, + selectedDataSourceIds = state.selectedDataSourceIds, + connectionStatuses = state.connectionStatuses, filterQuery = state.filterQuery, onFilterChange = viewModel::setFilterQuery, - onSelectDataSource = { id -> - viewModel.selectDataSource(id) - // When selecting a different data source, show its config in the pane + onToggleDataSource = { id -> + viewModel.toggleDataSource(id) + // When toggling a data source, optionally show its config in the pane val selected = state.dataSources.find { it.id == id } if (selected != null && state.isConfigPaneOpen) { viewModel.openConfigPane(selected) @@ -70,8 +70,10 @@ fun ChatDBPage( onAddClick = { viewModel.openConfigPane(null) }, onEditClick = { config -> viewModel.openConfigPane(config) }, onDeleteClick = viewModel::deleteDataSource, - onConnectClick = viewModel::connect, - onDisconnectClick = viewModel::disconnect, + onConnectClick = viewModel::connectDataSource, + onDisconnectClick = viewModel::disconnectDataSource, + onConnectAllClick = viewModel::connectAll, + onDisconnectAllClick = viewModel::disconnectAll, modifier = Modifier.fillMaxSize() ) }, @@ -89,6 +91,8 @@ fun ChatDBPage( ChatDBChatPane( renderer = viewModel.renderer, connectionStatus = state.connectionStatus, + connectedCount = state.connectedCount, + selectedCount = state.selectedCount, schema = viewModel.getSchema(), isGenerating = viewModel.isGenerating, onSendMessage = viewModel::sendMessage, diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 68a364bf17..809316bf5f 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -5,8 +5,10 @@ import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue import cc.unitmesh.agent.chatdb.ChatDBAgent import cc.unitmesh.agent.chatdb.ChatDBTask +import cc.unitmesh.agent.chatdb.MultiDatabaseChatDBAgent import cc.unitmesh.agent.config.McpToolConfigService import cc.unitmesh.agent.config.ToolConfigFile +import cc.unitmesh.agent.database.DatabaseConfig import cc.unitmesh.agent.database.DatabaseConnection import cc.unitmesh.agent.database.DatabaseSchema import cc.unitmesh.agent.database.createDatabaseConnection @@ -38,9 +40,9 @@ class ChatDBViewModel( private var llmService: KoogLLMService? = null private var currentExecutionJob: Job? = null - // Database connection - private var currentConnection: DatabaseConnection? = null - private var currentSchema: DatabaseSchema? = null + // Database connections (multi-datasource support) + private val connections: MutableMap = mutableMapOf() + private val schemas: MutableMap = mutableMapOf() // Data source repository for persistence private val dataSourceRepository: DataSourceRepository by lazy { @@ -81,12 +83,13 @@ class ChatDBViewModel( scope.launch { try { val dataSources = dataSourceRepository.getAll() - val defaultDataSource = dataSourceRepository.getDefault() + // Multi-datasource: select all data sources by default + val allIds = dataSources.map { it.id }.toSet() state = state.copy( dataSources = dataSources, - selectedDataSourceId = defaultDataSource?.id + selectedDataSourceIds = allIds ) - println("[ChatDB] Loaded ${dataSources.size} data sources") + println("[ChatDB] Loaded ${dataSources.size} data sources, all selected by default") } catch (e: Exception) { println("[ChatDB] Failed to load data sources: ${e.message}") state = state.copy(dataSources = emptyList()) @@ -122,13 +125,14 @@ class ChatDBViewModel( } fun deleteDataSource(id: String) { + // Disconnect this specific data source if connected + disconnectDataSource(id) + state = state.copy( dataSources = state.dataSources.filter { it.id != id }, - selectedDataSourceId = if (state.selectedDataSourceId == id) null else state.selectedDataSourceId + selectedDataSourceIds = state.selectedDataSourceIds - id, + connectionStatuses = state.connectionStatuses - id ) - if (state.selectedDataSourceId == id) { - disconnect() - } deleteDataSourceFromRepository(id) } @@ -168,11 +172,44 @@ class ChatDBViewModel( } } + /** + * Toggle selection of a data source (multi-selection mode) + */ + fun toggleDataSource(id: String) { + val newSelectedIds = if (id in state.selectedDataSourceIds) { + // Deselect: disconnect if connected + disconnectDataSource(id) + state.selectedDataSourceIds - id + } else { + // Select: add to selection + state.selectedDataSourceIds + id + } + state = state.copy(selectedDataSourceIds = newSelectedIds) + } + + /** + * Select a data source (for backward compatibility, also used for single-click selection) + */ fun selectDataSource(id: String) { - if (state.selectedDataSourceId == id) return + if (id !in state.selectedDataSourceIds) { + state = state.copy(selectedDataSourceIds = state.selectedDataSourceIds + id) + } + } + + /** + * Select all data sources + */ + fun selectAllDataSources() { + val allIds = state.dataSources.map { it.id }.toSet() + state = state.copy(selectedDataSourceIds = allIds) + } - disconnect() - state = state.copy(selectedDataSourceId = id) + /** + * Deselect all data sources + */ + fun deselectAllDataSources() { + disconnectAll() + state = state.copy(selectedDataSourceIds = emptySet()) } fun setFilterQuery(query: String) { @@ -215,7 +252,7 @@ class ChatDBViewModel( dataSources = state.dataSources + savedConfig, isConfigPaneOpen = false, configuringDataSource = null, - selectedDataSourceId = savedConfig.id + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id ) } else { state.copy( @@ -241,58 +278,122 @@ class ChatDBViewModel( dataSources = state.dataSources + savedConfig, isConfigPaneOpen = false, configuringDataSource = null, - selectedDataSourceId = savedConfig.id + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id ) } else { state.copy( dataSources = state.dataSources.map { if (it.id == savedConfig.id) savedConfig else it }, isConfigPaneOpen = false, configuringDataSource = null, - selectedDataSourceId = savedConfig.id + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id ) } saveDataSource(savedConfig) - // Trigger connection - connect() + // Trigger connection for this specific data source + connectDataSource(savedConfig.id) } - fun connect() { - val dataSource = state.selectedDataSource ?: return + /** + * Connect all selected data sources + */ + fun connectAll() { + state.selectedDataSources.forEach { dataSource -> + connectDataSource(dataSource.id) + } + } + + /** + * Connect a specific data source + */ + fun connectDataSource(id: String) { + val dataSource = state.dataSources.find { it.id == id } ?: return scope.launch { - state = state.copy(connectionStatus = ConnectionStatus.Connecting) + // Update status to connecting + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Connecting), + connectionStatus = ConnectionStatus.Connecting + ) try { val connection = createDatabaseConnection(dataSource.toDatabaseConfig()) if (connection.isConnected()) { - currentConnection = connection - currentSchema = connection.getSchema() - state = state.copy(connectionStatus = ConnectionStatus.Connected) + connections[id] = connection + schemas[id] = connection.getSchema() + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Connected), + connectionStatus = if (state.hasAnyConnection || true) ConnectionStatus.Connected else state.connectionStatus + ) _notificationEvent.emit("Connected" to "Successfully connected to ${dataSource.name}") } else { - state = state.copy(connectionStatus = ConnectionStatus.Error("Failed to connect")) + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Error("Failed to connect")) + ) } } catch (e: Exception) { - state = state.copy(connectionStatus = ConnectionStatus.Error(e.message ?: "Unknown error")) - _notificationEvent.emit("Connection Failed" to (e.message ?: "Unknown error")) + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Error(e.message ?: "Unknown error")) + ) + _notificationEvent.emit("Connection Failed" to "${dataSource.name}: ${e.message ?: "Unknown error"}") } } } - fun disconnect() { + /** + * Disconnect a specific data source + */ + fun disconnectDataSource(id: String) { scope.launch { try { - currentConnection?.close() + connections[id]?.close() } catch (e: Exception) { - println("[ChatDB] Error closing connection: ${e.message}") + println("[ChatDB] Error closing connection for $id: ${e.message}") } - currentConnection = null - currentSchema = null - state = state.copy(connectionStatus = ConnectionStatus.Disconnected) + connections.remove(id) + schemas.remove(id) + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Disconnected), + connectionStatus = if (connections.isEmpty()) ConnectionStatus.Disconnected else ConnectionStatus.Connected + ) } } + /** + * Disconnect all data sources + */ + fun disconnectAll() { + scope.launch { + connections.forEach { (id, connection) -> + try { + connection.close() + } catch (e: Exception) { + println("[ChatDB] Error closing connection for $id: ${e.message}") + } + } + connections.clear() + schemas.clear() + state = state.copy( + connectionStatuses = emptyMap(), + connectionStatus = ConnectionStatus.Disconnected + ) + } + } + + /** + * Legacy connect method - connects all selected data sources + */ + fun connect() { + connectAll() + } + + /** + * Legacy disconnect method - disconnects all data sources + */ + fun disconnect() { + disconnectAll() + } + fun sendMessage(text: String) { if (isGenerating || text.isBlank()) return @@ -308,21 +409,30 @@ class ChatDBViewModel( return@launch } - val dataSource = state.selectedDataSource - if (dataSource == null) { - renderer.renderError("No database selected. Please select a data source first.") + // Get connected data sources + val connectedDataSources = state.selectedDataSources.filter { ds -> + state.getConnectionStatus(ds.id) is ConnectionStatus.Connected + } + + if (connectedDataSources.isEmpty()) { + renderer.renderError("No database connected. Please connect to at least one data source first.") isGenerating = false return@launch } - val databaseConfig = dataSource.toDatabaseConfig() val projectPath = workspace?.rootPath ?: "." val mcpConfigService = McpToolConfigService(ToolConfigFile()) - val agent = ChatDBAgent( + // Build database configs map for multi-database agent + val databaseConfigs: Map = connectedDataSources.associate { ds -> + ds.id to ds.toDatabaseConfig() + } + + // Use MultiDatabaseChatDBAgent for unified schema linking and query execution + val agent = MultiDatabaseChatDBAgent( projectPath = projectPath, llmService = service, - databaseConfig = databaseConfig, + databaseConfigs = databaseConfigs, maxIterations = 10, renderer = renderer, mcpToolConfigService = mcpConfigService, @@ -335,15 +445,17 @@ class ChatDBViewModel( generateVisualization = false ) - // Execute the agent - it will render results to the timeline via the renderer - agent.execute(task) { progress -> - // Progress callback - can be used for UI updates - println("[ChatDB] Progress: $progress") + try { + // Execute the multi-database agent + // It will merge schemas, let LLM decide which database(s) to query, + // and execute SQL on the appropriate database(s) + agent.execute(task) { progress -> + println("[ChatDB] Progress: $progress") + } + } finally { + agent.close() } - // Close the agent connection - agent.close() - } catch (e: CancellationException) { renderer.forceStop() renderer.renderError("Generation cancelled") @@ -371,11 +483,30 @@ class ChatDBViewModel( renderer.clearMessages() } - fun getSchema(): DatabaseSchema? = currentSchema + /** + * Get combined schema from all connected data sources + */ + fun getSchema(): DatabaseSchema? { + if (schemas.isEmpty()) return null + + // Return the first schema for now, or combine them + // For multi-datasource, we might want to show all schemas + return schemas.values.firstOrNull() + } + + /** + * Get all schemas from connected data sources + */ + fun getAllSchemas(): Map = schemas.toMap() + + /** + * Get schema for a specific data source + */ + fun getSchemaForDataSource(id: String): DatabaseSchema? = schemas[id] fun dispose() { stopGeneration() - disconnect() + disconnectAll() scope.cancel() } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt index 38248994c5..c22777203d 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt @@ -29,11 +29,15 @@ import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons /** * Chat pane for ChatDB - right side chat area for text-to-SQL + * + * Supports multi-datasource mode: shows connection status for multiple databases. */ @Composable fun ChatDBChatPane( renderer: ComposeRenderer, connectionStatus: ConnectionStatus, + connectedCount: Int = 0, + selectedCount: Int = 0, schema: DatabaseSchema?, isGenerating: Boolean, onSendMessage: (String) -> Unit, @@ -41,10 +45,14 @@ fun ChatDBChatPane( onNewSession: () -> Unit = {}, modifier: Modifier = Modifier ) { + val hasAnyConnection = connectedCount > 0 + Column(modifier = modifier.fillMaxSize()) { - // Connection status banner + // Connection status banner (multi-datasource aware) ConnectionStatusBanner( connectionStatus = connectionStatus, + connectedCount = connectedCount, + selectedCount = selectedCount, schema = schema, onNewSession = onNewSession, hasMessages = renderer.timeline.isNotEmpty() @@ -60,7 +68,8 @@ fun ChatDBChatPane( // Welcome message when no messages and not streaming if (renderer.timeline.isEmpty() && renderer.currentStreamingOutput.isEmpty() && !renderer.isProcessing) { WelcomeMessage( - isConnected = connectionStatus is ConnectionStatus.Connected, + isConnected = hasAnyConnection, + connectedCount = connectedCount, schema = schema, onQuickQuery = onSendMessage, modifier = Modifier.fillMaxSize() @@ -73,7 +82,8 @@ fun ChatDBChatPane( // Input area ChatInputArea( isGenerating = isGenerating, - isConnected = connectionStatus is ConnectionStatus.Connected, + isConnected = hasAnyConnection, + connectedCount = connectedCount, onSendMessage = onSendMessage, onStopGeneration = onStopGeneration ) @@ -83,12 +93,16 @@ fun ChatDBChatPane( @Composable private fun ConnectionStatusBanner( connectionStatus: ConnectionStatus, + connectedCount: Int, + selectedCount: Int, schema: DatabaseSchema?, onNewSession: () -> Unit, hasMessages: Boolean ) { - when (connectionStatus) { - is ConnectionStatus.Connected -> { + val hasAnyConnection = connectedCount > 0 + + when { + hasAnyConnection -> { Card( modifier = Modifier .fillMaxWidth() @@ -110,13 +124,22 @@ private fun ConnectionStatusBanner( ) Column(modifier = Modifier.weight(1f)) { Text( - text = "Connected", + text = if (connectedCount == 1) "Connected" else "$connectedCount databases connected", style = MaterialTheme.typography.labelLarge, color = MaterialTheme.colorScheme.primary ) - if (schema != null) { + val statusText = buildString { + if (schema != null) { + append("${schema.tables.size} tables available") + } + if (connectedCount < selectedCount) { + if (isNotEmpty()) append(" โ€ข ") + append("${selectedCount - connectedCount} not connected") + } + } + if (statusText.isNotEmpty()) { Text( - text = "${schema.tables.size} tables available", + text = statusText, style = MaterialTheme.typography.bodySmall, color = MaterialTheme.colorScheme.onSurfaceVariant ) @@ -139,7 +162,37 @@ private fun ConnectionStatusBanner( } } } - is ConnectionStatus.Disconnected -> { + selectedCount > 0 -> { + // Selected but not connected + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "$selectedCount data source${if (selectedCount > 1) "s" else ""} selected. Click 'Connect All' to connect.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + else -> { + // No data sources selected Card( modifier = Modifier .fillMaxWidth() @@ -160,20 +213,20 @@ private fun ConnectionStatusBanner( tint = MaterialTheme.colorScheme.onSurfaceVariant ) Text( - text = "Not connected. Select a data source and connect.", + text = "No data sources selected. Select data sources from the left panel.", style = MaterialTheme.typography.bodySmall, color = MaterialTheme.colorScheme.onSurfaceVariant ) } } } - else -> { /* Connecting or Error handled elsewhere */ } } } @Composable private fun WelcomeMessage( isConnected: Boolean, + connectedCount: Int = 0, schema: DatabaseSchema?, onQuickQuery: (String) -> Unit, modifier: Modifier = Modifier @@ -203,10 +256,10 @@ private fun WelcomeMessage( Spacer(modifier = Modifier.height(8.dp)) Text( - text = if (isConnected) { - "Ask questions about your data in natural language" - } else { - "Connect to a database to start querying" + text = when { + connectedCount > 1 -> "Query across $connectedCount databases in natural language" + isConnected -> "Ask questions about your data in natural language" + else -> "Connect to a database to start querying" }, style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant @@ -244,6 +297,7 @@ private fun WelcomeMessage( private fun ChatInputArea( isGenerating: Boolean, isConnected: Boolean, + connectedCount: Int = 0, onSendMessage: (String) -> Unit, onStopGeneration: () -> Unit ) { @@ -297,8 +351,13 @@ private fun ChatInputArea( decorationBox = { innerTextField -> Box { if (inputText.text.isEmpty()) { + val placeholderText = when { + connectedCount > 1 -> "Ask a question across $connectedCount databases..." + isConnected -> "Ask a question about your data..." + else -> "Connect to a database first" + } Text( - text = if (isConnected) "Ask a question about your data..." else "Connect to a database first", + text = placeholderText, style = MaterialTheme.typography.bodyMedium, color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f) ) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt index c5ffdcfd3a..f172d1ebb5 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -164,7 +164,7 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { ) { when (stepType) { ChatDBStepType.FETCH_SCHEMA -> { - // Database overview with table cards + // Single database mode details["databaseName"]?.let { dbName -> DetailRow("Database", dbName.toString()) } @@ -172,7 +172,23 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { DetailRow("Total Tables", it.toString()) } - // Show table schema cards + // Multi-database mode - show databases list + @Suppress("UNCHECKED_CAST") + val databases = details["databases"] as? List> + if (databases != null && databases.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Connected Databases", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + databases.forEach { dbInfo -> + DatabaseInfoCard(dbInfo) + } + } + + // Show table schema cards (single database mode) @Suppress("UNCHECKED_CAST") val tableSchemas = details["tableSchemas"] as? List> if (tableSchemas != null && tableSchemas.isNotEmpty()) { @@ -186,7 +202,7 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { tableSchemas.forEach { tableInfo -> TableSchemaCard(tableInfo) } - } else { + } else if (databases == null) { // Fallback to simple table list details["tables"]?.let { tables -> if (tables is List<*>) { @@ -197,7 +213,52 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } ChatDBStepType.SCHEMA_LINKING -> { - // Show relevant tables with their columns + // Multi-database mode - show analyzed databases + @Suppress("UNCHECKED_CAST") + val databasesAnalyzed = details["databasesAnalyzed"] as? List + if (databasesAnalyzed != null && databasesAnalyzed.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Databases Analyzed:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + databasesAnalyzed.forEach { dbName -> + KeywordChip(dbName) + } + } + } + + // Show schema context preview + details["schemaContext"]?.let { context -> + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Schema Context", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) { + Text( + text = context.toString(), + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 10.sp + ), + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(8.dp), + maxLines = 10 + ) + } + } + + // Show keywords (single database mode) details["keywords"]?.let { keywords -> if (keywords is List<*> && keywords.isNotEmpty()) { Row( @@ -230,7 +291,7 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { relevantTableSchemas.forEach { tableInfo -> TableSchemaCard(tableInfo, highlightRelevant = true) } - } else { + } else if (databasesAnalyzed == null) { // Fallback details["relevantTables"]?.let { tables -> if (tables is List<*>) { @@ -241,8 +302,48 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } ChatDBStepType.GENERATE_SQL, ChatDBStepType.REVISE_SQL -> { - details["sql"]?.let { sql -> - CodeBlock(code = sql.toString(), language = "sql") + // Multi-database mode - show target databases + @Suppress("UNCHECKED_CAST") + val targetDatabases = details["targetDatabases"] as? List + if (targetDatabases != null && targetDatabases.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Target Databases:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + targetDatabases.forEach { dbName -> + KeywordChip(dbName) + } + } + Spacer(modifier = Modifier.height(8.dp)) + } + + // Multi-database mode - show SQL blocks + @Suppress("UNCHECKED_CAST") + val sqlBlocks = details["sqlBlocks"] as? List> + if (sqlBlocks != null && sqlBlocks.isNotEmpty()) { + sqlBlocks.forEach { block -> + val database = block["database"]?.toString() ?: "Unknown" + val sql = block["sql"]?.toString() ?: "" + Text( + text = "Database: $database", + style = MaterialTheme.typography.labelSmall, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.secondary, + modifier = Modifier.padding(bottom = 4.dp) + ) + CodeBlock(code = sql, language = "sql") + Spacer(modifier = Modifier.height(8.dp)) + } + } else { + // Single database mode + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } } } @@ -272,6 +373,11 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } ChatDBStepType.EXECUTE_SQL -> { + // Multi-database mode - show database name + details["database"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } + // Show SQL that was executed details["sql"]?.let { sql -> CodeBlock(code = sql.toString(), language = "sql") @@ -313,21 +419,67 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } ChatDBStepType.FINAL_RESULT -> { - // Show final SQL + // Multi-database mode - show databases queried + @Suppress("UNCHECKED_CAST") + val databases = details["databases"] as? List + if (databases != null && databases.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Databases Queried:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + databases.forEach { dbName -> + KeywordChip(dbName) + } + } + Spacer(modifier = Modifier.height(8.dp)) + } + + // Show final SQL (single database mode) details["sql"]?.let { sql -> CodeBlock(code = sql.toString(), language = "sql") } // Show result summary - details["rowCount"]?.let { + details["totalRows"]?.let { DetailRow("Total Rows", it.toString()) } + details["rowCount"]?.let { + if (details["totalRows"] == null) { + DetailRow("Total Rows", it.toString()) + } + } details["revisionAttempts"]?.let { attempts -> if ((attempts as? Int ?: 0) > 0) { DetailRow("Revision Attempts", attempts.toString()) } } + // Show errors if any + @Suppress("UNCHECKED_CAST") + val errors = details["errors"] as? List + if (errors != null && errors.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Errors", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier.padding(start = 8.dp) + ) + } + } + // Show data preview @Suppress("UNCHECKED_CAST") val columns = details["columns"] as? List @@ -401,6 +553,94 @@ private fun CodeBlock(code: String, language: String) { } } +/** + * Card displaying database information for multi-database mode + */ +@Composable +private fun DatabaseInfoCard(dbInfo: Map) { + var isExpanded by remember { mutableStateOf(false) } + val name = dbInfo["name"]?.toString() ?: "Unknown" + val displayName = dbInfo["displayName"]?.toString() ?: name + val tableCount = dbInfo["tableCount"]?.toString() ?: "0" + + @Suppress("UNCHECKED_CAST") + val tables = dbInfo["tables"] as? List ?: emptyList() + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(6.dp), + color = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(8.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown + else Icons.Default.KeyboardArrowRight, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(4.dp)) + Text(text = "๐Ÿ—„๏ธ", fontSize = 12.sp) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = displayName, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + if (name != displayName) { + Text( + text = " ($name)", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + Text( + text = "$tableCount tables", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + if (tables.isNotEmpty()) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp, start = 24.dp) + ) { + Text( + text = tables.joinToString(", "), + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } + } +} + /** * Card displaying table schema information */ diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt index f0c444dd5a..30d9219446 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt @@ -19,29 +19,41 @@ import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons /** * Data Source Panel - Left side panel for managing database connections + * + * Supports multi-datasource selection: all data sources are selected by default, + * users can toggle individual data sources on/off using checkboxes. */ @Composable fun DataSourcePanel( dataSources: List, - selectedDataSourceId: String?, - connectionStatus: ConnectionStatus, + selectedDataSourceIds: Set, + connectionStatuses: Map, filterQuery: String, onFilterChange: (String) -> Unit, - onSelectDataSource: (String) -> Unit, + onToggleDataSource: (String) -> Unit, onAddClick: () -> Unit, onEditClick: (DataSourceConfig) -> Unit, onDeleteClick: (String) -> Unit, - onConnectClick: () -> Unit, - onDisconnectClick: () -> Unit, + onConnectClick: (String) -> Unit, + onDisconnectClick: (String) -> Unit, + onConnectAllClick: () -> Unit, + onDisconnectAllClick: () -> Unit, modifier: Modifier = Modifier ) { + val selectedCount = selectedDataSourceIds.size + val connectedCount = connectionStatuses.values.count { it is ConnectionStatus.Connected } + Column( modifier = modifier .fillMaxHeight() .background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f)) ) { - // Header with Add button - DataSourceHeader(onAddClick = onAddClick) + // Header with Add button and selection info + DataSourceHeader( + selectedCount = selectedCount, + totalCount = dataSources.size, + onAddClick = onAddClick + ) HorizontalDivider() @@ -60,29 +72,36 @@ fun DataSourcePanel( items(dataSources, key = { it.id }) { dataSource -> DataSourceItem( dataSource = dataSource, - isSelected = dataSource.id == selectedDataSourceId, - connectionStatus = if (dataSource.id == selectedDataSourceId) connectionStatus else ConnectionStatus.Disconnected, - onClick = { onSelectDataSource(dataSource.id) }, + isSelected = dataSource.id in selectedDataSourceIds, + connectionStatus = connectionStatuses[dataSource.id] ?: ConnectionStatus.Disconnected, + onToggle = { onToggleDataSource(dataSource.id) }, onEditClick = { onEditClick(dataSource) }, - onDeleteClick = { onDeleteClick(dataSource.id) } + onDeleteClick = { onDeleteClick(dataSource.id) }, + onConnectClick = { onConnectClick(dataSource.id) }, + onDisconnectClick = { onDisconnectClick(dataSource.id) } ) } } - // Connection controls - if (selectedDataSourceId != null) { + // Connection controls for all selected data sources + if (selectedDataSourceIds.isNotEmpty()) { HorizontalDivider() - ConnectionControls( - connectionStatus = connectionStatus, - onConnect = onConnectClick, - onDisconnect = onDisconnectClick + MultiConnectionControls( + selectedCount = selectedCount, + connectedCount = connectedCount, + onConnectAll = onConnectAllClick, + onDisconnectAll = onDisconnectAllClick ) } } } @Composable -private fun DataSourceHeader(onAddClick: () -> Unit) { +private fun DataSourceHeader( + selectedCount: Int, + totalCount: Int, + onAddClick: () -> Unit +) { Row( modifier = Modifier .fillMaxWidth() @@ -90,10 +109,19 @@ private fun DataSourceHeader(onAddClick: () -> Unit) { horizontalArrangement = Arrangement.SpaceBetween, verticalAlignment = Alignment.CenterVertically ) { - Text( - text = "Data Sources", - style = MaterialTheme.typography.titleMedium - ) + Column { + Text( + text = "Data Sources", + style = MaterialTheme.typography.titleMedium + ) + if (totalCount > 0) { + Text( + text = "$selectedCount of $totalCount selected", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } IconButton( onClick = onAddClick, modifier = Modifier.size(32.dp) @@ -147,28 +175,40 @@ private fun DataSourceItem( dataSource: DataSourceConfig, isSelected: Boolean, connectionStatus: ConnectionStatus, - onClick: () -> Unit, + onToggle: () -> Unit, onEditClick: () -> Unit, - onDeleteClick: () -> Unit + onDeleteClick: () -> Unit, + onConnectClick: () -> Unit, + onDisconnectClick: () -> Unit ) { var showMenu by remember { mutableStateOf(false) } + val isConnected = connectionStatus is ConnectionStatus.Connected + val isConnecting = connectionStatus is ConnectionStatus.Connecting Surface( modifier = Modifier .fillMaxWidth() - .padding(horizontal = 8.dp, vertical = 2.dp) - .clickable(onClick = onClick), + .padding(horizontal = 8.dp, vertical = 2.dp), shape = RoundedCornerShape(8.dp), color = if (isSelected) { - MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f) + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) } else { MaterialTheme.colorScheme.surface } ) { Row( - modifier = Modifier.padding(12.dp), + modifier = Modifier.padding(horizontal = 8.dp, vertical = 8.dp), verticalAlignment = Alignment.CenterVertically ) { + // Checkbox for selection + Checkbox( + checked = isSelected, + onCheckedChange = { onToggle() }, + modifier = Modifier.size(24.dp) + ) + + Spacer(modifier = Modifier.width(8.dp)) + // Status indicator Box( modifier = Modifier @@ -184,7 +224,7 @@ private fun DataSourceItem( ) ) - Spacer(modifier = Modifier.width(12.dp)) + Spacer(modifier = Modifier.width(8.dp)) Column(modifier = Modifier.weight(1f)) { Text( @@ -193,13 +233,78 @@ private fun DataSourceItem( maxLines = 1, overflow = TextOverflow.Ellipsis ) - Text( - text = dataSource.getDisplayUrl(), - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant, - maxLines = 1, - overflow = TextOverflow.Ellipsis - ) + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Text( + text = dataSource.getDisplayUrl(), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + modifier = Modifier.weight(1f, fill = false) + ) + // Connection status text + when (connectionStatus) { + is ConnectionStatus.Connected -> { + Text( + text = "Connected", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.primary + ) + } + is ConnectionStatus.Connecting -> { + Text( + text = "Connecting...", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.tertiary + ) + } + is ConnectionStatus.Error -> { + Text( + text = "Error", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.error + ) + } + else -> {} + } + } + } + + // Quick connect/disconnect button + if (isSelected) { + if (isConnecting) { + CircularProgressIndicator( + modifier = Modifier.size(20.dp), + strokeWidth = 2.dp + ) + } else if (isConnected) { + IconButton( + onClick = onDisconnectClick, + modifier = Modifier.size(28.dp) + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = "Disconnect", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.error + ) + } + } else { + IconButton( + onClick = onConnectClick, + modifier = Modifier.size(28.dp) + ) { + Icon( + AutoDevComposeIcons.Cloud, + contentDescription = "Connect", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + } + } } Box { @@ -249,64 +354,54 @@ private fun DataSourceItem( } } +/** + * Connection controls for multi-datasource mode + */ @Composable -private fun ConnectionControls( - connectionStatus: ConnectionStatus, - onConnect: () -> Unit, - onDisconnect: () -> Unit +private fun MultiConnectionControls( + selectedCount: Int, + connectedCount: Int, + onConnectAll: () -> Unit, + onDisconnectAll: () -> Unit ) { - Row( + Column( modifier = Modifier .fillMaxWidth() .padding(12.dp), - horizontalArrangement = Arrangement.spacedBy(8.dp), - verticalAlignment = Alignment.CenterVertically + verticalArrangement = Arrangement.spacedBy(8.dp) ) { - when (connectionStatus) { - is ConnectionStatus.Connected -> { - Button( - onClick = onDisconnect, - colors = ButtonDefaults.buttonColors( - containerColor = MaterialTheme.colorScheme.error - ), - modifier = Modifier.fillMaxWidth() - ) { - Text("Disconnect") - } - } - is ConnectionStatus.Connecting -> { - Button( - onClick = {}, - enabled = false, - modifier = Modifier.fillMaxWidth() - ) { - CircularProgressIndicator( - modifier = Modifier.size(16.dp), - strokeWidth = 2.dp, - color = MaterialTheme.colorScheme.onPrimary - ) - Spacer(modifier = Modifier.width(8.dp)) - Text("Connecting...") - } - } - else -> { - Button( - onClick = onConnect, - modifier = Modifier.fillMaxWidth() - ) { - Text("Connect") - } - } - } - } - - if (connectionStatus is ConnectionStatus.Error) { + // Status text Text( - text = connectionStatus.message, + text = "$connectedCount of $selectedCount connected", style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.error, - modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp) + color = MaterialTheme.colorScheme.onSurfaceVariant ) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Connect All button + Button( + onClick = onConnectAll, + enabled = connectedCount < selectedCount, + modifier = Modifier.weight(1f) + ) { + Text("Connect All") + } + + // Disconnect All button + OutlinedButton( + onClick = onDisconnectAll, + enabled = connectedCount > 0, + modifier = Modifier.weight(1f), + colors = ButtonDefaults.outlinedButtonColors( + contentColor = MaterialTheme.colorScheme.error + ) + ) { + Text("Disconnect All") + } + } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt index d06e3249c3..38dce1480e 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt @@ -138,12 +138,27 @@ data class QueryResultDisplay( val executionTimeMs: Long = 0 ) +/** + * Connection status for a specific data source (used in multi-datasource mode) + */ +data class DataSourceConnectionState( + val dataSourceId: String, + val status: ConnectionStatus = ConnectionStatus.Disconnected +) + /** * UI state for ChatDB page + * + * Supports multi-datasource selection: all configured data sources are selected by default, + * users can toggle individual data sources on/off. */ data class ChatDBState( val dataSources: List = emptyList(), - val selectedDataSourceId: String? = null, + /** Set of selected data source IDs (multi-selection mode) */ + val selectedDataSourceIds: Set = emptySet(), + /** Connection status per data source */ + val connectionStatuses: Map = emptyMap(), + /** Overall connection status (for backward compatibility) */ val connectionStatus: ConnectionStatus = ConnectionStatus.Disconnected, val filterQuery: String = "", val isLoading: Boolean = false, @@ -155,8 +170,28 @@ data class ChatDBState( /** The data source being configured in the pane */ val configuringDataSource: DataSourceConfig? = null ) { - val selectedDataSource: DataSourceConfig? - get() = dataSources.find { it.id == selectedDataSourceId } + /** Get all selected data sources */ + val selectedDataSources: List + get() = dataSources.filter { it.id in selectedDataSourceIds } + + /** Check if a data source is selected */ + fun isSelected(id: String): Boolean = id in selectedDataSourceIds + + /** Get connection status for a specific data source */ + fun getConnectionStatus(id: String): ConnectionStatus = + connectionStatuses[id] ?: ConnectionStatus.Disconnected + + /** Check if any data source is connected */ + val hasAnyConnection: Boolean + get() = connectionStatuses.values.any { it is ConnectionStatus.Connected } + + /** Get count of connected data sources */ + val connectedCount: Int + get() = connectionStatuses.values.count { it is ConnectionStatus.Connected } + + /** Get count of selected data sources */ + val selectedCount: Int + get() = selectedDataSourceIds.size val filteredDataSources: List get() = if (filterQuery.isBlank()) { From 97561fed51a7e91acb29fff0358b8bcf56ddd7aa Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 17:19:50 +0800 Subject: [PATCH 29/34] feat(chatdb): add SQL validation, revision, and visualization #508 Enhance MultiDatabaseChatDBExecutor with SQL validation, automatic revision on errors, execution retries, and optional PlotDSL visualization generation for query results. --- .../chatdb/MultiDatabaseChatDBExecutor.kt | 518 ++++++++++++++++-- 1 file changed, 473 insertions(+), 45 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt index 0200813425..7a9eef71d8 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -10,17 +10,28 @@ import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.CodingAgentRenderer import cc.unitmesh.agent.subagent.SqlValidator import cc.unitmesh.agent.subagent.SqlReviseAgent +import cc.unitmesh.agent.subagent.SqlRevisionInput import cc.unitmesh.devins.parser.CodeFence import cc.unitmesh.llm.KoogLLMService /** * Multi-Database ChatDB Executor - Executes Text2SQL across multiple databases - * + * * Key differences from single-database executor: * 1. Merges schemas from all databases with database name prefixes * 2. Parses SQL comments to determine target database * 3. Can execute queries on multiple databases in parallel * 4. Combines results from multiple databases + * + * Flow: + * 1. FETCH_SCHEMA - Fetch schemas from all databases + * 2. SCHEMA_LINKING - Use LLM/Keyword linker to find relevant tables + * 3. GENERATE_SQL - Generate SQL with database routing comments + * 4. VALIDATE_SQL - Validate SQL syntax and table names + * 5. REVISE_SQL - Fix SQL if validation fails + * 6. EXECUTE_SQL - Execute on target databases with retry + * 7. GENERATE_VISUALIZATION - Optional visualization + * 8. FINAL_RESULT - Return combined results */ class MultiDatabaseChatDBExecutor( projectPath: String, @@ -40,6 +51,7 @@ class MultiDatabaseChatDBExecutor( enableLLMStreaming = enableLLMStreaming ) { private val logger = getLogger("MultiDatabaseChatDBExecutor") + private val keywordSchemaLinker = KeywordSchemaLinker() private val sqlValidator = SqlValidator() private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) private val maxRevisionAttempts = 3 @@ -55,12 +67,13 @@ class MultiDatabaseChatDBExecutor( ): MultiDatabaseChatDBResult { currentIteration = 0 conversationManager = ConversationManager(llmService, systemPrompt) - + val errors = mutableListOf() var generatedSql: String? = null val queryResults = mutableMapOf() var revisionAttempts = 0 val targetDatabases = mutableListOf() + var plotDslCode: String? = null try { // Step 1: Fetch and merge schemas from all databases @@ -87,27 +100,77 @@ class MultiDatabaseChatDBExecutor( "tables" to schema.tables.map { it.name } ) }, - "totalTables" to merged.totalTableCount + "totalTables" to merged.totalTableCount, + "tableSchemas" to merged.databases.flatMap { (dbName, schema) -> + schema.tables.map { table -> + mapOf( + "name" to "[$dbName].${table.name}", + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: "") + ) + } + ) + } + } ) ) - // Step 2: Schema Linking with merged schema + // Step 2: Schema Linking - Find relevant tables using keyword linker renderer.renderChatDBStep( stepType = ChatDBStepType.SCHEMA_LINKING, status = ChatDBStepStatus.IN_PROGRESS, - title = "Analyzing query across ${merged.databases.size} databases..." + title = "Performing schema linking across ${merged.databases.size} databases..." ) - onProgress("๐Ÿ”— Analyzing query across databases...") + onProgress("๐Ÿ”— Performing schema linking...") + + // Perform schema linking for each database + val linkingResults = mutableMapOf() + val allRelevantTables = mutableListOf>() + val allKeywords = mutableSetOf() + + for ((dbId, schema) in merged.databases) { + val linkingResult = keywordSchemaLinker.link(task.query, schema) + linkingResults[dbId] = linkingResult + allKeywords.addAll(linkingResult.keywords) + + // Collect relevant table schemas for UI + linkingResult.relevantTables.forEach { tableName -> + schema.getTable(tableName)?.let { table -> + allRelevantTables.add(mapOf( + "name" to "[$dbId].${table.name}", + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey + ) + } + )) + } + } + } + + val schemaContext = buildMultiDatabaseSchemaContext(merged, task.query, linkingResults) - val schemaContext = buildMultiDatabaseSchemaContext(merged, task.query) - renderer.renderChatDBStep( stepType = ChatDBStepType.SCHEMA_LINKING, status = ChatDBStepStatus.SUCCESS, - title = "Schema analysis complete", + title = "Schema linking complete - found ${allRelevantTables.size} relevant tables", details = mapOf( "databasesAnalyzed" to merged.databases.keys.toList(), - "schemaContext" to schemaContext.take(500) + "..." + "keywords" to allKeywords.toList(), + "relevantTableSchemas" to allRelevantTables, + "schemaContext" to schemaContext.take(500) + if (schemaContext.length > 500) "..." else "" ) ) @@ -115,9 +178,9 @@ class MultiDatabaseChatDBExecutor( renderer.renderChatDBStep( stepType = ChatDBStepType.GENERATE_SQL, status = ChatDBStepStatus.IN_PROGRESS, - title = "Generating SQL..." + title = "Generating SQL query..." ) - onProgress("๐Ÿค– Generating SQL...") + onProgress("๐Ÿค– Generating SQL query...") val sqlPrompt = buildMultiDatabaseSqlPrompt(task.query, schemaContext, task.maxRows) val sqlResponse = getLLMResponse(sqlPrompt, compileDevIns = false) { chunk -> @@ -125,13 +188,19 @@ class MultiDatabaseChatDBExecutor( } // Parse SQL blocks with database targets - val sqlBlocks = parseSqlBlocksWithTargets(sqlResponse) + var sqlBlocks = parseSqlBlocksWithTargets(sqlResponse) if (sqlBlocks.isEmpty()) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Failed to extract SQL", + error = "Could not find SQL code block in LLM response" + ) throw DatabaseException("No valid SQL generated") } - generatedSql = sqlBlocks.map { "${it.database}: ${it.sql}" }.joinToString("\n") + generatedSql = sqlBlocks.joinToString("\n\n") { "-- database: ${it.database}\n${it.sql}" } targetDatabases.addAll(sqlBlocks.map { it.database }.distinct()) renderer.renderChatDBStep( @@ -144,10 +213,110 @@ class MultiDatabaseChatDBExecutor( ) ) - // Step 4: Execute SQL on target databases + // Step 4: Validate SQL for each database + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating SQL..." + ) + onProgress("๐Ÿ” Validating SQL...") + + val validatedBlocks = mutableListOf() + var hasValidationErrors = false + + for (block in sqlBlocks) { + val dbSchema = merged.databases[block.database] + val allTableNames = dbSchema?.tables?.map { it.name }?.toSet() ?: emptySet() + + val syntaxValidation = sqlValidator.validate(block.sql) + val tableValidation = if (syntaxValidation.isValid) { + sqlValidator.validateWithTableWhitelist(block.sql, allTableNames) + } else { + syntaxValidation + } + + if (!tableValidation.isValid) { + hasValidationErrors = true + val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" + + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL validation failed for ${block.database}", + details = mapOf( + "database" to block.database, + "errorType" to errorType, + "errors" to tableValidation.errors + ), + error = tableValidation.errors.joinToString("; ") + ) + + // Step 5: Revise SQL + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL for ${block.database}..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(block.database, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = block.sql, + errorMessage = tableValidation.errors.joinToString("; "), + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + revisionAttempts += revisionResult.metadata["attempts"]?.toIntOrNull() ?: 0 + + if (revisionResult.success) { + validatedBlocks.add(SqlBlock(block.database, revisionResult.content)) + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully for ${block.database}", + details = mapOf( + "database" to block.database, + "attempts" to revisionAttempts, + "sql" to revisionResult.content + ) + ) + onProgress("โœ… SQL revised successfully") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed for ${block.database}", + error = revisionResult.content + ) + errors.add("[${block.database}] SQL revision failed: ${revisionResult.content}") + } + } else { + validatedBlocks.add(block) + } + } + + if (!hasValidationErrors) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL validation passed for all databases" + ) + } + + sqlBlocks = validatedBlocks + generatedSql = sqlBlocks.joinToString("\n\n") { "-- database: ${it.database}\n${it.sql}" } + + // Step 6: Execute SQL on target databases with retry for (sqlBlock in sqlBlocks) { val dbName = sqlBlock.database - val sql = sqlBlock.sql + var sql = sqlBlock.sql val connection = databaseConnections[dbName] if (connection == null) { @@ -155,44 +324,149 @@ class MultiDatabaseChatDBExecutor( continue } - renderer.renderChatDBStep( - stepType = ChatDBStepType.EXECUTE_SQL, - status = ChatDBStepStatus.IN_PROGRESS, - title = "Executing on $dbName...", - details = mapOf("database" to dbName, "sql" to sql) - ) - onProgress("โšก Executing SQL on $dbName...") - - try { - val result = connection.executeQuery(sql) - queryResults[dbName] = result + var executionRetries = 0 + var lastExecutionError: String? = null + var result: QueryResult? = null + while (executionRetries < maxExecutionRetries && result == null) { renderer.renderChatDBStep( stepType = ChatDBStepType.EXECUTE_SQL, - status = ChatDBStepStatus.SUCCESS, - title = "Query executed on $dbName", + status = ChatDBStepStatus.IN_PROGRESS, + title = "Executing on $dbName${if (executionRetries > 0) " (retry $executionRetries)" else ""}...", details = mapOf( "database" to dbName, "sql" to sql, - "rowCount" to result.rowCount, - "columns" to result.columns, - "previewRows" to result.rows.take(5) + "attempt" to (executionRetries + 1) ) ) - } catch (e: Exception) { - errors.add("[$dbName] ${e.message}") + onProgress("โšก Executing SQL on $dbName${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") + + try { + result = connection.executeQuery(sql) + queryResults[dbName] = result + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "Query executed on $dbName", + details = mapOf( + "database" to dbName, + "sql" to sql, + "rowCount" to result.rowCount, + "columns" to result.columns, + "previewRows" to result.rows.take(5) + ) + ) + onProgress("โœ… Query returned ${result.rowCount} rows from $dbName") + } catch (e: Exception) { + lastExecutionError = e.message ?: "Unknown execution error" + logger.warn { "Query execution failed on $dbName (attempt ${executionRetries + 1}): $lastExecutionError" } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Query execution failed on $dbName", + details = mapOf( + "database" to dbName, + "attempt" to (executionRetries + 1), + "maxAttempts" to maxExecutionRetries + ), + error = lastExecutionError + ) + + // Try to revise SQL based on execution error + if (executionRetries < maxExecutionRetries - 1) { + onProgress("๐Ÿ”„ Attempting to fix SQL based on execution error...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on execution error..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(dbName, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = sql, + errorMessage = "Execution error: $lastExecutionError", + schemaDescription = relevantSchema, + maxAttempts = 1 + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success && revisionResult.content != sql) { + sql = revisionResult.content + revisionAttempts++ + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised based on execution error", + details = mapOf("database" to dbName, "sql" to sql) + ) + onProgress("๐Ÿ”ง SQL revised, retrying execution...") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.WARNING, + title = "SQL revision did not help", + error = "Revision did not produce a different SQL" + ) + break + } + } + executionRetries++ + } + } + + if (result == null && lastExecutionError != null) { + errors.add("[$dbName] Query execution failed after $executionRetries retries: $lastExecutionError") + } + } + + // Step 7: Generate visualization if requested + val combinedResult = combineResults(queryResults) + if (task.generateVisualization && combinedResult.rowCount > 0) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating visualization..." + ) + onProgress("๐Ÿ“ˆ Generating visualization...") + + plotDslCode = generateVisualization(task.query, combinedResult, onProgress) + + if (plotDslCode != null) { renderer.renderChatDBStep( - stepType = ChatDBStepType.EXECUTE_SQL, - status = ChatDBStepStatus.ERROR, - title = "Query failed on $dbName", - details = mapOf("database" to dbName, "sql" to sql, "error" to (e.message ?: "Unknown error")) + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.SUCCESS, + title = "Visualization generated", + details = mapOf("code" to plotDslCode) + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.WARNING, + title = "Visualization not generated" ) } } - // Step 5: Final result + // Step 8: Final result val success = queryResults.isNotEmpty() - val combinedResult = combineResults(queryResults) + + val resultMessage = buildResultMessage( + success = success, + generatedSql = generatedSql, + queryResults = queryResults, + combinedResult = combinedResult, + revisionAttempts = revisionAttempts, + plotDslCode = plotDslCode, + errors = errors + ) renderer.renderChatDBStep( stepType = ChatDBStepType.FINAL_RESULT, @@ -203,17 +477,28 @@ class MultiDatabaseChatDBExecutor( "totalRows" to combinedResult.rowCount, "columns" to combinedResult.columns, "previewRows" to combinedResult.rows.take(10), + "revisionAttempts" to revisionAttempts, "errors" to errors ) ) + // Render final message + if (success) { + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(resultMessage) + renderer.renderLLMResponseEnd() + } else { + renderer.renderError(resultMessage) + } + return MultiDatabaseChatDBResult( success = success, - message = if (success) "Query executed successfully on ${queryResults.size} database(s)" else errors.joinToString("\n"), + message = resultMessage, generatedSql = generatedSql, queryResult = combinedResult, queryResultsByDatabase = queryResults, targetDatabases = targetDatabases, + plotDslCode = plotDslCode, revisionAttempts = revisionAttempts, errors = errors ) @@ -226,6 +511,7 @@ class MultiDatabaseChatDBExecutor( title = "Query failed", details = mapOf("error" to (e.message ?: "Unknown error")) ) + renderer.renderError("Query failed: ${e.message}") return MultiDatabaseChatDBResult( success = false, message = "Error: ${e.message}", @@ -254,19 +540,35 @@ class MultiDatabaseChatDBExecutor( } /** - * Build schema context for multi-database prompt + * Build schema context for multi-database prompt with schema linking results */ - private fun buildMultiDatabaseSchemaContext(merged: MergedDatabaseSchema, query: String): String { + private fun buildMultiDatabaseSchemaContext( + merged: MergedDatabaseSchema, + query: String, + linkingResults: Map = emptyMap() + ): String { val sb = StringBuilder() sb.append("=== AVAILABLE DATABASES AND TABLES ===\n\n") for ((dbId, schema) in merged.databases) { val displayName = databaseConfigs[dbId]?.databaseName ?: dbId + val linkingResult = linkingResults[dbId] + val relevantTables = linkingResult?.relevantTables?.toSet() ?: schema.tables.map { it.name }.toSet() + sb.append("DATABASE: $dbId ($displayName)\n") sb.append("-".repeat(40)).append("\n") - for (table in schema.tables) { - sb.append(" Table: ${table.name}\n") + // Show relevant tables first (if schema linking was performed) + val sortedTables = if (linkingResult != null) { + schema.tables.sortedByDescending { it.name in relevantTables } + } else { + schema.tables + } + + for (table in sortedTables) { + val isRelevant = table.name in relevantTables + val marker = if (isRelevant && linkingResult != null) " [RELEVANT]" else "" + sb.append(" Table: ${table.name}$marker\n") if (table.comment != null) { sb.append(" Comment: ${table.comment}\n") } @@ -287,6 +589,132 @@ class MultiDatabaseChatDBExecutor( return sb.toString() } + /** + * Build schema description for a specific database (for SQL revision) + */ + private fun buildSchemaDescriptionForDatabase(dbId: String, merged: MergedDatabaseSchema): String { + val schema = merged.databases[dbId] ?: return "" + val displayName = databaseConfigs[dbId]?.databaseName ?: dbId + + return buildString { + appendLine("## Database Schema: $displayName (USE ONLY THESE TABLES)") + appendLine() + for (table in schema.tables) { + appendLine("Table: ${table.name}") + appendLine("Columns: ${table.columns.joinToString(", ") { "${it.name} (${it.type})" }}") + appendLine() + } + } + } + + /** + * Build result message for final output + */ + private fun buildResultMessage( + success: Boolean, + generatedSql: String?, + queryResults: Map, + combinedResult: QueryResult, + revisionAttempts: Int, + plotDslCode: String?, + errors: List + ): String { + return if (success) { + buildString { + appendLine("## โœ… Query Executed Successfully") + appendLine() + + if (queryResults.size > 1) { + appendLine("**Databases Queried:** ${queryResults.keys.joinToString(", ")}") + appendLine() + } + + if (generatedSql != null) { + appendLine("**Executed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + appendLine() + } + + if (revisionAttempts > 0) { + appendLine("*Note: SQL was revised $revisionAttempts time(s) to fix validation/execution errors*") + appendLine() + } + + appendLine("**Results** (${combinedResult.rowCount} row${if (combinedResult.rowCount != 1) "s" else ""}):") + appendLine() + appendLine(combinedResult.toTableString()) + + if (plotDslCode != null) { + appendLine() + appendLine("**Visualization:**") + appendLine("```plotdsl") + appendLine(plotDslCode) + appendLine("```") + } + } + } else { + buildString { + appendLine("## โŒ Query Failed") + appendLine() + appendLine("**Errors:**") + errors.forEach { error -> + appendLine("- $error") + } + if (generatedSql != null) { + appendLine() + appendLine("**Failed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + } + } + } + } + + /** + * Generate visualization for query results + */ + private suspend fun generateVisualization( + query: String, + result: QueryResult, + onProgress: (String) -> Unit + ): String? { + val visualizationPrompt = buildString { + appendLine("Based on the following query result, generate a PlotDSL visualization:") + appendLine() + appendLine("**Original Query**: $query") + appendLine() + appendLine("**Query Result** (${result.rowCount} rows):") + appendLine("```csv") + appendLine(result.toCsvString()) + appendLine("```") + appendLine() + appendLine("Generate a PlotDSL chart that best visualizes this data.") + appendLine("Choose an appropriate chart type (bar, line, scatter, etc.) based on the data.") + appendLine("Wrap the PlotDSL code in a ```plotdsl code block.") + } + + try { + val response = getLLMResponse(visualizationPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + + val codeFence = CodeFence.parse(response) + if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } + + val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = plotPattern.find(response) + return match?.groupValues?.get(1)?.trim() + } catch (e: Exception) { + logger.error(e) { "Visualization generation failed" } + return null + } + } + /** * Build SQL generation prompt for multi-database context */ From 97c1793328cd9a1fcc73dadf2136188f63d0688b Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 19:03:43 +0800 Subject: [PATCH 30/34] feat(chatdb): add PlotDSLAgent for query visualizations #508 Integrate PlotDSLAgent to generate chart visualizations for SQL query results. Improve table string formatting with row limits and truncation. --- .../chatdb/MultiDatabaseChatDBExecutor.kt | 49 +++++++++++------- .../agent/database/DatabaseConnection.kt | 50 ++++++++++++++----- .../compose/agent/chatdb/ChatDBViewModel.kt | 2 +- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt index 7a9eef71d8..333ebb9eaf 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -8,6 +8,8 @@ import cc.unitmesh.agent.orchestrator.ToolOrchestrator import cc.unitmesh.agent.render.ChatDBStepStatus import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.PlotDSLAgent +import cc.unitmesh.agent.subagent.PlotDSLContext import cc.unitmesh.agent.subagent.SqlValidator import cc.unitmesh.agent.subagent.SqlReviseAgent import cc.unitmesh.agent.subagent.SqlRevisionInput @@ -54,6 +56,7 @@ class MultiDatabaseChatDBExecutor( private val keywordSchemaLinker = KeywordSchemaLinker() private val sqlValidator = SqlValidator() private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) + private val plotDSLAgent = PlotDSLAgent(llmService) private val maxRevisionAttempts = 3 private val maxExecutionRetries = 3 @@ -674,41 +677,53 @@ class MultiDatabaseChatDBExecutor( } /** - * Generate visualization for query results + * Generate visualization for query results using PlotDSLAgent */ private suspend fun generateVisualization( query: String, result: QueryResult, onProgress: (String) -> Unit ): String? { - val visualizationPrompt = buildString { - appendLine("Based on the following query result, generate a PlotDSL visualization:") + // Check if PlotDSLAgent is available on this platform + if (!plotDSLAgent.isAvailable) { + logger.info { "PlotDSLAgent not available on this platform, skipping visualization" } + return null + } + + // Build description for PlotDSLAgent + val description = buildString { + appendLine("Create a visualization for the following database query result:") appendLine() appendLine("**Original Query**: $query") appendLine() - appendLine("**Query Result** (${result.rowCount} rows):") + appendLine("**Data** (${result.rowCount} rows, columns: ${result.columns.joinToString(", ")}):") appendLine("```csv") appendLine(result.toCsvString()) appendLine("```") appendLine() - appendLine("Generate a PlotDSL chart that best visualizes this data.") - appendLine("Choose an appropriate chart type (bar, line, scatter, etc.) based on the data.") - appendLine("Wrap the PlotDSL code in a ```plotdsl code block.") + appendLine("Choose the most appropriate chart type based on the data structure.") } try { - val response = getLLMResponse(visualizationPrompt, compileDevIns = false) { chunk -> - onProgress(chunk) - } + val plotContext = PlotDSLContext(description = description) + val agentResult = plotDSLAgent.execute(plotContext, onProgress) + + if (agentResult.success) { + // Extract PlotDSL code from the result + val content = agentResult.content + val codeFence = CodeFence.parse(content) + if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } - val codeFence = CodeFence.parse(response) - if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { - return codeFence.text.trim() + // Try to find plotdsl block manually + val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = plotPattern.find(content) + return match?.groupValues?.get(1)?.trim() + } else { + logger.warn { "PlotDSLAgent failed: ${agentResult.content}" } + return null } - - val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) - val match = plotPattern.find(response) - return match?.groupValues?.get(1)?.trim() } catch (e: Exception) { logger.error(e) { "Visualization generation failed" } return null diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index 3ada6fefa4..bb26f15142 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -133,19 +133,31 @@ data class QueryResult( /** * Convert result to formatted table string (for user display) */ - fun toTableString(): String { + /** + * Convert result to formatted table string (for user display) + * @param maxDisplayRows Maximum number of rows to display (default 10) + */ + fun toTableString(maxDisplayRows: Int = 10): String { if (isEmpty()) return "No results" - - // Calculate column widths + + val displayRows = rows.take(maxDisplayRows) + val hasMoreRows = rows.size > maxDisplayRows + val moreRowsText = if (hasMoreRows) "... (${rows.size - maxDisplayRows} more rows)" else null + + // Calculate column widths (consider all rows for accurate width, but limit for performance) val colWidths = columns.indices.map { colIdx -> - maxOf( - columns[colIdx].length, - rows.maxOfOrNull { it[colIdx].length } ?: 4 - ) + val headerWidth = columns[colIdx].length + val dataWidth = rows.take(100).maxOfOrNull { + if (colIdx < it.size) it[colIdx].length else 0 + } ?: 4 + val moreRowsWidth = if (colIdx == 0 && moreRowsText != null) moreRowsText.length else 0 + maxOf(headerWidth, dataWidth, moreRowsWidth) } + val totalWidth = colWidths.sumOf { it + 3 } + 1 // +3 for " โ”‚ " per column, +1 for final โ”‚ + val sb = StringBuilder() - + // Header sb.append("โ”Œ") colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ฌ") } @@ -165,18 +177,30 @@ data class QueryResult( sb.setLength(sb.length - 1) sb.append("โ”ค\n") - // Data rows (show first 10) - rows.take(10).forEach { row -> + // Data rows + displayRows.forEach { row -> sb.append("โ”‚") row.forEachIndexed { idx, value -> val str = value.ifEmpty { "NULL" } - sb.append(" ").append(str.padEnd(colWidths[idx])).append(" โ”‚") + // Truncate long values to fit column width + val displayStr = if (str.length > colWidths[idx]) { + str.take(colWidths[idx] - 3) + "..." + } else { + str + } + sb.append(" ").append(displayStr.padEnd(colWidths[idx])).append(" โ”‚") } sb.append("\n") } - if (rows.size > 10) { - sb.append("โ”‚ ... (${rows.size - 10} more rows)\n") + // More rows indicator (as a proper table row) + if (hasMoreRows && moreRowsText != null) { + sb.append("โ”‚") + sb.append(" ").append(moreRowsText.padEnd(colWidths[0])).append(" โ”‚") + for (idx in 1 until colWidths.size) { + sb.append(" ".repeat(colWidths[idx] + 2)).append("โ”‚") + } + sb.append("\n") } // Footer diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt index 809316bf5f..248303275c 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -450,7 +450,7 @@ class ChatDBViewModel( // It will merge schemas, let LLM decide which database(s) to query, // and execute SQL on the appropriate database(s) agent.execute(task) { progress -> - println("[ChatDB] Progress: $progress") +// println("[ChatDB] Progress: $progress") } } finally { agent.close() From 7e624484d3afed84956b24e472fcff496683b41d Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 19:16:45 +0800 Subject: [PATCH 31/34] feat(database): output query results as Markdown tables #508 Change toTableString to generate Markdown table format for query results, replacing the previous ASCII table style. Adds escaping for special Markdown characters and updates tests accordingly. --- .../agent/database/DatabaseConnection.kt | 94 ++++++------------- .../agent/database/DatabaseConnectionTest.kt | 18 ++-- 2 files changed, 38 insertions(+), 74 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index bb26f15142..ad8efab577 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -131,86 +131,46 @@ data class QueryResult( } /** - * Convert result to formatted table string (for user display) + * Convert result to Markdown table format (for rendering with MarkdownTableRenderer) + * Shows all rows without truncation */ - /** - * Convert result to formatted table string (for user display) - * @param maxDisplayRows Maximum number of rows to display (default 10) - */ - fun toTableString(maxDisplayRows: Int = 10): String { + fun toTableString(): String { if (isEmpty()) return "No results" - val displayRows = rows.take(maxDisplayRows) - val hasMoreRows = rows.size > maxDisplayRows - val moreRowsText = if (hasMoreRows) "... (${rows.size - maxDisplayRows} more rows)" else null - - // Calculate column widths (consider all rows for accurate width, but limit for performance) - val colWidths = columns.indices.map { colIdx -> - val headerWidth = columns[colIdx].length - val dataWidth = rows.take(100).maxOfOrNull { - if (colIdx < it.size) it[colIdx].length else 0 - } ?: 4 - val moreRowsWidth = if (colIdx == 0 && moreRowsText != null) moreRowsText.length else 0 - maxOf(headerWidth, dataWidth, moreRowsWidth) - } - - val totalWidth = colWidths.sumOf { it + 3 } + 1 // +3 for " โ”‚ " per column, +1 for final โ”‚ - val sb = StringBuilder() - // Header - sb.append("โ”Œ") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ฌ") } - sb.setLength(sb.length - 1) - sb.append("โ”\n") - - // Column names - sb.append("โ”‚") - columns.forEachIndexed { idx, col -> - sb.append(" ").append(col.padEnd(colWidths[idx])).append(" โ”‚") - } - sb.append("\n") + // Header row + sb.append("| ") + sb.append(columns.joinToString(" | ") { escapeMarkdown(it) }) + sb.append(" |\n") - // Separator - sb.append("โ”œ") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ผ") } - sb.setLength(sb.length - 1) - sb.append("โ”ค\n") + // Separator row + sb.append("| ") + sb.append(columns.joinToString(" | ") { "---" }) + sb.append(" |\n") // Data rows - displayRows.forEach { row -> - sb.append("โ”‚") - row.forEachIndexed { idx, value -> + rows.forEach { row -> + sb.append("| ") + sb.append(row.mapIndexed { idx, value -> val str = value.ifEmpty { "NULL" } - // Truncate long values to fit column width - val displayStr = if (str.length > colWidths[idx]) { - str.take(colWidths[idx] - 3) + "..." - } else { - str - } - sb.append(" ").append(displayStr.padEnd(colWidths[idx])).append(" โ”‚") - } - sb.append("\n") + escapeMarkdown(str) + }.joinToString(" | ")) + sb.append(" |\n") } - // More rows indicator (as a proper table row) - if (hasMoreRows && moreRowsText != null) { - sb.append("โ”‚") - sb.append(" ").append(moreRowsText.padEnd(colWidths[0])).append(" โ”‚") - for (idx in 1 until colWidths.size) { - sb.append(" ".repeat(colWidths[idx] + 2)).append("โ”‚") - } - sb.append("\n") - } - - // Footer - sb.append("โ””") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ด") } - sb.setLength(sb.length - 1) - sb.append("โ”˜\n") - return sb.toString() } + + /** + * Escape special Markdown characters in table cell content + */ + private fun escapeMarkdown(text: String): String { + return text + .replace("|", "\\|") + .replace("\n", " ") + .replace("\r", "") + } } /** diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt index fb21a5856c..c9e53df286 100644 --- a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt @@ -18,21 +18,25 @@ class DatabaseConnectionTest { ), rowCount = 3 ) - + assertEquals(3, result.rowCount) assertFalse(result.isEmpty()) assertEquals(3, result.columns.size) assertEquals(3, result.rows.size) - + val csv = result.toCsvString() assertTrue(csv.contains("id,name,email")) assertTrue(csv.contains("Alice")) - + + // Test Markdown table format val table = result.toTableString() - assertTrue(table.contains("id")) - assertTrue(table.contains("Alice")) - - println("Query result table:") + assertTrue(table.contains("| id | name | email |")) + assertTrue(table.contains("| --- | --- | --- |")) + assertTrue(table.contains("| 1 | Alice | alice@example.com |")) + assertTrue(table.contains("| 2 | Bob | bob@example.com |")) + assertTrue(table.contains("| 3 | Charlie | charlie@example.com |")) + + println("Query result table (Markdown format):") println(result.toTableString()) } From 017ddfeae3a3dead362a3f554fceb7bb001555ed Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 19:58:04 +0800 Subject: [PATCH 32/34] feat(chatdb): refactor multi-database agent and renderer #508 Refactored MultiDatabaseChatDBAgent and related renderer logic to improve database handling and UI integration. Updated SQL validation and connection modules for better cross-platform support. --- .../agent/chatdb/MultiDatabaseChatDBAgent.kt | 32 +++- .../chatdb/MultiDatabaseChatDBExecutor.kt | 153 ++++++++++++++++++ .../agent/database/DatabaseConnection.kt | 53 +++++- .../agent/render/CodingAgentRenderer.kt | 25 +++ .../render/DefaultCodingAgentRenderer.kt | 25 +++ .../unitmesh/agent/render/RendererModels.kt | 12 +- .../unitmesh/agent/subagent/SqlReviseAgent.kt | 55 +++++++ .../unitmesh/agent/subagent/SqlValidator.kt | 11 ++ .../agent/subagent/SqlValidator.ios.kt | 30 +++- .../cc/unitmesh/agent/RendererExports.kt | 14 ++ .../agent/subagent/SqlValidator.js.kt | 30 +++- .../database/ExposedDatabaseConnection.kt | 27 ++++ .../agent/subagent/SqlValidator.jvm.kt | 53 +++++- .../agent/subagent/SqlValidator.wasmJs.kt | 30 +++- .../connection/IdeaDatabaseConnection.kt | 25 +++ .../devins/idea/renderer/JewelRenderer.kt | 70 ++++++++ .../server/render/ServerSideRenderer.kt | 14 ++ .../ui/compose/agent/ComposeRenderer.kt | 68 ++++++++ .../agent/chatdb/components/ChatDBStepCard.kt | 87 ++++++++++ 19 files changed, 790 insertions(+), 24 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt index 76c4abe7f2..36adfff158 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt @@ -165,8 +165,22 @@ CRITICAL RULES: 2. ONLY use column names provided in the schema - NEVER invent or guess column names 3. When generating SQL, use the table name WITHOUT the database prefix (the system will route to the correct database) 4. If the user's question relates to tables in multiple databases, generate SEPARATE SQL queries for each database -5. Only generate SELECT queries (read-only operations) -6. Always add LIMIT clause to prevent large result sets +5. Always add LIMIT clause for SELECT queries to prevent large result sets + +SUPPORTED OPERATIONS: +- SELECT: Read data (no approval required) +- INSERT: Add new records (requires user approval) +- UPDATE: Modify existing records (requires user approval) +- DELETE: Remove records (requires user approval, HIGH RISK) +- CREATE TABLE: Create new tables (requires user approval) +- ALTER TABLE: Modify table structure (requires user approval, HIGH RISK) +- DROP TABLE: Delete tables (requires user approval, HIGH RISK) +- TRUNCATE: Remove all records (requires user approval, HIGH RISK) + +โš ๏ธ WRITE OPERATIONS WARNING: +- All write operations (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, TRUNCATE) require explicit user approval before execution +- HIGH RISK operations (DELETE, DROP, TRUNCATE, ALTER) will be highlighted with additional warnings +- Always confirm with the user before generating destructive SQL OUTPUT FORMAT: - For single database query: @@ -185,6 +199,20 @@ SELECT * FROM users LIMIT 100; SELECT * FROM customers LIMIT 100; ``` +- For write operations: +```sql +-- database: +INSERT INTO users (name, email) VALUES ('John', 'john@example.com'); +``` + +```sql +-- database: +CREATE TABLE new_table ( + id INT PRIMARY KEY, + name VARCHAR(100) +); +``` + The "-- database: " comment is REQUIRED to specify which database to execute the query on.""" } } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt index 333ebb9eaf..4e61b702a8 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -10,11 +10,14 @@ import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.CodingAgentRenderer import cc.unitmesh.agent.subagent.PlotDSLAgent import cc.unitmesh.agent.subagent.PlotDSLContext +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.subagent.SqlValidator import cc.unitmesh.agent.subagent.SqlReviseAgent import cc.unitmesh.agent.subagent.SqlRevisionInput import cc.unitmesh.devins.parser.CodeFence import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlin.coroutines.resume /** * Multi-Database ChatDB Executor - Executes Text2SQL across multiple databases @@ -327,6 +330,97 @@ class MultiDatabaseChatDBExecutor( continue } + // Detect SQL operation type + val operationType = sqlValidator.detectSqlType(sql) + val isWriteOperation = operationType.requiresApproval() + val isHighRisk = operationType.isHighRisk() + + // If write operation, request approval + if (isWriteOperation) { + val affectedTables = extractTablesFromSql(sql, merged.databases[dbName]) + + val approved = requestSqlApproval( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + onProgress = onProgress + ) + + if (!approved) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.REJECTED, + title = "Write operation rejected by user", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + errors.add("[$dbName] Write operation rejected by user: ${operationType.name}") + continue + } + + // Execute write operation + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.APPROVED, + title = "Write operation approved, executing...", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + onProgress("โœ… Write operation approved, executing...") + + try { + val updateResult = connection.executeUpdate(sql) + + if (updateResult.success) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.SUCCESS, + title = "Write operation completed on $dbName", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "affectedRows" to updateResult.affectedRows, + "message" to (updateResult.message ?: "") + ) + ) + onProgress("โœ… ${operationType.name} completed: ${updateResult.affectedRows} row(s) affected") + + // Create a synthetic QueryResult for write operations + queryResults[dbName] = QueryResult( + columns = listOf("Operation", "Affected Rows", "Status"), + rows = listOf(listOf(operationType.name, updateResult.affectedRows.toString(), "Success")), + rowCount = 1 + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.ERROR, + title = "Write operation failed on $dbName", + error = updateResult.message ?: "Unknown error" + ) + errors.add("[$dbName] Write operation failed: ${updateResult.message}") + } + } catch (e: Exception) { + val errorMsg = e.message ?: "Unknown execution error" + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.ERROR, + title = "Write operation failed on $dbName", + error = errorMsg + ) + errors.add("[$dbName] Write operation failed: $errorMsg") + } + continue + } + + // Regular SELECT query execution with retry var executionRetries = 0 var lastExecutionError: String? = null var result: QueryResult? = null @@ -819,6 +913,65 @@ Generate the SQL: return QueryResult(allColumns, allRows, allRows.size) } + + /** + * Request user approval for SQL write operation + * Returns true if approved, false if rejected + */ + private suspend fun requestSqlApproval( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onProgress: (String) -> Unit + ): Boolean { + val riskLevel = if (isHighRisk) "โš ๏ธ HIGH RISK" else "โšก Write Operation" + onProgress("$riskLevel: ${operationType.name} requires approval") + + return suspendCancellableCoroutine { continuation -> + renderer.renderSqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + onApprove = { + if (continuation.isActive) { + continuation.resume(true) + } + }, + onReject = { + if (continuation.isActive) { + continuation.resume(false) + } + } + ) + } + } + + /** + * Extract table names from SQL statement + */ + private fun extractTablesFromSql(sql: String, schema: DatabaseSchema?): List { + if (schema == null) return emptyList() + + val tableNames = schema.tables.map { it.name.lowercase() }.toSet() + val sqlLower = sql.lowercase() + + return tableNames.filter { tableName -> + // Check for common SQL patterns that reference tables + sqlLower.contains(" $tableName ") || + sqlLower.contains(" $tableName;") || + sqlLower.contains(" $tableName\n") || + sqlLower.contains("from $tableName") || + sqlLower.contains("into $tableName") || + sqlLower.contains("update $tableName") || + sqlLower.contains("table $tableName") || + sqlLower.contains("join $tableName") + }.map { tableName -> + // Return original case from schema + schema.tables.find { it.name.lowercase() == tableName }?.name ?: tableName + } + } } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index ad8efab577..e8eae3f8b4 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -10,16 +10,25 @@ import kotlinx.serialization.Serializable interface DatabaseConnection { /** * Execute a SQL query - * + * * @param sql SQL query statement (SELECT only) * @return Query result * @throws DatabaseException If execution fails */ suspend fun executeQuery(sql: String): QueryResult + /** + * Execute a SQL update statement (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, etc.) + * + * @param sql SQL update statement + * @return UpdateResult containing affected row count and any generated keys + * @throws DatabaseException If execution fails + */ + suspend fun executeUpdate(sql: String): UpdateResult + /** * Get database schema information - * + * * @return DatabaseSchema containing all tables and columns * @throws DatabaseException If retrieval fails */ @@ -102,6 +111,46 @@ interface DatabaseConnection { } } +/** + * Database update result (for INSERT, UPDATE, DELETE, DDL statements) + */ +@Serializable +data class UpdateResult( + /** + * Number of rows affected by the update + */ + val affectedRows: Int, + + /** + * Generated keys (for INSERT with auto-increment) + */ + val generatedKeys: List = emptyList(), + + /** + * Whether the update was successful + */ + val success: Boolean = true, + + /** + * Optional message (e.g., for DDL statements) + */ + val message: String? = null +) { + companion object { + fun success(affectedRows: Int, generatedKeys: List = emptyList()): UpdateResult { + return UpdateResult(affectedRows, generatedKeys, true) + } + + fun ddlSuccess(message: String = "DDL statement executed successfully"): UpdateResult { + return UpdateResult(0, emptyList(), true, message) + } + + fun failure(message: String): UpdateResult { + return UpdateResult(0, emptyList(), false, message) + } + } +} + /** * Database query result */ diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt index c539f64588..0132b206d5 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt @@ -1,6 +1,7 @@ package cc.unitmesh.agent.render import cc.unitmesh.agent.plan.PlanSummaryData +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolResult import cc.unitmesh.llm.compression.TokenInfo @@ -113,6 +114,30 @@ interface CodingAgentRenderer { fun renderUserConfirmationRequest(toolName: String, params: Map) + /** + * Request user approval for a SQL write operation. + * This is called when a write operation (INSERT, UPDATE, DELETE, CREATE, etc.) is detected. + * The renderer should display the SQL and allow the user to approve or reject. + * + * @param sql The SQL statement to be executed + * @param operationType The type of SQL operation (INSERT, UPDATE, DELETE, CREATE, etc.) + * @param affectedTables List of tables that will be affected + * @param isHighRisk Whether this is a high-risk operation (DROP, TRUNCATE) + * @param onApprove Callback to invoke when user approves the operation + * @param onReject Callback to invoke when user rejects the operation + */ + fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // Default: auto-reject for safety (renderers should override to show UI) + onReject() + } + /** * Add a live terminal session to the timeline. * Called when a Shell tool starts execution with PTY support. diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt index 58d0294640..9cd6a6c5fc 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt @@ -1,6 +1,7 @@ package cc.unitmesh.agent.render import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.subagent.SqlOperationType /** * Default console renderer - simple text output @@ -118,6 +119,30 @@ class DefaultCodingAgentRenderer : BaseRenderer() { println(" (Auto-approved for now)") } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + val riskIcon = if (isHighRisk) "!!" else "!" + println("\n$riskIcon SQL Write Operation Requires Approval") + println("-".repeat(50)) + println("Operation: ${operationType.name}") + println("Affected Tables: ${affectedTables.joinToString(", ")}") + if (isHighRisk) { + println("WARNING: This is a HIGH-RISK operation!") + } + println("\nSQL:") + println(sql) + println("-".repeat(50)) + // Default console renderer auto-rejects for safety + println("(Auto-rejected in console mode - use interactive UI for approval)") + onReject() + } + override fun renderAgentSketchBlock( agentName: String, language: String, diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt index 7715ba4054..4e923b2767 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt @@ -251,7 +251,11 @@ enum class ChatDBStepType(val displayName: String, val icon: String) { GENERATE_SQL("Generate SQL Query", "๐Ÿค–"), VALIDATE_SQL("Validate SQL", "โœ“"), REVISE_SQL("Revise SQL", "๐Ÿ”„"), + /** Waiting for user approval before executing write operation */ + AWAIT_APPROVAL("Awaiting Approval", "?"), EXECUTE_SQL("Execute SQL Query", "โšก"), + /** Execute write operation (INSERT, UPDATE, DELETE, DDL) */ + EXECUTE_WRITE("Execute Write Operation", "!"), GENERATE_VISUALIZATION("Generate Visualization", "๐Ÿ“ˆ"), FINAL_RESULT("Query Result", "โœ…") } @@ -264,6 +268,12 @@ enum class ChatDBStepStatus(val displayName: String) { IN_PROGRESS("In Progress"), SUCCESS("Success"), WARNING("Warning"), - ERROR("Error") + ERROR("Error"), + /** User approval is required */ + AWAITING_APPROVAL("Awaiting Approval"), + /** User approved the operation */ + APPROVED("Approved"), + /** User rejected the operation */ + REJECTED("Rejected") } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt index 1fd2c92950..a9b31db281 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt @@ -278,6 +278,61 @@ interface SqlValidatorInterface { fun validate(sql: String): SqlValidationResult fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult fun extractTableNames(sql: String): List + + /** + * Detect the type of SQL statement. + * Used to determine if approval is needed for write operations. + * + * @param sql The SQL statement to analyze + * @return The detected SQL operation type + */ + fun detectSqlType(sql: String): SqlOperationType +} + +/** + * SQL operation types for determining approval requirements + */ +enum class SqlOperationType { + /** SELECT queries - read-only, no approval needed */ + SELECT, + /** INSERT statements - requires approval */ + INSERT, + /** UPDATE statements - requires approval */ + UPDATE, + /** DELETE statements - requires approval */ + DELETE, + /** CREATE statements (tables, indexes, etc.) - requires approval */ + CREATE, + /** ALTER statements - requires approval */ + ALTER, + /** DROP statements - requires approval, high risk */ + DROP, + /** TRUNCATE statements - requires approval, high risk */ + TRUNCATE, + /** Other DDL/DCL/TCL statements */ + OTHER, + /** Unknown or unparseable SQL */ + UNKNOWN; + + /** + * Check if this operation type requires user approval + */ + fun requiresApproval(): Boolean = this != SELECT && this != UNKNOWN + + /** + * Check if this is a high-risk operation (DROP, TRUNCATE) + */ + fun isHighRisk(): Boolean = this == DROP || this == TRUNCATE + + /** + * Check if this is a write operation (INSERT, UPDATE, DELETE) + */ + fun isWriteOperation(): Boolean = this == INSERT || this == UPDATE || this == DELETE + + /** + * Check if this is a DDL operation (CREATE, ALTER, DROP, TRUNCATE) + */ + fun isDdlOperation(): Boolean = this == CREATE || this == ALTER || this == DROP || this == TRUNCATE } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt index cbb2a5f5d6..e9abe81ed6 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt @@ -48,5 +48,16 @@ expect class SqlValidator() : SqlValidatorInterface { * @return List of table names found in the query */ override fun extractTableNames(sql: String): List + + /** + * Detect the type of SQL statement. + * + * On JVM platforms, this uses JSqlParser for accurate detection. + * On non-JVM platforms, this uses regex-based detection. + * + * @param sql The SQL statement to analyze + * @return The detected SQL operation type + */ + override fun detectSqlType(sql: String): SqlOperationType } diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt index e1c0db3f3a..e997c4d825 100644 --- a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt @@ -48,39 +48,57 @@ actual class SqlValidator actual constructor() : SqlValidatorInterface { actual override fun extractTableNames(sql: String): List { val tables = mutableListOf() - + // Match FROM clause val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) fromPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match JOIN clause val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) joinPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match UPDATE clause val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) updatePattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match INSERT INTO clause val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) insertPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match DELETE FROM clause val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) deletePattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + return tables.distinct() } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } private fun performBasicValidation(sql: String): SqlValidationResult { val errors = mutableListOf() diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt index 9735dde91a..8455963fea 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt @@ -4,6 +4,7 @@ import cc.unitmesh.agent.plan.PlanSummaryData import cc.unitmesh.agent.plan.StepSummary import cc.unitmesh.agent.plan.TaskSummary import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.SqlOperationType import kotlin.js.JsExport /** @@ -196,6 +197,19 @@ class JsRendererAdapter(private val jsRenderer: JsCodingAgentRenderer) : CodingA jsRenderer.renderError("Tool '$toolName' requires user confirmation: $params (Auto-approved)") } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // JS renderer auto-rejects for safety + jsRenderer.renderError("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")} (Auto-rejected)") + onReject() + } + override fun renderPlanSummary(summary: PlanSummaryData) { jsRenderer.renderPlanSummary(JsPlanSummaryData.from(summary)) } diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt index 702a298ee9..337f0a6d3b 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt @@ -48,39 +48,57 @@ actual class SqlValidator actual constructor() : SqlValidatorInterface { actual override fun extractTableNames(sql: String): List { val tables = mutableListOf() - + // Match FROM clause: FROM table_name or FROM schema.table_name val fromPattern = Regex("""FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) fromPattern.findAll(sql).forEach { match -> tables.add(cleanTableName(match.groupValues[1])) } - + // Match JOIN clause: JOIN table_name or LEFT/RIGHT/INNER/OUTER JOIN table_name val joinPattern = Regex("""(?:LEFT|RIGHT|INNER|OUTER|CROSS|FULL)?\s*JOIN\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) joinPattern.findAll(sql).forEach { match -> tables.add(cleanTableName(match.groupValues[1])) } - + // Match UPDATE clause: UPDATE table_name val updatePattern = Regex("""UPDATE\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) updatePattern.findAll(sql).forEach { match -> tables.add(cleanTableName(match.groupValues[1])) } - + // Match INSERT INTO clause: INSERT INTO table_name val insertPattern = Regex("""INSERT\s+INTO\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) insertPattern.findAll(sql).forEach { match -> tables.add(cleanTableName(match.groupValues[1])) } - + // Match DELETE FROM clause: DELETE FROM table_name val deletePattern = Regex("""DELETE\s+FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) deletePattern.findAll(sql).forEach { match -> tables.add(cleanTableName(match.groupValues[1])) } - + return tables.distinct() } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } private fun cleanTableName(name: String): String { // Remove quotes and brackets, extract table name from schema.table format diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt index cb12cb6ed1..9885b47905 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt @@ -54,6 +54,33 @@ class ExposedDatabaseConnection( } } + override suspend fun executeUpdate(sql: String): UpdateResult = withContext(Dispatchers.IO) { + try { + hikariDataSource.connection.use { connection -> + val stmt = connection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + val affectedRows = stmt.executeUpdate() + + // Try to get generated keys + val generatedKeys = mutableListOf() + try { + val keysRs = stmt.generatedKeys + while (keysRs.next()) { + generatedKeys.add(keysRs.getString(1)) + } + keysRs.close() + } catch (e: Exception) { + // Ignore if generated keys are not available + } + + stmt.close() + + UpdateResult.success(affectedRows, generatedKeys) + } + } catch (e: Exception) { + throw DatabaseException.queryFailed(sql, e.message ?: "Unknown error") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { hikariDataSource.connection.use { connection -> diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt index cba40aa96c..e49ff02ae7 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt @@ -2,6 +2,16 @@ package cc.unitmesh.agent.subagent import net.sf.jsqlparser.parser.CCJSqlParserUtil import net.sf.jsqlparser.statement.Statement +import net.sf.jsqlparser.statement.alter.Alter +import net.sf.jsqlparser.statement.create.table.CreateTable +import net.sf.jsqlparser.statement.create.index.CreateIndex +import net.sf.jsqlparser.statement.create.view.CreateView +import net.sf.jsqlparser.statement.delete.Delete +import net.sf.jsqlparser.statement.drop.Drop +import net.sf.jsqlparser.statement.insert.Insert +import net.sf.jsqlparser.statement.select.Select +import net.sf.jsqlparser.statement.truncate.Truncate +import net.sf.jsqlparser.statement.update.Update import net.sf.jsqlparser.util.TablesNamesFinder /** @@ -91,7 +101,48 @@ actual class SqlValidator actual constructor() : SqlValidatorInterface { emptyList() } } - + + /** + * Detect the type of SQL statement using JSqlParser + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + when (statement) { + is Select -> SqlOperationType.SELECT + is Insert -> SqlOperationType.INSERT + is Update -> SqlOperationType.UPDATE + is Delete -> SqlOperationType.DELETE + is CreateTable, is CreateIndex, is CreateView -> SqlOperationType.CREATE + is Alter -> SqlOperationType.ALTER + is Drop -> SqlOperationType.DROP + is Truncate -> SqlOperationType.TRUNCATE + else -> SqlOperationType.OTHER + } + } catch (e: Exception) { + // Fallback to regex-based detection if parsing fails + detectSqlTypeByRegex(sql) + } + } + + /** + * Fallback regex-based SQL type detection + */ + private fun detectSqlTypeByRegex(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } + /** * Validate SQL and return the parsed statement if valid */ diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt index d1106660aa..8fd2c5830b 100644 --- a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt @@ -48,39 +48,57 @@ actual class SqlValidator actual constructor() : SqlValidatorInterface { actual override fun extractTableNames(sql: String): List { val tables = mutableListOf() - + // Match FROM clause val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) fromPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match JOIN clause val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) joinPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match UPDATE clause val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) updatePattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match INSERT INTO clause val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) insertPattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + // Match DELETE FROM clause val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) deletePattern.findAll(sql).forEach { match -> tables.add(match.groupValues[1]) } - + return tables.distinct() } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } private fun performBasicValidation(sql: String): SqlValidationResult { val errors = mutableListOf() diff --git a/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt index c49ad26ff2..0247f69e41 100644 --- a/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt +++ b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt @@ -54,6 +54,31 @@ class IdeaDatabaseConnection( } } + override suspend fun executeUpdate(sql: String): UpdateResult = withContext(Dispatchers.IO) { + try { + val stmt = ideaConnection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + val affectedRows = stmt.executeUpdate() + + // Try to get generated keys + val generatedKeys = mutableListOf() + try { + val keysRs = stmt.generatedKeys + while (keysRs.next()) { + generatedKeys.add(keysRs.getString(1)) + } + keysRs.close() + } catch (e: Exception) { + // Ignore if generated keys are not available + } + + stmt.close() + + UpdateResult.success(affectedRows, generatedKeys) + } catch (e: Exception) { + throw DatabaseException.queryFailed(sql, e.message ?: "Unknown error") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { val metadata = ideaConnection.metaData diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt index bec59498ce..ed775c0374 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt @@ -3,6 +3,8 @@ package cc.unitmesh.devins.idea.renderer import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.devins.llm.MessageRole import cc.unitmesh.agent.render.RendererUtils @@ -11,6 +13,7 @@ import cc.unitmesh.agent.render.TaskStatus import cc.unitmesh.agent.render.TimelineItem import cc.unitmesh.agent.render.ToolCallDisplayInfo import cc.unitmesh.agent.render.ToolCallInfo +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolType import cc.unitmesh.agent.tool.toToolType import cc.unitmesh.devins.llm.Message @@ -88,6 +91,10 @@ class JewelRenderer : BaseRenderer() { private val _tasks = MutableStateFlow>(emptyList()) val tasks: StateFlow> = _tasks.asStateFlow() + // SQL approval state + private val _pendingSqlApproval = MutableStateFlow(null) + val pendingSqlApproval: StateFlow = _pendingSqlApproval.asStateFlow() + /** * Set the current plan directly. * Used to sync with PlanStateService from CodingAgent. @@ -536,6 +543,57 @@ class JewelRenderer : BaseRenderer() { ) } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + _pendingSqlApproval.value = SqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + onApprove = { + _pendingSqlApproval.value = null + onApprove() + }, + onReject = { + _pendingSqlApproval.value = null + onReject() + } + ) + + // Also add to timeline for visibility + renderChatDBStep( + stepType = ChatDBStepType.AWAIT_APPROVAL, + status = ChatDBStepStatus.AWAITING_APPROVAL, + title = "Awaiting Approval: ${operationType.name}", + details = mapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + ) + } + + /** + * Approve the pending SQL operation + */ + fun approveSqlOperation() { + _pendingSqlApproval.value?.onApprove?.invoke() + } + + /** + * Reject the pending SQL operation + */ + fun rejectSqlOperation() { + _pendingSqlApproval.value?.onReject?.invoke() + } + override fun updateTokenInfo(tokenInfo: TokenInfo) { _lastMessageTokenInfo = tokenInfo _totalTokenInfo.update { current -> @@ -784,3 +842,15 @@ class JewelRenderer : BaseRenderer() { } } +/** + * Data class representing a pending SQL approval request + */ +data class SqlApprovalRequest( + val sql: String, + val operationType: SqlOperationType, + val affectedTables: List, + val isHighRisk: Boolean, + val onApprove: () -> Unit, + val onReject: () -> Unit +) + diff --git a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt index f4a4bf150c..f2f41ffc83 100644 --- a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt +++ b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt @@ -67,6 +67,20 @@ class ServerSideRenderer : CodingAgentRenderer { eventChannel.trySend(AgentEvent.Error("User confirmation required for tool: $toolName")) } + override fun renderSqlApprovalRequest( + sql: String, + operationType: cc.unitmesh.agent.subagent.SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // Server-side renderer auto-rejects for safety + // In a real implementation, this would send an event to the client for approval + eventChannel.trySend(AgentEvent.Error("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}")) + onReject() + } + override fun renderAgentSketchBlock( agentName: String, language: String, diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt index 8069d38056..05ddb7f44c 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt @@ -12,6 +12,7 @@ import cc.unitmesh.agent.render.TaskStatus import cc.unitmesh.agent.render.TimelineItem import cc.unitmesh.agent.render.TimelineItem.* import cc.unitmesh.agent.render.ToolCallInfo +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolType import cc.unitmesh.agent.tool.impl.docql.DocQLSearchStats import cc.unitmesh.agent.tool.toToolType @@ -81,6 +82,10 @@ class ComposeRenderer : BaseRenderer() { private var _currentPlan by mutableStateOf(null) val currentPlan: AgentPlan? get() = _currentPlan + // SQL approval state + private var _pendingSqlApproval by mutableStateOf(null) + val pendingSqlApproval: SqlApprovalRequest? get() = _pendingSqlApproval + // BaseRenderer implementation override fun renderIterationHeader( @@ -565,6 +570,57 @@ class ComposeRenderer : BaseRenderer() { // For now, just use error rendering since JS renderer doesn't have this method yet } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + _pendingSqlApproval = SqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + onApprove = { + _pendingSqlApproval = null + onApprove() + }, + onReject = { + _pendingSqlApproval = null + onReject() + } + ) + + // Also add to timeline for visibility + renderChatDBStep( + stepType = ChatDBStepType.AWAIT_APPROVAL, + status = ChatDBStepStatus.AWAITING_APPROVAL, + title = "Awaiting Approval: ${operationType.name}", + details = mapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + ) + } + + /** + * Approve the pending SQL operation + */ + fun approveSqlOperation() { + _pendingSqlApproval?.onApprove?.invoke() + } + + /** + * Reject the pending SQL operation + */ + fun rejectSqlOperation() { + _pendingSqlApproval?.onReject?.invoke() + } + // Public methods for UI interaction fun addUserMessage(content: String) { _timeline.add( @@ -1173,3 +1229,15 @@ class ComposeRenderer : BaseRenderer() { } } +/** + * Data class representing a pending SQL approval request + */ +data class SqlApprovalRequest( + val sql: String, + val operationType: SqlOperationType, + val affectedTables: List, + val isHighRisk: Boolean, + val onApprove: () -> Unit, + val onReject: () -> Unit +) + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt index f172d1ebb5..4f89de3394 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -136,6 +136,9 @@ private fun StepStatusBadge(status: ChatDBStepStatus) { ChatDBStepStatus.SUCCESS -> MaterialTheme.colorScheme.tertiary to "Success" ChatDBStepStatus.WARNING -> MaterialTheme.colorScheme.secondary to "Warning" ChatDBStepStatus.ERROR -> MaterialTheme.colorScheme.error to "Error" + ChatDBStepStatus.AWAITING_APPROVAL -> MaterialTheme.colorScheme.tertiary to "Awaiting Approval" + ChatDBStepStatus.APPROVED -> MaterialTheme.colorScheme.primary to "Approved" + ChatDBStepStatus.REJECTED -> MaterialTheme.colorScheme.error to "Rejected" } Surface( @@ -498,6 +501,90 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } } + ChatDBStepType.AWAIT_APPROVAL -> { + // Show SQL that requires approval + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show operation type + details["operationType"]?.let { opType -> + DetailRow("Operation Type", opType.toString()) + } + + // Show affected tables + @Suppress("UNCHECKED_CAST") + val affectedTables = details["affectedTables"] as? List + if (affectedTables != null && affectedTables.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Affected Tables:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + affectedTables.forEach { tableName -> + KeywordChip(tableName) + } + } + } + + // Show high risk warning + val isHighRisk = details["isHighRisk"] as? Boolean ?: false + if (isHighRisk) { + Spacer(modifier = Modifier.height(8.dp)) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + ) { + Row( + modifier = Modifier.padding(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text(text = "โš ๏ธ", fontSize = 16.sp) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = "HIGH RISK OPERATION - This operation may cause data loss", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onErrorContainer, + fontWeight = FontWeight.Bold + ) + } + } + } + } + + ChatDBStepType.EXECUTE_WRITE -> { + // Show database name + details["database"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } + + // Show operation type + details["operationType"]?.let { opType -> + DetailRow("Operation Type", opType.toString()) + } + + // Show SQL that was executed + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show affected rows + details["affectedRows"]?.let { rows -> + DetailRow("Affected Rows", rows.toString()) + } + + // Show message if present + details["message"]?.let { message -> + if (message.toString().isNotEmpty()) { + DetailRow("Message", message.toString()) + } + } + } + else -> { // Generic detail rendering details.forEach { (key, value) -> From 99037b0be8428428b22cd73a440e78c2d46be9ae Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 20:22:36 +0800 Subject: [PATCH 33/34] feat(renderer): refactor agent and database rendering logic #508 Refactored renderer and database connection modules for improved structure and maintainability. Updated related UI components and server-side renderer. --- .../chatdb/MultiDatabaseChatDBExecutor.kt | 117 ++++++++++++- .../agent/database/DatabaseConnection.kt | 69 ++++++++ .../agent/render/CodingAgentRenderer.kt | 3 + .../render/DefaultCodingAgentRenderer.kt | 11 ++ .../unitmesh/agent/render/RendererModels.kt | 2 + .../cc/unitmesh/agent/RendererExports.kt | 5 +- .../database/ExposedDatabaseConnection.kt | 99 +++++++++++ .../connection/IdeaDatabaseConnection.kt | 94 ++++++++++ .../timeline/IdeaTimelineContent.kt | 53 ++++++ .../devins/idea/renderer/JewelRenderer.kt | 28 ++- .../knowledge/IdeaKnowledgeContent.kt | 59 +++++++ .../server/render/ServerSideRenderer.kt | 4 +- .../ui/compose/agent/AgentMessageList.kt | 4 +- .../ui/compose/agent/ComposeRenderer.kt | 28 ++- .../agent/chatdb/components/ChatDBStepCard.kt | 164 +++++++++++++++++- 15 files changed, 720 insertions(+), 20 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt index 4e61b702a8..821bd62b47 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -335,15 +335,123 @@ class MultiDatabaseChatDBExecutor( val isWriteOperation = operationType.requiresApproval() val isHighRisk = operationType.isHighRisk() - // If write operation, request approval + // If write operation, perform dry run first, then request approval if (isWriteOperation) { val affectedTables = extractTablesFromSql(sql, merged.databases[dbName]) + // Step: Dry run to validate SQL before asking for approval + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating ${operationType.name} operation...", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + onProgress("๐Ÿ” Performing dry run validation...") + + val dryRunResult = connection.dryRun(sql) + + if (!dryRunResult.isValid) { + // Dry run failed - SQL has errors + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.ERROR, + title = "Dry run failed: SQL validation error", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "errors" to dryRunResult.errors + ), + error = dryRunResult.message ?: dryRunResult.errors.firstOrNull() ?: "Unknown error" + ) + errors.add("[$dbName] Dry run failed: ${dryRunResult.message}") + + // Try to revise SQL based on dry run error + onProgress("๐Ÿ”„ SQL validation failed, attempting to revise...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on validation error..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(dbName, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = sql, + errorMessage = "Dry run error: ${dryRunResult.message}", + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success) { + sql = revisionResult.content + revisionAttempts++ + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully", + details = mapOf("database" to dbName, "sql" to sql) + ) + + // Re-run dry run with revised SQL + val revisedDryRunResult = connection.dryRun(sql) + if (!revisedDryRunResult.isValid) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.ERROR, + title = "Revised SQL still has errors", + error = revisedDryRunResult.message + ) + errors.add("[$dbName] Revised SQL still invalid: ${revisedDryRunResult.message}") + continue + } + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed", + error = revisionResult.content + ) + continue + } + } else { + // Dry run succeeded + val estimatedInfo = if (dryRunResult.estimatedRows != null) { + " (estimated ${dryRunResult.estimatedRows} row(s) affected)" + } else "" + + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.SUCCESS, + title = "Dry run passed$estimatedInfo", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "estimatedRows" to (dryRunResult.estimatedRows ?: "unknown"), + "warnings" to dryRunResult.warnings + ) + ) + onProgress("โœ… Dry run validation passed$estimatedInfo") + } + + // Now request user approval val approved = requestSqlApproval( sql = sql, operationType = operationType, affectedTables = affectedTables, isHighRisk = isHighRisk, + dryRunResult = dryRunResult, onProgress = onProgress ) @@ -923,10 +1031,14 @@ Generate the SQL: operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: DryRunResult? = null, onProgress: (String) -> Unit ): Boolean { val riskLevel = if (isHighRisk) "โš ๏ธ HIGH RISK" else "โšก Write Operation" - onProgress("$riskLevel: ${operationType.name} requires approval") + val dryRunInfo = if (dryRunResult?.estimatedRows != null) { + " (dry run: ${dryRunResult.estimatedRows} row(s) would be affected)" + } else "" + onProgress("$riskLevel: ${operationType.name} requires approval$dryRunInfo") return suspendCancellableCoroutine { continuation -> renderer.renderSqlApprovalRequest( @@ -934,6 +1046,7 @@ Generate the SQL: operationType = operationType, affectedTables = affectedTables, isHighRisk = isHighRisk, + dryRunResult = dryRunResult, onApprove = { if (continuation.isActive) { continuation.resume(true) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index e8eae3f8b4..862d1c50cd 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -26,6 +26,26 @@ interface DatabaseConnection { */ suspend fun executeUpdate(sql: String): UpdateResult + /** + * Dry run a SQL statement to validate it without actually executing it. + * This is useful for validating write operations before user approval. + * + * The default implementation uses database-specific techniques: + * - For SELECT: Uses EXPLAIN + * - For INSERT/UPDATE/DELETE: Wraps in a transaction and rolls back + * - For DDL: Returns validation based on syntax only + * + * @param sql SQL statement to validate + * @return DryRunResult containing validation status and any errors + */ + suspend fun dryRun(sql: String): DryRunResult { + // Default implementation - subclasses can override for better support + return DryRunResult( + isValid = true, + message = "Dry run not fully supported, syntax validation only" + ) + } + /** * Get database schema information * @@ -111,6 +131,55 @@ interface DatabaseConnection { } } +/** + * Dry run result - validation result without actual execution + */ +@Serializable +data class DryRunResult( + /** + * Whether the SQL is valid and can be executed + */ + val isValid: Boolean, + + /** + * Validation message or error description + */ + val message: String? = null, + + /** + * Detailed errors if validation failed + */ + val errors: List = emptyList(), + + /** + * Estimated affected rows (if available from EXPLAIN) + */ + val estimatedRows: Int? = null, + + /** + * Warnings (non-fatal issues) + */ + val warnings: List = emptyList() +) { + companion object { + fun valid(message: String? = null, estimatedRows: Int? = null): DryRunResult { + return DryRunResult(true, message, emptyList(), estimatedRows) + } + + fun invalid(error: String): DryRunResult { + return DryRunResult(false, error, listOf(error)) + } + + fun invalid(errors: List): DryRunResult { + return DryRunResult(false, errors.firstOrNull(), errors) + } + + fun withWarnings(warnings: List): DryRunResult { + return DryRunResult(true, "Valid with warnings", emptyList(), null, warnings) + } + } +} + /** * Database update result (for INSERT, UPDATE, DELETE, DDL statements) */ diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt index 0132b206d5..ee7a691d5b 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt @@ -1,5 +1,6 @@ package cc.unitmesh.agent.render +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.PlanSummaryData import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolResult @@ -123,6 +124,7 @@ interface CodingAgentRenderer { * @param operationType The type of SQL operation (INSERT, UPDATE, DELETE, CREATE, etc.) * @param affectedTables List of tables that will be affected * @param isHighRisk Whether this is a high-risk operation (DROP, TRUNCATE) + * @param dryRunResult Optional result from dry run validation (if available) * @param onApprove Callback to invoke when user approves the operation * @param onReject Callback to invoke when user rejects the operation */ @@ -131,6 +133,7 @@ interface CodingAgentRenderer { operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: DryRunResult? = null, onApprove: () -> Unit, onReject: () -> Unit ) { diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt index 9cd6a6c5fc..acbe9f82f5 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt @@ -124,6 +124,7 @@ class DefaultCodingAgentRenderer : BaseRenderer() { operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: cc.unitmesh.agent.database.DryRunResult?, onApprove: () -> Unit, onReject: () -> Unit ) { @@ -135,6 +136,16 @@ class DefaultCodingAgentRenderer : BaseRenderer() { if (isHighRisk) { println("WARNING: This is a HIGH-RISK operation!") } + if (dryRunResult != null) { + println("\nDry Run Result:") + println(" Valid: ${dryRunResult.isValid}") + if (dryRunResult.estimatedRows != null) { + println(" Estimated Rows Affected: ${dryRunResult.estimatedRows}") + } + if (dryRunResult.warnings.isNotEmpty()) { + println(" Warnings: ${dryRunResult.warnings.joinToString(", ")}") + } + } println("\nSQL:") println(sql) println("-".repeat(50)) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt index 4e923b2767..3669f0d6d8 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt @@ -251,6 +251,8 @@ enum class ChatDBStepType(val displayName: String, val icon: String) { GENERATE_SQL("Generate SQL Query", "๐Ÿค–"), VALIDATE_SQL("Validate SQL", "โœ“"), REVISE_SQL("Revise SQL", "๐Ÿ”„"), + /** Dry run to validate SQL without executing (uses transaction rollback) */ + DRY_RUN("Dry Run Validation", "๐Ÿงช"), /** Waiting for user approval before executing write operation */ AWAIT_APPROVAL("Awaiting Approval", "?"), EXECUTE_SQL("Execute SQL Query", "โšก"), diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt index 8455963fea..aa7865ac57 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt @@ -1,5 +1,6 @@ package cc.unitmesh.agent +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.PlanSummaryData import cc.unitmesh.agent.plan.StepSummary import cc.unitmesh.agent.plan.TaskSummary @@ -202,11 +203,13 @@ class JsRendererAdapter(private val jsRenderer: JsCodingAgentRenderer) : CodingA operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: DryRunResult?, onApprove: () -> Unit, onReject: () -> Unit ) { // JS renderer auto-rejects for safety - jsRenderer.renderError("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")} (Auto-rejected)") + val dryRunInfo = if (dryRunResult != null) " (dry run: ${if (dryRunResult.isValid) "passed" else "failed"})" else "" + jsRenderer.renderError("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}$dryRunInfo (Auto-rejected)") onReject() } diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt index 9885b47905..3b32dfe29c 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt @@ -81,6 +81,105 @@ class ExposedDatabaseConnection( } } + /** + * Dry run SQL to validate without executing. + * Uses transaction rollback to test DML statements safely. + */ + override suspend fun dryRun(sql: String): DryRunResult = withContext(Dispatchers.IO) { + try { + hikariDataSource.connection.use { connection -> + val originalAutoCommit = connection.autoCommit + try { + // Disable auto-commit to use transaction + connection.autoCommit = false + + val sqlUpper = sql.trim().uppercase() + val warnings = mutableListOf() + + when { + // For SELECT, use EXPLAIN to validate + sqlUpper.startsWith("SELECT") -> { + val explainSql = "EXPLAIN $sql" + val stmt = connection.prepareStatement(explainSql) + try { + stmt.executeQuery() + stmt.close() + DryRunResult.valid("SELECT query is valid") + } catch (e: Exception) { + DryRunResult.invalid("Query validation failed: ${e.message}") + } + } + + // For INSERT/UPDATE/DELETE, execute in transaction and rollback + sqlUpper.startsWith("INSERT") || + sqlUpper.startsWith("UPDATE") || + sqlUpper.startsWith("DELETE") -> { + val stmt = connection.prepareStatement(sql) + try { + val affectedRows = stmt.executeUpdate() + stmt.close() + // Rollback to undo the changes + connection.rollback() + DryRunResult.valid( + "Statement is valid (would affect $affectedRows row(s))", + estimatedRows = affectedRows + ) + } catch (e: Exception) { + connection.rollback() + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + + // For DDL (CREATE, ALTER, DROP), we can't safely dry run + // Just validate syntax using EXPLAIN if possible + sqlUpper.startsWith("CREATE") || + sqlUpper.startsWith("ALTER") || + sqlUpper.startsWith("DROP") || + sqlUpper.startsWith("TRUNCATE") -> { + // For DDL, we can try to parse it but can't execute safely + // Return a warning that DDL cannot be fully validated + warnings.add("DDL statements cannot be fully validated without execution") + + // Try basic syntax check by preparing the statement + try { + val stmt = connection.prepareStatement(sql) + stmt.close() + DryRunResult( + isValid = true, + message = "DDL syntax appears valid (cannot fully validate without execution)", + warnings = warnings + ) + } catch (e: Exception) { + DryRunResult.invalid("DDL syntax error: ${e.message}") + } + } + + else -> { + // Unknown statement type, try to prepare it + try { + val stmt = connection.prepareStatement(sql) + stmt.close() + DryRunResult.valid("Statement syntax is valid") + } catch (e: Exception) { + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + } + } finally { + // Restore original auto-commit setting + try { + connection.rollback() // Ensure any pending changes are rolled back + connection.autoCommit = originalAutoCommit + } catch (e: Exception) { + // Ignore cleanup errors + } + } + } + } catch (e: Exception) { + DryRunResult.invalid("Dry run failed: ${e.message}") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { hikariDataSource.connection.use { connection -> diff --git a/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt index 0247f69e41..34ecbabacc 100644 --- a/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt +++ b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt @@ -79,6 +79,100 @@ class IdeaDatabaseConnection( } } + /** + * Dry run SQL to validate without executing. + * Uses transaction rollback to test DML statements safely. + */ + override suspend fun dryRun(sql: String): DryRunResult = withContext(Dispatchers.IO) { + try { + val originalAutoCommit = ideaConnection.autoCommit + try { + // Disable auto-commit to use transaction + ideaConnection.autoCommit = false + + val sqlUpper = sql.trim().uppercase() + val warnings = mutableListOf() + + when { + // For SELECT, use EXPLAIN to validate + sqlUpper.startsWith("SELECT") -> { + val explainSql = "EXPLAIN $sql" + val stmt = ideaConnection.prepareStatement(explainSql) + try { + stmt.executeQuery() + stmt.close() + DryRunResult.valid("SELECT query is valid") + } catch (e: Exception) { + DryRunResult.invalid("Query validation failed: ${e.message}") + } + } + + // For INSERT/UPDATE/DELETE, execute in transaction and rollback + sqlUpper.startsWith("INSERT") || + sqlUpper.startsWith("UPDATE") || + sqlUpper.startsWith("DELETE") -> { + val stmt = ideaConnection.prepareStatement(sql) + try { + val affectedRows = stmt.executeUpdate() + stmt.close() + // Rollback to undo the changes + ideaConnection.rollback() + DryRunResult.valid( + "Statement is valid (would affect $affectedRows row(s))", + estimatedRows = affectedRows + ) + } catch (e: Exception) { + ideaConnection.rollback() + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + + // For DDL (CREATE, ALTER, DROP), we can't safely dry run + sqlUpper.startsWith("CREATE") || + sqlUpper.startsWith("ALTER") || + sqlUpper.startsWith("DROP") || + sqlUpper.startsWith("TRUNCATE") -> { + warnings.add("DDL statements cannot be fully validated without execution") + + // Try basic syntax check by preparing the statement + try { + val stmt = ideaConnection.prepareStatement(sql) + stmt.close() + DryRunResult( + isValid = true, + message = "DDL syntax appears valid (cannot fully validate without execution)", + warnings = warnings + ) + } catch (e: Exception) { + DryRunResult.invalid("DDL syntax error: ${e.message}") + } + } + + else -> { + // Unknown statement type, try to prepare it + try { + val stmt = ideaConnection.prepareStatement(sql) + stmt.close() + DryRunResult.valid("Statement syntax is valid") + } catch (e: Exception) { + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + } + } finally { + // Restore original auto-commit setting + try { + ideaConnection.rollback() // Ensure any pending changes are rolled back + ideaConnection.autoCommit = originalAutoCommit + } catch (e: Exception) { + // Ignore cleanup errors + } + } + } catch (e: Exception) { + DryRunResult.invalid("Dry run failed: ${e.message}") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { val metadata = ideaConnection.metaData diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt index b3d7a44042..8da1d95a22 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt @@ -97,6 +97,17 @@ fun IdeaTimelineItemView( // Agent-generated sketch block (chart, nanodsl, mermaid, etc.) IdeaAgentSketchBlockBubble(item, project = project) } + is TimelineItem.ChatDBStepItem -> { + // ChatDB execution step - display as info bubble + IdeaInfoBubble( + message = "${item.stepType.icon} ${item.stepType.displayName}: ${item.title}", + status = item.status.displayName + ) + } + is TimelineItem.InfoItem -> { + // Info message bubble + IdeaInfoBubble(message = item.message) + } } } @@ -255,3 +266,45 @@ fun IdeaEmptyStateMessage(text: String) { } } +/** + * Info bubble for displaying informational messages. + */ +@Composable +fun IdeaInfoBubble( + message: String, + status: String? = null +) { + Box( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 2.dp) + .background( + JewelTheme.globalColors.panelBackground.copy(alpha = 0.5f), + shape = RoundedCornerShape(6.dp) + ) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = message, + style = JewelTheme.defaultTextStyle.copy( + fontSize = 12.sp, + color = JewelTheme.globalColors.text.info + ) + ) + if (status != null) { + Text( + text = "[$status]", + style = JewelTheme.defaultTextStyle.copy( + fontSize = 10.sp, + color = JewelTheme.globalColors.text.info.copy(alpha = 0.6f) + ) + ) + } + } + } +} + diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt index ed775c0374..fe70e21922 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt @@ -1,5 +1,6 @@ package cc.unitmesh.devins.idea.renderer +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer @@ -548,6 +549,7 @@ class JewelRenderer : BaseRenderer() { operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: DryRunResult?, onApprove: () -> Unit, onReject: () -> Unit ) { @@ -556,6 +558,7 @@ class JewelRenderer : BaseRenderer() { operationType = operationType, affectedTables = affectedTables, isHighRisk = isHighRisk, + dryRunResult = dryRunResult, onApprove = { _pendingSqlApproval.value = null onApprove() @@ -566,17 +569,29 @@ class JewelRenderer : BaseRenderer() { } ) + // Build details map with dry run info + val details = mutableMapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + if (dryRunResult != null) { + details["dryRunValid"] = dryRunResult.isValid + if (dryRunResult.estimatedRows != null) { + details["estimatedRows"] = dryRunResult.estimatedRows!! + } + if (dryRunResult.warnings.isNotEmpty()) { + details["warnings"] = dryRunResult.warnings.joinToString(", ") + } + } + // Also add to timeline for visibility renderChatDBStep( stepType = ChatDBStepType.AWAIT_APPROVAL, status = ChatDBStepStatus.AWAITING_APPROVAL, title = "Awaiting Approval: ${operationType.name}", - details = mapOf( - "sql" to sql, - "operationType" to operationType.name, - "affectedTables" to affectedTables.joinToString(", "), - "isHighRisk" to isHighRisk - ) + details = details ) } @@ -850,6 +865,7 @@ data class SqlApprovalRequest( val operationType: SqlOperationType, val affectedTables: List, val isHighRisk: Boolean, + val dryRunResult: DryRunResult? = null, val onApprove: () -> Unit, val onReject: () -> Unit ) diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt index 8f5fa30dc7..f280a8cb1a 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt @@ -838,6 +838,65 @@ private fun ChatMessageItem(item: TimelineItem) { } } } + + is TimelineItem.ChatDBStepItem -> { + // ChatDB execution step + Box( + modifier = Modifier + .fillMaxWidth() + .background(JewelTheme.globalColors.panelBackground.copy(alpha = 0.5f)) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(6.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = item.stepType.icon, + style = JewelTheme.defaultTextStyle.copy(fontSize = 14.sp) + ) + Column { + Text( + text = item.stepType.displayName, + style = JewelTheme.defaultTextStyle.copy( + fontWeight = FontWeight.Bold, + fontSize = 12.sp + ) + ) + Text( + text = item.title, + style = JewelTheme.defaultTextStyle.copy(fontSize = 11.sp) + ) + } + } + } + } + + is TimelineItem.InfoItem -> { + // Info message + Box( + modifier = Modifier + .fillMaxWidth() + .background(IdeaAutoDevColors.Blue.c400.copy(alpha = 0.1f)) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(6.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = IdeaComposeIcons.Info, + contentDescription = "Info", + modifier = Modifier.size(16.dp), + tint = IdeaAutoDevColors.Blue.c400 + ) + Text( + text = item.message, + style = JewelTheme.defaultTextStyle.copy(fontSize = 12.sp) + ) + } + } + } } } diff --git a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt index f2f41ffc83..07072df07e 100644 --- a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt +++ b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt @@ -72,12 +72,14 @@ class ServerSideRenderer : CodingAgentRenderer { operationType: cc.unitmesh.agent.subagent.SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: cc.unitmesh.agent.database.DryRunResult?, onApprove: () -> Unit, onReject: () -> Unit ) { // Server-side renderer auto-rejects for safety // In a real implementation, this would send an event to the client for approval - eventChannel.trySend(AgentEvent.Error("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}")) + val dryRunInfo = if (dryRunResult != null) " (dry run: ${if (dryRunResult.isValid) "passed" else "failed"})" else "" + eventChannel.trySend(AgentEvent.Error("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}$dryRunInfo")) onReject() } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt index 7afbccf3c0..a28503bfe7 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt @@ -239,7 +239,9 @@ fun RenderMessageItem( is TimelineItem.ChatDBStepItem -> { cc.unitmesh.devins.ui.compose.agent.chatdb.components.ChatDBStepCard( - step = timelineItem + step = timelineItem, + onApprove = { renderer.approveSqlOperation() }, + onReject = { renderer.rejectSqlOperation() } ) } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt index 05ddb7f44c..1ebae22c18 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt @@ -1,6 +1,7 @@ package cc.unitmesh.devins.ui.compose.agent import androidx.compose.runtime.* +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer @@ -575,6 +576,7 @@ class ComposeRenderer : BaseRenderer() { operationType: SqlOperationType, affectedTables: List, isHighRisk: Boolean, + dryRunResult: DryRunResult?, onApprove: () -> Unit, onReject: () -> Unit ) { @@ -583,6 +585,7 @@ class ComposeRenderer : BaseRenderer() { operationType = operationType, affectedTables = affectedTables, isHighRisk = isHighRisk, + dryRunResult = dryRunResult, onApprove = { _pendingSqlApproval = null onApprove() @@ -593,17 +596,29 @@ class ComposeRenderer : BaseRenderer() { } ) + // Build details map with dry run info + val details = mutableMapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + if (dryRunResult != null) { + details["dryRunValid"] = dryRunResult.isValid + if (dryRunResult.estimatedRows != null) { + details["estimatedRows"] = dryRunResult.estimatedRows!! + } + if (dryRunResult.warnings.isNotEmpty()) { + details["warnings"] = dryRunResult.warnings.joinToString(", ") + } + } + // Also add to timeline for visibility renderChatDBStep( stepType = ChatDBStepType.AWAIT_APPROVAL, status = ChatDBStepStatus.AWAITING_APPROVAL, title = "Awaiting Approval: ${operationType.name}", - details = mapOf( - "sql" to sql, - "operationType" to operationType.name, - "affectedTables" to affectedTables.joinToString(", "), - "isHighRisk" to isHighRisk - ) + details = details ) } @@ -1237,6 +1252,7 @@ data class SqlApprovalRequest( val operationType: SqlOperationType, val affectedTables: List, val isHighRisk: Boolean, + val dryRunResult: DryRunResult? = null, val onApprove: () -> Unit, val onReject: () -> Unit ) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt index 4f89de3394..791b1454af 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -5,6 +5,7 @@ import androidx.compose.animation.expandVertically import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeOut import androidx.compose.animation.shrinkVertically +import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.background import androidx.compose.foundation.border import androidx.compose.foundation.clickable @@ -36,7 +37,9 @@ import cc.unitmesh.agent.render.TimelineItem @Composable fun ChatDBStepCard( step: TimelineItem.ChatDBStepItem, - modifier: Modifier = Modifier + modifier: Modifier = Modifier, + onApprove: (() -> Unit)? = null, + onReject: (() -> Unit)? = null ) { var isExpanded by remember { mutableStateOf(step.status == ChatDBStepStatus.ERROR) } @@ -120,7 +123,13 @@ fun ChatDBStepCard( // Details if (step.details.isNotEmpty()) { - StepDetails(details = step.details, stepType = step.stepType) + StepDetails( + details = step.details, + stepType = step.stepType, + step = step, + onApprove = onApprove, + onReject = onReject + ) } } } @@ -156,7 +165,13 @@ private fun StepStatusBadge(status: ChatDBStepStatus) { } @Composable -private fun StepDetails(details: Map, stepType: ChatDBStepType) { +private fun StepDetails( + details: Map, + stepType: ChatDBStepType, + step: TimelineItem.ChatDBStepItem, + onApprove: (() -> Unit)?, + onReject: (() -> Unit)? +) { Column( modifier = Modifier .fillMaxWidth() @@ -501,6 +516,113 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } } + ChatDBStepType.DRY_RUN -> { + // Show SQL being validated + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show dry run result + val isValid = details["isValid"] as? Boolean + val message = details["message"]?.toString() + @Suppress("UNCHECKED_CAST") + val errors = details["errors"] as? List + @Suppress("UNCHECKED_CAST") + val warnings = details["warnings"] as? List + val estimatedRows = details["estimatedRows"] + + // Validation status + if (isValid != null) { + Spacer(modifier = Modifier.height(8.dp)) + Surface( + shape = RoundedCornerShape(4.dp), + color = if (isValid) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f) + } else { + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + } + ) { + Row( + modifier = Modifier.padding(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = if (isValid) "โœ“" else "โœ—", + fontSize = 16.sp, + color = if (isValid) { + MaterialTheme.colorScheme.primary + } else { + MaterialTheme.colorScheme.error + } + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = if (isValid) "Validation Passed" else "Validation Failed", + style = MaterialTheme.typography.bodySmall, + fontWeight = FontWeight.Bold, + color = if (isValid) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onErrorContainer + } + ) + } + } + } + + // Show message + if (!message.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + // Show errors + if (!errors.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Errors:", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error + ) + } + } + + // Show warnings + if (!warnings.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Warnings:", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.tertiary + ) + warnings.forEach { warning -> + Text( + text = "โ€ข $warning", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.tertiary + ) + } + } + + // Show estimated rows + if (estimatedRows != null) { + Spacer(modifier = Modifier.height(4.dp)) + DetailRow("Estimated Rows", estimatedRows.toString()) + } + } + ChatDBStepType.AWAIT_APPROVAL -> { // Show SQL that requires approval details["sql"]?.let { sql -> @@ -554,6 +676,42 @@ private fun StepDetails(details: Map, stepType: ChatDBStepType) { } } } + + // Show Approve/Reject buttons only when status is AWAITING_APPROVAL + if (step.status == ChatDBStepStatus.AWAITING_APPROVAL && (onApprove != null || onReject != null)) { + Spacer(modifier = Modifier.height(12.dp)) + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.End, + verticalAlignment = Alignment.CenterVertically + ) { + // Reject button + if (onReject != null) { + OutlinedButton( + onClick = onReject, + colors = ButtonDefaults.outlinedButtonColors( + contentColor = MaterialTheme.colorScheme.error + ), + border = BorderStroke(1.dp, MaterialTheme.colorScheme.error) + ) { + Text("Reject") + } + Spacer(modifier = Modifier.width(8.dp)) + } + + // Approve button + if (onApprove != null) { + Button( + onClick = onApprove, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.primary + ) + ) { + Text("Approve") + } + } + } + } } ChatDBStepType.EXECUTE_WRITE -> { From 55a571378b7b3c54fc4e0de1a30a1dd0e091d3eb Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Wed, 10 Dec 2025 20:45:16 +0800 Subject: [PATCH 34/34] feat(chatdb): render success for revised dry run validation #508 Render a success step when revised SQL passes dry run validation and update approval to use the latest dry run result. --- .../chatdb/MultiDatabaseChatDBExecutor.kt | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt index 821bd62b47..f24aecbcbc 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -352,7 +352,7 @@ class MultiDatabaseChatDBExecutor( ) onProgress("๐Ÿ” Performing dry run validation...") - val dryRunResult = connection.dryRun(sql) + var dryRunResult = connection.dryRun(sql) if (!dryRunResult.isValid) { // Dry run failed - SQL has errors @@ -415,6 +415,26 @@ class MultiDatabaseChatDBExecutor( errors.add("[$dbName] Revised SQL still invalid: ${revisedDryRunResult.message}") continue } + // Update dryRunResult with the successful revised result + dryRunResult = revisedDryRunResult + + // Render success for revised dry run + val estimatedInfo = if (dryRunResult.estimatedRows != null) { + " (estimated ${dryRunResult.estimatedRows} row(s) affected)" + } else "" + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.SUCCESS, + title = "Revised SQL passed validation$estimatedInfo", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "estimatedRows" to (dryRunResult.estimatedRows ?: "unknown"), + "warnings" to dryRunResult.warnings + ) + ) + onProgress("โœ… Revised SQL validation passed$estimatedInfo") } else { renderer.renderChatDBStep( stepType = ChatDBStepType.REVISE_SQL, @@ -445,7 +465,7 @@ class MultiDatabaseChatDBExecutor( onProgress("โœ… Dry run validation passed$estimatedInfo") } - // Now request user approval + // Now request user approval (with the latest dryRunResult) val approved = requestSqlApproval( sql = sql, operationType = operationType,