diff --git a/mpp-core/build.gradle.kts b/mpp-core/build.gradle.kts index 5369891311..547a344946 100644 --- a/mpp-core/build.gradle.kts +++ b/mpp-core/build.gradle.kts @@ -212,6 +212,9 @@ kotlin { // JSQLParser for SQL validation and parsing implementation("com.github.jsqlparser:jsqlparser:4.9") + + // MyNLP for Chinese NLP tokenization + implementation("com.mayabot.mynlp:mynlp-all:4.0.0") } } diff --git a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt new file mode 100644 index 0000000000..b6b0b0e2a2 --- /dev/null +++ b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.android.kt @@ -0,0 +1,23 @@ +package cc.unitmesh.agent.chatdb + +/** + * Android implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only + * and may have compatibility issues on Android. + * + * TODO: Consider using Android's BreakIterator or a lightweight NLP library for better tokenization. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackNlpTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt new file mode 100644 index 0000000000..f3e1e12170 --- /dev/null +++ b/mpp-core/src/androidMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.android.kt @@ -0,0 +1,125 @@ +package cc.unitmesh.agent.subagent + +import net.sf.jsqlparser.parser.CCJSqlParserUtil +import net.sf.jsqlparser.statement.Statement +import net.sf.jsqlparser.util.TablesNamesFinder + +/** + * Android implementation of SqlValidator using JSqlParser. + * + * This validator uses JSqlParser to validate SQL syntax. + * It can detect: + * - Syntax errors + * - Malformed SQL statements + * - Unsupported SQL constructs + * - Table names not in whitelist (schema validation) + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + + // Extract table names from the SQL + val tablesNamesFinder = TablesNamesFinder() + val usedTables = tablesNamesFinder.getTableList(statement) + + // Check if all used tables are in the whitelist (case-insensitive) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = collectWarnings(statement) + ) + } else { + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + actual override fun extractTableNames(sql: String): List { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + val tablesNamesFinder = TablesNamesFinder() + tablesNamesFinder.getTableList(statement) + } catch (e: Exception) { + emptyList() + } + } + + private fun extractErrorMessage(e: Exception): String { + val message = e.message ?: "Unknown SQL parsing error" + return when { + message.contains("Encountered") -> { + val match = Regex("Encountered \"(.+?)\" at line (\\d+), column (\\d+)").find(message) + if (match != null) { + val (token, line, column) = match.destructured + "Syntax error at line $line, column $column: unexpected token '$token'" + } else { + message + } + } + message.contains("Was expecting") -> { + val match = Regex("Was expecting.*?:\\s*(.+)").find(message) + if (match != null) { + "Expected: ${match.groupValues[1].take(100)}" + } else { + message + } + } + else -> message.take(200) + } + } + + private fun collectWarnings(statement: Statement): List { + val warnings = mutableListOf() + val sql = statement.toString() + + if (sql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!sql.contains("WHERE", ignoreCase = true) && + (sql.contains("UPDATE", ignoreCase = true) || sql.contains("DELETE", ignoreCase = true))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return warnings + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt index 81d5d381bc..11fee4b743 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt @@ -7,6 +7,8 @@ package cc.unitmesh.agent * - LOCAL: Simple local chat mode without heavy tooling * - CODING: Local coding agent with full tool access (file system, shell, etc.) * - CODE_REVIEW: Dedicated code review agent with git integration + * - KNOWLEDGE: Document reader mode for AI-native document reading + * - CHAT_DB: Database chat mode for text-to-SQL interactions * - REMOTE: Remote agent connected to mpp-server */ enum class AgentType { @@ -30,6 +32,11 @@ enum class AgentType { */ KNOWLEDGE, + /** + * Database chat mode - text-to-SQL agent for database queries + */ + CHAT_DB, + /** * Remote agent mode - connects to remote mpp-server for distributed execution */ @@ -40,6 +47,7 @@ enum class AgentType { CODING -> "Agentic" CODE_REVIEW -> "Review" KNOWLEDGE -> "Knowledge" + CHAT_DB -> "ChatDB" REMOTE -> "Remote" } @@ -51,6 +59,7 @@ enum class AgentType { "coding" -> CODING "codereview" -> CODE_REVIEW "documentreader", "documents" -> KNOWLEDGE + "chatdb", "database" -> CHAT_DB else -> LOCAL_CHAT } } diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt new file mode 100644 index 0000000000..8feac9d5db --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgent.kt @@ -0,0 +1,172 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.core.MainAgent +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.policy.DefaultPolicyEngine +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.render.DefaultCodingAgentRenderer +import cc.unitmesh.agent.tool.shell.DefaultShellExecutor +import cc.unitmesh.agent.tool.shell.ShellExecutor +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.agent.tool.filesystem.ToolFileSystem +import cc.unitmesh.agent.tool.registry.ToolRegistry +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig + +/** + * ChatDB Agent - Text2SQL Agent for natural language database queries + * + * This agent converts natural language queries to SQL, executes them, + * and optionally generates visualizations of the results. + * + * Features: + * - Schema Linking: Keyword-based search to find relevant tables/columns + * - SQL Generation: LLM generates SQL from natural language + * - Revise Agent: Self-correction loop using JSqlParser for SQL validation + * - Query Execution: Execute validated SQL and return results + * - Visualization: Optional PlotDSL generation for data visualization + * + * Based on GitHub Issue #508: https://github.com/phodal/auto-dev/issues/508 + */ +class ChatDBAgent( + private val projectPath: String, + private val llmService: KoogLLMService, + private val databaseConfig: DatabaseConfig, + override val maxIterations: Int = 10, + private val renderer: CodingAgentRenderer = DefaultCodingAgentRenderer(), + private val fileSystem: ToolFileSystem? = null, + private val shellExecutor: ShellExecutor? = null, + private val mcpToolConfigService: McpToolConfigService, + private val enableLLMStreaming: Boolean = true +) : MainAgent( + AgentDefinition( + name = "ChatDBAgent", + displayName = "ChatDB Agent", + description = "Text2SQL Agent that converts natural language to SQL queries with schema linking and self-correction", + promptConfig = PromptConfig( + systemPrompt = SYSTEM_PROMPT + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 10, maxTimeMinutes = 5) + ) +) { + private val logger = getLogger("ChatDBAgent") + + private val actualFileSystem = fileSystem ?: DefaultToolFileSystem(projectPath = projectPath) + + private val toolRegistry = ToolRegistry( + fileSystem = actualFileSystem, + shellExecutor = shellExecutor ?: DefaultShellExecutor(), + configService = mcpToolConfigService, + llmService = llmService + ) + + private val policyEngine = DefaultPolicyEngine() + + private val toolOrchestrator = ToolOrchestrator( + registry = toolRegistry, + policyEngine = policyEngine, + renderer = renderer, + mcpConfigService = mcpToolConfigService + ) + + private var databaseConnection: DatabaseConnection? = null + + private val executor: ChatDBAgentExecutor by lazy { + val connection = databaseConnection ?: createDatabaseConnection(databaseConfig) + databaseConnection = connection + + ChatDBAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + databaseConnection = connection, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming + ) + } + + override fun validateInput(input: Map): ChatDBTask { + val query = input["query"] as? String + ?: throw IllegalArgumentException("Missing required parameter: query") + + return ChatDBTask( + query = query, + additionalContext = input["additionalContext"] as? String ?: "", + maxRows = (input["maxRows"] as? Number)?.toInt() ?: 100, + generateVisualization = input["generateVisualization"] as? Boolean ?: true + ) + } + + override suspend fun execute( + input: ChatDBTask, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + logger.info { "Starting ChatDB Agent for query: ${input.query}" } + + val systemPrompt = buildSystemPrompt() + val result = executor.execute(input, systemPrompt, onProgress) + + return ToolResult.AgentResult( + success = result.success, + content = result.message, + metadata = mapOf( + "generatedSql" to (result.generatedSql ?: ""), + "rowCount" to (result.queryResult?.rowCount?.toString() ?: "0"), + "revisionAttempts" to result.revisionAttempts.toString(), + "hasVisualization" to (result.plotDslCode != null).toString() + ) + ) + } + + private fun buildSystemPrompt(): String { + return SYSTEM_PROMPT + } + + override fun formatOutput(output: ToolResult.AgentResult): String { + return output.content + } + + override fun getParameterClass(): String = "ChatDBTask" + + /** + * Close database connection when done + */ + suspend fun close() { + databaseConnection?.close() + databaseConnection = null + } + + companion object { + const val SYSTEM_PROMPT = """You are an expert SQL developer. Generate SQL queries from natural language. + +CRITICAL RULES - YOU MUST FOLLOW THESE: +1. ONLY use table names provided in the schema - NEVER invent or guess table names +2. ONLY use column names provided in the schema - NEVER invent or guess column names +3. If a table or column doesn't exist in the schema, DO NOT use it +4. Only generate SELECT queries (read-only operations) +5. Always add LIMIT clause to prevent large result sets + +OUTPUT FORMAT: +- Return ONLY the SQL query wrapped in ```sql code block +- Do NOT include explanations, alternatives, or reasoning +- Do NOT add comments outside the code block +- Keep response concise - just the SQL + +Example response: +```sql +SELECT id, name FROM users WHERE status = 'active' LIMIT 100; +```""" + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt new file mode 100644 index 0000000000..4477f5cd63 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBAgentExecutor.kt @@ -0,0 +1,722 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.conversation.ConversationManager +import cc.unitmesh.agent.database.* +import cc.unitmesh.agent.executor.BaseAgentExecutor +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.SqlValidator +import cc.unitmesh.agent.subagent.SqlReviseAgent +import cc.unitmesh.agent.subagent.SqlRevisionInput +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService + +/** + * ChatDB Agent Executor - Text2SQL Agent with Schema Linking and Revise Agent + * + * Features: + * 1. Schema Linking - Keyword-based search to find relevant tables/columns + * 2. SQL Generation - LLM generates SQL from natural language + * 3. SQL Validation - JSqlParser validates SQL syntax and safety + * 4. Revise Agent - Self-correction loop for fixing SQL errors + * 5. Query Execution - Execute validated SQL and return results + * 6. Visualization - Optional PlotDSL generation for data visualization + */ +class ChatDBAgentExecutor( + projectPath: String, + llmService: KoogLLMService, + toolOrchestrator: ToolOrchestrator, + renderer: CodingAgentRenderer, + private val databaseConnection: DatabaseConnection, + maxIterations: Int = 10, + enableLLMStreaming: Boolean = true, + useLlmSchemaLinker: Boolean = true +) : BaseAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming +) { + private val logger = getLogger("ChatDBAgentExecutor") + private val keywordSchemaLinker = KeywordSchemaLinker() + private val schemaLinker: SchemaLinker = if (useLlmSchemaLinker) { + val fallbackLinker = LlmSchemaLinker(llmService, databaseConnection, keywordSchemaLinker) + DatabaseContentSchemaLinker(llmService, databaseConnection, fallbackLinker) + } else { + keywordSchemaLinker + } + + private val sqlValidator = SqlValidator() + private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) + private val maxRevisionAttempts = 3 + private val maxExecutionRetries = 3 + + suspend fun execute( + task: ChatDBTask, + systemPrompt: String, + onProgress: (String) -> Unit = {} + ): ChatDBResult { + resetExecution() + conversationManager = ConversationManager(llmService, systemPrompt) + + val errors = mutableListOf() + var generatedSql: String? = null + var queryResult: QueryResult? = null + var plotDslCode: String? = null + var revisionAttempts = 0 + + try { + // Step 1: Get database schema + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Fetching database schema..." + ) + onProgress("๐Ÿ“Š Fetching database schema...") + + val schema = task.schema ?: databaseConnection.getSchema() + logger.info { "Database has ${schema.tables.size} tables: ${schema.tables.map { it.name }}" } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.SUCCESS, + title = "Database schema fetched", + details = mapOf( + "databaseName" to (schema.databaseName ?: "Database"), + "totalTables" to schema.tables.size, + "tables" to schema.tables.map { it.name }, + "tableSchemas" to schemaToTableInfoList(schema) + ) + ) + + // Step 2: Schema Linking + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Performing schema linking..." + ) + onProgress("๐Ÿ”— Performing schema linking...") + + val linkingResult = schemaLinker.link(task.query, schema) + logger.info { "Schema linking found ${linkingResult.relevantTables.size} relevant tables: ${linkingResult.relevantTables}" } + logger.info { "Schema linking keywords: ${linkingResult.keywords}" } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.SUCCESS, + title = "Schema linking completed", + details = mapOf( + "relevantTables" to linkingResult.relevantTables, + "keywords" to linkingResult.keywords, + "relevantTableSchemas" to getTableInfoForNames(schema, linkingResult.relevantTables) + ) + ) + + // Step 3: Build context with relevant schema + // If schema linking found too few tables, use all tables to avoid missing important ones + val effectiveLinkingResult = if (linkingResult.relevantTables.size < 2 && schema.tables.size <= 10) { + logger.info { "Schema linking found few tables, using all ${schema.tables.size} tables" } + linkingResult.copy(relevantTables = schema.tables.map { it.name }) + } else { + linkingResult + } + + val relevantSchema = buildRelevantSchemaDescription(schema, effectiveLinkingResult) + val initialMessage = buildInitialUserMessage(task, relevantSchema, effectiveLinkingResult) + + // Step 4: Generate SQL with LLM + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating SQL query..." + ) + onProgress("๐Ÿค– Generating SQL query...") + + val llmResponse = StringBuilder() + val response = getLLMResponse(initialMessage, compileDevIns = false) { chunk -> + onProgress(chunk) + } + llmResponse.append(response) + + // Step 5: Extract SQL from response + generatedSql = extractSqlFromResponse(llmResponse.toString()) + if (generatedSql == null) { + errors.add("Failed to extract SQL from LLM response") + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Failed to extract SQL", + error = "Could not find SQL code block in LLM response" + ) + return buildResult(false, errors, null, null, null, 0) + } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL query generated", + details = mapOf("sql" to generatedSql) + ) + + // Step 6: Validate SQL syntax and table names using SqlReviseAgent + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating SQL..." + ) + + var validatedSql = generatedSql + + // Get all table names from schema for whitelist validation + val allTableNames = schema.tables.map { it.name }.toSet() + + // First validate syntax, then validate table names + val syntaxValidation = sqlValidator.validate(validatedSql!!) + val tableValidation = if (syntaxValidation.isValid) { + sqlValidator.validateWithTableWhitelist(validatedSql, allTableNames) + } else { + syntaxValidation + } + + if (!tableValidation.isValid) { + val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" + + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL validation failed", + details = mapOf( + "errorType" to errorType, + "errors" to tableValidation.errors + ), + error = tableValidation.errors.joinToString("; ") + ) + + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL..." + ) + + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = validatedSql, + errorMessage = tableValidation.errors.joinToString("; "), + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + revisionAttempts = revisionResult.metadata["attempts"]?.toIntOrNull() ?: 0 + + if (revisionResult.success) { + validatedSql = revisionResult.content + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully", + details = mapOf( + "attempts" to revisionAttempts, + "sql" to validatedSql + ) + ) + + onProgress("โœ… SQL revised successfully after $revisionAttempts attempts") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed", + error = revisionResult.content + ) + errors.add("SQL revision failed: ${revisionResult.content}") + } + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL validation passed" + ) + } + + generatedSql = validatedSql + + // Step 7: Execute SQL with retry loop for execution errors + if (generatedSql != null) { + var executionRetries = 0 + var lastExecutionError: String? = null + + while (executionRetries < maxExecutionRetries && queryResult == null) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...", + details = mapOf( + "attempt" to (executionRetries + 1), + "sql" to generatedSql!! + ) + ) + + onProgress("โšก Executing SQL query${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") + + try { + queryResult = databaseConnection.executeQuery(generatedSql!!) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "Query executed successfully", + details = mapOf( + "sql" to generatedSql!!, + "rowCount" to queryResult.rowCount, + "columns" to queryResult.columns, + "previewRows" to getPreviewRows(queryResult, 5) + ) + ) + + onProgress("โœ… Query returned ${queryResult.rowCount} rows") + } catch (e: Exception) { + lastExecutionError = e.message ?: "Unknown execution error" + logger.warn { "Query execution failed (attempt ${executionRetries + 1}): $lastExecutionError" } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Query execution failed", + details = mapOf( + "attempt" to (executionRetries + 1), + "maxAttempts" to maxExecutionRetries + ), + error = lastExecutionError + ) + + // Try to revise SQL based on execution error + if (executionRetries < maxExecutionRetries - 1) { + onProgress("๐Ÿ”„ Attempting to fix SQL based on execution error...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on execution error..." + ) + + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = generatedSql!!, + errorMessage = "Execution error: $lastExecutionError", + schemaDescription = relevantSchema, + maxAttempts = 1 + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success && revisionResult.content != generatedSql) { + generatedSql = revisionResult.content + revisionAttempts++ + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised based on execution error", + details = mapOf( + "sql" to generatedSql + ) + ) + + onProgress("๐Ÿ”ง SQL revised, retrying execution...") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.WARNING, + title = "SQL revision did not help", + error = "Revision did not produce a different SQL" + ) + // Revision didn't help, break the loop + break + } + } + executionRetries++ + } + } + + if (queryResult == null && lastExecutionError != null) { + errors.add("Query execution failed after $executionRetries retries: $lastExecutionError") + logger.error { "Query execution failed after $executionRetries retries" } + } + } + + // Step 8: Generate visualization if requested + if (task.generateVisualization && queryResult != null && !queryResult.isEmpty()) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating visualization..." + ) + + onProgress("๐Ÿ“ˆ Generating visualization...") + + plotDslCode = generateVisualization(task.query, queryResult, onProgress) + + if (plotDslCode != null) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.SUCCESS, + title = "Visualization generated", + details = mapOf( + "code" to plotDslCode + ) + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.WARNING, + title = "Visualization not generated" + ) + } + } + + } catch (e: Exception) { + logger.error(e) { "ChatDB execution failed" } + errors.add("Execution failed: ${e.message}") + } + + val result = buildResult( + success = errors.isEmpty() && queryResult != null, + errors = errors, + generatedSql = generatedSql, + queryResult = queryResult, + plotDslCode = plotDslCode, + revisionAttempts = revisionAttempts + ) + + // Render the final result to the timeline + if (result.success) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.SUCCESS, + title = "Query completed successfully", + details = buildMap { + generatedSql?.let { put("sql", it) } + queryResult?.let { + put("rowCount", it.rowCount) + put("columns", it.columns) + put("previewRows", getPreviewRows(it, 10)) + } + if (revisionAttempts > 0) { + put("revisionAttempts", revisionAttempts) + } + plotDslCode?.let { put("visualization", it) } + } + ) + + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(result.message) + renderer.renderLLMResponseEnd() + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.ERROR, + title = "Query failed", + details = buildMap { + generatedSql?.let { put("sql", it) } + }, + error = errors.joinToString("; ") + ) + + renderer.renderError(result.message) + } + + return result + } + + private fun resetExecution() { + currentIteration = 0 + } + + /** + * Build schema description with sample data for SQL generation + * Based on RSL-SQL research: sample data helps LLM understand table semantics + */ + private suspend fun buildRelevantSchemaDescription( + schema: DatabaseSchema, + linkingResult: SchemaLinkingResult + ): String { + val relevantTables = schema.tables.filter { it.name in linkingResult.relevantTables } + return buildString { + appendLine("## Database Schema (USE ONLY THESE TABLES)") + appendLine() + relevantTables.forEach { table -> + appendLine("Table: ${table.name}") + appendLine("Columns: ${table.columns.joinToString(", ") { "${it.name} (${it.type})" }}") + + // Add sample data to help LLM understand table content + try { + val sampleRows = databaseConnection.getSampleRows(table.name, 2) + if (!sampleRows.isEmpty()) { + appendLine("Sample Data:") + appendLine(" ${sampleRows.columns.joinToString(" | ")}") + sampleRows.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(50) }}") + } + } + } catch (e: Exception) { + // Ignore sample data errors + } + appendLine() + } + } + } + + private fun buildInitialUserMessage( + task: ChatDBTask, + schemaDescription: String, + linkingResult: SchemaLinkingResult + ): String { + return buildString { + appendLine("Generate a SQL query for: ${task.query}") + appendLine() + if (task.additionalContext.isNotBlank()) { + appendLine("Context: ${task.additionalContext}") + appendLine() + } + appendLine("ALLOWED TABLES (use ONLY these): ${linkingResult.relevantTables.joinToString(", ")}") + appendLine() + appendLine(schemaDescription) + appendLine() + appendLine("Max rows: ${task.maxRows}") + appendLine() + appendLine("CRITICAL RULES:") + appendLine("1. Return ONLY ONE SQL statement - never multiple statements") + appendLine("2. Choose the BEST matching table if multiple similar tables exist") + appendLine("3. Wrap the SQL in a ```sql code block") + appendLine("4. No comments, no explanations, just the SQL") + } + } + + private fun extractSqlFromResponse(response: String): String? { + val codeFence = CodeFence.parse(response) + if (codeFence.languageId.lowercase() == "sql" && codeFence.text.isNotBlank()) { + return extractFirstStatement(codeFence.text.trim()) + } + + // Try to find SQL block manually + val sqlPattern = Regex("```sql\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = sqlPattern.find(response) + if (match != null) { + return extractFirstStatement(match.groupValues[1].trim()) + } + + // Last resort: look for SELECT statement + val selectPattern = Regex("(SELECT[\\s\\S]*?;)", RegexOption.IGNORE_CASE) + val selectMatch = selectPattern.find(response) + return selectMatch?.groupValues?.get(1)?.trim() + } + + /** + * Extract only the first SQL statement if LLM returns multiple statements. + * This prevents "multiple statements" errors. + */ + private fun extractFirstStatement(sql: String): String { + // Remove SQL comments (-- style) + val withoutComments = sql.lines() + .filterNot { it.trim().startsWith("--") } + .joinToString("\n") + .trim() + + // If there are multiple statements separated by semicolons, take only the first + val statements = withoutComments.split(";") + .map { it.trim() } + .filter { it.isNotBlank() } + + return if (statements.isNotEmpty()) { + statements.first() + ";" + } else { + withoutComments + } + } + + private suspend fun generateVisualization( + query: String, + result: QueryResult, + onProgress: (String) -> Unit + ): String? { + val visualizationPrompt = buildString { + appendLine("Based on the following query result, generate a PlotDSL visualization:") + appendLine() + appendLine("**Original Query**: $query") + appendLine() + appendLine("**Query Result** (${result.rowCount} rows):") + appendLine("```csv") + appendLine(result.toCsvString()) + appendLine("```") + appendLine() + appendLine("Generate a PlotDSL chart that best visualizes this data.") + appendLine("Choose an appropriate chart type (bar, line, scatter, etc.) based on the data.") + appendLine("Wrap the PlotDSL code in a ```plotdsl code block.") + } + + try { + val response = getLLMResponse(visualizationPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + + val codeFence = CodeFence.parse(response) + if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } + + // Try to find plotdsl block manually + val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = plotPattern.find(response) + return match?.groupValues?.get(1)?.trim() + } catch (e: Exception) { + logger.error(e) { "Visualization generation failed" } + return null + } + } + + private fun buildResult( + success: Boolean, + errors: List, + generatedSql: String?, + queryResult: QueryResult?, + plotDslCode: String?, + revisionAttempts: Int + ): ChatDBResult { + val message = if (success) { + buildString { + appendLine("## โœ… Query Executed Successfully") + appendLine() + + // Show SQL that was executed + if (generatedSql != null) { + appendLine("**Executed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + appendLine() + } + + // Show revision info if applicable + if (revisionAttempts > 0) { + appendLine("*Note: SQL was revised $revisionAttempts time(s) to fix validation/execution errors*") + appendLine() + } + + // Show query results + if (queryResult != null) { + appendLine("**Results** (${queryResult.rowCount} row${if (queryResult.rowCount != 1) "s" else ""}):") + appendLine() + appendLine(queryResult.toTableString()) + } + + // Show visualization if generated + if (plotDslCode != null) { + appendLine() + appendLine("**Visualization:**") + appendLine("```plotdsl") + appendLine(plotDslCode) + appendLine("```") + } + } + } else { + buildString { + appendLine("## โŒ Query Failed") + appendLine() + appendLine("**Errors:**") + errors.forEach { error -> + appendLine("- $error") + } + if (generatedSql != null) { + appendLine() + appendLine("**Failed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + } + } + } + + return ChatDBResult( + success = success, + message = message, + generatedSql = generatedSql, + queryResult = queryResult, + plotDslCode = plotDslCode, + revisionAttempts = revisionAttempts, + errors = errors + ) + } + + override fun buildContinuationMessage(): String { + return "Please continue with the database query based on the results above." + } + + /** + * Convert DatabaseSchema to a list of table info maps for UI rendering + */ + private fun schemaToTableInfoList(schema: DatabaseSchema): List> { + return schema.tables.map { table -> + mapOf( + "name" to table.name, + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: ""), + "defaultValue" to (col.defaultValue ?: "") + ) + } + ) + } + } + + /** + * Get table info maps for specific table names + */ + private fun getTableInfoForNames(schema: DatabaseSchema, tableNames: List): List> { + return tableNames.mapNotNull { tableName -> + schema.getTable(tableName)?.let { table -> + mapOf( + "name" to table.name, + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: ""), + "defaultValue" to (col.defaultValue ?: "") + ) + } + ) + } + } + } + + /** + * Get preview rows from query result (limited to first N rows) + */ + private fun getPreviewRows(queryResult: QueryResult, maxRows: Int = 5): List> { + return queryResult.rows.take(maxRows) + } +} diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt new file mode 100644 index 0000000000..394c9d5de8 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/ChatDBModels.kt @@ -0,0 +1,112 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.QueryResult +import kotlinx.serialization.Serializable + +/** + * ChatDB Task - Input for the Text2SQL Agent + * + * Contains the natural language query and database context + */ +@Serializable +data class ChatDBTask( + /** + * Natural language query from user + * e.g., "Show me the top 10 customers by total order amount" + */ + val query: String, + + /** + * Database schema for context (optional, will be fetched if not provided) + */ + val schema: DatabaseSchema? = null, + + /** + * Additional context or constraints + * e.g., "Only consider orders from 2024" + */ + val additionalContext: String = "", + + /** + * Maximum number of rows to return + */ + val maxRows: Int = 100, + + /** + * Whether to generate visualization after query + */ + val generateVisualization: Boolean = true +) + +/** + * ChatDB Result - Output from the Text2SQL Agent + */ +@Serializable +data class ChatDBResult( + /** + * Whether the query was successful + */ + val success: Boolean, + + /** + * Human-readable message + */ + val message: String, + + /** + * Generated SQL query + */ + val generatedSql: String? = null, + + /** + * Query execution result + */ + val queryResult: QueryResult? = null, + + /** + * Generated PlotDSL code for visualization (if applicable) + */ + val plotDslCode: String? = null, + + /** + * Number of revision attempts made + */ + val revisionAttempts: Int = 0, + + /** + * Errors encountered during execution + */ + val errors: List = emptyList(), + + /** + * Metadata about the execution + */ + val metadata: Map = emptyMap() +) + +/** + * Schema Linking Result - Tables and columns relevant to the query + */ +@Serializable +data class SchemaLinkingResult( + /** + * Relevant table names + */ + val relevantTables: List, + + /** + * Relevant column names (table.column format) + */ + val relevantColumns: List, + + /** + * Keywords extracted from the query + */ + val keywords: List, + + /** + * Confidence score (0.0 - 1.0) + */ + val confidence: Double = 0.0 +) diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt new file mode 100644 index 0000000000..13263c272b --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/DatabaseContentSchemaLinker.kt @@ -0,0 +1,166 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema +import cc.unitmesh.llm.KoogLLMService +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * Database Content Schema Linker - Uses database content to improve Schema Linking accuracy + * + * Based on RSL-SQL research, this linker: + * 1. Filters out system tables (sys_*, x$*, etc.) + * 2. Retrieves sample data from each table + * 3. Uses LLM with enriched context to identify relevant tables + * + * This approach is more accurate than pure keyword matching because it uses + * actual database content to understand table semantics. + */ +class DatabaseContentSchemaLinker( + private val llmService: KoogLLMService, + private val databaseConnection: DatabaseConnection, + private val fallbackLinker: SchemaLinker = KeywordSchemaLinker() +) : SchemaLinker() { + + private val json = Json { ignoreUnknownKeys = true } + + companion object { + // System table prefixes to filter out + private val SYSTEM_TABLE_PREFIXES = listOf( + "sys_", "x$", "innodb_", "io_", "memory_", "schema_", "statement", + "user_summary", "host_summary", "wait", "process", "session", + "metrics", "privileges", "ps_", "flyway_", "hibernate_" + ) + + // System table exact names + private val SYSTEM_TABLE_NAMES = setOf( + "version", "latest_file_io", "session_ssl_status" + ) + + private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database tables with sample data, identify the most relevant tables. + +CRITICAL RULES: +1. ONLY select tables from the provided list - do NOT invent table names +2. Look at sample data to understand what each table actually contains +3. Match the user's semantic intent, not just keywords +4. For "ๆ–‡็ซ /article/post" queries, look for tables containing blog/post content +5. For "ไฝœ่€…/author/creator" queries, look for tables with author/creator information + +Available Tables with Sample Data: +{{TABLES_WITH_SAMPLES}} + +User Query: {{QUERY}} + +Respond ONLY with a JSON object: +{"tables": ["table1", "table2"], "reason": "brief explanation", "confidence": 0.9} + +Select ONLY tables that are directly needed to answer the query.""" + } + + @Serializable + private data class SchemaLinkingResult( + val tables: List = emptyList(), + val reason: String = "", + val confidence: Double = 0.0 + ) + + /** + * Filter out system tables that are not relevant for user queries + */ + private fun filterUserTables(schema: DatabaseSchema): List { + return schema.tables.filter { table -> + val lowerName = table.name.lowercase() + // Filter out system tables + val isSystemTable = SYSTEM_TABLE_PREFIXES.any { prefix -> + lowerName.startsWith(prefix) + } || SYSTEM_TABLE_NAMES.contains(lowerName) + !isSystemTable + } + } + + /** + * Build table description with sample data for Schema Linking + */ + private suspend fun buildTableWithSamples(table: TableSchema): String { + return buildString { + appendLine("Table: ${table.name}") + appendLine(" Columns: ${table.columns.joinToString(", ") { "${it.name}(${it.type})" }}") + + // Get sample data to help understand table content + try { + val samples = databaseConnection.getSampleRows(table.name, 2) + if (!samples.isEmpty()) { + appendLine(" Sample Data:") + appendLine(" ${samples.columns.joinToString(" | ")}") + samples.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(30) }}") + } + } + } catch (e: Exception) { + // Table might not exist in current database, skip it + appendLine(" (No sample data available)") + } + } + } + + override suspend fun link(query: String, schema: DatabaseSchema): cc.unitmesh.agent.chatdb.SchemaLinkingResult { + return try { + // Step 1: Filter out system tables + val userTables = filterUserTables(schema) + if (userTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + // Step 2: Build table descriptions with sample data + val tablesWithSamples = userTables.map { table -> + buildTableWithSamples(table) + }.joinToString("\n") + + // Step 3: Ask LLM to identify relevant tables + val prompt = SCHEMA_LINKING_PROMPT + .replace("{{TABLES_WITH_SAMPLES}}", tablesWithSamples) + .replace("{{QUERY}}", query) + + val response = llmService.sendPrompt(prompt) + val result = parseResponse(response) + + // Step 4: Validate tables exist + val validTables = result.tables.filter { tableName -> + userTables.any { it.name.equals(tableName, ignoreCase = true) } + } + + if (validTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + // Extract keywords for additional context + val keywords = fallbackLinker.extractKeywords(query) + + cc.unitmesh.agent.chatdb.SchemaLinkingResult( + relevantTables = validTables, + relevantColumns = emptyList(), + keywords = keywords, + confidence = result.confidence + ) + } catch (e: Exception) { + fallbackLinker.link(query, schema) + } + } + + override suspend fun extractKeywords(query: String): List { + return fallbackLinker.extractKeywords(query) + } + + private fun parseResponse(response: String): SchemaLinkingResult { + val jsonPattern = Regex("""\{[^{}]*\}""") + val jsonStr = jsonPattern.find(response)?.value ?: "{}" + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + SchemaLinkingResult() + } + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt new file mode 100644 index 0000000000..14e7ffbcd4 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizer.kt @@ -0,0 +1,524 @@ +package cc.unitmesh.agent.chatdb + +/** + * Enhanced Fallback NLP Tokenizer for Kotlin Multiplatform. + * + * Capabilities: + * 1. English: Porter Stemmer for morphological normalization. + * 2. Chinese: Bi-directional Maximum Matching (BiMM) for segmentation. + * 3. Code: SemVer and CamelCase handling. + * 4. Keywords: RAKE (Rapid Automatic Keyword Extraction) algorithm. + * + * References: + * - Porter Stemmer: https://tartarus.org/martin/PorterStemmer/ + * - BiMM: Bi-directional Maximum Matching for Chinese word segmentation + * - RAKE: Rapid Automatic Keyword Extraction + */ +object FallbackNlpTokenizer { + + /** + * Main entry point: Extract weighted keywords from a query. + * This is the primary method for keyword extraction with intelligent weighting. + */ + fun extractKeywords(query: String, maxKeywords: Int = 10): List { + // 1. Pre-process and Tokenize + val tokens = tokenize(query) + + // 2. RAKE Algorithm implementation + // Filter out stop words and delimiters to get content words + val stopWords = StopWords.ENGLISH + StopWords.CHINESE + val contentTokens = tokens.filter { it.text !in stopWords && it.text.length > 1 } + + if (contentTokens.isEmpty()) return emptyList() + + // Build Co-occurrence Graph + val frequency = mutableMapOf() + val degree = mutableMapOf() + + // RAKE window size (words appear together within this distance) + val windowSize = 3 + + for (i in contentTokens.indices) { + val token = contentTokens[i].text + frequency[token] = (frequency[token] ?: 0) + 1 + + val windowStart = maxOf(0, i - windowSize) + val windowEnd = minOf(contentTokens.size - 1, i + windowSize) + + for (j in windowStart..windowEnd) { + if (i == j) continue + degree[token] = (degree[token] ?: 0) + 1 + } + // Degree includes self-occurrence in RAKE definitions + degree[token] = (degree[token] ?: 0) + 1 + } + + // Calculate Scores: Score(w) = deg(w) / freq(w) + val scores = contentTokens.map { it.text }.distinct().associateWith { word -> + val deg = degree[word]?.toDouble() ?: 0.0 + val freq = frequency[word]?.toDouble() ?: 1.0 + deg / freq + } + + return scores.entries + .sortedByDescending { it.value } + .take(maxKeywords) + .map { it.key } + } + + /** + * Legacy method for backward compatibility. + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + */ + fun extractKeywords(query: String, stopWords: Set): List { + val tokens = tokenize(query) + + return tokens + .filter { it.text !in stopWords && it.text.length > 1 } + .map { it.text } + .distinct() + } + + /** + * Unified tokenization pipeline. + */ + fun tokenize(text: String): List { + val tokens = mutableListOf() + + // Step 1: Extract Semantic Versions (v1.2.3, 1.0.0-alpha, 2.1.0+build.123) to prevent splitting + val versionPattern = Regex("""v?(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?""") + + // Use placeholder strategy to preserve SemVer tokens + var processedText = text + val versionTokens = mutableListOf() + versionPattern.findAll(text).forEach { match -> + versionTokens.add(match.value) + processedText = processedText.replace(match.value, " __SEMVER_${versionTokens.size - 1}__ ") + } + + // Step 2: Split by non-alphanumeric (but keep Chinese chars together for now) + val rawSegments = processedText.split(Regex("[^a-zA-Z0-9\\u4e00-\\u9fa5._-]+")) + + for (segment in rawSegments) { + if (segment.isBlank()) continue + + // Check for SemVer placeholder + val semverMatch = Regex("__SEMVER_(\\d+)__").matchEntire(segment) + if (semverMatch != null) { + val index = semverMatch.groupValues[1].toInt() + tokens.add(Token(versionTokens[index], TokenType.CODE)) + continue + } + + if (segment.matches(Regex("[\\u4e00-\\u9fa5]+"))) { + // Pure Chinese Segment + tokens.addAll(ChineseSegmenter.segment(segment)) + } else if (segment.matches(Regex("[a-zA-Z0-9._-]+"))) { + // Pure English/Code segment + val splitCamel = splitCamelCase(segment) + splitCamel.forEach { word -> + val stem = PorterStemmer.stem(word.lowercase()) + tokens.add(Token(stem, TokenType.ENGLISH)) + } + } else { + // Mixed English and Chinese - split at character type boundaries + tokens.addAll(splitMixedScript(segment)) + } + } + + return tokens + } + + /** + * Split mixed script text (English/Chinese) at character type boundaries. + * E.g., "Helloไธ–็•ŒTest" -> ["hello", "ไธ–", "็•Œ", "test"] with appropriate types + */ + private fun splitMixedScript(text: String): List { + val tokens = mutableListOf() + val currentSegment = StringBuilder() + var currentType: Char? = null // 'E' for English/Code, 'C' for Chinese + + for (char in text) { + val charType = when { + char in '\u4e00'..'\u9fa5' -> 'C' + char.isLetterOrDigit() || char in "._-" -> 'E' + else -> null + } + + if (charType == null) { + // Delimiter - flush current segment + if (currentSegment.isNotEmpty()) { + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + currentSegment.clear() + } + currentType = null + continue + } + + if (currentType != null && currentType != charType) { + // Type changed - flush current segment + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + currentSegment.clear() + } + + currentSegment.append(char) + currentType = charType + } + + // Flush remaining + if (currentSegment.isNotEmpty()) { + tokens.addAll(processSegment(currentSegment.toString(), currentType)) + } + + return tokens + } + + private fun processSegment(segment: String, type: Char?): List { + return when (type) { + 'C' -> ChineseSegmenter.segment(segment) + 'E' -> { + splitCamelCase(segment).map { word -> + Token(PorterStemmer.stem(word.lowercase()), TokenType.ENGLISH) + } + } + else -> emptyList() + } + } + + private fun splitCamelCase(s: String): List { + // Regex to look for switch from lower to upper, or number boundaries + return s.replace(Regex("([a-z])([A-Z])"), "$1 $2") + .replace(Regex("([A-Z])([A-Z][a-z])"), "$1 $2") // Handle HTMLParser -> HTML Parser + .replace(Regex("([a-zA-Z])([0-9])"), "$1 $2") + .replace(Regex("([0-9])([a-zA-Z])"), "$1 $2") + .replace('_', ' ') + .replace('.', ' ') + .replace('-', ' ') + .split(' ') + .filter { it.isNotBlank() } + } + + data class Token(val text: String, val type: TokenType) + enum class TokenType { ENGLISH, CHINESE, CODE } + + // --- INTERNAL HELPER CLASSES --- + + /** + * Pure Kotlin implementation of the Porter Stemming Algorithm. + * Reference: https://tartarus.org/martin/PorterStemmer/ + */ + internal object PorterStemmer { + fun stem(word: String): String { + if (word.length < 3) return word + val stemmer = StemmerState(word.toCharArray()) + stemmer.step1() + stemmer.step2() + stemmer.step3() + stemmer.step4() + stemmer.step5() + return stemmer.toString() + } + + private class StemmerState(var b: CharArray) { + var k: Int = b.size - 1 // offset to end of stemmed part + var j: Int = 0 // general offset into string + + override fun toString() = b.concatToString(0, 0 + (k + 1)) + + private fun cons(i: Int): Boolean { + return when (b[i]) { + 'a', 'e', 'i', 'o', 'u' -> false + 'y' -> if (i == 0) true else !cons(i - 1) + else -> true + } + } + + // m() measures the number of consonant sequences between 0 and j. + private fun m(): Int { + var n = 0 + var i = 0 + while (true) { + if (i > j) return n + if (!cons(i)) break + i++ + } + i++ + while (true) { + while (true) { + if (i > j) return n + if (cons(i)) break + i++ + } + i++ + n++ + while (true) { + if (i > j) return n + if (!cons(i)) break + i++ + } + i++ + } + } + + // vowelInStem() is true if 0,...,j contains a vowel + private fun vowelInStem(): Boolean { + for (i in 0..j) if (!cons(i)) return true + return false + } + + // doublec(j) is true if j,(j-1) contain a double consonant + private fun doublec(j: Int): Boolean { + if (j < 1) return false + if (b[j] != b[j - 1]) return false + return cons(j) + } + + // cvc(i) is true if i-2,i-1,i has the form consonant - vowel - consonant + // and also if the second c is not w, x or y. + private fun cvc(i: Int): Boolean { + if (i < 2 || !cons(i) || cons(i - 1) || !cons(i - 2)) return false + val ch = b[i] + return ch != 'w' && ch != 'x' && ch != 'y' + } + + private fun ends(s: String): Boolean { + val l = s.length + val o = k - l + 1 + if (o < 0) return false + for (i in 0 until l) if (b[o + i] != s[i]) return false + j = k - l + return true + } + + private fun setto(s: String) { + val l = s.length + val o = j + 1 + val newB = if (o + l > b.size) b.copyOf(o + l) else b + for (i in 0 until l) newB[o + i] = s[i] + b = newB + k = j + l + } + + private fun r(s: String) { + if (m() > 0) setto(s) + } + + fun step1() { + if (b[k] == 's') { + if (ends("sses")) k -= 2 + else if (ends("ies")) setto("i") + else if (b[k - 1] != 's') k-- + } + if (ends("eed")) { + if (m() > 0) k-- + } else if ((ends("ed") || ends("ing")) && vowelInStem()) { + k = j + if (ends("at")) setto("ate") + else if (ends("bl")) setto("ble") + else if (ends("iz")) setto("ize") + else if (doublec(k)) { + k-- + val ch = b[k] + if (ch == 'l' || ch == 's' || ch == 'z') k++ + } else if (m() == 1 && cvc(k)) setto("e") + } + } + + fun step2() { + if (ends("y") && vowelInStem()) b[k] = 'i' + if (k == 0) return + // Optimization: switch on penultimate char + when (b[k - 1]) { + 'a' -> { + if (ends("ational")) r("ate") + else if (ends("tional")) r("tion") + } + 'c' -> { + if (ends("enci")) r("ence") + else if (ends("anci")) r("ance") + } + 'e' -> if (ends("izer")) r("ize") + 'l' -> { + if (ends("bli")) r("ble") + else if (ends("alli")) r("al") + else if (ends("entli")) r("ent") + else if (ends("eli")) r("e") + else if (ends("ousli")) r("ous") + } + 'o' -> { + if (ends("ization")) r("ize") + else if (ends("ation")) r("ate") + else if (ends("ator")) r("ate") + } + 's' -> { + if (ends("alism")) r("al") + else if (ends("iveness")) r("ive") + else if (ends("fulness")) r("ful") + else if (ends("ousness")) r("ous") + } + 't' -> { + if (ends("aliti")) r("al") + else if (ends("iviti")) r("ive") + else if (ends("biliti")) r("ble") + } + } + } + + fun step3() { + when (b[k]) { + 'e' -> { + if (ends("icate")) r("ic") + else if (ends("ative")) r("") + else if (ends("alize")) r("al") + } + 'i' -> if (ends("iciti")) r("ic") + 'l' -> { + if (ends("ical")) r("ic") + else if (ends("ful")) r("") + } + 's' -> if (ends("ness")) r("") + } + } + + fun step4() { + if (k < 2) return + when (b[k - 1]) { + 'a' -> if (ends("al")) { if (m() > 1) k = j } + 'c' -> { + if (ends("ance") || ends("ence")) { + if (m() > 1) k = j + } + } + 'e' -> if (ends("er")) { if (m() > 1) k = j } + 'i' -> if (ends("ic")) { if (m() > 1) k = j } + 'l' -> { + if (ends("able") || ends("ible")) { + if (m() > 1) k = j + } + } + 'n' -> { + if (ends("ant") || ends("ement") || ends("ment") || ends("ent")) { + if (m() > 1) k = j + } + } + 'o' -> { + if (ends("ion") && j >= 0 && (b[j] == 's' || b[j] == 't')) { + if (m() > 1) k = j + } else if (ends("ou")) { + if (m() > 1) k = j + } + } + 's' -> if (ends("ism")) { if (m() > 1) k = j } + 't' -> { + if (ends("ate") || ends("iti")) { + if (m() > 1) k = j + } + } + 'u' -> if (ends("ous")) { if (m() > 1) k = j } + 'v' -> if (ends("ive")) { if (m() > 1) k = j } + 'z' -> if (ends("ize")) { if (m() > 1) k = j } + } + } + + fun step5() { + j = k + if (b[k] == 'e') { + val a = m() + if (a > 1 || (a == 1 && !cvc(k - 1))) k-- + } + if (b[k] == 'l' && doublec(k) && m() > 1) k-- + } + } + } + + /** + * Bi-directional Maximum Matching for Chinese Segmentation. + * Uses a built-in dictionary for common words. + */ + internal object ChineseSegmenter { + // Common Chinese words dictionary for segmentation + // In a real app, this can be loaded from external resources + val dictionary = setOf( + // Common verbs + "ๆˆ‘ไปฌ", "ไป–ไปฌ", "ไฝ ไปฌ", "่ฟ™ไธช", "้‚ฃไธช", "ไป€ไนˆ", "ๆ€Žไนˆ", "ไธบไป€ไนˆ", "ๅ› ไธบ", "ๆ‰€ไปฅ", + "ๅฏไปฅ", "่ƒฝๅคŸ", "ๅบ”่ฏฅ", "ๅฟ…้กป", "้œ€่ฆ", "ๅธŒๆœ›", "ๆƒณ่ฆ", "็Ÿฅ้“", "่ฎคไธบ", "่ง‰ๅพ—", + // Tech terms + "่ถ…ๅธ‚", "่‡ช็„ถ", "่ฏญ่จ€", "ๅค„็†", "ๆต‹่ฏ•", "ๆ•ฐๆฎ", "ๅผ€ๅ‘", "ๅทฅ็จ‹ๅธˆ", "็จ‹ๅบๅ‘˜", + "ไบบๅทฅๆ™บ่ƒฝ", "ๆœบๅ™จๅญฆไน ", "ๆทฑๅบฆๅญฆไน ", "็ฅž็ป็ฝ‘็ปœ", "ๆจกๅž‹", "ไปฃ็ ", "้€ป่พ‘", + "็”จๆˆท", "็•Œ้ข", "่ฎพ่ฎก", "็ณป็ปŸ", "ๅˆ†ๆž", "ๆ•ฐๆฎๅบ“", "ๆœๅŠกๅ™จ", "ๅฎขๆˆท็ซฏ", + "่ฝฏไปถ", "็กฌไปถ", "็ฎ—ๆณ•", "ๅ‡ฝๆ•ฐ", "ๅ˜้‡", "ๅฏน่ฑก", "็ฑปๅž‹", "ๆŽฅๅฃ", "ๅฎž็Žฐ", + // Place names + "ๅŒ—ไบฌ", "ไธŠๆตท", "ๅนฟๅทž", "ๆทฑๅœณ", "ๆญๅทž", "ๅ—ไบฌ", "ๅคฉๆดฅ", "้‡ๅบ†", "ๆˆ้ƒฝ", "ๆญฆๆฑ‰", + "ๅ—ไบฌๅธ‚", "้•ฟๆฑŸ", "ๅคงๆกฅ", "้ป„ๆฒณ", "้•ฟๅŸŽ", "ๆ•…ๅฎซ", + // Common nouns + "ๅ…ฌๅธ", "ๅญฆๆ ก", "ๅŒป้™ข", "้“ถ่กŒ", "ๅ•†ๅบ—", "้คๅŽ…", "้…’ๅบ—", "ๆœบๅœบ", "็ซ่ฝฆ็ซ™", + "็”ต่„‘", "ๆ‰‹ๆœบ", "็ฝ‘็ปœ", "ไบ’่”็ฝ‘", "็”ตๅญ", "็ง‘ๆŠ€", "ๆŠ€ๆœฏ", "ไบงๅ“", "้กน็›ฎ", + // Time words + "ไปŠๅคฉ", "ๆ˜Žๅคฉ", "ๆ˜จๅคฉ", "็Žฐๅœจ", "ไปฅๅŽ", "ไปฅๅ‰", "ๅฐ†ๆฅ", "่ฟ‡ๅŽป", "ๅนดๆœˆ", "ๆ—ฅๆœŸ" + ) + private const val MAX_LEN = 5 + + fun segment(text: String): List { + val fmm = forwardMaxMatch(text) + val rmm = reverseMaxMatch(text) + + // Heuristic: Prefer fewer tokens (implies longer words matched) + return if (fmm.size <= rmm.size) fmm else rmm + } + + private fun forwardMaxMatch(text: String): List { + val tokens = mutableListOf() + var i = 0 + while (i < text.length) { + var matched = false + // Try window sizes from MAX_LEN down to 1 + for (len in minOf(MAX_LEN, text.length - i) downTo 1) { + val sub = text.substring(i, i + len) + if (len == 1 || dictionary.contains(sub)) { + tokens.add(Token(sub, TokenType.CHINESE)) + i += len + matched = true + break + } + } + if (!matched) i++ // Should not happen due to len=1 fallback + } + return tokens + } + + private fun reverseMaxMatch(text: String): List { + val tokens = mutableListOf() + var i = text.length + while (i > 0) { + var matched = false + for (len in minOf(MAX_LEN, i) downTo 1) { + val sub = text.substring(i - len, i) + if (len == 1 || dictionary.contains(sub)) { + tokens.add(0, Token(sub, TokenType.CHINESE)) + i -= len + matched = true + break + } + } + if (!matched) i-- + } + return tokens + } + } + + internal object StopWords { + val ENGLISH = setOf( + "the", "is", "at", "which", "on", "and", "a", "an", "in", "to", "of", "for", + "it", "that", "this", "by", "from", "be", "or", "as", "with", "are", "was", + "were", "been", "being", "have", "has", "had", "do", "does", "did", "will", + "would", "could", "should", "may", "might", "must", "shall", "can", "need", + "get", "all", "me", "my", "show", "most", "top", "up", "down", "out" + ) + // Minimal set for Chinese + val CHINESE = setOf( + "็š„", "ไบ†", "ๅ’Œ", "ๆ˜ฏ", "ๅฐฑ", "้ƒฝ", "่€Œ", "ๅŠ", "ไธŽ", "็€", "ๆˆ–", "ไธ€ไธช", "ๆฒกๆœ‰", "ๆˆ‘ไปฌ", + "ไธ", "ไนŸ", "ๅพˆ", "ๅœจ", "ๆœ‰", "่ฟ™", "้‚ฃ", "ไป–", "ๅฅน", "ๅฎƒ", "ไปฌ", "ๅ—", "ๅ‘ข", "ๅง" + ) + } +} diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt new file mode 100644 index 0000000000..dede9b34e4 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/KeywordSchemaLinker.kt @@ -0,0 +1,121 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema + +/** + * Keyword-based Schema Linker - Uses keyword matching and fuzzy matching + * + * This is the default/fallback implementation that doesn't require LLM calls. + * It extracts keywords from the query and matches them against table/column names. + */ +class KeywordSchemaLinker : SchemaLinker() { + + /** + * Link natural language query to relevant schema elements + */ + override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + val keywords = extractKeywords(query) + val relevantTables = mutableListOf() + val relevantColumns = mutableListOf() + var totalScore = 0.0 + var matchCount = 0 + + for (table in schema.tables) { + val tableScore = calculateTableRelevance(table, keywords) + if (tableScore > 0) { + relevantTables.add(table.name) + totalScore += tableScore + matchCount++ + + // Find relevant columns in this table + for (column in table.columns) { + val columnScore = calculateColumnRelevance(column.name, column.comment, keywords) + if (columnScore > 0) { + relevantColumns.add("${table.name}.${column.name}") + } + } + } + } + + // If no tables matched, include all tables (fallback) + if (relevantTables.isEmpty()) { + relevantTables.addAll(schema.tables.map { it.name }) + } + + val confidence = if (matchCount > 0) (totalScore / matchCount).coerceIn(0.0, 1.0) else 0.0 + + return SchemaLinkingResult( + relevantTables = relevantTables, + relevantColumns = relevantColumns, + keywords = keywords, + confidence = confidence + ) + } + + /** + * Extract keywords from natural language query using platform-specific NLP tokenization. + * + * On JVM, this uses MyNLP for proper Chinese word segmentation. + * On other platforms, this falls back to simple regex-based tokenization. + */ + override suspend fun extractKeywords(query: String): List { + return NlpTokenizer.extractKeywords(query, STOP_WORDS) + } + + /** + * Calculate relevance score for a table + */ + private fun calculateTableRelevance(table: TableSchema, keywords: List): Double { + var score = 0.0 + val tableName = table.name.lowercase() + val tableComment = table.comment?.lowercase() ?: "" + + for (keyword in keywords) { + // Exact match in table name + if (tableName == keyword) { + score += 1.0 + } + // Partial match in table name + else if (tableName.contains(keyword) || keyword.contains(tableName)) { + score += 0.7 + } + // Match in table comment + else if (tableComment.contains(keyword)) { + score += 0.5 + } + // Fuzzy match (Levenshtein distance) + else if (fuzzyMatch(tableName, keyword)) { + score += 0.3 + } + + // Check column names + for (column in table.columns) { + val colName = column.name.lowercase() + if (colName == keyword || colName.contains(keyword)) { + score += 0.4 + } + } + } + + return score + } + + /** + * Calculate relevance score for a column + */ + private fun calculateColumnRelevance(columnName: String, comment: String?, keywords: List): Double { + var score = 0.0 + val colName = columnName.lowercase() + val colComment = comment?.lowercase() ?: "" + + for (keyword in keywords) { + if (colName == keyword) score += 1.0 + else if (colName.contains(keyword)) score += 0.7 + else if (colComment.contains(keyword)) score += 0.5 + else if (fuzzyMatch(colName, keyword)) score += 0.3 + } + + return score + } +} \ No newline at end of file diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt new file mode 100644 index 0000000000..2ade26623f --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/LlmSchemaLinker.kt @@ -0,0 +1,218 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.llm.KoogLLMService +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * LLM-based Schema Linker - Uses LLM to extract keywords and link schema + * + * This implementation uses the LLM to: + * 1. Extract semantic keywords from the natural language query + * 2. Map natural language terms to database schema elements + * 3. Use sample data for better value matching (RSL-SQL approach) + * + * Falls back to KeywordSchemaLinker if LLM fails. + * + * Based on research from: + * - RSL-SQL: Robust Schema Linking in Text-to-SQL Generation + * - A Survey of NL2SQL with Large Language Models + */ +class LlmSchemaLinker( + private val llmService: KoogLLMService, + private val databaseConnection: DatabaseConnection? = null, + private val fallbackLinker: KeywordSchemaLinker = KeywordSchemaLinker() +) : SchemaLinker() { + + private val json = Json { ignoreUnknownKeys = true } + + companion object { + private const val KEYWORD_EXTRACTION_PROMPT = """You are a database schema expert. Extract keywords from the user's natural language query that are relevant for finding database tables and columns. + +Given the user query, extract: +1. Entity names (nouns that might be table names) +2. Attribute names (properties that might be column names) +3. Semantic synonyms (alternative terms for the same concept) + +Respond ONLY with a JSON object in this exact format: +{"keywords": ["keyword1", "keyword2", ...], "entities": ["entity1", ...], "attributes": ["attr1", ...]} + +User Query: """ + + // Enhanced Schema Linking prompt based on RSL-SQL research + private const val SCHEMA_LINKING_PROMPT = """You are a database schema expert. Given a user query and database schema with sample data, identify the most relevant tables and columns. + +CRITICAL RULES: +1. You MUST only use table and column names that EXACTLY exist in the provided schema +2. Do NOT invent or hallucinate table/column names +3. Look at the sample data to understand what each table contains +4. Match user's intent with the actual table/column names, not assumed names + +Database Schema with Sample Data: +{{SCHEMA}} + +User Query: {{QUERY}} + +Respond ONLY with a JSON object in this exact format: +{"tables": ["table1", "table2"], "columns": ["table1.column1", "table2.column2"], "confidence": 0.8} + +Only include tables and columns that are directly relevant to answering the query.""" + } + + @Serializable + private data class KeywordExtractionResult( + val keywords: List = emptyList(), + val entities: List = emptyList(), + val attributes: List = emptyList() + ) + + @Serializable + private data class SchemaLinkingLlmResult( + val tables: List = emptyList(), + val columns: List = emptyList(), + val confidence: Double = 0.0 + ) + + /** + * Link natural language query to relevant schema elements using LLM + */ + override suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult { + return try { + // Build schema description for LLM + val schemaDescription = buildSchemaDescription(schema) + + // Ask LLM to identify relevant tables and columns + val prompt = SCHEMA_LINKING_PROMPT + .replace("{{SCHEMA}}", schemaDescription) + .replace("{{QUERY}}", query) + val response = llmService.sendPrompt(prompt) + + // Parse LLM response + val llmResult = parseSchemaLinkingResponse(response) + + // Validate that tables/columns exist in schema + val validTables = llmResult.tables.filter { tableName -> + schema.tables.any { it.name.equals(tableName, ignoreCase = true) } + } + val validColumns = llmResult.columns.filter { colRef -> + val parts = colRef.split(".") + if (parts.size == 2) { + val table = schema.tables.find { it.name.equals(parts[0], ignoreCase = true) } + table?.columns?.any { it.name.equals(parts[1], ignoreCase = true) } == true + } else false + } + + // Extract keywords for the result + val keywords = extractKeywords(query) + + // If LLM didn't find valid tables, fall back to keyword linker + if (validTables.isEmpty()) { + return fallbackLinker.link(query, schema) + } + + SchemaLinkingResult( + relevantTables = validTables, + relevantColumns = validColumns, + keywords = keywords, + confidence = llmResult.confidence + ) + } catch (e: Exception) { + // Fall back to keyword-based linking on any error + fallbackLinker.link(query, schema) + } + } + + /** + * Extract keywords from natural language query using LLM + */ + override suspend fun extractKeywords(query: String): List { + return try { + val prompt = KEYWORD_EXTRACTION_PROMPT + query + val response = llmService.sendPrompt(prompt) + + val result = parseKeywordExtractionResponse(response) + (result.keywords + result.entities + result.attributes).distinct() + } catch (e: Exception) { + // Fall back to simple keyword extraction + fallbackLinker.extractKeywords(query) + } + } + + /** + * Build schema description with sample data for better Schema Linking + * Based on RSL-SQL research: sample data helps LLM understand table semantics + */ + private suspend fun buildSchemaDescription(schema: DatabaseSchema): String { + val tableDescriptions = mutableListOf() + + for (table in schema.tables) { + val description = buildString { + appendLine("Table: ${table.name}") + + // Column information + val columns = table.columns.joinToString(", ") { col -> + "${col.name} (${col.type}${if (col.isPrimaryKey) ", PK" else ""})" + } + appendLine("Columns: $columns") + + // Add sample data if database connection is available + if (databaseConnection != null) { + try { + val sampleRows = databaseConnection.getSampleRows(table.name, 2) + if (!sampleRows.isEmpty()) { + appendLine("Sample Data:") + appendLine(" ${sampleRows.columns.joinToString(" | ")}") + sampleRows.rows.take(2).forEach { row -> + appendLine(" ${row.joinToString(" | ") { it.take(30) }}") + } + } + } catch (e: Exception) { + // Ignore sample data errors + } + } + }.trim() + tableDescriptions.add(description) + } + + return tableDescriptions.joinToString("\n\n") + } + + /** + * Build schema description without sample data (synchronous version) + */ + private fun buildSchemaDescriptionSync(schema: DatabaseSchema): String { + return schema.tables.joinToString("\n\n") { table -> + val columns = table.columns.joinToString(", ") { col -> + "${col.name} (${col.type}${if (col.isPrimaryKey) ", PK" else ""})" + } + "Table: ${table.name}\nColumns: $columns" + } + } + + private fun parseKeywordExtractionResponse(response: String): KeywordExtractionResult { + val jsonStr = extractJsonFromResponse(response) + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + KeywordExtractionResult() + } + } + + private fun parseSchemaLinkingResponse(response: String): SchemaLinkingLlmResult { + val jsonStr = extractJsonFromResponse(response) + return try { + json.decodeFromString(jsonStr) + } catch (e: Exception) { + SchemaLinkingLlmResult() + } + } + + private fun extractJsonFromResponse(response: String): String { + // Try to find JSON object in the response + val jsonPattern = Regex("""\{[^{}]*\}""") + return jsonPattern.find(response)?.value ?: "{}" + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt new file mode 100644 index 0000000000..36adfff158 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBAgent.kt @@ -0,0 +1,219 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.core.MainAgent +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.policy.DefaultPolicyEngine +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.render.DefaultCodingAgentRenderer +import cc.unitmesh.agent.tool.shell.DefaultShellExecutor +import cc.unitmesh.agent.tool.shell.ShellExecutor +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.agent.tool.filesystem.ToolFileSystem +import cc.unitmesh.agent.tool.registry.ToolRegistry +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig + +/** + * Multi-Database ChatDB Agent - Text2SQL Agent supporting multiple database connections + * + * This agent converts natural language queries to SQL across multiple databases. + * It merges schemas from all connected databases and lets the LLM decide which + * database(s) to query based on the user's question. + * + * Features: + * - Multi-Database Schema Linking: Merges schemas from all databases with database prefixes + * - Intelligent Database Selection: LLM determines which database to query + * - Parallel Execution: Can execute queries on multiple databases simultaneously + * - Unified Results: Combines results from multiple databases + */ +class MultiDatabaseChatDBAgent( + private val projectPath: String, + private val llmService: KoogLLMService, + private val databaseConfigs: Map, + override val maxIterations: Int = 10, + private val renderer: CodingAgentRenderer = DefaultCodingAgentRenderer(), + private val fileSystem: ToolFileSystem? = null, + private val shellExecutor: ShellExecutor? = null, + private val mcpToolConfigService: McpToolConfigService, + private val enableLLMStreaming: Boolean = true +) : MainAgent( + AgentDefinition( + name = "MultiDatabaseChatDBAgent", + displayName = "Multi-Database ChatDB Agent", + description = "Text2SQL Agent that queries across multiple databases with intelligent database selection", + promptConfig = PromptConfig( + systemPrompt = MULTI_DB_SYSTEM_PROMPT + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 10, maxTimeMinutes = 5) + ) +) { + private val logger = getLogger("MultiDatabaseChatDBAgent") + + private val actualFileSystem = fileSystem ?: DefaultToolFileSystem(projectPath = projectPath) + + private val toolRegistry = ToolRegistry( + fileSystem = actualFileSystem, + shellExecutor = shellExecutor ?: DefaultShellExecutor(), + configService = mcpToolConfigService, + llmService = llmService + ) + + private val policyEngine = DefaultPolicyEngine() + + private val toolOrchestrator = ToolOrchestrator( + registry = toolRegistry, + policyEngine = policyEngine, + renderer = renderer, + mcpConfigService = mcpToolConfigService + ) + + // Database connections keyed by database name/id + private val databaseConnections: MutableMap = mutableMapOf() + + private val executor: MultiDatabaseChatDBExecutor by lazy { + // Create connections for all configured databases + databaseConfigs.forEach { (id, config) -> + if (!databaseConnections.containsKey(id)) { + databaseConnections[id] = createDatabaseConnection(config) + } + } + + MultiDatabaseChatDBExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + databaseConnections = databaseConnections, + databaseConfigs = databaseConfigs, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming + ) + } + + override fun validateInput(input: Map): ChatDBTask { + val query = input["query"] as? String + ?: throw IllegalArgumentException("Missing required parameter: query") + + return ChatDBTask( + query = query, + additionalContext = input["additionalContext"] as? String ?: "", + maxRows = (input["maxRows"] as? Number)?.toInt() ?: 100, + generateVisualization = input["generateVisualization"] as? Boolean ?: true + ) + } + + override suspend fun execute( + input: ChatDBTask, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + logger.info { "Starting Multi-Database ChatDB Agent for query: ${input.query}" } + logger.info { "Connected databases: ${databaseConfigs.keys}" } + + val systemPrompt = buildSystemPrompt() + val result = executor.execute(input, systemPrompt, onProgress) + + return ToolResult.AgentResult( + success = result.success, + content = result.message, + metadata = mapOf( + "generatedSql" to (result.generatedSql ?: ""), + "rowCount" to (result.queryResult?.rowCount?.toString() ?: "0"), + "revisionAttempts" to result.revisionAttempts.toString(), + "hasVisualization" to (result.plotDslCode != null).toString(), + "targetDatabases" to (result.targetDatabases?.joinToString(",") ?: "") + ) + ) + } + + private fun buildSystemPrompt(): String { + return MULTI_DB_SYSTEM_PROMPT + } + + override fun formatOutput(output: ToolResult.AgentResult): String { + return output.content + } + + override fun getParameterClass(): String = "ChatDBTask" + + /** + * Close all database connections when done + */ + suspend fun close() { + databaseConnections.values.forEach { it.close() } + databaseConnections.clear() + } + + companion object { + const val MULTI_DB_SYSTEM_PROMPT = """You are an expert SQL developer working with MULTIPLE databases. + +IMPORTANT: You are connected to multiple databases. Each table in the schema is prefixed with its database name. +Format: [database_name].table_name + +CRITICAL RULES: +1. ONLY use table names provided in the schema - NEVER invent or guess table names +2. ONLY use column names provided in the schema - NEVER invent or guess column names +3. When generating SQL, use the table name WITHOUT the database prefix (the system will route to the correct database) +4. If the user's question relates to tables in multiple databases, generate SEPARATE SQL queries for each database +5. Always add LIMIT clause for SELECT queries to prevent large result sets + +SUPPORTED OPERATIONS: +- SELECT: Read data (no approval required) +- INSERT: Add new records (requires user approval) +- UPDATE: Modify existing records (requires user approval) +- DELETE: Remove records (requires user approval, HIGH RISK) +- CREATE TABLE: Create new tables (requires user approval) +- ALTER TABLE: Modify table structure (requires user approval, HIGH RISK) +- DROP TABLE: Delete tables (requires user approval, HIGH RISK) +- TRUNCATE: Remove all records (requires user approval, HIGH RISK) + +โš ๏ธ WRITE OPERATIONS WARNING: +- All write operations (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, TRUNCATE) require explicit user approval before execution +- HIGH RISK operations (DELETE, DROP, TRUNCATE, ALTER) will be highlighted with additional warnings +- Always confirm with the user before generating destructive SQL + +OUTPUT FORMAT: +- For single database query: +```sql +-- database: +SELECT id, name FROM users WHERE status = 'active' LIMIT 100; +``` + +- For multiple database queries: +```sql +-- database: db1 +SELECT * FROM users LIMIT 100; +``` +```sql +-- database: db2 +SELECT * FROM customers LIMIT 100; +``` + +- For write operations: +```sql +-- database: +INSERT INTO users (name, email) VALUES ('John', 'john@example.com'); +``` + +```sql +-- database: +CREATE TABLE new_table ( + id INT PRIMARY KEY, + name VARCHAR(100) +); +``` + +The "-- database: " comment is REQUIRED to specify which database to execute the query on.""" + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt new file mode 100644 index 0000000000..f24aecbcbc --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/MultiDatabaseChatDBExecutor.kt @@ -0,0 +1,1146 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.conversation.ConversationManager +import cc.unitmesh.agent.database.* +import cc.unitmesh.agent.executor.BaseAgentExecutor +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.orchestrator.ToolOrchestrator +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType +import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.PlotDSLAgent +import cc.unitmesh.agent.subagent.PlotDSLContext +import cc.unitmesh.agent.subagent.SqlOperationType +import cc.unitmesh.agent.subagent.SqlValidator +import cc.unitmesh.agent.subagent.SqlReviseAgent +import cc.unitmesh.agent.subagent.SqlRevisionInput +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.suspendCancellableCoroutine +import kotlin.coroutines.resume + +/** + * Multi-Database ChatDB Executor - Executes Text2SQL across multiple databases + * + * Key differences from single-database executor: + * 1. Merges schemas from all databases with database name prefixes + * 2. Parses SQL comments to determine target database + * 3. Can execute queries on multiple databases in parallel + * 4. Combines results from multiple databases + * + * Flow: + * 1. FETCH_SCHEMA - Fetch schemas from all databases + * 2. SCHEMA_LINKING - Use LLM/Keyword linker to find relevant tables + * 3. GENERATE_SQL - Generate SQL with database routing comments + * 4. VALIDATE_SQL - Validate SQL syntax and table names + * 5. REVISE_SQL - Fix SQL if validation fails + * 6. EXECUTE_SQL - Execute on target databases with retry + * 7. GENERATE_VISUALIZATION - Optional visualization + * 8. FINAL_RESULT - Return combined results + */ +class MultiDatabaseChatDBExecutor( + projectPath: String, + llmService: KoogLLMService, + toolOrchestrator: ToolOrchestrator, + renderer: CodingAgentRenderer, + private val databaseConnections: Map, + private val databaseConfigs: Map, + maxIterations: Int = 10, + enableLLMStreaming: Boolean = true +) : BaseAgentExecutor( + projectPath = projectPath, + llmService = llmService, + toolOrchestrator = toolOrchestrator, + renderer = renderer, + maxIterations = maxIterations, + enableLLMStreaming = enableLLMStreaming +) { + private val logger = getLogger("MultiDatabaseChatDBExecutor") + private val keywordSchemaLinker = KeywordSchemaLinker() + private val sqlValidator = SqlValidator() + private val sqlReviseAgent = SqlReviseAgent(llmService, sqlValidator) + private val plotDSLAgent = PlotDSLAgent(llmService) + private val maxRevisionAttempts = 3 + private val maxExecutionRetries = 3 + + // Cache for merged schema + private var mergedSchema: MergedDatabaseSchema? = null + + suspend fun execute( + task: ChatDBTask, + systemPrompt: String, + onProgress: (String) -> Unit = {} + ): MultiDatabaseChatDBResult { + currentIteration = 0 + conversationManager = ConversationManager(llmService, systemPrompt) + + val errors = mutableListOf() + var generatedSql: String? = null + val queryResults = mutableMapOf() + var revisionAttempts = 0 + val targetDatabases = mutableListOf() + var plotDslCode: String? = null + + try { + // Step 1: Fetch and merge schemas from all databases + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Fetching schemas from ${databaseConnections.size} databases..." + ) + onProgress("๐Ÿ“Š Fetching schemas from ${databaseConnections.size} databases...") + + val merged = fetchAndMergeSchemas() + mergedSchema = merged + + renderer.renderChatDBStep( + stepType = ChatDBStepType.FETCH_SCHEMA, + status = ChatDBStepStatus.SUCCESS, + title = "Schemas fetched from ${merged.databases.size} databases", + details = mapOf( + "databases" to merged.databases.map { (name, schema) -> + mapOf( + "name" to name, + "displayName" to (databaseConfigs[name]?.databaseName ?: name), + "tableCount" to schema.tables.size, + "tables" to schema.tables.map { it.name } + ) + }, + "totalTables" to merged.totalTableCount, + "tableSchemas" to merged.databases.flatMap { (dbName, schema) -> + schema.tables.map { table -> + mapOf( + "name" to "[$dbName].${table.name}", + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey, + "comment" to (col.comment ?: "") + ) + } + ) + } + } + ) + ) + + // Step 2: Schema Linking - Find relevant tables using keyword linker + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Performing schema linking across ${merged.databases.size} databases..." + ) + onProgress("๐Ÿ”— Performing schema linking...") + + // Perform schema linking for each database + val linkingResults = mutableMapOf() + val allRelevantTables = mutableListOf>() + val allKeywords = mutableSetOf() + + for ((dbId, schema) in merged.databases) { + val linkingResult = keywordSchemaLinker.link(task.query, schema) + linkingResults[dbId] = linkingResult + allKeywords.addAll(linkingResult.keywords) + + // Collect relevant table schemas for UI + linkingResult.relevantTables.forEach { tableName -> + schema.getTable(tableName)?.let { table -> + allRelevantTables.add(mapOf( + "name" to "[$dbId].${table.name}", + "comment" to (table.comment ?: ""), + "columns" to table.columns.map { col -> + mapOf( + "name" to col.name, + "type" to col.type, + "nullable" to col.nullable, + "isPrimaryKey" to col.isPrimaryKey, + "isForeignKey" to col.isForeignKey + ) + } + )) + } + } + } + + val schemaContext = buildMultiDatabaseSchemaContext(merged, task.query, linkingResults) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.SCHEMA_LINKING, + status = ChatDBStepStatus.SUCCESS, + title = "Schema linking complete - found ${allRelevantTables.size} relevant tables", + details = mapOf( + "databasesAnalyzed" to merged.databases.keys.toList(), + "keywords" to allKeywords.toList(), + "relevantTableSchemas" to allRelevantTables, + "schemaContext" to schemaContext.take(500) + if (schemaContext.length > 500) "..." else "" + ) + ) + + // Step 3: Generate SQL with multi-database context + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating SQL query..." + ) + onProgress("๐Ÿค– Generating SQL query...") + + val sqlPrompt = buildMultiDatabaseSqlPrompt(task.query, schemaContext, task.maxRows) + val sqlResponse = getLLMResponse(sqlPrompt, compileDevIns = false) { chunk -> + onProgress(chunk) + } + + // Parse SQL blocks with database targets + var sqlBlocks = parseSqlBlocksWithTargets(sqlResponse) + + if (sqlBlocks.isEmpty()) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Failed to extract SQL", + error = "Could not find SQL code block in LLM response" + ) + throw DatabaseException("No valid SQL generated") + } + + generatedSql = sqlBlocks.joinToString("\n\n") { "-- database: ${it.database}\n${it.sql}" } + targetDatabases.addAll(sqlBlocks.map { it.database }.distinct()) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL generated for ${targetDatabases.size} database(s)", + details = mapOf( + "targetDatabases" to targetDatabases, + "sqlBlocks" to sqlBlocks.map { mapOf("database" to it.database, "sql" to it.sql) } + ) + ) + + // Step 4: Validate SQL for each database + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating SQL..." + ) + onProgress("๐Ÿ” Validating SQL...") + + val validatedBlocks = mutableListOf() + var hasValidationErrors = false + + for (block in sqlBlocks) { + val dbSchema = merged.databases[block.database] + val allTableNames = dbSchema?.tables?.map { it.name }?.toSet() ?: emptySet() + + val syntaxValidation = sqlValidator.validate(block.sql) + val tableValidation = if (syntaxValidation.isValid) { + sqlValidator.validateWithTableWhitelist(block.sql, allTableNames) + } else { + syntaxValidation + } + + if (!tableValidation.isValid) { + hasValidationErrors = true + val errorType = if (!syntaxValidation.isValid) "syntax" else "table name" + + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL validation failed for ${block.database}", + details = mapOf( + "database" to block.database, + "errorType" to errorType, + "errors" to tableValidation.errors + ), + error = tableValidation.errors.joinToString("; ") + ) + + // Step 5: Revise SQL + onProgress("๐Ÿ”„ SQL validation failed ($errorType), invoking SqlReviseAgent...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL for ${block.database}..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(block.database, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = block.sql, + errorMessage = tableValidation.errors.joinToString("; "), + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + revisionAttempts += revisionResult.metadata["attempts"]?.toIntOrNull() ?: 0 + + if (revisionResult.success) { + validatedBlocks.add(SqlBlock(block.database, revisionResult.content)) + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully for ${block.database}", + details = mapOf( + "database" to block.database, + "attempts" to revisionAttempts, + "sql" to revisionResult.content + ) + ) + onProgress("โœ… SQL revised successfully") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed for ${block.database}", + error = revisionResult.content + ) + errors.add("[${block.database}] SQL revision failed: ${revisionResult.content}") + } + } else { + validatedBlocks.add(block) + } + } + + if (!hasValidationErrors) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.VALIDATE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL validation passed for all databases" + ) + } + + sqlBlocks = validatedBlocks + generatedSql = sqlBlocks.joinToString("\n\n") { "-- database: ${it.database}\n${it.sql}" } + + // Step 6: Execute SQL on target databases with retry + for (sqlBlock in sqlBlocks) { + val dbName = sqlBlock.database + var sql = sqlBlock.sql + val connection = databaseConnections[dbName] + + if (connection == null) { + errors.add("Database '$dbName' not connected") + continue + } + + // Detect SQL operation type + val operationType = sqlValidator.detectSqlType(sql) + val isWriteOperation = operationType.requiresApproval() + val isHighRisk = operationType.isHighRisk() + + // If write operation, perform dry run first, then request approval + if (isWriteOperation) { + val affectedTables = extractTablesFromSql(sql, merged.databases[dbName]) + + // Step: Dry run to validate SQL before asking for approval + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Validating ${operationType.name} operation...", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + onProgress("๐Ÿ” Performing dry run validation...") + + var dryRunResult = connection.dryRun(sql) + + if (!dryRunResult.isValid) { + // Dry run failed - SQL has errors + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.ERROR, + title = "Dry run failed: SQL validation error", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "errors" to dryRunResult.errors + ), + error = dryRunResult.message ?: dryRunResult.errors.firstOrNull() ?: "Unknown error" + ) + errors.add("[$dbName] Dry run failed: ${dryRunResult.message}") + + // Try to revise SQL based on dry run error + onProgress("๐Ÿ”„ SQL validation failed, attempting to revise...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on validation error..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(dbName, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = sql, + errorMessage = "Dry run error: ${dryRunResult.message}", + schemaDescription = relevantSchema, + maxAttempts = maxRevisionAttempts + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success) { + sql = revisionResult.content + revisionAttempts++ + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised successfully", + details = mapOf("database" to dbName, "sql" to sql) + ) + + // Re-run dry run with revised SQL + val revisedDryRunResult = connection.dryRun(sql) + if (!revisedDryRunResult.isValid) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.ERROR, + title = "Revised SQL still has errors", + error = revisedDryRunResult.message + ) + errors.add("[$dbName] Revised SQL still invalid: ${revisedDryRunResult.message}") + continue + } + // Update dryRunResult with the successful revised result + dryRunResult = revisedDryRunResult + + // Render success for revised dry run + val estimatedInfo = if (dryRunResult.estimatedRows != null) { + " (estimated ${dryRunResult.estimatedRows} row(s) affected)" + } else "" + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.SUCCESS, + title = "Revised SQL passed validation$estimatedInfo", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "estimatedRows" to (dryRunResult.estimatedRows ?: "unknown"), + "warnings" to dryRunResult.warnings + ) + ) + onProgress("โœ… Revised SQL validation passed$estimatedInfo") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.ERROR, + title = "SQL revision failed", + error = revisionResult.content + ) + continue + } + } else { + // Dry run succeeded + val estimatedInfo = if (dryRunResult.estimatedRows != null) { + " (estimated ${dryRunResult.estimatedRows} row(s) affected)" + } else "" + + renderer.renderChatDBStep( + stepType = ChatDBStepType.DRY_RUN, + status = ChatDBStepStatus.SUCCESS, + title = "Dry run passed$estimatedInfo", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql, + "estimatedRows" to (dryRunResult.estimatedRows ?: "unknown"), + "warnings" to dryRunResult.warnings + ) + ) + onProgress("โœ… Dry run validation passed$estimatedInfo") + } + + // Now request user approval (with the latest dryRunResult) + val approved = requestSqlApproval( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + dryRunResult = dryRunResult, + onProgress = onProgress + ) + + if (!approved) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.REJECTED, + title = "Write operation rejected by user", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + errors.add("[$dbName] Write operation rejected by user: ${operationType.name}") + continue + } + + // Execute write operation + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.APPROVED, + title = "Write operation approved, executing...", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "sql" to sql + ) + ) + onProgress("โœ… Write operation approved, executing...") + + try { + val updateResult = connection.executeUpdate(sql) + + if (updateResult.success) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.SUCCESS, + title = "Write operation completed on $dbName", + details = mapOf( + "database" to dbName, + "operationType" to operationType.name, + "affectedRows" to updateResult.affectedRows, + "message" to (updateResult.message ?: "") + ) + ) + onProgress("โœ… ${operationType.name} completed: ${updateResult.affectedRows} row(s) affected") + + // Create a synthetic QueryResult for write operations + queryResults[dbName] = QueryResult( + columns = listOf("Operation", "Affected Rows", "Status"), + rows = listOf(listOf(operationType.name, updateResult.affectedRows.toString(), "Success")), + rowCount = 1 + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.ERROR, + title = "Write operation failed on $dbName", + error = updateResult.message ?: "Unknown error" + ) + errors.add("[$dbName] Write operation failed: ${updateResult.message}") + } + } catch (e: Exception) { + val errorMsg = e.message ?: "Unknown execution error" + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_WRITE, + status = ChatDBStepStatus.ERROR, + title = "Write operation failed on $dbName", + error = errorMsg + ) + errors.add("[$dbName] Write operation failed: $errorMsg") + } + continue + } + + // Regular SELECT query execution with retry + var executionRetries = 0 + var lastExecutionError: String? = null + var result: QueryResult? = null + + while (executionRetries < maxExecutionRetries && result == null) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Executing on $dbName${if (executionRetries > 0) " (retry $executionRetries)" else ""}...", + details = mapOf( + "database" to dbName, + "sql" to sql, + "attempt" to (executionRetries + 1) + ) + ) + onProgress("โšก Executing SQL on $dbName${if (executionRetries > 0) " (retry $executionRetries)" else ""}...") + + try { + result = connection.executeQuery(sql) + queryResults[dbName] = result + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "Query executed on $dbName", + details = mapOf( + "database" to dbName, + "sql" to sql, + "rowCount" to result.rowCount, + "columns" to result.columns, + "previewRows" to result.rows.take(5) + ) + ) + onProgress("โœ… Query returned ${result.rowCount} rows from $dbName") + } catch (e: Exception) { + lastExecutionError = e.message ?: "Unknown execution error" + logger.warn { "Query execution failed on $dbName (attempt ${executionRetries + 1}): $lastExecutionError" } + + renderer.renderChatDBStep( + stepType = ChatDBStepType.EXECUTE_SQL, + status = ChatDBStepStatus.ERROR, + title = "Query execution failed on $dbName", + details = mapOf( + "database" to dbName, + "attempt" to (executionRetries + 1), + "maxAttempts" to maxExecutionRetries + ), + error = lastExecutionError + ) + + // Try to revise SQL based on execution error + if (executionRetries < maxExecutionRetries - 1) { + onProgress("๐Ÿ”„ Attempting to fix SQL based on execution error...") + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Revising SQL based on execution error..." + ) + + val relevantSchema = buildSchemaDescriptionForDatabase(dbName, merged) + val revisionInput = SqlRevisionInput( + originalQuery = task.query, + failedSql = sql, + errorMessage = "Execution error: $lastExecutionError", + schemaDescription = relevantSchema, + maxAttempts = 1 + ) + + val revisionResult = sqlReviseAgent.execute(revisionInput) { progress -> + onProgress(progress) + } + + if (revisionResult.success && revisionResult.content != sql) { + sql = revisionResult.content + revisionAttempts++ + + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.SUCCESS, + title = "SQL revised based on execution error", + details = mapOf("database" to dbName, "sql" to sql) + ) + onProgress("๐Ÿ”ง SQL revised, retrying execution...") + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.REVISE_SQL, + status = ChatDBStepStatus.WARNING, + title = "SQL revision did not help", + error = "Revision did not produce a different SQL" + ) + break + } + } + executionRetries++ + } + } + + if (result == null && lastExecutionError != null) { + errors.add("[$dbName] Query execution failed after $executionRetries retries: $lastExecutionError") + } + } + + // Step 7: Generate visualization if requested + val combinedResult = combineResults(queryResults) + if (task.generateVisualization && combinedResult.rowCount > 0) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.IN_PROGRESS, + title = "Generating visualization..." + ) + onProgress("๐Ÿ“ˆ Generating visualization...") + + plotDslCode = generateVisualization(task.query, combinedResult, onProgress) + + if (plotDslCode != null) { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.SUCCESS, + title = "Visualization generated", + details = mapOf("code" to plotDslCode) + ) + } else { + renderer.renderChatDBStep( + stepType = ChatDBStepType.GENERATE_VISUALIZATION, + status = ChatDBStepStatus.WARNING, + title = "Visualization not generated" + ) + } + } + + // Step 8: Final result + val success = queryResults.isNotEmpty() + + val resultMessage = buildResultMessage( + success = success, + generatedSql = generatedSql, + queryResults = queryResults, + combinedResult = combinedResult, + revisionAttempts = revisionAttempts, + plotDslCode = plotDslCode, + errors = errors + ) + + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = if (success) ChatDBStepStatus.SUCCESS else ChatDBStepStatus.ERROR, + title = if (success) "Query completed on ${queryResults.size} database(s)" else "Query failed", + details = mapOf( + "databases" to queryResults.keys.toList(), + "totalRows" to combinedResult.rowCount, + "columns" to combinedResult.columns, + "previewRows" to combinedResult.rows.take(10), + "revisionAttempts" to revisionAttempts, + "errors" to errors + ) + ) + + // Render final message + if (success) { + renderer.renderLLMResponseStart() + renderer.renderLLMResponseChunk(resultMessage) + renderer.renderLLMResponseEnd() + } else { + renderer.renderError(resultMessage) + } + + return MultiDatabaseChatDBResult( + success = success, + message = resultMessage, + generatedSql = generatedSql, + queryResult = combinedResult, + queryResultsByDatabase = queryResults, + targetDatabases = targetDatabases, + plotDslCode = plotDslCode, + revisionAttempts = revisionAttempts, + errors = errors + ) + + } catch (e: Exception) { + logger.error(e) { "Multi-database query failed: ${e.message}" } + renderer.renderChatDBStep( + stepType = ChatDBStepType.FINAL_RESULT, + status = ChatDBStepStatus.ERROR, + title = "Query failed", + details = mapOf("error" to (e.message ?: "Unknown error")) + ) + renderer.renderError("Query failed: ${e.message}") + return MultiDatabaseChatDBResult( + success = false, + message = "Error: ${e.message}", + errors = listOf(e.message ?: "Unknown error") + ) + } + } + + /** + * Fetch schemas from all databases and merge them + */ + private suspend fun fetchAndMergeSchemas(): MergedDatabaseSchema { + val schemas = mutableMapOf() + + for ((dbId, connection) in databaseConnections) { + try { + val schema = connection.getSchema() + schemas[dbId] = schema + logger.info { "Fetched schema for $dbId: ${schema.tables.size} tables" } + } catch (e: Exception) { + logger.error(e) { "Failed to fetch schema for $dbId" } + } + } + + return MergedDatabaseSchema(schemas) + } + + /** + * Build schema context for multi-database prompt with schema linking results + */ + private fun buildMultiDatabaseSchemaContext( + merged: MergedDatabaseSchema, + query: String, + linkingResults: Map = emptyMap() + ): String { + val sb = StringBuilder() + sb.append("=== AVAILABLE DATABASES AND TABLES ===\n\n") + + for ((dbId, schema) in merged.databases) { + val displayName = databaseConfigs[dbId]?.databaseName ?: dbId + val linkingResult = linkingResults[dbId] + val relevantTables = linkingResult?.relevantTables?.toSet() ?: schema.tables.map { it.name }.toSet() + + sb.append("DATABASE: $dbId ($displayName)\n") + sb.append("-".repeat(40)).append("\n") + + // Show relevant tables first (if schema linking was performed) + val sortedTables = if (linkingResult != null) { + schema.tables.sortedByDescending { it.name in relevantTables } + } else { + schema.tables + } + + for (table in sortedTables) { + val isRelevant = table.name in relevantTables + val marker = if (isRelevant && linkingResult != null) " [RELEVANT]" else "" + sb.append(" Table: ${table.name}$marker\n") + if (table.comment != null) { + sb.append(" Comment: ${table.comment}\n") + } + sb.append(" Columns:\n") + for (col in table.columns) { + val flags = mutableListOf() + if (col.isPrimaryKey) flags.add("PK") + if (col.isForeignKey) flags.add("FK") + if (!col.nullable) flags.add("NOT NULL") + val flagStr = if (flags.isNotEmpty()) " [${flags.joinToString(", ")}]" else "" + sb.append(" - ${col.name}: ${col.type}$flagStr\n") + } + sb.append("\n") + } + sb.append("\n") + } + + return sb.toString() + } + + /** + * Build schema description for a specific database (for SQL revision) + */ + private fun buildSchemaDescriptionForDatabase(dbId: String, merged: MergedDatabaseSchema): String { + val schema = merged.databases[dbId] ?: return "" + val displayName = databaseConfigs[dbId]?.databaseName ?: dbId + + return buildString { + appendLine("## Database Schema: $displayName (USE ONLY THESE TABLES)") + appendLine() + for (table in schema.tables) { + appendLine("Table: ${table.name}") + appendLine("Columns: ${table.columns.joinToString(", ") { "${it.name} (${it.type})" }}") + appendLine() + } + } + } + + /** + * Build result message for final output + */ + private fun buildResultMessage( + success: Boolean, + generatedSql: String?, + queryResults: Map, + combinedResult: QueryResult, + revisionAttempts: Int, + plotDslCode: String?, + errors: List + ): String { + return if (success) { + buildString { + appendLine("## โœ… Query Executed Successfully") + appendLine() + + if (queryResults.size > 1) { + appendLine("**Databases Queried:** ${queryResults.keys.joinToString(", ")}") + appendLine() + } + + if (generatedSql != null) { + appendLine("**Executed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + appendLine() + } + + if (revisionAttempts > 0) { + appendLine("*Note: SQL was revised $revisionAttempts time(s) to fix validation/execution errors*") + appendLine() + } + + appendLine("**Results** (${combinedResult.rowCount} row${if (combinedResult.rowCount != 1) "s" else ""}):") + appendLine() + appendLine(combinedResult.toTableString()) + + if (plotDslCode != null) { + appendLine() + appendLine("**Visualization:**") + appendLine("```plotdsl") + appendLine(plotDslCode) + appendLine("```") + } + } + } else { + buildString { + appendLine("## โŒ Query Failed") + appendLine() + appendLine("**Errors:**") + errors.forEach { error -> + appendLine("- $error") + } + if (generatedSql != null) { + appendLine() + appendLine("**Failed SQL:**") + appendLine("```sql") + appendLine(generatedSql) + appendLine("```") + } + } + } + } + + /** + * Generate visualization for query results using PlotDSLAgent + */ + private suspend fun generateVisualization( + query: String, + result: QueryResult, + onProgress: (String) -> Unit + ): String? { + // Check if PlotDSLAgent is available on this platform + if (!plotDSLAgent.isAvailable) { + logger.info { "PlotDSLAgent not available on this platform, skipping visualization" } + return null + } + + // Build description for PlotDSLAgent + val description = buildString { + appendLine("Create a visualization for the following database query result:") + appendLine() + appendLine("**Original Query**: $query") + appendLine() + appendLine("**Data** (${result.rowCount} rows, columns: ${result.columns.joinToString(", ")}):") + appendLine("```csv") + appendLine(result.toCsvString()) + appendLine("```") + appendLine() + appendLine("Choose the most appropriate chart type based on the data structure.") + } + + try { + val plotContext = PlotDSLContext(description = description) + val agentResult = plotDSLAgent.execute(plotContext, onProgress) + + if (agentResult.success) { + // Extract PlotDSL code from the result + val content = agentResult.content + val codeFence = CodeFence.parse(content) + if (codeFence.languageId.lowercase() == "plotdsl" && codeFence.text.isNotBlank()) { + return codeFence.text.trim() + } + + // Try to find plotdsl block manually + val plotPattern = Regex("```plotdsl\\s*([\\s\\S]*?)```", RegexOption.IGNORE_CASE) + val match = plotPattern.find(content) + return match?.groupValues?.get(1)?.trim() + } else { + logger.warn { "PlotDSLAgent failed: ${agentResult.content}" } + return null + } + } catch (e: Exception) { + logger.error(e) { "Visualization generation failed" } + return null + } + } + + /** + * Build SQL generation prompt for multi-database context + */ + private fun buildMultiDatabaseSqlPrompt(query: String, schemaContext: String, maxRows: Int): String { + return """ +$schemaContext + +USER QUERY: $query + +INSTRUCTIONS: +1. Analyze which database(s) contain the relevant tables for this query +2. Generate SQL for the appropriate database(s) +3. Each SQL block MUST start with a comment specifying the target database: -- database: +4. Use LIMIT $maxRows to restrict results +5. Only generate SELECT queries + +Generate the SQL: +""".trimIndent() + } + + /** + * Parse SQL response to extract SQL blocks with their target databases + */ + private fun parseSqlBlocksWithTargets(response: String): List { + val blocks = mutableListOf() + val codeFences = CodeFence.parseAll(response) + + for (fence in codeFences) { + if (fence.languageId.lowercase() == "sql") { + val sql = fence.text.trim() + val database = extractDatabaseFromSql(sql) + val cleanSql = removeDatabaseComment(sql) + + if (database != null && cleanSql.isNotBlank()) { + blocks.add(SqlBlock(database, cleanSql)) + } else if (databaseConnections.size == 1) { + // If only one database, use it as default + val defaultDb = databaseConnections.keys.first() + blocks.add(SqlBlock(defaultDb, cleanSql.ifBlank { sql })) + } + } + } + + return blocks + } + + /** + * Extract database name from SQL comment + */ + private fun extractDatabaseFromSql(sql: String): String? { + val regex = Regex("--\\s*database:\\s*(\\S+)", RegexOption.IGNORE_CASE) + val match = regex.find(sql) + return match?.groupValues?.get(1) + } + + /** + * Remove database comment from SQL + */ + private fun removeDatabaseComment(sql: String): String { + return sql.replace(Regex("--\\s*database:\\s*\\S+\\s*\n?", RegexOption.IGNORE_CASE), "").trim() + } + + /** + * Combine results from multiple databases + */ + private fun combineResults(results: Map): QueryResult { + if (results.isEmpty()) { + return QueryResult(emptyList(), emptyList(), 0) + } + + if (results.size == 1) { + return results.values.first() + } + + // For multiple results, add a "database" column and combine + val allColumns = mutableListOf("_database") + val allRows = mutableListOf>() + + for ((dbName, result) in results) { + if (allColumns.size == 1) { + allColumns.addAll(result.columns) + } + for (row in result.rows) { + allRows.add(listOf(dbName) + row) + } + } + + return QueryResult(allColumns, allRows, allRows.size) + } + + /** + * Request user approval for SQL write operation + * Returns true if approved, false if rejected + */ + private suspend fun requestSqlApproval( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: DryRunResult? = null, + onProgress: (String) -> Unit + ): Boolean { + val riskLevel = if (isHighRisk) "โš ๏ธ HIGH RISK" else "โšก Write Operation" + val dryRunInfo = if (dryRunResult?.estimatedRows != null) { + " (dry run: ${dryRunResult.estimatedRows} row(s) would be affected)" + } else "" + onProgress("$riskLevel: ${operationType.name} requires approval$dryRunInfo") + + return suspendCancellableCoroutine { continuation -> + renderer.renderSqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + dryRunResult = dryRunResult, + onApprove = { + if (continuation.isActive) { + continuation.resume(true) + } + }, + onReject = { + if (continuation.isActive) { + continuation.resume(false) + } + } + ) + } + } + + /** + * Extract table names from SQL statement + */ + private fun extractTablesFromSql(sql: String, schema: DatabaseSchema?): List { + if (schema == null) return emptyList() + + val tableNames = schema.tables.map { it.name.lowercase() }.toSet() + val sqlLower = sql.lowercase() + + return tableNames.filter { tableName -> + // Check for common SQL patterns that reference tables + sqlLower.contains(" $tableName ") || + sqlLower.contains(" $tableName;") || + sqlLower.contains(" $tableName\n") || + sqlLower.contains("from $tableName") || + sqlLower.contains("into $tableName") || + sqlLower.contains("update $tableName") || + sqlLower.contains("table $tableName") || + sqlLower.contains("join $tableName") + }.map { tableName -> + // Return original case from schema + schema.tables.find { it.name.lowercase() == tableName }?.name ?: tableName + } + } +} + +/** + * SQL block with target database + */ +data class SqlBlock( + val database: String, + val sql: String +) + +/** + * Merged schema from multiple databases + */ +data class MergedDatabaseSchema( + val databases: Map +) { + val totalTableCount: Int + get() = databases.values.sumOf { it.tables.size } + + fun getTablesForDatabase(dbId: String): List { + return databases[dbId]?.tables ?: emptyList() + } +} + +/** + * Result from multi-database query + */ +data class MultiDatabaseChatDBResult( + val success: Boolean, + val message: String, + val generatedSql: String? = null, + val queryResult: QueryResult? = null, + val queryResultsByDatabase: Map = emptyMap(), + val targetDatabases: List? = null, + val plotDslCode: String? = null, + val revisionAttempts: Int = 0, + val errors: List = emptyList() +) + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt new file mode 100644 index 0000000000..cc7453401e --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.kt @@ -0,0 +1,19 @@ +package cc.unitmesh.agent.chatdb + +/** + * Platform-specific NLP tokenizer for keyword extraction. + * + * On JVM, this uses MyNLP (https://github.com/jimichan/mynlp) for Chinese tokenization. + * On other platforms (JS, WASM, iOS, Android), this falls back to simple regex-based tokenization. + */ +expect object NlpTokenizer { + /** + * Extract keywords from natural language query using NLP tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + fun extractKeywords(query: String, stopWords: Set): List +} diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt new file mode 100644 index 0000000000..e282724709 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/chatdb/SchemaLinker.kt @@ -0,0 +1,69 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.DatabaseSchema + +/** + * Schema Linker - Abstract base class for Text2SQL schema linking + * + * This class finds relevant tables and columns based on natural language queries. + * Different implementations can use different strategies: + * - KeywordSchemaLinker: Keyword matching and fuzzy matching + * - LlmSchemaLinker: LLM-based keyword extraction and schema linking + * - VectorSchemaLinker: Vector similarity search using embeddings (future) + */ +abstract class SchemaLinker { + + /** + * Link natural language query to relevant schema elements + */ + abstract suspend fun link(query: String, schema: DatabaseSchema): SchemaLinkingResult + + /** + * Extract keywords from natural language query + */ + abstract suspend fun extractKeywords(query: String): List + + companion object { + /** + * Common SQL keywords to filter out + */ + val STOP_WORDS = setOf( + "select", "from", "where", "and", "or", "not", "in", "is", "null", + "order", "by", "group", "having", "limit", "offset", "join", "on", + "left", "right", "inner", "outer", "cross", "union", "all", "distinct", + "as", "asc", "desc", "between", "like", "exists", "case", "when", "then", + "else", "end", "count", "sum", "avg", "min", "max", "the", "a", "an", + "show", "me", "get", "find", "list", "display", "give", "what", "which", + "how", "many", "much", "all", "each", "every", "any", "some", "most", + "top", "first", "last", "recent", "latest", "oldest", "highest", "lowest", + "total", "average", "number", "amount", "value", "data", "information" + ) + + /** + * Calculate Levenshtein distance between two strings + */ + fun levenshteinDistance(s1: String, s2: String): Int { + val dp = Array(s1.length + 1) { IntArray(s2.length + 1) } + for (i in 0..s1.length) dp[i][0] = i + for (j in 0..s2.length) dp[0][j] = j + for (i in 1..s1.length) { + for (j in 1..s2.length) { + val cost = if (s1[i - 1] == s2[j - 1]) 0 else 1 + dp[i][j] = minOf(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) + } + } + return dp[s1.length][s2.length] + } + + /** + * Simple fuzzy matching using edit distance threshold + */ + fun fuzzyMatch(s1: String, s2: String): Boolean { + if (kotlin.math.abs(s1.length - s2.length) > 3) return false + val distance = levenshteinDistance(s1, s2) + val threshold = kotlin.math.min(s1.length, s2.length) / 3 + return distance <= threshold + } + } +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt index e504ea3f4e..862d1c50cd 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/database/DatabaseConnection.kt @@ -10,16 +10,45 @@ import kotlinx.serialization.Serializable interface DatabaseConnection { /** * Execute a SQL query - * + * * @param sql SQL query statement (SELECT only) * @return Query result * @throws DatabaseException If execution fails */ suspend fun executeQuery(sql: String): QueryResult + /** + * Execute a SQL update statement (INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, etc.) + * + * @param sql SQL update statement + * @return UpdateResult containing affected row count and any generated keys + * @throws DatabaseException If execution fails + */ + suspend fun executeUpdate(sql: String): UpdateResult + + /** + * Dry run a SQL statement to validate it without actually executing it. + * This is useful for validating write operations before user approval. + * + * The default implementation uses database-specific techniques: + * - For SELECT: Uses EXPLAIN + * - For INSERT/UPDATE/DELETE: Wraps in a transaction and rolls back + * - For DDL: Returns validation based on syntax only + * + * @param sql SQL statement to validate + * @return DryRunResult containing validation status and any errors + */ + suspend fun dryRun(sql: String): DryRunResult { + // Default implementation - subclasses can override for better support + return DryRunResult( + isValid = true, + message = "Dry run not fully supported, syntax validation only" + ) + } + /** * Get database schema information - * + * * @return DatabaseSchema containing all tables and columns * @throws DatabaseException If retrieval fails */ @@ -52,6 +81,38 @@ interface DatabaseConnection { return (result.rows[0][0] as? Number)?.toLong() ?: 0L } + /** + * Get sample rows from a table (for Schema Linking context) + * + * @param tableName Table name + * @param limit Maximum number of rows to return (default 3) + * @return Sample rows as QueryResult + */ + suspend fun getSampleRows(tableName: String, limit: Int = 3): QueryResult { + return try { + executeQuery("SELECT * FROM `$tableName` LIMIT $limit") + } catch (e: Exception) { + QueryResult(emptyList(), emptyList(), 0) + } + } + + /** + * Get distinct values for a column (for Value Matching in Schema Linking) + * + * @param tableName Table name + * @param columnName Column name + * @param limit Maximum number of distinct values to return (default 10) + * @return List of distinct values as strings + */ + suspend fun getDistinctValues(tableName: String, columnName: String, limit: Int = 10): List { + return try { + val result = executeQuery("SELECT DISTINCT `$columnName` FROM `$tableName` LIMIT $limit") + result.rows.map { it.firstOrNull() ?: "" }.filter { it.isNotEmpty() } + } catch (e: Exception) { + emptyList() + } + } + /** * Close database connection */ @@ -70,6 +131,95 @@ interface DatabaseConnection { } } +/** + * Dry run result - validation result without actual execution + */ +@Serializable +data class DryRunResult( + /** + * Whether the SQL is valid and can be executed + */ + val isValid: Boolean, + + /** + * Validation message or error description + */ + val message: String? = null, + + /** + * Detailed errors if validation failed + */ + val errors: List = emptyList(), + + /** + * Estimated affected rows (if available from EXPLAIN) + */ + val estimatedRows: Int? = null, + + /** + * Warnings (non-fatal issues) + */ + val warnings: List = emptyList() +) { + companion object { + fun valid(message: String? = null, estimatedRows: Int? = null): DryRunResult { + return DryRunResult(true, message, emptyList(), estimatedRows) + } + + fun invalid(error: String): DryRunResult { + return DryRunResult(false, error, listOf(error)) + } + + fun invalid(errors: List): DryRunResult { + return DryRunResult(false, errors.firstOrNull(), errors) + } + + fun withWarnings(warnings: List): DryRunResult { + return DryRunResult(true, "Valid with warnings", emptyList(), null, warnings) + } + } +} + +/** + * Database update result (for INSERT, UPDATE, DELETE, DDL statements) + */ +@Serializable +data class UpdateResult( + /** + * Number of rows affected by the update + */ + val affectedRows: Int, + + /** + * Generated keys (for INSERT with auto-increment) + */ + val generatedKeys: List = emptyList(), + + /** + * Whether the update was successful + */ + val success: Boolean = true, + + /** + * Optional message (e.g., for DDL statements) + */ + val message: String? = null +) { + companion object { + fun success(affectedRows: Int, generatedKeys: List = emptyList()): UpdateResult { + return UpdateResult(affectedRows, generatedKeys, true) + } + + fun ddlSuccess(message: String = "DDL statement executed successfully"): UpdateResult { + return UpdateResult(0, emptyList(), true, message) + } + + fun failure(message: String): UpdateResult { + return UpdateResult(0, emptyList(), false, message) + } + } +} + /** * Database query result */ @@ -99,62 +249,46 @@ data class QueryResult( } /** - * Convert result to formatted table string (for user display) + * Convert result to Markdown table format (for rendering with MarkdownTableRenderer) + * Shows all rows without truncation */ fun toTableString(): String { if (isEmpty()) return "No results" - - // Calculate column widths - val colWidths = columns.indices.map { colIdx -> - maxOf( - columns[colIdx].length, - rows.maxOfOrNull { it[colIdx].length } ?: 4 - ) - } val sb = StringBuilder() - - // Header - sb.append("โ”Œ") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ฌ") } - sb.setLength(sb.length - 1) - sb.append("โ”\n") - - // Column names - sb.append("โ”‚") - columns.forEachIndexed { idx, col -> - sb.append(" ").append(col.padEnd(colWidths[idx])).append(" โ”‚") - } - sb.append("\n") - - // Separator - sb.append("โ”œ") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ผ") } - sb.setLength(sb.length - 1) - sb.append("โ”ค\n") - - // Data rows (show first 10) - rows.take(10).forEach { row -> - sb.append("โ”‚") - row.forEachIndexed { idx, value -> - val str = value.ifEmpty { "NULL" } - sb.append(" ").append(str.padEnd(colWidths[idx])).append(" โ”‚") - } - sb.append("\n") - } - if (rows.size > 10) { - sb.append("โ”‚ ... (${rows.size - 10} more rows)\n") - } + // Header row + sb.append("| ") + sb.append(columns.joinToString(" | ") { escapeMarkdown(it) }) + sb.append(" |\n") - // Footer - sb.append("โ””") - colWidths.forEach { width -> sb.append("โ”€".repeat(width + 2)).append("โ”ด") } - sb.setLength(sb.length - 1) - sb.append("โ”˜\n") + // Separator row + sb.append("| ") + sb.append(columns.joinToString(" | ") { "---" }) + sb.append(" |\n") + + // Data rows + rows.forEach { row -> + sb.append("| ") + sb.append(row.mapIndexed { idx, value -> + val str = value.ifEmpty { "NULL" } + escapeMarkdown(str) + }.joinToString(" | ")) + sb.append(" |\n") + } return sb.toString() } + + /** + * Escape special Markdown characters in table cell content + */ + private fun escapeMarkdown(text: String): String { + return text + .replace("|", "\\|") + .replace("\n", " ") + .replace("\r", "") + } } /** diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt index 134e57876b..ee7a691d5b 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/CodingAgentRenderer.kt @@ -1,6 +1,8 @@ package cc.unitmesh.agent.render +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.PlanSummaryData +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolResult import cc.unitmesh.llm.compression.TokenInfo @@ -44,6 +46,15 @@ interface CodingAgentRenderer { fun renderTaskComplete(executionTimeMs: Long = 0L, toolsUsedCount: Int = 0) fun renderFinalResult(success: Boolean, message: String, iterations: Int) fun renderError(message: String) + + /** + * Render an informational message (non-error, non-warning) + * Used for status updates, progress information, etc. + */ + fun renderInfo(message: String) { + // Default: no-op, renderers can override to display info messages + } + fun renderRepeatWarning(toolName: String, count: Int) fun renderRecoveryAdvice(recoveryAdvice: String) @@ -66,6 +77,27 @@ interface CodingAgentRenderer { // Default: no-op for renderers that don't display task progress } + /** + * Render a ChatDB execution step. + * This is an optional method primarily used by UI renderers (ComposeRenderer, JewelRenderer). + * Console renderers can ignore this or provide simple text output. + * + * @param stepType The type of step being executed + * @param status The current status of the step + * @param title The display title for the step (defaults to stepType.displayName) + * @param details Additional details about the step (e.g., table names, row counts, SQL) + * @param error Error message if the step failed + */ + fun renderChatDBStep( + stepType: ChatDBStepType, + status: ChatDBStepStatus, + title: String = stepType.displayName, + details: Map = emptyMap(), + error: String? = null + ) { + // Default: no-op for renderers that don't support ChatDB steps + } + /** * Render a compact plan summary bar. * Called when plan is created or updated to show progress in a compact format. @@ -83,6 +115,32 @@ interface CodingAgentRenderer { fun renderUserConfirmationRequest(toolName: String, params: Map) + /** + * Request user approval for a SQL write operation. + * This is called when a write operation (INSERT, UPDATE, DELETE, CREATE, etc.) is detected. + * The renderer should display the SQL and allow the user to approve or reject. + * + * @param sql The SQL statement to be executed + * @param operationType The type of SQL operation (INSERT, UPDATE, DELETE, CREATE, etc.) + * @param affectedTables List of tables that will be affected + * @param isHighRisk Whether this is a high-risk operation (DROP, TRUNCATE) + * @param dryRunResult Optional result from dry run validation (if available) + * @param onApprove Callback to invoke when user approves the operation + * @param onReject Callback to invoke when user rejects the operation + */ + fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: DryRunResult? = null, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // Default: auto-reject for safety (renderers should override to show UI) + onReject() + } + /** * Add a live terminal session to the timeline. * Called when a Shell tool starts execution with PTY support. diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt index 58d0294640..acbe9f82f5 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/DefaultCodingAgentRenderer.kt @@ -1,6 +1,7 @@ package cc.unitmesh.agent.render import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.subagent.SqlOperationType /** * Default console renderer - simple text output @@ -118,6 +119,41 @@ class DefaultCodingAgentRenderer : BaseRenderer() { println(" (Auto-approved for now)") } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: cc.unitmesh.agent.database.DryRunResult?, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + val riskIcon = if (isHighRisk) "!!" else "!" + println("\n$riskIcon SQL Write Operation Requires Approval") + println("-".repeat(50)) + println("Operation: ${operationType.name}") + println("Affected Tables: ${affectedTables.joinToString(", ")}") + if (isHighRisk) { + println("WARNING: This is a HIGH-RISK operation!") + } + if (dryRunResult != null) { + println("\nDry Run Result:") + println(" Valid: ${dryRunResult.isValid}") + if (dryRunResult.estimatedRows != null) { + println(" Estimated Rows Affected: ${dryRunResult.estimatedRows}") + } + if (dryRunResult.warnings.isNotEmpty()) { + println(" Warnings: ${dryRunResult.warnings.joinToString(", ")}") + } + } + println("\nSQL:") + println(sql) + println("-".repeat(50)) + // Default console renderer auto-rejects for safety + println("(Auto-rejected in console mode - use interactive UI for approval)") + onReject() + } + override fun renderAgentSketchBlock( agentName: String, language: String, diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt index f1bcd9e973..3669f0d6d8 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/render/RendererModels.kt @@ -138,6 +138,16 @@ sealed class TimelineItem( override val id: String = generateId() ) : TimelineItem(timestamp, id) + /** + * Info item for displaying informational messages (non-error, non-warning). + * Used for status updates, progress information, database context, etc. + */ + data class InfoItem( + val message: String, + override val timestamp: Long = Platform.getCurrentTimestamp(), + override val id: String = generateId() + ) : TimelineItem(timestamp, id) + /** * Task completion item. */ @@ -203,6 +213,20 @@ sealed class TimelineItem( override val id: String = generateId() ) : TimelineItem(timestamp, id) + /** + * ChatDB execution step item for displaying database query execution steps. + * Each step can be expanded/collapsed and shows detailed information. + */ + data class ChatDBStepItem( + val stepType: ChatDBStepType, + val status: ChatDBStepStatus, + val title: String, + val details: Map = emptyMap(), + val error: String? = null, + override val timestamp: Long = Platform.getCurrentTimestamp(), + override val id: String = generateId() + ) : TimelineItem(timestamp, id) + companion object { /** * Thread-safe counter for generating unique IDs. @@ -218,3 +242,40 @@ sealed class TimelineItem( } } +/** + * ChatDB execution step types + */ +enum class ChatDBStepType(val displayName: String, val icon: String) { + FETCH_SCHEMA("Fetch Database Schema", "๐Ÿ“Š"), + SCHEMA_LINKING("Schema Linking", "๐Ÿ”—"), + GENERATE_SQL("Generate SQL Query", "๐Ÿค–"), + VALIDATE_SQL("Validate SQL", "โœ“"), + REVISE_SQL("Revise SQL", "๐Ÿ”„"), + /** Dry run to validate SQL without executing (uses transaction rollback) */ + DRY_RUN("Dry Run Validation", "๐Ÿงช"), + /** Waiting for user approval before executing write operation */ + AWAIT_APPROVAL("Awaiting Approval", "?"), + EXECUTE_SQL("Execute SQL Query", "โšก"), + /** Execute write operation (INSERT, UPDATE, DELETE, DDL) */ + EXECUTE_WRITE("Execute Write Operation", "!"), + GENERATE_VISUALIZATION("Generate Visualization", "๐Ÿ“ˆ"), + FINAL_RESULT("Query Result", "โœ…") +} + +/** + * ChatDB execution step status + */ +enum class ChatDBStepStatus(val displayName: String) { + PENDING("Pending"), + IN_PROGRESS("In Progress"), + SUCCESS("Success"), + WARNING("Warning"), + ERROR("Error"), + /** User approval is required */ + AWAITING_APPROVAL("Awaiting Approval"), + /** User approved the operation */ + APPROVED("Approved"), + /** User rejected the operation */ + REJECTED("Rejected") +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt new file mode 100644 index 0000000000..a9b31db281 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgent.kt @@ -0,0 +1,347 @@ +package cc.unitmesh.agent.subagent + +import cc.unitmesh.agent.core.SubAgent +import cc.unitmesh.agent.logging.getLogger +import cc.unitmesh.agent.model.AgentDefinition +import cc.unitmesh.agent.model.PromptConfig +import cc.unitmesh.agent.model.RunConfig +import cc.unitmesh.agent.tool.ToolResult +import cc.unitmesh.agent.tool.schema.DeclarativeToolSchema +import cc.unitmesh.agent.tool.schema.SchemaPropertyBuilder.integer +import cc.unitmesh.agent.tool.schema.SchemaPropertyBuilder.string +import cc.unitmesh.devins.parser.CodeFence +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.ModelConfig +import kotlinx.serialization.Serializable + +/** + * SQL Revise Agent Schema - Tool definition for SQL revision + */ +object SqlReviseAgentSchema : DeclarativeToolSchema( + description = "Revise and fix SQL queries based on validation errors or execution failures", + properties = mapOf( + "originalQuery" to string( + description = "The original natural language query from user", + required = true + ), + "failedSql" to string( + description = "The SQL query that failed validation or execution", + required = true + ), + "errorMessage" to string( + description = "The error message from validation or execution", + required = true + ), + "schemaDescription" to string( + description = "Database schema description for context", + required = true + ), + "maxAttempts" to integer( + description = "Maximum number of revision attempts", + required = false + ) + ) +) { + override fun getExampleUsage(toolName: String): String { + return "/$toolName originalQuery=\"Show top customers\" failedSql=\"SELECT * FROM customer\" errorMessage=\"Table 'customer' doesn't exist\"" + } +} + +/** + * SQL Revise Agent - Self-correction loop for SQL queries + * + * This SubAgent is responsible for: + * 1. Analyzing SQL validation/execution errors + * 2. Understanding the original user intent + * 3. Generating corrected SQL based on schema and error context + * 4. Iterating until a valid SQL is produced or max attempts reached + * + * Based on GitHub Issue #508: https://github.com/phodal/auto-dev/issues/508 + * Implements the "Revise Agent (่‡ชๆˆ‘ไฟฎๆญฃ้—ญ็Žฏ)" feature + */ +class SqlReviseAgent( + private val llmService: KoogLLMService, + private val sqlValidator: SqlValidatorInterface? = null +) : SubAgent( + definition = createDefinition() +) { + private val logger = getLogger("SqlReviseAgent") + + override fun validateInput(input: Map): SqlRevisionInput { + return SqlRevisionInput( + originalQuery = input["originalQuery"] as? String + ?: throw IllegalArgumentException("originalQuery is required"), + failedSql = input["failedSql"] as? String + ?: throw IllegalArgumentException("failedSql is required"), + errorMessage = input["errorMessage"] as? String + ?: throw IllegalArgumentException("errorMessage is required"), + schemaDescription = input["schemaDescription"] as? String ?: "", + previousAttempts = (input["previousAttempts"] as? List<*>)?.filterIsInstance() ?: emptyList(), + maxAttempts = (input["maxAttempts"] as? Number)?.toInt() ?: 3 + ) + } + + override fun getParameterClass(): String = SqlRevisionInput::class.simpleName ?: "SqlRevisionInput" + + override suspend fun execute( + input: SqlRevisionInput, + onProgress: (String) -> Unit + ): ToolResult.AgentResult { + onProgress("๐Ÿ”„ SQL Revise Agent - Starting revision") + onProgress("Original query: ${input.originalQuery.take(50)}...") + onProgress("Error: ${input.errorMessage.take(80)}...") + + var currentSql = input.failedSql + var currentError = input.errorMessage + val attempts = input.previousAttempts.toMutableList() + var attemptCount = attempts.size + + while (attemptCount < input.maxAttempts) { + attemptCount++ + onProgress("๐Ÿ“ Revision attempt $attemptCount/${input.maxAttempts}") + + // Build revision context + val context = buildRevisionContext(input, currentSql, currentError, attempts) + + // Ask LLM for revised SQL + val revisedSql = askLLMForRevision(context, onProgress) + + if (revisedSql == null) { + onProgress("โŒ Failed to generate revised SQL") + return ToolResult.AgentResult( + success = false, + content = "Failed to generate revised SQL after $attemptCount attempts", + metadata = mapOf( + "attempts" to attemptCount.toString(), + "lastError" to currentError + ) + ) + } + + attempts.add(revisedSql) + + // Validate the revised SQL if validator is available + if (sqlValidator != null) { + val validation = sqlValidator.validate(revisedSql) + if (validation.isValid) { + onProgress("โœ… SQL validated successfully") + return ToolResult.AgentResult( + success = true, + content = revisedSql, + metadata = mapOf( + "attempts" to attemptCount.toString(), + "validated" to "true" + ) + ) + } else { + currentSql = revisedSql + currentError = validation.errors.joinToString("; ") + onProgress("โš ๏ธ Validation failed: ${currentError.take(50)}...") + } + } else { + // No validator, return the revised SQL + onProgress("โœ… SQL revised (no validation)") + return ToolResult.AgentResult( + success = true, + content = revisedSql, + metadata = mapOf( + "attempts" to attemptCount.toString(), + "validated" to "false" + ) + ) + } + } + + onProgress("โŒ Max revision attempts reached") + return ToolResult.AgentResult( + success = false, + content = "Max revision attempts ($attemptCount) reached. Last SQL: $currentSql", + metadata = mapOf( + "attempts" to attemptCount.toString(), + "lastError" to currentError, + "lastSql" to currentSql + ) + ) + } + + override fun formatOutput(output: ToolResult.AgentResult): String = output.content + + private fun buildRevisionContext( + input: SqlRevisionInput, + currentSql: String, + currentError: String, + previousAttempts: List + ): String = buildString { + appendLine("# SQL Revision Task") + appendLine() + appendLine("## User Query") + appendLine(input.originalQuery) + appendLine() + appendLine("## Available Schema (USE ONLY THESE TABLES AND COLUMNS)") + appendLine("```") + appendLine(input.schemaDescription.take(2000)) + appendLine("```") + appendLine() + appendLine("## Failed SQL") + appendLine("```sql") + appendLine(currentSql) + appendLine("```") + appendLine() + appendLine("## Error") + appendLine(currentError) + if (previousAttempts.isNotEmpty()) { + appendLine() + appendLine("## Previous Failed Attempts (do not repeat)") + previousAttempts.forEachIndexed { i, sql -> + appendLine("${i + 1}: $sql") + } + } + } + + private suspend fun askLLMForRevision(context: String, onProgress: (String) -> Unit): String? { + val systemPrompt = """ +You are a SQL Revision Agent. Fix SQL queries that failed validation or execution. + +CRITICAL RULES: +1. ONLY use table names from the provided schema - NEVER invent table names +2. ONLY use column names from the provided schema - NEVER invent column names +3. If the error says a table doesn't exist, find the correct table name from the schema +4. Analyze the error message carefully and fix the specific issue +5. Avoid repeating previous failed attempts + +OUTPUT FORMAT: +Return ONLY the corrected SQL in ```sql code block. No explanations. + +```sql +SELECT column FROM table WHERE condition LIMIT 100; +``` +""".trimIndent() + + val userPrompt = """ +$context + +**Task:** Generate a corrected SQL query that fixes the error while preserving the original intent. +""".trimIndent() + + return try { + val response = llmService.sendPrompt("$systemPrompt\n\n$userPrompt") + extractSqlFromResponse(response) + } catch (e: Exception) { + logger.error(e) { "LLM revision failed: ${e.message}" } + null + } + } + + private fun extractSqlFromResponse(response: String): String? { + val codeFences = CodeFence.parseAll(response) + val sqlFence = codeFences.find { it.languageId.lowercase() == "sql" } + if (sqlFence != null) { + return sqlFence.text.trim() + } + // Fallback: try to extract from markdown + val sqlMatch = Regex("```sql\\s*([\\s\\S]*?)\\s*```", RegexOption.IGNORE_CASE) + .find(response)?.groupValues?.get(1) + return sqlMatch?.trim() + } + + companion object { + private fun createDefinition() = AgentDefinition( + name = "SqlReviseAgent", + displayName = "SQL Revise Agent", + description = "Revises and fixes SQL queries based on validation errors or execution failures", + promptConfig = PromptConfig( + systemPrompt = "You are a SQL Revision Agent specialized in fixing SQL queries." + ), + modelConfig = ModelConfig.default(), + runConfig = RunConfig(maxTurns = 5, maxTimeMinutes = 2) + ) + } +} + +/** + * Input for SQL revision + */ +@Serializable +data class SqlRevisionInput( + val originalQuery: String, + val failedSql: String, + val errorMessage: String, + val schemaDescription: String = "", + val previousAttempts: List = emptyList(), + val maxAttempts: Int = 3 +) + +/** + * Interface for SQL validation - platform-specific implementations + */ +interface SqlValidatorInterface { + fun validate(sql: String): SqlValidationResult + fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult + fun extractTableNames(sql: String): List + + /** + * Detect the type of SQL statement. + * Used to determine if approval is needed for write operations. + * + * @param sql The SQL statement to analyze + * @return The detected SQL operation type + */ + fun detectSqlType(sql: String): SqlOperationType +} + +/** + * SQL operation types for determining approval requirements + */ +enum class SqlOperationType { + /** SELECT queries - read-only, no approval needed */ + SELECT, + /** INSERT statements - requires approval */ + INSERT, + /** UPDATE statements - requires approval */ + UPDATE, + /** DELETE statements - requires approval */ + DELETE, + /** CREATE statements (tables, indexes, etc.) - requires approval */ + CREATE, + /** ALTER statements - requires approval */ + ALTER, + /** DROP statements - requires approval, high risk */ + DROP, + /** TRUNCATE statements - requires approval, high risk */ + TRUNCATE, + /** Other DDL/DCL/TCL statements */ + OTHER, + /** Unknown or unparseable SQL */ + UNKNOWN; + + /** + * Check if this operation type requires user approval + */ + fun requiresApproval(): Boolean = this != SELECT && this != UNKNOWN + + /** + * Check if this is a high-risk operation (DROP, TRUNCATE) + */ + fun isHighRisk(): Boolean = this == DROP || this == TRUNCATE + + /** + * Check if this is a write operation (INSERT, UPDATE, DELETE) + */ + fun isWriteOperation(): Boolean = this == INSERT || this == UPDATE || this == DELETE + + /** + * Check if this is a DDL operation (CREATE, ALTER, DROP, TRUNCATE) + */ + fun isDdlOperation(): Boolean = this == CREATE || this == ALTER || this == DROP || this == TRUNCATE +} + +/** + * Result of SQL validation + */ +@Serializable +data class SqlValidationResult( + val isValid: Boolean, + val errors: List = emptyList(), + val warnings: List = emptyList() +) + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt new file mode 100644 index 0000000000..e9abe81ed6 --- /dev/null +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.kt @@ -0,0 +1,63 @@ +package cc.unitmesh.agent.subagent + +/** + * Platform-agnostic SQL validator. + * + * JVM platforms use JSqlParser for full SQL parsing and validation. + * Non-JVM platforms perform basic syntax checks. + * + * Usage: + * ```kotlin + * val validator = SqlValidator() + * val result = validator.validate("SELECT * FROM users") + * if (result.isValid) { + * // SQL is valid + * } else { + * // Handle errors: result.errors + * } + * ``` + */ +expect class SqlValidator() : SqlValidatorInterface { + /** + * Validate SQL syntax. + * + * @param sql The SQL query to validate + * @return Validation result with errors and warnings + */ + override fun validate(sql: String): SqlValidationResult + + /** + * Validate SQL with table whitelist - ensures only allowed tables are used. + * + * On JVM platforms, this uses JSqlParser to extract table names and validate. + * On non-JVM platforms, this performs basic regex-based table name extraction. + * + * @param sql The SQL query to validate + * @param allowedTables Set of table names that are allowed in the query + * @return SqlValidationResult with errors if invalid tables are used + */ + override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult + + /** + * Extract table names from SQL query. + * + * On JVM platforms, this uses JSqlParser for accurate extraction. + * On non-JVM platforms, this uses regex-based extraction which may be less accurate. + * + * @param sql The SQL query to extract table names from + * @return List of table names found in the query + */ + override fun extractTableNames(sql: String): List + + /** + * Detect the type of SQL statement. + * + * On JVM platforms, this uses JSqlParser for accurate detection. + * On non-JVM platforms, this uses regex-based detection. + * + * @param sql The SQL statement to analyze + * @return The detected SQL operation type + */ + override fun detectSqlType(sql: String): SqlOperationType +} + diff --git a/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt b/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt index 0985b0e653..2b1e40cf45 100644 --- a/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt +++ b/mpp-core/src/commonMain/kotlin/cc/unitmesh/llm/ExecutorFactory.kt @@ -10,13 +10,25 @@ import cc.unitmesh.llm.provider.LLMClientRegistry /** * Try to auto-register GitHub Copilot provider. * Returns the provider if successfully registered, null otherwise. - * + * * Implementation is platform-specific: * - JVM: Creates and registers GithubCopilotClientProvider * - Other platforms: Returns null */ internal expect fun tryAutoRegisterGithubCopilot(): LLMClientProvider? +/** + * Platform-specific blocking executor creation for registered providers. + * + * Implementation is platform-specific: + * - JVM: Uses runBlocking to create executor synchronously + * - JS/WASM: Returns null (registered providers not supported in sync mode) + */ +internal expect fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? + /** * Executor ๅทฅๅŽ‚ - ่ดŸ่ดฃๆ นๆฎ้…็ฝฎๅˆ›ๅปบๅˆ้€‚็š„ LLM Executor * ่Œ่ดฃ๏ผš @@ -49,14 +61,14 @@ object ExecutorFactory { } if (registryProvider != null) { - // Use blocking call for registered providers + // Use platform-specific blocking call for registered providers // This works because on JVM, the provider caches the API token - return kotlinx.coroutines.runBlocking { - registryProvider.createExecutor(config) - } ?: throw IllegalStateException( - "Failed to create executor for ${config.provider.displayName}. " + - "Provider is registered but returned null." - ) + return createExecutorBlocking(registryProvider, config) + ?: throw IllegalStateException( + "Failed to create executor for ${config.provider.displayName}. " + + "Provider is registered but returned null. " + + "On non-JVM platforms, use createAsync() instead." + ) } return when (config.provider) { diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt new file mode 100644 index 0000000000..6f1f73464c --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/FallbackNlpTokenizerTest.kt @@ -0,0 +1,376 @@ +package cc.unitmesh.agent.chatdb + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertNotNull + +/** + * Comprehensive tests for FallbackNlpTokenizer. + * Tests cover: + * 1. Porter Stemmer for English morphological normalization + * 2. Bi-directional Maximum Matching (BiMM) for Chinese segmentation + * 3. RAKE keyword extraction algorithm + * 4. SemVer and CamelCase handling + */ +class FallbackNlpTokenizerTest { + + // ==================== Porter Stemmer Tests ==================== + + @Test + fun `Porter Stemmer should reduce processing to process`() { + val tokens = FallbackNlpTokenizer.tokenize("processing") + assertEquals(1, tokens.size) + assertEquals("process", tokens[0].text) + } + + @Test + fun `Porter Stemmer should reduce processed to process`() { + val tokens = FallbackNlpTokenizer.tokenize("processed") + assertEquals(1, tokens.size) + assertEquals("process", tokens[0].text) + } + + @Test + fun `Porter Stemmer should handle various English word forms`() { + // Test various word forms that should stem to similar roots + val testCases = mapOf( + "running" to "run", + "runs" to "run", + "runner" to "runner", // noun, different handling + "cats" to "cat", + "playing" to "plai", // Porter stemmer result + "played" to "plai", + "happily" to "happili", + "happiness" to "happi" + ) + + testCases.forEach { (input, expectedStem) -> + val tokens = FallbackNlpTokenizer.tokenize(input) + assertEquals(1, tokens.size, "Expected 1 token for '$input'") + assertEquals(expectedStem, tokens[0].text, "Expected '$input' to stem to '$expectedStem'") + } + } + + @Test + fun `Porter Stemmer should unify related words for better recall`() { + // Words like "processing", "processed", "process" should all map to "process" + val words = listOf("processing", "processed", "process") + val stems = words.map { FallbackNlpTokenizer.tokenize(it)[0].text } + + // All should be "process" + assertTrue(stems.all { it == "process" }, "All variations should stem to 'process', got: $stems") + } + + @Test + fun `Porter Stemmer should preserve short words`() { + // Words shorter than 3 characters should be preserved + val tokens = FallbackNlpTokenizer.tokenize("go do be") + assertEquals(3, tokens.size) + assertEquals("go", tokens[0].text) + assertEquals("do", tokens[1].text) + assertEquals("be", tokens[2].text) + } + + // ==================== Chinese BiMM Segmentation Tests ==================== + + @Test + fun `BiMM should segment Chinese text with dictionary words`() { + val tokens = FallbackNlpTokenizer.tokenize("ไบบๅทฅๆ™บ่ƒฝ") + // Should recognize "ไบบๅทฅๆ™บ่ƒฝ" as a single word + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + assertTrue(chineseTokens.any { it.text == "ไบบๅทฅๆ™บ่ƒฝ" }, "Should recognize 'ไบบๅทฅๆ™บ่ƒฝ' as a word") + } + + @Test + fun `BiMM should segment common Chinese words`() { + val tokens = FallbackNlpTokenizer.tokenize("ๆ•ฐๆฎๅบ“็ณป็ปŸ") + val texts = tokens.map { it.text } + + // Should segment as "ๆ•ฐๆฎๅบ“" + "็ณป็ปŸ" (both in dictionary) + assertTrue(texts.contains("ๆ•ฐๆฎๅบ“") || texts.contains("็ณป็ปŸ"), + "Should recognize common tech terms, got: $texts") + } + + @Test + fun `BiMM should handle ambiguous segmentation with forward and reverse matching`() { + // "ๅ—ไบฌๅธ‚้•ฟๆฑŸๅคงๆกฅ" is a classic ambiguity example + // FMM might give: ๅ—ไบฌๅธ‚/้•ฟๆฑŸ/ๅคงๆกฅ + // RMM might give: ๅ—ไบฌ/ๅธ‚้•ฟ/ๆฑŸๅคงๆกฅ (incorrect) + // BiMM chooses based on fewer tokens (longer matches) + val tokens = FallbackNlpTokenizer.tokenize("ๅ—ไบฌๅธ‚้•ฟๆฑŸๅคงๆกฅ") + val texts = tokens.map { it.text } + + // Should prefer segmentation with "ๅ—ไบฌๅธ‚", "้•ฟๆฑŸ", "ๅคงๆกฅ" (all in dictionary) + assertTrue( + texts.contains("ๅ—ไบฌๅธ‚") || texts.contains("้•ฟๆฑŸ") || texts.contains("ๅคงๆกฅ"), + "Should handle ambiguous segmentation correctly, got: $texts" + ) + } + + @Test + fun `BiMM should fall back to character-by-character for unknown words`() { + // Text with characters not in dictionary + val tokens = FallbackNlpTokenizer.tokenize("ๅ•Šๅงๅ”ง") + assertEquals(3, tokens.size, "Should segment unknown characters individually") + } + + // ==================== Mixed Language Tests ==================== + + @Test + fun `should handle mixed English and Chinese text`() { + val tokens = FallbackNlpTokenizer.tokenize("Helloไธ–็•ŒTestๆต‹่ฏ•") + val englishTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.ENGLISH } + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + + assertTrue(englishTokens.isNotEmpty(), "Should have English tokens") + assertTrue(chineseTokens.isNotEmpty(), "Should have Chinese tokens") + } + + // ==================== CamelCase Handling Tests ==================== + + @Test + fun `should split CamelCase words`() { + val tokens = FallbackNlpTokenizer.tokenize("UserDao") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("user"), "Should extract 'user' from 'UserDao'") + assertTrue(texts.contains("dao"), "Should extract 'dao' from 'UserDao'") + } + + @Test + fun `should split complex CamelCase with multiple words`() { + val tokens = FallbackNlpTokenizer.tokenize("getUserNameById") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("get"), "Should extract 'get'") + assertTrue(texts.contains("user"), "Should extract 'user'") + assertTrue(texts.contains("name"), "Should extract 'name'") + // "by" is short, "id" is short - they should still be present (as "by", "id") + assertTrue(texts.contains("by"), "Should extract 'by'") + assertTrue(texts.contains("id"), "Should extract 'id'") + } + + @Test + fun `should handle HTML style uppercase sequences`() { + val tokens = FallbackNlpTokenizer.tokenize("HTMLParser") + val texts = tokens.map { it.text } + + // Should split as "HTML" + "Parser" -> "html", "parser" + assertTrue(texts.contains("html") || texts.any { it.contains("html") }, + "Should handle HTMLParser correctly, got: $texts") + } + + @Test + fun `should handle snake_case`() { + val tokens = FallbackNlpTokenizer.tokenize("user_name_service") + val texts = tokens.map { it.text } + + assertTrue(texts.contains("user"), "Should extract 'user' from snake_case") + assertTrue(texts.contains("name"), "Should extract 'name' from snake_case") + assertTrue(texts.contains("servic"), "Should extract 'servic' (stemmed) from snake_case") + } + + // ==================== SemVer Handling Tests ==================== + + @Test + fun `should preserve semantic version numbers`() { + val tokens = FallbackNlpTokenizer.tokenize("v1.2.3") + assertEquals(1, tokens.size, "SemVer should be kept as single token") + assertEquals("v1.2.3", tokens[0].text) + assertEquals(FallbackNlpTokenizer.TokenType.CODE, tokens[0].type) + } + + @Test + fun `should preserve complex SemVer with prerelease`() { + val tokens = FallbackNlpTokenizer.tokenize("1.0.0-alpha") + assertEquals(1, tokens.size, "SemVer with prerelease should be single token") + assertEquals("1.0.0-alpha", tokens[0].text) + } + + @Test + fun `should preserve SemVer with build metadata`() { + val tokens = FallbackNlpTokenizer.tokenize("2.1.0+build.123") + assertEquals(1, tokens.size, "SemVer with build metadata should be single token") + assertEquals("2.1.0+build.123", tokens[0].text) + } + + // ==================== RAKE Keyword Extraction Tests ==================== + + @Test + fun `RAKE should extract meaningful keywords from text`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "The quick brown fox jumps over the lazy dog", + maxKeywords = 5 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords") + assertTrue(keywords.none { it in FallbackNlpTokenizer.StopWords.ENGLISH }, + "Keywords should not contain stop words") + } + + @Test + fun `RAKE should filter stop words`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "The is at which on and in to of for", + maxKeywords = 10 + ) + + assertTrue(keywords.isEmpty(), "Stop words only text should return empty keywords") + } + + @Test + fun `RAKE should rank keywords by co-occurrence`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "data processing pipeline data analysis data visualization", + maxKeywords = 5 + ) + + // "data" appears most frequently and co-occurs with many words + assertTrue(keywords.contains("data"), "Frequent co-occurring word 'data' should be top keyword") + } + + @Test + fun `RAKE should handle Chinese text`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ไบบๅทฅๆ™บ่ƒฝๆŠ€ๆœฏๅœจๆ•ฐๆฎๅˆ†ๆž้ข†ๅŸŸ็š„ๅบ”็”จ", + maxKeywords = 5 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract Chinese keywords") + assertTrue(keywords.none { it in FallbackNlpTokenizer.StopWords.CHINESE }, + "Chinese keywords should not contain stop words") + } + + @Test + fun `RAKE should respect maxKeywords limit`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "apple banana cherry date elderberry fig grape honeydew kiwi lemon mango nectarine", + maxKeywords = 3 + ) + + assertTrue(keywords.size <= 3, "Should respect maxKeywords limit") + } + + // ==================== Legacy Method Tests ==================== + + @Test + fun `legacy extractKeywords should work with custom stop words`() { + val customStopWords = setOf("quick", "lazy") + val keywords = FallbackNlpTokenizer.extractKeywords( + "The quick brown fox jumps over the lazy dog", + customStopWords + ) + + assertTrue(keywords.none { it == "quick" }, "Should filter custom stop word 'quick'") + assertTrue(keywords.none { it == "lazy" }, "Should filter custom stop word 'lazy'") + } + + @Test + fun `legacy extractKeywords should return distinct tokens`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "test test test data data", + emptySet() + ) + + assertEquals(keywords.distinct().size, keywords.size, "Keywords should be distinct") + } + + // ==================== Edge Cases ==================== + + @Test + fun `should handle empty string`() { + val tokens = FallbackNlpTokenizer.tokenize("") + assertTrue(tokens.isEmpty(), "Empty string should return empty tokens") + + val keywords = FallbackNlpTokenizer.extractKeywords("", maxKeywords = 5) + assertTrue(keywords.isEmpty(), "Empty string should return empty keywords") + } + + @Test + fun `should handle whitespace only`() { + val tokens = FallbackNlpTokenizer.tokenize(" \t\n ") + assertTrue(tokens.isEmpty(), "Whitespace only should return empty tokens") + } + + @Test + fun `should handle special characters`() { + val tokens = FallbackNlpTokenizer.tokenize("hello! @world# \$test%") + assertTrue(tokens.isNotEmpty(), "Should extract tokens from text with special chars") + } + + @Test + fun `should handle numbers mixed with text`() { + val tokens = FallbackNlpTokenizer.tokenize("user123 test456") + val texts = tokens.map { it.text } + + // Numbers should be separated from text + assertTrue(texts.any { it.contains("user") || it == "user" }, + "Should handle alphanumeric text") + } + + // ==================== Integration Tests ==================== + + @Test + fun `should handle realistic code query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "How to implement UserService with database connection pooling?", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from code query") + // Should contain stemmed versions of key terms + assertTrue( + keywords.any { it.contains("user") || it.contains("servic") || it.contains("databas") }, + "Should extract relevant programming terms, got: $keywords" + ) + } + + @Test + fun `should handle realistic Chinese tech query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ๅฆ‚ไฝ•ไฝฟ็”จๆ•ฐๆฎๅบ“่ฟžๆŽฅๆฑ ไผ˜ๅŒ–็ณป็ปŸๆ€ง่ƒฝ๏ผŸ", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from Chinese query") + } + + @Test + fun `should handle mixed language tech query`() { + val keywords = FallbackNlpTokenizer.extractKeywords( + "ๅฆ‚ไฝ•ๅฎž็ŽฐUserService็š„ๆ•ฐๆฎๅบ“connection pooling?", + maxKeywords = 10 + ) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from mixed language query") + } + + // ==================== Token Type Tests ==================== + + @Test + fun `should correctly identify token types`() { + val tokens = FallbackNlpTokenizer.tokenize("Helloไธ–็•Œv1.2.3") + + val englishTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.ENGLISH } + val chineseTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CHINESE } + val codeTokens = tokens.filter { it.type == FallbackNlpTokenizer.TokenType.CODE } + + assertTrue(englishTokens.isNotEmpty(), "Should have English tokens") + assertTrue(chineseTokens.isNotEmpty(), "Should have Chinese tokens") + assertTrue(codeTokens.isNotEmpty(), "Should have CODE tokens (SemVer)") + } + + // ==================== Performance Considerations ==================== + + @Test + fun `should handle long text efficiently`() { + val longText = "data processing ".repeat(100) + "ไบบๅทฅๆ™บ่ƒฝ".repeat(50) + + // Simply verify it completes without timeout (test framework handles timeout) + val keywords = FallbackNlpTokenizer.extractKeywords(longText, maxKeywords = 10) + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from long text") + } +} + diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt new file mode 100644 index 0000000000..cf733f565c --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/chatdb/SchemaLinkerTest.kt @@ -0,0 +1,206 @@ +package cc.unitmesh.agent.chatdb + +import cc.unitmesh.agent.database.ColumnSchema +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.TableSchema +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertNotNull + +/** + * Tests for KeywordSchemaLinker - keyword-based schema linking for Text2SQL + */ +class SchemaLinkerTest { + + private val schemaLinker = KeywordSchemaLinker() + + // ============= Keyword Extraction Tests ============= + + @Test + fun testExtractKeywordsBasic() = runTest { + val keywords = schemaLinker.extractKeywords("Show me all users") + + // "show" and "me" are stop words, "all" is also filtered + assertTrue(keywords.contains("users")) + } + + @Test + fun testExtractKeywordsFiltersStopWords() = runTest { + val keywords = schemaLinker.extractKeywords("Show me the top 10 users with the most orders") + + // Stop words should be filtered (show, me, the, top, most are in stopWords) + assertTrue(keywords.none { it in listOf("show", "me", "the", "top", "most") }) + // Important words should remain + assertTrue(keywords.contains("users")) + assertTrue(keywords.contains("orders")) + } + + @Test + fun testExtractKeywordsFiltersShortWords() = runTest { + val keywords = schemaLinker.extractKeywords("Get a list of all items") + + // Short words (length <= 2) should be filtered + assertTrue(keywords.none { it.length <= 2 }) + } + + @Test + fun testExtractKeywordsLowercase() = runTest { + val keywords = schemaLinker.extractKeywords("Show USERS and ORDERS") + + assertTrue(keywords.contains("users")) + assertTrue(keywords.contains("orders")) + assertTrue(keywords.none { word -> word.any { c -> c.isUpperCase() } }) + } + + // ============= Schema Linking Tests ============= + + private fun createTestSchema(): DatabaseSchema { + return DatabaseSchema( + tables = listOf( + TableSchema( + name = "users", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("name", "VARCHAR", false, null), + ColumnSchema("email", "VARCHAR", false, null), + ColumnSchema("created_at", "DATETIME", false, null) + ) + ), + TableSchema( + name = "orders", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("user_id", "INT", false, null), + ColumnSchema("total", "DECIMAL", false, null), + ColumnSchema("status", "VARCHAR", false, null) + ) + ), + TableSchema( + name = "products", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("name", "VARCHAR", false, null), + ColumnSchema("price", "DECIMAL", false, null), + ColumnSchema("category", "VARCHAR", false, null) + ) + ) + ) + ) + } + + @Test + fun testLinkSchemaFindsRelevantTables() = runTest { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me all users with their orders", schema) + + assertTrue(result.relevantTables.contains("users")) + assertTrue(result.relevantTables.contains("orders")) + } + + @Test + fun testLinkSchemaFindsRelevantColumns() = runTest { + val schema = createTestSchema() + + val result = schemaLinker.link("Show user names and order totals", schema) + + assertTrue(result.relevantColumns.any { it.contains("name") }) + } + + @Test + fun testLinkSchemaWithNoMatches() = runTest { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me the weather forecast", schema) + + // Should still return a result - fallback includes all tables + assertNotNull(result) + assertTrue(result.relevantTables.isNotEmpty()) + } + + @Test + fun testLinkSchemaWithPartialMatch() = runTest { + val schema = DatabaseSchema( + tables = listOf( + TableSchema( + name = "user_accounts", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("username", "VARCHAR", false, null), + ColumnSchema("email", "VARCHAR", false, null) + ) + ), + TableSchema( + name = "customer_orders", + columns = listOf( + ColumnSchema("id", "INT", false, null), + ColumnSchema("customer_id", "INT", false, null), + ColumnSchema("amount", "DECIMAL", false, null) + ) + ) + ) + ) + + val result = schemaLinker.link("Show me all users", schema) + + // Should find user_accounts due to partial match + assertTrue(result.relevantTables.any { it.contains("user") }) + } + + // ============= Edge Cases ============= + + @Test + fun testExtractKeywordsEmptyQuery() = runTest { + val keywords = schemaLinker.extractKeywords("") + + assertTrue(keywords.isEmpty()) + } + + @Test + fun testLinkSchemaEmptySchema() = runTest { + val schema = DatabaseSchema(tables = emptyList()) + val result = schemaLinker.link("Show me all users", schema) + + assertTrue(result.relevantTables.isEmpty()) + assertTrue(result.relevantColumns.isEmpty()) + } + + @Test + fun testLinkSchemaWithSpecialCharacters() = runTest { + val schema = createTestSchema() + + val result = schemaLinker.link("Show me users' emails!", schema) + + assertTrue(result.relevantTables.contains("users")) + } + + @Test + fun testLinkSchemaWithNumbers() = runTest { + val schema = createTestSchema() + + val result = schemaLinker.link("Show top 10 users with orders over 100", schema) + + assertTrue(result.relevantTables.contains("users")) + assertTrue(result.relevantTables.contains("orders")) + } + + // ============= Schema Linking Result Tests ============= + + @Test + fun testSchemaLinkingResultDescription() { + val result = SchemaLinkingResult( + relevantTables = listOf("users", "orders"), + relevantColumns = listOf("users.name", "orders.total"), + keywords = listOf("users", "orders"), + confidence = 0.9 + ) + + assertEquals(2, result.relevantTables.size) + assertEquals(2, result.relevantColumns.size) + assertEquals(2, result.keywords.size) + assertEquals(0.9, result.confidence) + } +} + diff --git a/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt new file mode 100644 index 0000000000..f7881c35a9 --- /dev/null +++ b/mpp-core/src/commonTest/kotlin/cc/unitmesh/agent/subagent/SqlReviseAgentTest.kt @@ -0,0 +1,161 @@ +package cc.unitmesh.agent.subagent + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Tests for SqlReviseAgent data classes and schema + */ +class SqlReviseAgentTest { + + // ============= Input Validation Tests ============= + + @Test + fun testSqlRevisionInputCreation() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM user", + errorMessage = "Table 'user' doesn't exist", + schemaDescription = "Tables: users (id, name, email)", + maxAttempts = 3 + ) + + assertEquals("Show me all users", input.originalQuery) + assertEquals("SELECT * FROM user", input.failedSql) + assertEquals("Table 'user' doesn't exist", input.errorMessage) + assertEquals("Tables: users (id, name, email)", input.schemaDescription) + assertEquals(3, input.maxAttempts) + } + + @Test + fun testSqlRevisionInputDefaultMaxAttempts() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM user", + errorMessage = "Table 'user' doesn't exist", + schemaDescription = "Tables: users" + ) + + assertEquals(3, input.maxAttempts) + } + + // ============= Schema Tests ============= + + @Test + fun testSqlReviseAgentSchemaExampleUsage() { + val example = SqlReviseAgentSchema.getExampleUsage("sql-revise") + + assertTrue(example.contains("/sql-revise")) + assertTrue(example.contains("originalQuery=")) + assertTrue(example.contains("failedSql=")) + assertTrue(example.contains("errorMessage=")) + } + + @Test + fun testSqlReviseAgentSchemaToJsonSchema() { + val jsonSchema = SqlReviseAgentSchema.toJsonSchema() + + assertNotNull(jsonSchema) + val schemaString = jsonSchema.toString() + assertTrue(schemaString.contains("originalQuery")) + assertTrue(schemaString.contains("failedSql")) + assertTrue(schemaString.contains("errorMessage")) + assertTrue(schemaString.contains("schemaDescription")) + assertTrue(schemaString.contains("maxAttempts")) + } + + @Test + fun testSqlReviseAgentSchemaDescription() { + val description = SqlReviseAgentSchema.description + + assertTrue(description.contains("Revise")) + assertTrue(description.contains("SQL")) + } + + // ============= Validation Result Tests ============= + + @Test + fun testSqlValidationResultSuccess() { + val result = SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = listOf("Using SELECT *") + ) + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + assertEquals(1, result.warnings.size) + assertEquals("Using SELECT *", result.warnings[0]) + } + + @Test + fun testSqlValidationResultFailure() { + val result = SqlValidationResult( + isValid = false, + errors = listOf("Syntax error near 'FORM'"), + warnings = emptyList() + ) + + assertEquals(false, result.isValid) + assertEquals("Syntax error near 'FORM'", result.errors[0]) + assertTrue(result.warnings.isEmpty()) + } + + // ============= Edge Cases ============= + + @Test + fun testSqlRevisionInputWithEmptySchema() { + val input = SqlRevisionInput( + originalQuery = "Show me all users", + failedSql = "SELECT * FROM users", + errorMessage = "Unknown error", + schemaDescription = "" + ) + + assertEquals("", input.schemaDescription) + } + + @Test + fun testSqlRevisionInputWithComplexQuery() { + val complexSql = """ + SELECT u.name, COUNT(o.id) as order_count + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.created_at > '2023-01-01' + GROUP BY u.name + HAVING COUNT(o.id) > 5 + ORDER BY order_count DESC + LIMIT 10 + """.trimIndent() + + val input = SqlRevisionInput( + originalQuery = "Show top 10 users with most orders since 2023", + failedSql = complexSql, + errorMessage = "Column 'created_at' not found", + schemaDescription = "Tables: users (id, name, created_date), orders (id, user_id, total)" + ) + + assertTrue(input.failedSql.contains("LEFT JOIN")) + assertTrue(input.failedSql.contains("GROUP BY")) + assertTrue(input.failedSql.contains("HAVING")) + } + + @Test + fun testSqlValidationResultWithMultipleWarnings() { + val result = SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = listOf( + "Using SELECT *", + "No LIMIT clause", + "Consider adding index on user_id" + ) + ) + + assertTrue(result.isValid) + assertEquals(3, result.warnings.size) + } +} + diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt new file mode 100644 index 0000000000..167ed5d19b --- /dev/null +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.ios.kt @@ -0,0 +1,22 @@ +package cc.unitmesh.agent.chatdb + +/** + * iOS implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + * + * TODO: Consider using iOS NaturalLanguage framework for better Chinese tokenization. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackNlpTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt new file mode 100644 index 0000000000..e997c4d825 --- /dev/null +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.ios.kt @@ -0,0 +1,158 @@ +package cc.unitmesh.agent.subagent + +/** + * iOS implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause + val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match JOIN clause + val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match UPDATE clause + val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match INSERT INTO clause + val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match DELETE FROM clause + val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + return tables.distinct() + } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt b/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt index a36b3140dd..020373bcbb 100644 --- a/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt +++ b/mpp-core/src/iosMain/kotlin/cc/unitmesh/llm/ExecutorFactory.ios.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * iOS implementation: Blocking executor creation not supported + * iOS doesn't support runBlocking, so we return null + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt index 9735dde91a..aa7865ac57 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/RendererExports.kt @@ -1,9 +1,11 @@ package cc.unitmesh.agent +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.PlanSummaryData import cc.unitmesh.agent.plan.StepSummary import cc.unitmesh.agent.plan.TaskSummary import cc.unitmesh.agent.render.CodingAgentRenderer +import cc.unitmesh.agent.subagent.SqlOperationType import kotlin.js.JsExport /** @@ -196,6 +198,21 @@ class JsRendererAdapter(private val jsRenderer: JsCodingAgentRenderer) : CodingA jsRenderer.renderError("Tool '$toolName' requires user confirmation: $params (Auto-approved)") } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: DryRunResult?, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // JS renderer auto-rejects for safety + val dryRunInfo = if (dryRunResult != null) " (dry run: ${if (dryRunResult.isValid) "passed" else "failed"})" else "" + jsRenderer.renderError("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}$dryRunInfo (Auto-rejected)") + onReject() + } + override fun renderPlanSummary(summary: PlanSummaryData) { jsRenderer.renderPlanSummary(JsPlanSummaryData.from(summary)) } diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt new file mode 100644 index 0000000000..1c03f28567 --- /dev/null +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.js.kt @@ -0,0 +1,20 @@ +package cc.unitmesh.agent.chatdb + +/** + * JavaScript implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackNlpTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt new file mode 100644 index 0000000000..337f0a6d3b --- /dev/null +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.js.kt @@ -0,0 +1,168 @@ +package cc.unitmesh.agent.subagent + +/** + * JavaScript implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause: FROM table_name or FROM schema.table_name + val fromPattern = Regex("""FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match JOIN clause: JOIN table_name or LEFT/RIGHT/INNER/OUTER JOIN table_name + val joinPattern = Regex("""(?:LEFT|RIGHT|INNER|OUTER|CROSS|FULL)?\s*JOIN\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match UPDATE clause: UPDATE table_name + val updatePattern = Regex("""UPDATE\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match INSERT INTO clause: INSERT INTO table_name + val insertPattern = Regex("""INSERT\s+INTO\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + // Match DELETE FROM clause: DELETE FROM table_name + val deletePattern = Regex("""DELETE\s+FROM\s+([`"\[]?\w+[`"\]]?(?:\.[`"\[]?\w+[`"\]]?)?)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(cleanTableName(match.groupValues[1])) + } + + return tables.distinct() + } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } + + private fun cleanTableName(name: String): String { + // Remove quotes and brackets, extract table name from schema.table format + val cleaned = name.replace(Regex("""[`"\[\]]"""), "") + return if (cleaned.contains(".")) { + cleaned.substringAfterLast(".") + } else { + cleaned + } + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt b/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt index 5601d81f97..edf022ecba 100644 --- a/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt +++ b/mpp-core/src/jsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.js.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * JS implementation: Blocking executor creation not supported + * Use createAsync() instead for async initialization + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt new file mode 100644 index 0000000000..5ec36e8c1a --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.jvm.kt @@ -0,0 +1,85 @@ +package cc.unitmesh.agent.chatdb + +import com.mayabot.nlp.segment.Lexers +import io.github.oshai.kotlinlogging.KotlinLogging + +private val logger = KotlinLogging.logger {} + +/** + * JVM implementation of NlpTokenizer using MyNLP for Chinese tokenization. + * + * MyNLP (https://github.com/jimichan/mynlp) provides high-quality Chinese word segmentation + * which is essential for accurate keyword extraction from Chinese natural language queries. + */ +actual object NlpTokenizer { + + // Lazy initialization of the lexer to avoid startup overhead + private val lexer by lazy { + try { + Lexers.core() + } catch (e: Exception) { + logger.warn(e) { "Failed to initialize MyNLP lexer, falling back to simple tokenization" } + null + } + } + + /** + * Extract keywords from natural language query using MyNLP tokenization. + * Supports both English and Chinese text. + * + * For Chinese text, MyNLP provides proper word segmentation instead of + * character-by-character splitting, which significantly improves matching accuracy. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + val currentLexer = lexer + if (currentLexer == null) { + // Fallback to simple tokenization if MyNLP initialization failed + return FallbackNlpTokenizer.extractKeywords(query, stopWords) + } + + return try { + extractKeywordsWithMyNlp(query, stopWords, currentLexer) + } catch (e: Exception) { + logger.warn(e) { "MyNLP tokenization failed, falling back to simple tokenization" } + FallbackNlpTokenizer.extractKeywords(query, stopWords) + } + } + + private fun extractKeywordsWithMyNlp( + query: String, + stopWords: Set, + lexer: com.mayabot.nlp.segment.Lexer + ): List { + val keywords = mutableListOf() + + // Use MyNLP to tokenize the query + val sentence = lexer.scan(query) + + for (term in sentence) { + val word = term.word.lowercase() + + // Skip short words and stop words + if (word.length <= 1) continue + if (word in stopWords) continue + + // Skip pure punctuation + if (word.all { !it.isLetterOrDigit() }) continue + + keywords.add(word) + } + + // Also extract English words that might be missed by Chinese tokenizer + val englishWords = query.lowercase() + .replace(Regex("[^a-z0-9\\s_]"), " ") + .split(Regex("\\s+")) + .filter { it.length > 2 && it !in stopWords && it !in keywords } + keywords.addAll(englishWords) + + return keywords.distinct() + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt index 288658e523..3b32dfe29c 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/database/ExposedDatabaseConnection.kt @@ -54,15 +54,144 @@ class ExposedDatabaseConnection( } } + override suspend fun executeUpdate(sql: String): UpdateResult = withContext(Dispatchers.IO) { + try { + hikariDataSource.connection.use { connection -> + val stmt = connection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + val affectedRows = stmt.executeUpdate() + + // Try to get generated keys + val generatedKeys = mutableListOf() + try { + val keysRs = stmt.generatedKeys + while (keysRs.next()) { + generatedKeys.add(keysRs.getString(1)) + } + keysRs.close() + } catch (e: Exception) { + // Ignore if generated keys are not available + } + + stmt.close() + + UpdateResult.success(affectedRows, generatedKeys) + } + } catch (e: Exception) { + throw DatabaseException.queryFailed(sql, e.message ?: "Unknown error") + } + } + + /** + * Dry run SQL to validate without executing. + * Uses transaction rollback to test DML statements safely. + */ + override suspend fun dryRun(sql: String): DryRunResult = withContext(Dispatchers.IO) { + try { + hikariDataSource.connection.use { connection -> + val originalAutoCommit = connection.autoCommit + try { + // Disable auto-commit to use transaction + connection.autoCommit = false + + val sqlUpper = sql.trim().uppercase() + val warnings = mutableListOf() + + when { + // For SELECT, use EXPLAIN to validate + sqlUpper.startsWith("SELECT") -> { + val explainSql = "EXPLAIN $sql" + val stmt = connection.prepareStatement(explainSql) + try { + stmt.executeQuery() + stmt.close() + DryRunResult.valid("SELECT query is valid") + } catch (e: Exception) { + DryRunResult.invalid("Query validation failed: ${e.message}") + } + } + + // For INSERT/UPDATE/DELETE, execute in transaction and rollback + sqlUpper.startsWith("INSERT") || + sqlUpper.startsWith("UPDATE") || + sqlUpper.startsWith("DELETE") -> { + val stmt = connection.prepareStatement(sql) + try { + val affectedRows = stmt.executeUpdate() + stmt.close() + // Rollback to undo the changes + connection.rollback() + DryRunResult.valid( + "Statement is valid (would affect $affectedRows row(s))", + estimatedRows = affectedRows + ) + } catch (e: Exception) { + connection.rollback() + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + + // For DDL (CREATE, ALTER, DROP), we can't safely dry run + // Just validate syntax using EXPLAIN if possible + sqlUpper.startsWith("CREATE") || + sqlUpper.startsWith("ALTER") || + sqlUpper.startsWith("DROP") || + sqlUpper.startsWith("TRUNCATE") -> { + // For DDL, we can try to parse it but can't execute safely + // Return a warning that DDL cannot be fully validated + warnings.add("DDL statements cannot be fully validated without execution") + + // Try basic syntax check by preparing the statement + try { + val stmt = connection.prepareStatement(sql) + stmt.close() + DryRunResult( + isValid = true, + message = "DDL syntax appears valid (cannot fully validate without execution)", + warnings = warnings + ) + } catch (e: Exception) { + DryRunResult.invalid("DDL syntax error: ${e.message}") + } + } + + else -> { + // Unknown statement type, try to prepare it + try { + val stmt = connection.prepareStatement(sql) + stmt.close() + DryRunResult.valid("Statement syntax is valid") + } catch (e: Exception) { + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + } + } finally { + // Restore original auto-commit setting + try { + connection.rollback() // Ensure any pending changes are rolled back + connection.autoCommit = originalAutoCommit + } catch (e: Exception) { + // Ignore cleanup errors + } + } + } + } catch (e: Exception) { + DryRunResult.invalid("Dry run failed: ${e.message}") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { hikariDataSource.connection.use { connection -> val metadata = connection.metaData val tables = mutableListOf() - // Get all tables + // Get current database/catalog name to filter tables + val currentCatalog = connection.catalog + + // Get tables only from current database (catalog) val tableTypes = arrayOf("TABLE", "VIEW") - val tableRs = metadata.getTables(null, null, "%", tableTypes) + val tableRs = metadata.getTables(currentCatalog, null, "%", tableTypes) while (tableRs.next()) { val tableName = tableRs.getString("TABLE_NAME") @@ -72,10 +201,10 @@ class ExposedDatabaseConnection( null } - // Get primary keys + // Get primary keys for current catalog val primaryKeys = mutableSetOf() try { - val pkRs = metadata.getPrimaryKeys(null, null, tableName) + val pkRs = metadata.getPrimaryKeys(currentCatalog, null, tableName) while (pkRs.next()) { primaryKeys.add(pkRs.getString("COLUMN_NAME")) } @@ -84,8 +213,8 @@ class ExposedDatabaseConnection( // Ignore if primary keys cannot be retrieved } - // Get columns - val columnRs = metadata.getColumns(null, null, tableName, null) + // Get columns for current catalog + val columnRs = metadata.getColumns(currentCatalog, null, tableName, null) val columns = mutableListOf() while (columnRs.next()) { diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt new file mode 100644 index 0000000000..e49ff02ae7 --- /dev/null +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.jvm.kt @@ -0,0 +1,236 @@ +package cc.unitmesh.agent.subagent + +import net.sf.jsqlparser.parser.CCJSqlParserUtil +import net.sf.jsqlparser.statement.Statement +import net.sf.jsqlparser.statement.alter.Alter +import net.sf.jsqlparser.statement.create.table.CreateTable +import net.sf.jsqlparser.statement.create.index.CreateIndex +import net.sf.jsqlparser.statement.create.view.CreateView +import net.sf.jsqlparser.statement.delete.Delete +import net.sf.jsqlparser.statement.drop.Drop +import net.sf.jsqlparser.statement.insert.Insert +import net.sf.jsqlparser.statement.select.Select +import net.sf.jsqlparser.statement.truncate.Truncate +import net.sf.jsqlparser.statement.update.Update +import net.sf.jsqlparser.util.TablesNamesFinder + +/** + * JVM implementation of SqlValidator using JSqlParser. + * + * This validator uses JSqlParser to validate SQL syntax. + * It can detect: + * - Syntax errors + * - Malformed SQL statements + * - Unsupported SQL constructs + * - Table names not in whitelist (schema validation) + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + /** + * Validate SQL with table whitelist - ensures only allowed tables are used + * + * @param sql The SQL query to validate + * @param allowedTables Set of table names that are allowed in the query + * @return SqlValidationResult with errors if invalid tables are used + */ + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + + // Extract table names from the SQL + val tablesNamesFinder = TablesNamesFinder() + val usedTables = tablesNamesFinder.getTableList(statement) + + // Check if all used tables are in the whitelist (case-insensitive) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = collectWarnings(statement) + ) + } else { + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ) + } + } catch (e: Exception) { + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ) + } + } + + /** + * Extract table names from SQL query + */ + actual override fun extractTableNames(sql: String): List { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + val tablesNamesFinder = TablesNamesFinder() + tablesNamesFinder.getTableList(statement) + } catch (e: Exception) { + emptyList() + } + } + + /** + * Detect the type of SQL statement using JSqlParser + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + when (statement) { + is Select -> SqlOperationType.SELECT + is Insert -> SqlOperationType.INSERT + is Update -> SqlOperationType.UPDATE + is Delete -> SqlOperationType.DELETE + is CreateTable, is CreateIndex, is CreateView -> SqlOperationType.CREATE + is Alter -> SqlOperationType.ALTER + is Drop -> SqlOperationType.DROP + is Truncate -> SqlOperationType.TRUNCATE + else -> SqlOperationType.OTHER + } + } catch (e: Exception) { + // Fallback to regex-based detection if parsing fails + detectSqlTypeByRegex(sql) + } + } + + /** + * Fallback regex-based SQL type detection + */ + private fun detectSqlTypeByRegex(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } + + /** + * Validate SQL and return the parsed statement if valid + */ + fun validateAndParse(sql: String): Pair { + return try { + val statement: Statement = CCJSqlParserUtil.parse(sql) + Pair( + SqlValidationResult( + isValid = true, + errors = emptyList(), + warnings = collectWarnings(statement) + ), + statement + ) + } catch (e: Exception) { + Pair( + SqlValidationResult( + isValid = false, + errors = listOf(extractErrorMessage(e)), + warnings = emptyList() + ), + null + ) + } + } + + /** + * Extract a clean error message from the exception + */ + private fun extractErrorMessage(e: Exception): String { + val message = e.message ?: "Unknown SQL parsing error" + // Clean up JSqlParser error messages + return when { + message.contains("Encountered") -> { + // Parse error with position info + val match = Regex("Encountered \"(.+?)\" at line (\\d+), column (\\d+)").find(message) + if (match != null) { + val (token, line, column) = match.destructured + "Syntax error at line $line, column $column: unexpected token '$token'" + } else { + message + } + } + message.contains("Was expecting") -> { + val match = Regex("Was expecting.*?:\\s*(.+)").find(message) + if (match != null) { + "Expected: ${match.groupValues[1].take(100)}" + } else { + message + } + } + else -> message.take(200) + } + } + + /** + * Collect warnings from parsed statement (e.g., deprecated syntax) + */ + private fun collectWarnings(statement: Statement): List { + val warnings = mutableListOf() + + // Check for common issues that aren't errors but might be problematic + val sql = statement.toString() + + if (sql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!sql.contains("WHERE", ignoreCase = true) && + (sql.contains("UPDATE", ignoreCase = true) || sql.contains("DELETE", ignoreCase = true))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return warnings + } + + companion object { + /** + * Quick validation check - returns true if SQL is syntactically valid + */ + fun isValidSql(sql: String): Boolean { + return try { + CCJSqlParserUtil.parse(sql) + true + } catch (e: Exception) { + false + } + } + } +} + diff --git a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt index a1ce2f923a..6158c6c838 100644 --- a/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt +++ b/mpp-core/src/jvmMain/kotlin/cc/unitmesh/llm/ExecutorFactory.jvm.kt @@ -1,8 +1,10 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.GithubCopilotClientProvider import cc.unitmesh.llm.provider.LLMClientProvider import cc.unitmesh.llm.provider.LLMClientRegistry +import kotlinx.coroutines.runBlocking /** * JVM implementation: Creates and registers GithubCopilotClientProvider @@ -21,3 +23,15 @@ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? { } } +/** + * JVM implementation: Uses runBlocking to create executor synchronously + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? { + return runBlocking { + provider.createExecutor(config) + } +} + diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt new file mode 100644 index 0000000000..6492d27b79 --- /dev/null +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizerTest.kt @@ -0,0 +1,100 @@ +package cc.unitmesh.agent.chatdb + +import org.junit.Test +import kotlin.test.assertTrue + +/** + * Test NLP tokenizer functionality on JVM using MyNLP. + */ +class NlpTokenizerTest { + + private val stopWords = SchemaLinker.STOP_WORDS + + @Test + fun `test Chinese tokenization with MyNLP`() { + val query = "ๆŸฅ่ฏขๆ‰€ๆœ‰็”จๆˆท็š„่ฎขๅ•้‡‘้ข" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // MyNLP should extract meaningful words like "ๆŸฅ่ฏข", "็”จๆˆท", "่ฎขๅ•", "้‡‘้ข" + // instead of just individual characters + assertTrue(keywords.isNotEmpty(), "Should extract keywords from Chinese text") + + // Check that we get proper word segmentation (not just single characters) + val multiCharWords = keywords.filter { it.length > 1 } + assertTrue(multiCharWords.isNotEmpty(), "Should have multi-character words from Chinese segmentation") + } + + @Test + fun `test mixed Chinese and English tokenization`() { + val query = "ๆŸฅ่ฏขuser่กจไธญ็š„orderๆ•ฐๆฎ" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // Should extract both Chinese words and English words + assertTrue(keywords.isNotEmpty(), "Should extract keywords from mixed text") + assertTrue(keywords.any { it.matches(Regex("[a-z]+")) }, "Should contain English words") + } + + @Test + fun `test English only tokenization`() { + val query = "Show me the top 10 customers by order amount" + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + + // Should extract English words, filtering out stop words + assertTrue(keywords.isNotEmpty(), "Should extract keywords from English text") + assertTrue(keywords.contains("customers") || keywords.contains("amount"), + "Should contain meaningful English words") + } + + @Test + fun `compare NLP vs Fallback tokenization for Chinese`() { + val query = "็ปŸ่ฎกๆฏไธช้ƒจ้—จ็š„ๅ‘˜ๅทฅไบบๆ•ฐ" + + val nlpKeywords = NlpTokenizer.extractKeywords(query, stopWords) + val fallbackKeywords = FallbackNlpTokenizer.extractKeywords(query, stopWords) + + println("Query: $query") + println("NLP Keywords: ${nlpKeywords.joinToString(", ")}") + println("Fallback Keywords: ${fallbackKeywords.joinToString(", ")}") + + // NLP should produce better segmentation than fallback + // Fallback will include single characters, NLP should have proper words + val nlpMultiCharWords = nlpKeywords.filter { it.length > 1 } + val fallbackMultiCharWords = fallbackKeywords.filter { it.length > 1 } + + println("NLP multi-char words: ${nlpMultiCharWords.size}") + println("Fallback multi-char words: ${fallbackMultiCharWords.size}") + + // NLP should have more meaningful multi-character words + // (fallback just adds the whole string plus individual chars) + assertTrue(nlpMultiCharWords.isNotEmpty(), "NLP should produce multi-character words") + } + + @Test + fun `test database related Chinese queries`() { + val queries = listOf( + "ๆŸฅ่ฏข็”จๆˆท่กจไธญๅนด้พ„ๅคงไบŽ30็š„็”จๆˆท", + "็ปŸ่ฎก2024ๅนดๆฏๆœˆ็š„้”€ๅ”ฎ้ข", + "ๆ˜พ็คบๆ‰€ๆœ‰ๆœชๆ”ฏไป˜็š„่ฎขๅ•", + "ๆ‰พๅ‡บ่ดญไนฐ้‡‘้ขๆœ€้ซ˜็š„ๅ‰10ไธชๅฎขๆˆท" + ) + + for (query in queries) { + val keywords = NlpTokenizer.extractKeywords(query, stopWords) + println("Query: $query") + println("Keywords: ${keywords.joinToString(", ")}") + println() + + assertTrue(keywords.isNotEmpty(), "Should extract keywords from: $query") + } + } +} + diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt index fb21a5856c..c9e53df286 100644 --- a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/database/DatabaseConnectionTest.kt @@ -18,21 +18,25 @@ class DatabaseConnectionTest { ), rowCount = 3 ) - + assertEquals(3, result.rowCount) assertFalse(result.isEmpty()) assertEquals(3, result.columns.size) assertEquals(3, result.rows.size) - + val csv = result.toCsvString() assertTrue(csv.contains("id,name,email")) assertTrue(csv.contains("Alice")) - + + // Test Markdown table format val table = result.toTableString() - assertTrue(table.contains("id")) - assertTrue(table.contains("Alice")) - - println("Query result table:") + assertTrue(table.contains("| id | name | email |")) + assertTrue(table.contains("| --- | --- | --- |")) + assertTrue(table.contains("| 1 | Alice | alice@example.com |")) + assertTrue(table.contains("| 2 | Bob | bob@example.com |")) + assertTrue(table.contains("| 3 | Charlie | charlie@example.com |")) + + println("Query result table (Markdown format):") println(result.toTableString()) } diff --git a/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt new file mode 100644 index 0000000000..b8878e87d8 --- /dev/null +++ b/mpp-core/src/jvmTest/kotlin/cc/unitmesh/agent/subagent/JSqlParserValidatorTest.kt @@ -0,0 +1,189 @@ +package cc.unitmesh.agent.subagent + +import org.junit.Test +import kotlin.test.* + +/** + * Tests for SqlValidator - JVM-specific SQL validation using JSqlParser + */ +class JSqlParserValidatorTest { + + private val validator = SqlValidator() + + // ============= Basic Validation Tests ============= + + @Test + fun testValidSelectQuery() { + val result = validator.validate("SELECT * FROM users WHERE age > 18") + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + } + + @Test + fun testValidSelectWithJoin() { + val sql = """ + SELECT u.name, o.total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE o.total > 100 + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + assertTrue(result.errors.isEmpty()) + } + + @Test + fun testInvalidSyntax() { + val result = validator.validate("SELECT * FORM users") // Typo: FORM instead of FROM + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + @Test + fun testEmptySql() { + val result = validator.validate("") + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + @Test + fun testBlankSql() { + val result = validator.validate(" ") + + assertFalse(result.isValid) + assertTrue(result.errors.isNotEmpty()) + } + + // ============= Warning Tests ============= + + @Test + fun testSelectStarWarning() { + val result = validator.validate("SELECT * FROM users") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("SELECT *") }) + } + + @Test + fun testUpdateWithoutWhereWarning() { + val result = validator.validate("UPDATE users SET status = 'inactive'") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("WHERE") }) + } + + @Test + fun testDeleteWithoutWhereWarning() { + val result = validator.validate("DELETE FROM users") + + assertTrue(result.isValid) + assertTrue(result.warnings.any { it.contains("WHERE") }) + } + + @Test + fun testSafeUpdateNoWarning() { + val result = validator.validate("UPDATE users SET status = 'active' WHERE id = 123") + + assertTrue(result.isValid) + assertFalse(result.warnings.any { it.contains("UPDATE/DELETE without WHERE") }) + } + + @Test + fun testSafeDeleteNoWarning() { + val result = validator.validate("DELETE FROM users WHERE id = 123") + + assertTrue(result.isValid) + assertFalse(result.warnings.any { it.contains("UPDATE/DELETE without WHERE") }) + } + + // ============= Complex Query Tests ============= + + @Test + fun testComplexSelectWithSubquery() { + val sql = """ + SELECT * FROM users + WHERE id IN (SELECT user_id FROM orders WHERE total > 1000) + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + @Test + fun testSelectWithGroupByAndHaving() { + val sql = """ + SELECT u.name, COUNT(o.id) as order_count + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + GROUP BY u.name + HAVING COUNT(o.id) > 5 + ORDER BY order_count DESC + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + @Test + fun testInsertStatement() { + val sql = "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')" + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } + + // ============= ValidateAndParse Tests ============= + + @Test + fun testValidateAndParseValid() { + val pair = validator.validateAndParse("SELECT * FROM users") + val result = pair.first + val statement = pair.second + + assertTrue(result.isValid) + assertNotNull(statement) + } + + @Test + fun testValidateAndParseInvalid() { + val pair = validator.validateAndParse("SELECT * FORM users") + val result = pair.first + val statement = pair.second + + assertFalse(result.isValid) + assertNull(statement) + } + + // ============= Edge Cases ============= + + @Test + fun testMultipleStatements() { + // JSqlParser may handle this differently + val sql = "SELECT * FROM users; SELECT * FROM orders" + + val result = validator.validate(sql) + // Just verify it doesn't crash - behavior may vary + assertNotNull(result) + } + + @Test + fun testSqlWithComments() { + val sql = """ + -- This is a comment + SELECT * FROM users WHERE id = 1 + """.trimIndent() + + val result = validator.validate(sql) + + assertTrue(result.isValid) + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt new file mode 100644 index 0000000000..9afcbb4f31 --- /dev/null +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/chatdb/NlpTokenizer.wasmJs.kt @@ -0,0 +1,20 @@ +package cc.unitmesh.agent.chatdb + +/** + * WebAssembly implementation of NlpTokenizer. + * Uses the fallback regex-based tokenization since MyNLP is JVM-only. + */ +actual object NlpTokenizer { + /** + * Extract keywords from natural language query using simple tokenization. + * Supports both English and Chinese text. + * + * @param query The natural language query to tokenize + * @param stopWords Set of words to filter out from results + * @return List of extracted keywords + */ + actual fun extractKeywords(query: String, stopWords: Set): List { + return FallbackNlpTokenizer.extractKeywords(query, stopWords) + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt new file mode 100644 index 0000000000..8fd2c5830b --- /dev/null +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/agent/subagent/SqlValidator.wasmJs.kt @@ -0,0 +1,158 @@ +package cc.unitmesh.agent.subagent + +/** + * WASM JavaScript implementation of SqlValidator. + * + * Performs basic SQL syntax validation without full parsing. + * Full parsing with JSqlParser is only available on JVM platforms. + */ +actual class SqlValidator actual constructor() : SqlValidatorInterface { + + actual override fun validate(sql: String): SqlValidationResult { + if (sql.isBlank()) { + return SqlValidationResult( + isValid = false, + errors = listOf("Empty SQL query") + ) + } + + return performBasicValidation(sql) + } + + actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set): SqlValidationResult { + val basicResult = validate(sql) + if (!basicResult.isValid) { + return basicResult + } + + // Extract table names using regex and validate against whitelist + val usedTables = extractTableNames(sql) + val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet() + val invalidTables = usedTables.filter { tableName -> + tableName.lowercase() !in allowedTablesLower + } + + return if (invalidTables.isNotEmpty()) { + SqlValidationResult( + isValid = false, + errors = listOf( + "Invalid table(s) used: ${invalidTables.joinToString(", ")}. " + + "Available tables: ${allowedTables.joinToString(", ")}" + ), + warnings = basicResult.warnings + ) + } else { + basicResult + } + } + + actual override fun extractTableNames(sql: String): List { + val tables = mutableListOf() + + // Match FROM clause + val fromPattern = Regex("""FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + fromPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match JOIN clause + val joinPattern = Regex("""JOIN\s+(\w+)""", RegexOption.IGNORE_CASE) + joinPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match UPDATE clause + val updatePattern = Regex("""UPDATE\s+(\w+)""", RegexOption.IGNORE_CASE) + updatePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match INSERT INTO clause + val insertPattern = Regex("""INSERT\s+INTO\s+(\w+)""", RegexOption.IGNORE_CASE) + insertPattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + // Match DELETE FROM clause + val deletePattern = Regex("""DELETE\s+FROM\s+(\w+)""", RegexOption.IGNORE_CASE) + deletePattern.findAll(sql).forEach { match -> + tables.add(match.groupValues[1]) + } + + return tables.distinct() + } + + /** + * Detect SQL type using regex-based detection + */ + actual override fun detectSqlType(sql: String): SqlOperationType { + val trimmedSql = sql.trim().uppercase() + return when { + trimmedSql.startsWith("SELECT") || trimmedSql.startsWith("WITH") -> SqlOperationType.SELECT + trimmedSql.startsWith("INSERT") -> SqlOperationType.INSERT + trimmedSql.startsWith("UPDATE") -> SqlOperationType.UPDATE + trimmedSql.startsWith("DELETE") -> SqlOperationType.DELETE + trimmedSql.startsWith("CREATE") -> SqlOperationType.CREATE + trimmedSql.startsWith("ALTER") -> SqlOperationType.ALTER + trimmedSql.startsWith("DROP") -> SqlOperationType.DROP + trimmedSql.startsWith("TRUNCATE") -> SqlOperationType.TRUNCATE + else -> SqlOperationType.UNKNOWN + } + } + + private fun performBasicValidation(sql: String): SqlValidationResult { + val errors = mutableListOf() + val warnings = mutableListOf() + + val upperSql = sql.uppercase() + + // Check for basic SQL structure + val hasValidStart = upperSql.trimStart().let { trimmed -> + trimmed.startsWith("SELECT") || + trimmed.startsWith("INSERT") || + trimmed.startsWith("UPDATE") || + trimmed.startsWith("DELETE") || + trimmed.startsWith("CREATE") || + trimmed.startsWith("ALTER") || + trimmed.startsWith("DROP") || + trimmed.startsWith("WITH") + } + + if (!hasValidStart) { + errors.add("SQL must start with a valid statement (SELECT, INSERT, UPDATE, DELETE, etc.)") + } + + // Check for balanced parentheses + var parenCount = 0 + for (char in sql) { + when (char) { + '(' -> parenCount++ + ')' -> parenCount-- + } + if (parenCount < 0) { + errors.add("Unbalanced parentheses: unexpected ')'") + break + } + } + if (parenCount > 0) { + errors.add("Unbalanced parentheses: missing ')'") + } + + // Warnings + if (upperSql.contains("SELECT *")) { + warnings.add("Consider specifying explicit columns instead of SELECT *") + } + + if (!upperSql.contains("WHERE") && + (upperSql.contains("UPDATE") || upperSql.contains("DELETE"))) { + warnings.add("UPDATE/DELETE without WHERE clause will affect all rows") + } + + return SqlValidationResult( + isValid = errors.isEmpty(), + errors = errors, + warnings = warnings + ) + } +} + diff --git a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt index 2aea30b56d..49617f73b7 100644 --- a/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt +++ b/mpp-core/src/wasmJsMain/kotlin/cc/unitmesh/llm/ExecutorFactory.wasmJs.kt @@ -1,5 +1,6 @@ package cc.unitmesh.llm +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import cc.unitmesh.llm.provider.LLMClientProvider /** @@ -7,3 +8,12 @@ import cc.unitmesh.llm.provider.LLMClientProvider */ internal actual fun tryAutoRegisterGithubCopilot(): LLMClientProvider? = null +/** + * WASM implementation: Blocking executor creation not supported + * Use createAsync() instead for async initialization + */ +internal actual fun createExecutorBlocking( + provider: LLMClientProvider, + config: ModelConfig +): SingleLLMPromptExecutor? = null + diff --git a/mpp-idea/build.gradle.kts b/mpp-idea/build.gradle.kts index 7bd25f588a..38d584e2ab 100644 --- a/mpp-idea/build.gradle.kts +++ b/mpp-idea/build.gradle.kts @@ -172,6 +172,7 @@ configure(subprojects) { testImplementation("junit:junit:4.13.2") testImplementation("org.opentest4j:opentest4j:1.3.0") testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.9.3") + testImplementation("org.assertj:assertj-core:3.24.2") testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-debug:1.7.0") { exclude(group = "net.java.dev.jna", module = "jna-platform") exclude(group = "net.java.dev.jna", module = "jna") diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt index a78e7acb41..d5620f9154 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/sketch/run/ShellSafetyCheckTest.kt @@ -1,7 +1,9 @@ package cc.unitmesh.devti.sketch.run -import org.assertj.core.api.Assertions.assertThat import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.test.assertFalse class ShellSafetyCheckTest { @Test @@ -9,8 +11,8 @@ class ShellSafetyCheckTest { val command = "rm -rf /some/path" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of -rf flags - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -18,8 +20,8 @@ class ShellSafetyCheckTest { val command = "rm /some/file" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous due to generic rm command check - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -27,8 +29,8 @@ class ShellSafetyCheckTest { val command = "rm -i somefile.txt" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect safe-ish command as interactive flag is present but still rm is detected - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Remove command detected, use with caution") + assertTrue(result.first) + assertEquals("Remove command detected, use with caution", result.second) } @Test @@ -36,8 +38,8 @@ class ShellSafetyCheckTest { val command = "rmdir /" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous as it touches root directory - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Removing directories from root") + assertTrue(result.first) + assertEquals("Removing directories from root", result.second) } @Test @@ -45,8 +47,8 @@ class ShellSafetyCheckTest { val command = "mkfs /dev/sda1" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of filesystem formatting command - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Filesystem formatting command") + assertTrue(result.first) + assertEquals("Filesystem formatting command", result.second) } @Test @@ -54,8 +56,8 @@ class ShellSafetyCheckTest { val command = "dd if=/dev/zero of=/dev/sda1" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of low-level disk operation - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Low-level disk operation") + assertTrue(result.first) + assertEquals("Low-level disk operation", result.second) } @Test @@ -63,8 +65,8 @@ class ShellSafetyCheckTest { val command = ":(){ :|:& };:" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous because of potential fork bomb pattern - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Potential fork bomb") + assertTrue(result.first) + assertEquals("Potential fork bomb", result.second) } @Test @@ -72,8 +74,8 @@ class ShellSafetyCheckTest { val command = "chmod -R 777 /some/directory" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous as recursive chmod with insecure permissions is detected - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Recursive chmod with insecure permissions") + assertTrue(result.first) + assertEquals("Recursive chmod with insecure permissions", result.second) } @Test @@ -81,8 +83,8 @@ class ShellSafetyCheckTest { val command = "sudo rm -rf /some/path" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect dangerous due to sudo rm pattern - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Dangerous rm command with recursive or force flags") + assertTrue(result.first) + assertEquals("Dangerous rm command with recursive or force flags", result.second) } @Test @@ -90,47 +92,47 @@ class ShellSafetyCheckTest { val command = "ls -la" val result = ShellSafetyCheck.checkDangerousCommand(command) // Expect no dangerous patterns detected - assertThat(result.first).isFalse() - assertThat(result.second).isEmpty() + assertFalse(result.first) + assertTrue(result.second.isEmpty()) } @Test fun testDangerousCurlPipeToShell() { val command = "curl https://some-site.com/script.sh | bash" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Downloading and executing scripts directly") + assertTrue(result.first) + assertEquals("Downloading and executing scripts directly", result.second) } @Test fun testDangerousKillAllProcesses() { val command = "kill -9 -1" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Killing all user processes") + assertTrue(result.first) + assertEquals("Killing all user processes", result.second) } @Test fun testDangerousOverwriteSystemConfig() { val command = "echo 'something' > /etc/passwd" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Overwriting system configuration files") + assertTrue(result.first) + assertEquals("Overwriting system configuration files", result.second) } @Test fun testDangerousSystemUserDeletion() { val command = "userdel root" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Removing critical system users") + assertTrue(result.first) + assertEquals("Removing critical system users", result.second) } @Test fun testDangerousRecursiveChown() { val command = "chown -R nobody:nobody /var" val result = ShellSafetyCheck.checkDangerousCommand(command) - assertThat(result.first).isTrue() - assertThat(result.second).isEqualTo("Recursive ownership change") + assertTrue(result.first) + assertEquals("Recursive ownership change", result.second) } } diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt index c5621e50b0..3dfb62c5d8 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownTest.kt @@ -1,11 +1,11 @@ package cc.unitmesh.devti.util.parser -import org.assertj.core.api.Assertions.assertThat import org.intellij.markdown.MarkdownElementTypes import org.intellij.markdown.flavours.gfm.GFMFlavourDescriptor import org.intellij.markdown.parser.MarkdownParser import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue class MarkdownHelperTest { @Test @@ -17,7 +17,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.parseCodeFromString(markdown) // Then - assertThat(result).containsExactly(markdown) + assertEquals(listOf(markdown), result) } @Test @@ -40,7 +40,7 @@ class MarkdownHelperTest { // you can skip this part of the code. ``` """.trimIndent() - assertThat(result).isEqualTo(expected) + assertEquals(expected, result) } @Test @@ -52,7 +52,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.removeAllMarkdownCode(markdown) // Then - assertThat(result).isEqualTo(markdown) + assertEquals(markdown, result) } @Test @@ -67,7 +67,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.extractCodeFenceLanguage(codeFenceNode, markdown) // Then - assertThat(result).isEqualTo("kotlin") + assertEquals("kotlin", result) } @Test @@ -82,7 +82,7 @@ class MarkdownHelperTest { val result = MarkdownCodeHelper.extractCodeFenceLanguage(codeFenceNode, markdown) // Then - assertThat(result).isEmpty() + assertTrue(result.isEmpty()) } @Test diff --git a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt index ebe71f5b63..ca73503407 100644 --- a/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt +++ b/mpp-idea/mpp-idea-core/src/test/kotlin/cc/unitmesh/devti/util/parser/MarkdownToHtmlConverterTest.kt @@ -1,8 +1,8 @@ package cc.unitmesh.devti.util.parser -import org.assertj.core.api.Assertions.assertThat import kotlin.test.Ignore import kotlin.test.Test +import kotlin.test.assertEquals class MarkdownConverterTest { @@ -16,17 +16,17 @@ class MarkdownConverterTest { - ๅฐ†BlogPostๅฎžไฝ“ๅˆๅนถๅˆฐBlog่šๅˆๆ น๏ผŒๅปบ็ซ‹ๅฎŒๆ•ด็š„้ข†ๅŸŸๅฏน่ฑก - ๆทปๅŠ ้ข†ๅŸŸ่กŒไธบๆ–นๆณ•๏ผˆๅ‘ๅธƒใ€ๅฎกๆ ธใ€่ฏ„่ฎบ็ญ‰๏ผ‰ - ๅผ•ๅ…ฅๅ€ผๅฏน่ฑก๏ผˆBlogIdใ€Content็ญ‰๏ผ‰ - + 2. ๅˆ†ๅฑ‚็ป“ๆž„่ฐƒๆ•ด๏ผš - ๆธ…็†entityๅฑ‚ๅ†—ไฝ™ๅฏน่ฑก๏ผŒๅปบ็ซ‹ๆธ…ๆ™ฐ็š„domainๅฑ‚ - ๅฎž็Žฐ้ข†ๅŸŸๆœๅŠกไธŽๅŸบ็ก€่ฎพๆ–ฝๅฑ‚ๅˆ†็ฆป - ้‡ๆž„ๆ•ฐๆฎๆŒไน…ๅŒ–ๆŽฅๅฃ - + 3. ๆˆ˜ๆœฏๆจกๅผๅฎž็Žฐ๏ผš - ไฝฟ็”จๅทฅๅŽ‚ๆจกๅผๅค„็†ๅคๆ‚ๅฏน่ฑกๅˆ›ๅปบ - ๅฎž็Žฐไป“ๅ‚จๆŽฅๅฃไธŽ้ข†ๅŸŸๅฑ‚็š„ไพ่ต–ๅ€’็ฝฎ - ๆทปๅŠ ้ข†ๅŸŸไบ‹ไปถๆœบๅˆถ - + 4. ๆต‹่ฏ•ไฟ้šœ๏ผš - ้‡ๆž„ๅ•ๅ…ƒๆต‹่ฏ•๏ผŒ้ชŒ่ฏ้ข†ๅŸŸๆจกๅž‹่กŒไธบ - ๆทปๅŠ ่šๅˆๆ นไธๅ˜ๆ€ง็บฆๆŸๆต‹่ฏ• @@ -56,6 +56,6 @@ class MarkdownConverterTest { val resultHtml = convertMarkdownToHtml(markdownText) // Then - assertThat(resultHtml).isEqualTo(expectedHtml) + assertEquals(expectedHtml, resultHtml) } } diff --git a/mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt similarity index 53% rename from mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt rename to mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt index 328d719cfc..34ecbabacc 100644 --- a/mpp-idea/mpp-idea-core/src/main/kotlin/cc/unitmesh/devti/database/IdeaDatabaseConnection.kt +++ b/mpp-idea/mpp-idea-exts/ext-database/src/main/kotlin/cc/unitmesh/database/connection/IdeaDatabaseConnection.kt @@ -1,8 +1,6 @@ -package cc.unitmesh.devti.database +package cc.unitmesh.database.connection import cc.unitmesh.agent.database.* -import com.intellij.database.dataSource.DatabaseConnection -import com.intellij.database.dataSource.LocalDataSource import com.intellij.database.psi.DbDataSource import com.intellij.openapi.project.Project import kotlinx.coroutines.Dispatchers @@ -11,7 +9,7 @@ import java.sql.Connection /** * IDEA platform database connection implementation - * + * * Leverages IDEA's built-in Database tools and connection configuration. * Can directly use data sources configured in IDEA Database tool window. */ @@ -56,6 +54,125 @@ class IdeaDatabaseConnection( } } + override suspend fun executeUpdate(sql: String): UpdateResult = withContext(Dispatchers.IO) { + try { + val stmt = ideaConnection.prepareStatement(sql, java.sql.Statement.RETURN_GENERATED_KEYS) + val affectedRows = stmt.executeUpdate() + + // Try to get generated keys + val generatedKeys = mutableListOf() + try { + val keysRs = stmt.generatedKeys + while (keysRs.next()) { + generatedKeys.add(keysRs.getString(1)) + } + keysRs.close() + } catch (e: Exception) { + // Ignore if generated keys are not available + } + + stmt.close() + + UpdateResult.success(affectedRows, generatedKeys) + } catch (e: Exception) { + throw DatabaseException.queryFailed(sql, e.message ?: "Unknown error") + } + } + + /** + * Dry run SQL to validate without executing. + * Uses transaction rollback to test DML statements safely. + */ + override suspend fun dryRun(sql: String): DryRunResult = withContext(Dispatchers.IO) { + try { + val originalAutoCommit = ideaConnection.autoCommit + try { + // Disable auto-commit to use transaction + ideaConnection.autoCommit = false + + val sqlUpper = sql.trim().uppercase() + val warnings = mutableListOf() + + when { + // For SELECT, use EXPLAIN to validate + sqlUpper.startsWith("SELECT") -> { + val explainSql = "EXPLAIN $sql" + val stmt = ideaConnection.prepareStatement(explainSql) + try { + stmt.executeQuery() + stmt.close() + DryRunResult.valid("SELECT query is valid") + } catch (e: Exception) { + DryRunResult.invalid("Query validation failed: ${e.message}") + } + } + + // For INSERT/UPDATE/DELETE, execute in transaction and rollback + sqlUpper.startsWith("INSERT") || + sqlUpper.startsWith("UPDATE") || + sqlUpper.startsWith("DELETE") -> { + val stmt = ideaConnection.prepareStatement(sql) + try { + val affectedRows = stmt.executeUpdate() + stmt.close() + // Rollback to undo the changes + ideaConnection.rollback() + DryRunResult.valid( + "Statement is valid (would affect $affectedRows row(s))", + estimatedRows = affectedRows + ) + } catch (e: Exception) { + ideaConnection.rollback() + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + + // For DDL (CREATE, ALTER, DROP), we can't safely dry run + sqlUpper.startsWith("CREATE") || + sqlUpper.startsWith("ALTER") || + sqlUpper.startsWith("DROP") || + sqlUpper.startsWith("TRUNCATE") -> { + warnings.add("DDL statements cannot be fully validated without execution") + + // Try basic syntax check by preparing the statement + try { + val stmt = ideaConnection.prepareStatement(sql) + stmt.close() + DryRunResult( + isValid = true, + message = "DDL syntax appears valid (cannot fully validate without execution)", + warnings = warnings + ) + } catch (e: Exception) { + DryRunResult.invalid("DDL syntax error: ${e.message}") + } + } + + else -> { + // Unknown statement type, try to prepare it + try { + val stmt = ideaConnection.prepareStatement(sql) + stmt.close() + DryRunResult.valid("Statement syntax is valid") + } catch (e: Exception) { + DryRunResult.invalid("Statement validation failed: ${e.message}") + } + } + } + } finally { + // Restore original auto-commit setting + try { + ideaConnection.rollback() // Ensure any pending changes are rolled back + ideaConnection.autoCommit = originalAutoCommit + } catch (e: Exception) { + // Ignore cleanup errors + } + } + } catch (e: Exception) { + DryRunResult.invalid("Dry run failed: ${e.message}") + } + } + override suspend fun getSchema(): DatabaseSchema = withContext(Dispatchers.IO) { try { val metadata = ideaConnection.metaData @@ -148,9 +265,9 @@ class IdeaDatabaseConnection( companion object { /** * Create database connection from IDEA data source - * + * * @param project IDEA project - * @param dataSourceName Data source name (configured in IDEA Database) + * @param dataSource Data source (configured in IDEA Database) * @return Database connection */ fun createFromIdea(project: Project, dataSource: DbDataSource): IdeaDatabaseConnection { diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt index d7e0b5ba43..9a2357db57 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/header/IdeaAgentTabsHeader.kt @@ -256,6 +256,7 @@ private fun getAgentTypeColor(type: AgentType): Color = when (type) { AgentType.CODING -> IdeaAutoDevColors.Blue.c400 AgentType.CODE_REVIEW -> IdeaAutoDevColors.Indigo.c400 AgentType.KNOWLEDGE -> IdeaAutoDevColors.Green.c400 + AgentType.CHAT_DB -> IdeaAutoDevColors.Cyan.c400 AgentType.REMOTE -> IdeaAutoDevColors.Amber.c400 AgentType.LOCAL_CHAT -> JewelTheme.globalColors.text.normal } @@ -267,6 +268,7 @@ private fun getAgentTypeIcon(type: AgentType): ImageVector = when (type) { AgentType.CODING -> IdeaComposeIcons.Code AgentType.CODE_REVIEW -> IdeaComposeIcons.Review AgentType.KNOWLEDGE -> IdeaComposeIcons.Book + AgentType.CHAT_DB -> IdeaComposeIcons.Database AgentType.REMOTE -> IdeaComposeIcons.Cloud AgentType.LOCAL_CHAT -> IdeaComposeIcons.Chat } diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt index b3d7a44042..8da1d95a22 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/components/timeline/IdeaTimelineContent.kt @@ -97,6 +97,17 @@ fun IdeaTimelineItemView( // Agent-generated sketch block (chart, nanodsl, mermaid, etc.) IdeaAgentSketchBlockBubble(item, project = project) } + is TimelineItem.ChatDBStepItem -> { + // ChatDB execution step - display as info bubble + IdeaInfoBubble( + message = "${item.stepType.icon} ${item.stepType.displayName}: ${item.title}", + status = item.status.displayName + ) + } + is TimelineItem.InfoItem -> { + // Info message bubble + IdeaInfoBubble(message = item.message) + } } } @@ -255,3 +266,45 @@ fun IdeaEmptyStateMessage(text: String) { } } +/** + * Info bubble for displaying informational messages. + */ +@Composable +fun IdeaInfoBubble( + message: String, + status: String? = null +) { + Box( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 2.dp) + .background( + JewelTheme.globalColors.panelBackground.copy(alpha = 0.5f), + shape = RoundedCornerShape(6.dp) + ) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = message, + style = JewelTheme.defaultTextStyle.copy( + fontSize = 12.sp, + color = JewelTheme.globalColors.text.info + ) + ) + if (status != null) { + Text( + text = "[$status]", + style = JewelTheme.defaultTextStyle.copy( + fontSize = 10.sp, + color = JewelTheme.globalColors.text.info.copy(alpha = 0.6f) + ) + ) + } + } + } +} + diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt index bec59498ce..fe70e21922 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/renderer/JewelRenderer.kt @@ -1,8 +1,11 @@ package cc.unitmesh.devins.idea.renderer +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.devins.llm.MessageRole import cc.unitmesh.agent.render.RendererUtils @@ -11,6 +14,7 @@ import cc.unitmesh.agent.render.TaskStatus import cc.unitmesh.agent.render.TimelineItem import cc.unitmesh.agent.render.ToolCallDisplayInfo import cc.unitmesh.agent.render.ToolCallInfo +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolType import cc.unitmesh.agent.tool.toToolType import cc.unitmesh.devins.llm.Message @@ -88,6 +92,10 @@ class JewelRenderer : BaseRenderer() { private val _tasks = MutableStateFlow>(emptyList()) val tasks: StateFlow> = _tasks.asStateFlow() + // SQL approval state + private val _pendingSqlApproval = MutableStateFlow(null) + val pendingSqlApproval: StateFlow = _pendingSqlApproval.asStateFlow() + /** * Set the current plan directly. * Used to sync with PlanStateService from CodingAgent. @@ -536,6 +544,71 @@ class JewelRenderer : BaseRenderer() { ) } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: DryRunResult?, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + _pendingSqlApproval.value = SqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + dryRunResult = dryRunResult, + onApprove = { + _pendingSqlApproval.value = null + onApprove() + }, + onReject = { + _pendingSqlApproval.value = null + onReject() + } + ) + + // Build details map with dry run info + val details = mutableMapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + if (dryRunResult != null) { + details["dryRunValid"] = dryRunResult.isValid + if (dryRunResult.estimatedRows != null) { + details["estimatedRows"] = dryRunResult.estimatedRows!! + } + if (dryRunResult.warnings.isNotEmpty()) { + details["warnings"] = dryRunResult.warnings.joinToString(", ") + } + } + + // Also add to timeline for visibility + renderChatDBStep( + stepType = ChatDBStepType.AWAIT_APPROVAL, + status = ChatDBStepStatus.AWAITING_APPROVAL, + title = "Awaiting Approval: ${operationType.name}", + details = details + ) + } + + /** + * Approve the pending SQL operation + */ + fun approveSqlOperation() { + _pendingSqlApproval.value?.onApprove?.invoke() + } + + /** + * Reject the pending SQL operation + */ + fun rejectSqlOperation() { + _pendingSqlApproval.value?.onReject?.invoke() + } + override fun updateTokenInfo(tokenInfo: TokenInfo) { _lastMessageTokenInfo = tokenInfo _totalTokenInfo.update { current -> @@ -784,3 +857,16 @@ class JewelRenderer : BaseRenderer() { } } +/** + * Data class representing a pending SQL approval request + */ +data class SqlApprovalRequest( + val sql: String, + val operationType: SqlOperationType, + val affectedTables: List, + val isHighRisk: Boolean, + val dryRunResult: DryRunResult? = null, + val onApprove: () -> Unit, + val onReject: () -> Unit +) + diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt index 7d8517d2f0..062d0163db 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentApp.kt @@ -352,6 +352,13 @@ fun IdeaAgentApp( } ?: IdeaEmptyStateMessage("Loading Knowledge Agent...") } } + AgentType.CHAT_DB -> { + // ChatDB mode - Text2SQL agent for database queries + // TODO: Implement ChatDB UI when ready + Box(modifier = Modifier.fillMaxWidth().weight(1f)) { + IdeaEmptyStateMessage("ChatDB Agent coming soon...") + } + } } // Tool loading status bar diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt index 1b11f4149f..f4fd6e7107 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaAgentViewModel.kt @@ -243,14 +243,7 @@ class IdeaAgentViewModel( // Save to config file for persistence coroutineScope.launch { try { - val typeString = when (agentType) { - AgentType.REMOTE -> "Remote" - AgentType.LOCAL_CHAT -> "Local" - AgentType.CODING -> "Coding" - AgentType.CODE_REVIEW -> "CodeReview" - AgentType.KNOWLEDGE -> "Documents" - } - + val typeString = agentType.getDisplayName() AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { // Silently fail - not critical if we can't save preference diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt index 730f95be57..d19ad56cd0 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/IdeaComposeIcons.kt @@ -1738,5 +1738,60 @@ object IdeaComposeIcons { }.build() } + /** + * Database icon (cylinder shape representing database storage) + */ + val Database: ImageVector by lazy { + ImageVector.Builder( + name = "Database", + defaultWidth = 24.dp, + defaultHeight = 24.dp, + viewportWidth = 24f, + viewportHeight = 24f + ).apply { + path( + fill = SolidColor(Color.Black) + ) { + // Database cylinder icon (Storage icon) + moveTo(2f, 20f) + horizontalLineToRelative(20f) + verticalLineToRelative(-4f) + horizontalLineTo(2f) + verticalLineToRelative(4f) + close() + moveTo(4f, 17f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + horizontalLineTo(4f) + verticalLineToRelative(-2f) + close() + moveTo(2f, 4f) + verticalLineToRelative(4f) + horizontalLineToRelative(20f) + verticalLineTo(4f) + horizontalLineTo(2f) + close() + moveTo(6f, 7f) + horizontalLineTo(4f) + verticalLineTo(5f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + close() + moveTo(2f, 14f) + horizontalLineToRelative(20f) + verticalLineToRelative(-4f) + horizontalLineTo(2f) + verticalLineToRelative(4f) + close() + moveTo(4f, 11f) + horizontalLineToRelative(2f) + verticalLineToRelative(2f) + horizontalLineTo(4f) + verticalLineToRelative(-2f) + close() + } + }.build() + } + } diff --git a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt index 8f5fa30dc7..f280a8cb1a 100644 --- a/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt +++ b/mpp-idea/src/main/kotlin/cc/unitmesh/devins/idea/toolwindow/knowledge/IdeaKnowledgeContent.kt @@ -838,6 +838,65 @@ private fun ChatMessageItem(item: TimelineItem) { } } } + + is TimelineItem.ChatDBStepItem -> { + // ChatDB execution step + Box( + modifier = Modifier + .fillMaxWidth() + .background(JewelTheme.globalColors.panelBackground.copy(alpha = 0.5f)) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(6.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = item.stepType.icon, + style = JewelTheme.defaultTextStyle.copy(fontSize = 14.sp) + ) + Column { + Text( + text = item.stepType.displayName, + style = JewelTheme.defaultTextStyle.copy( + fontWeight = FontWeight.Bold, + fontSize = 12.sp + ) + ) + Text( + text = item.title, + style = JewelTheme.defaultTextStyle.copy(fontSize = 11.sp) + ) + } + } + } + } + + is TimelineItem.InfoItem -> { + // Info message + Box( + modifier = Modifier + .fillMaxWidth() + .background(IdeaAutoDevColors.Blue.c400.copy(alpha = 0.1f)) + .padding(8.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(6.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = IdeaComposeIcons.Info, + contentDescription = "Info", + modifier = Modifier.size(16.dp), + tint = IdeaAutoDevColors.Blue.c400 + ) + Text( + text = item.message, + style = JewelTheme.defaultTextStyle.copy(fontSize = 12.sp) + ) + } + } + } } } diff --git a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt index f4a4bf150c..07072df07e 100644 --- a/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt +++ b/mpp-server/src/main/kotlin/cc/unitmesh/server/render/ServerSideRenderer.kt @@ -67,6 +67,22 @@ class ServerSideRenderer : CodingAgentRenderer { eventChannel.trySend(AgentEvent.Error("User confirmation required for tool: $toolName")) } + override fun renderSqlApprovalRequest( + sql: String, + operationType: cc.unitmesh.agent.subagent.SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: cc.unitmesh.agent.database.DryRunResult?, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + // Server-side renderer auto-rejects for safety + // In a real implementation, this would send an event to the client for approval + val dryRunInfo = if (dryRunResult != null) " (dry run: ${if (dryRunResult.isValid) "passed" else "failed"})" else "" + eventChannel.trySend(AgentEvent.Error("SQL write operation requires approval: ${operationType.name} on ${affectedTables.joinToString(", ")}$dryRunInfo")) + onReject() + } + override fun renderAgentSketchBlock( agentName: String, language: String, diff --git a/mpp-ui/build.gradle.kts b/mpp-ui/build.gradle.kts index c311f0eee5..e44332fd82 100644 --- a/mpp-ui/build.gradle.kts +++ b/mpp-ui/build.gradle.kts @@ -624,6 +624,47 @@ tasks.register("runDomainDictCli") { standardInput = System.`in` } +// Task to run ChatDB CLI (Text2SQL Agent) +tasks.register("runChatDBCli") { + group = "application" + description = "Run ChatDB CLI (Text2SQL Agent with Schema Linking and Revise Agent)" + + val jvmCompilation = kotlin.jvm().compilations.getByName("main") + classpath(jvmCompilation.output, configurations["jvmRuntimeClasspath"]) + mainClass.set("cc.unitmesh.server.cli.ChatDBCli") + + // Pass database connection properties + if (project.hasProperty("dbHost")) { + systemProperty("dbHost", project.property("dbHost") as String) + } + if (project.hasProperty("dbPort")) { + systemProperty("dbPort", project.property("dbPort") as String) + } + if (project.hasProperty("dbName")) { + systemProperty("dbName", project.property("dbName") as String) + } + if (project.hasProperty("dbUser")) { + systemProperty("dbUser", project.property("dbUser") as String) + } + if (project.hasProperty("dbPassword")) { + systemProperty("dbPassword", project.property("dbPassword") as String) + } + if (project.hasProperty("dbDialect")) { + systemProperty("dbDialect", project.property("dbDialect") as String) + } + if (project.hasProperty("dbQuery")) { + systemProperty("dbQuery", project.property("dbQuery") as String) + } + if (project.hasProperty("generateVisualization")) { + systemProperty("generateVisualization", project.property("generateVisualization") as String) + } + if (project.hasProperty("maxRows")) { + systemProperty("maxRows", project.property("maxRows") as String) + } + + standardInput = System.`in` +} + // Ktlint configuration configure { version.set("1.0.1") diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt new file mode 100644 index 0000000000..de1d57c955 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt @@ -0,0 +1,53 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - Data access layer for database connection configurations + * Uses expect/actual pattern for cross-platform support + */ +expect class DataSourceRepository { + /** + * Get all data source configurations + */ + fun getAll(): List + + /** + * Get a data source by ID + */ + fun getById(id: String): DataSourceConfig? + + /** + * Get the default data source + */ + fun getDefault(): DataSourceConfig? + + /** + * Save a data source configuration (insert or update) + */ + fun save(config: DataSourceConfig) + + /** + * Delete a data source by ID + */ + fun delete(id: String) + + /** + * Delete all data sources + */ + fun deleteAll() + + /** + * Set a data source as default + */ + fun setDefault(id: String) + + companion object { + /** + * Get singleton instance + */ + fun getInstance(): DataSourceRepository + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt deleted file mode 100644 index 683e75ee30..0000000000 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/app/SessionApp.kt +++ /dev/null @@ -1,330 +0,0 @@ -package cc.unitmesh.devins.ui.app - -import androidx.compose.foundation.layout.* -import androidx.compose.material3.* -import androidx.compose.runtime.* -import androidx.compose.ui.Modifier -import cc.unitmesh.devins.ui.compose.theme.AutoDevTheme -import cc.unitmesh.devins.ui.compose.theme.ThemeManager -import cc.unitmesh.devins.ui.project.ProjectListScreen -import cc.unitmesh.devins.ui.project.ProjectViewModel -import cc.unitmesh.devins.ui.remote.RemoteAgentClient -import cc.unitmesh.devins.ui.session.* -import cc.unitmesh.devins.ui.task.TaskExecutionScreen -import cc.unitmesh.devins.ui.task.TaskListScreen -import cc.unitmesh.devins.ui.task.TaskViewModel -import kotlinx.coroutines.launch - -/** - * SessionApp - ไผš่ฏ็ฎก็†ๅบ”็”จ - * ๆ”ฏๆŒ Android ๅ’Œ Desktop ๅนณๅฐ - * - * ๅŠŸ่ƒฝ๏ผš - * 1. ็”จๆˆท็™ปๅฝ•/ๆณจๅ†Œ - * 2. ้กน็›ฎ็ฎก็†๏ผˆๅˆ›ๅปบใ€ๅˆ—่กจใ€Git ๆ”ฏๆŒ๏ผ‰ - * 3. ไปปๅŠก็ฎก็†๏ผˆๅŸบไบŽ้กน็›ฎๅˆ›ๅปบไปปๅŠก๏ผ‰ - * 4. AI Agent ๆ‰ง่กŒ๏ผˆ้›†ๆˆ RemoteAgentClient๏ผ‰ - */ -@Composable -fun SessionApp( - serverUrl: String = "http://localhost:8080", - useBottomNavigation: Boolean = false // Android ไฝฟ็”จๅบ•้ƒจๅฏผ่ˆช๏ผŒDesktop ไฝฟ็”จไพง่พนๅฏผ่ˆช -) { - val currentTheme = ThemeManager.currentTheme - - AutoDevTheme(themeMode = currentTheme) { - SessionAppContent( - serverUrl = serverUrl, - useBottomNavigation = useBottomNavigation - ) - } -} - -@Composable -private fun SessionAppContent( - serverUrl: String, - useBottomNavigation: Boolean -) { - // ๅˆๅง‹ๅŒ–ๅฎขๆˆท็ซฏๅ’Œ ViewModel - val sessionClient = remember { SessionClient(serverUrl) } - val remoteAgentClient = remember { RemoteAgentClient(serverUrl) } - - val projectClient = remember { - cc.unitmesh.devins.ui.project.ProjectClient( - serverUrl, - sessionClient.httpClient - ) - } - - val sessionViewModel = remember { SessionViewModel(sessionClient) } - val projectViewModel = remember { ProjectViewModel(projectClient) } - val taskViewModel = remember { TaskViewModel(sessionClient, remoteAgentClient) } - - val isAuthenticated by sessionViewModel.isAuthenticated.collectAsState() - val currentProject by projectViewModel.currentProject.collectAsState() - val currentTask by taskViewModel.currentTask.collectAsState() - - // ็›‘ๅฌ่ฎค่ฏ็Šถๆ€๏ผŒๅŒๆญฅ token - LaunchedEffect(isAuthenticated, sessionClient.authToken) { - if (isAuthenticated && sessionClient.authToken != null) { - projectClient.setAuthToken(sessionClient.authToken!!) - } - } - - // ๅฑๅน•็Šถๆ€็ฎก็† - var currentScreen by remember { mutableStateOf(AppScreen.LOGIN) } - var skipLogin by remember { mutableStateOf(false) } - - // ็›‘ๅฌ่ฎค่ฏ็Šถๆ€ - LaunchedEffect(isAuthenticated) { - if (isAuthenticated && currentScreen == AppScreen.LOGIN) { - currentScreen = AppScreen.PROJECTS - } else if (!isAuthenticated && !skipLogin) { - currentScreen = AppScreen.LOGIN - } - } - - // ็›‘ๅฌ้กน็›ฎ้€‰ๆ‹ฉ - LaunchedEffect(currentProject) { - currentProject?.let { - taskViewModel.setCurrentProject(it) - } - } - - DisposableEffect(Unit) { - onDispose { - sessionViewModel.onCleared() - projectViewModel.onCleared() - taskViewModel.onCleared() - } - } - - when { - !isAuthenticated && !skipLogin -> { - LoginScreen( - viewModel = sessionViewModel, - onLoginSuccess = { - currentScreen = AppScreen.PROJECTS - }, - onSkipLogin = if (useBottomNavigation) { - // Android ๅนณๅฐๅ…่ฎธ่ทณ่ฟ‡็™ปๅฝ• - { - skipLogin = true - currentScreen = AppScreen.PROJECTS - } - } else null - ) - } - currentTask != null -> { - // ไปปๅŠกๆ‰ง่กŒ็•Œ้ข๏ผˆๅ…จๅฑ๏ผ‰ - TaskExecutionScreen( - viewModel = taskViewModel, - onBack = { - currentScreen = AppScreen.TASKS - } - ) - } - useBottomNavigation -> { - // Android ๅบ•้ƒจๅฏผ่ˆช - AndroidNavigationLayout( - currentScreen = currentScreen, - onScreenChange = { currentScreen = it }, - sessionViewModel = sessionViewModel, - projectViewModel = projectViewModel, - taskViewModel = taskViewModel - ) - } - else -> { - // Desktop ไพง่พนๅฏผ่ˆช - DesktopNavigationLayout( - currentScreen = currentScreen, - onScreenChange = { currentScreen = it }, - sessionViewModel = sessionViewModel, - projectViewModel = projectViewModel, - taskViewModel = taskViewModel - ) - } - } -} - -/** - * Android ๅบ•้ƒจๅฏผ่ˆชๅธƒๅฑ€ - * ไฝฟ็”จๆ–ฐ็š„ NavLayout ็ป„ไปถ๏ผŒๆ”ฏๆŒ Drawer + BottomNavigation - */ -@Composable -private fun AndroidNavigationLayout( - currentScreen: AppScreen, - onScreenChange: (AppScreen) -> Unit, - sessionViewModel: SessionViewModel, - projectViewModel: ProjectViewModel, - taskViewModel: TaskViewModel -) { - val scope = rememberCoroutineScope() - - MobileNavLayout( - currentScreen = currentScreen, - onScreenChange = onScreenChange, - sessionViewModel = sessionViewModel - ) { paddingValues -> - Box(modifier = Modifier.padding(paddingValues)) { - when (currentScreen) { - AppScreen.PROJECTS -> { - ProjectListScreen( - viewModel = projectViewModel, - onProjectClick = { project -> - projectViewModel.selectProject(project) - onScreenChange(AppScreen.TASKS) - } - ) - } - AppScreen.TASKS -> { - TaskListScreen( - viewModel = taskViewModel, - onTaskClick = { task -> - // ไปปๅŠก็‚นๅ‡ปไผš่‡ชๅŠจๅˆ‡ๆขๅˆฐ TaskExecutionScreen - } - ) - } - AppScreen.SESSIONS -> { - SessionListScreen( - viewModel = sessionViewModel, - onSessionClick = { session -> - scope.launch { - sessionViewModel.joinSession(session.id) - } - }, - onCreateSession = { - // ๅˆ›ๅปบไผš่ฏ - }, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - AppScreen.PROFILE -> { - ProfileScreen( - viewModel = sessionViewModel, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - else -> {} - } - } - } -} - -/** - * Desktop ไพง่พนๅฏผ่ˆชๅธƒๅฑ€ - * ไฝฟ็”จๆ–ฐ็š„ NavLayout ็ป„ไปถ - */ -@OptIn(ExperimentalMaterial3Api::class) -@Composable -private fun DesktopNavigationLayout( - currentScreen: AppScreen, - onScreenChange: (AppScreen) -> Unit, - sessionViewModel: SessionViewModel, - projectViewModel: ProjectViewModel, - taskViewModel: TaskViewModel -) { - val scope = rememberCoroutineScope() - - DesktopNavLayout( - currentScreen = currentScreen, - onScreenChange = onScreenChange, - sessionViewModel = sessionViewModel - ) { - when (currentScreen) { - AppScreen.PROJECTS -> { - ProjectListScreen( - viewModel = projectViewModel, - onProjectClick = { project -> - projectViewModel.selectProject(project) - onScreenChange(AppScreen.TASKS) - } - ) - } - AppScreen.TASKS -> { - TaskListScreen( - viewModel = taskViewModel, - onTaskClick = { task -> - // ไปปๅŠก็‚นๅ‡ปไผš่‡ชๅŠจๅˆ‡ๆขๅˆฐ TaskExecutionScreen - } - ) - } - AppScreen.SESSIONS -> { - SessionListScreen( - viewModel = sessionViewModel, - onSessionClick = { session -> - scope.launch { - sessionViewModel.joinSession(session.id) - } - }, - onCreateSession = { - // ๅˆ›ๅปบไผš่ฏ - }, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - AppScreen.PROFILE -> { - ProfileScreen( - viewModel = sessionViewModel, - onLogout = { - scope.launch { - sessionViewModel.logout() - } - } - ) - } - else -> {} - } - } -} - -/** - * SessionAppContext - Session ๅบ”็”จไธŠไธ‹ๆ–‡ - * - * ็”จไบŽ็ฎก็† Session ๅบ”็”จ็š„ๅ…จๅฑ€็Šถๆ€ๅ’Œ้…็ฝฎ - * ๅ‚่€ƒ SessionAppContext ็š„่ฎพ่ฎก๏ผŒๆไพ›็ปŸไธ€็š„ไธŠไธ‹ๆ–‡็ฎก็† - */ -data class SessionAppContext( - val serverUrl: String, - val useBottomNavigation: Boolean, - val sessionViewModel: SessionViewModel, - val projectViewModel: ProjectViewModel, - val taskViewModel: TaskViewModel, - val currentScreen: AppScreen = AppScreen.LOGIN, - val skipLogin: Boolean = false, - val isAuthenticated: Boolean = false -) { - companion object { - /** - * ๅˆ›ๅปบ้ป˜่ฎค็š„ SessionAppContext - */ - fun create( - serverUrl: String, - useBottomNavigation: Boolean, - sessionClient: SessionClient, - projectClient: cc.unitmesh.devins.ui.project.ProjectClient, - remoteAgentClient: cc.unitmesh.devins.ui.remote.RemoteAgentClient - ): SessionAppContext { - return SessionAppContext( - serverUrl = serverUrl, - useBottomNavigation = useBottomNavigation, - sessionViewModel = SessionViewModel(sessionClient), - projectViewModel = ProjectViewModel(projectClient), - taskViewModel = TaskViewModel(sessionClient, remoteAgentClient) - ) - } - } -} - diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt index 420dea03f7..2bb4783a31 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/AutoDevApp.kt @@ -172,11 +172,7 @@ private fun AutoDevContent( scope.launch { try { // Save as string for config compatibility - val typeString = when (type) { - AgentType.REMOTE -> "Remote" - AgentType.LOCAL_CHAT -> "Local" - else -> "Local" - } + val typeString = type.getDisplayName() AutoDevConfigWrapper.saveAgentTypePreference(typeString) } catch (e: Exception) { println("โš ๏ธ ไฟๅญ˜ Agent ็ฑปๅž‹ๅคฑ่ดฅ: ${e.message}") @@ -449,37 +445,47 @@ private fun AutoDevContent( containerColor = MaterialTheme.colorScheme.background, contentWindowInsets = WindowInsets.systemBars.only(WindowInsetsSides.Horizontal) ) { paddingValues -> + // Determine if SessionSidebar should be shown based on agent type + // Hide for pages that have their own navigation (CHAT_DB, KNOWLEDGE, CODE_REVIEW) + val shouldShowSessionSidebar = selectedAgentType in listOf( + AgentType.CODING, + AgentType.LOCAL_CHAT, + AgentType.REMOTE + ) + Row( modifier = Modifier .fillMaxSize() .padding(paddingValues) ) { - SessionSidebar( - chatHistoryManager = chatHistoryManager, - currentSessionId = chatHistoryManager.getCurrentSession().id, - isExpanded = showSessionSidebar, - onSessionSelected = { sessionId -> - if (agentSessionSelectedHandler != null) { - agentSessionSelectedHandler?.invoke(sessionId) - } else { - chatHistoryManager.switchSession(sessionId) - messages = chatHistoryManager.getMessages() - currentStreamingOutput = "" - } - }, - onNewChat = { - if (agentNewChatHandler != null) { - agentNewChatHandler?.invoke() - } else { - chatHistoryManager.createSession() - messages = emptyList() - currentStreamingOutput = "" + if (shouldShowSessionSidebar) { + SessionSidebar( + chatHistoryManager = chatHistoryManager, + currentSessionId = chatHistoryManager.getCurrentSession().id, + isExpanded = showSessionSidebar, + onSessionSelected = { sessionId -> + if (agentSessionSelectedHandler != null) { + agentSessionSelectedHandler?.invoke(sessionId) + } else { + chatHistoryManager.switchSession(sessionId) + messages = chatHistoryManager.getMessages() + currentStreamingOutput = "" + } + }, + onNewChat = { + if (agentNewChatHandler != null) { + agentNewChatHandler?.invoke() + } else { + chatHistoryManager.createSession() + messages = emptyList() + currentStreamingOutput = "" + } + }, + onRenameSession = { sessionId, newTitle -> + chatHistoryManager.renameSession(sessionId, newTitle) } - }, - onRenameSession = { sessionId, newTitle -> - chatHistoryManager.renameSession(sessionId, newTitle) - } - ) + ) + } Column( modifier = Modifier diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt index 88d2b2d900..23097e3b23 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentChatInterface.kt @@ -198,8 +198,9 @@ fun AgentChatInterface( } AgentType.CODE_REVIEW, - AgentType.KNOWLEDGE -> { - // CODE_REVIEW and DOCUMENT_READER have their own full-page interfaces + AgentType.KNOWLEDGE, + AgentType.CHAT_DB -> { + // CODE_REVIEW, DOCUMENT_READER and CHAT_DB have their own full-page interfaces // They should not reach here - handled by AgentInterfaceRouter } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt index 97f13119c9..cdd5750306 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentInterfaceRouter.kt @@ -1,10 +1,13 @@ package cc.unitmesh.devins.ui.compose.agent +import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier import cc.unitmesh.agent.AgentType +import cc.unitmesh.devins.ui.compose.agent.chatdb.ChatDBPage import cc.unitmesh.devins.ui.compose.agent.codereview.CodeReviewPage import cc.unitmesh.devins.ui.remote.RemoteAgentChatInterface +import cc.unitmesh.devins.workspace.Workspace import cc.unitmesh.llm.KoogLLMService /** @@ -53,9 +56,22 @@ fun AgentInterfaceRouter( onProjectChange: (String) -> Unit = {}, onGitUrlChange: (String) -> Unit = {}, onNotification: (String, String) -> Unit = { _, _ -> }, + workspace: Workspace? = null, modifier: Modifier = Modifier ) { when (selectedAgentType) { + AgentType.CHAT_DB -> { + ChatDBPage( + workspace = workspace, + llmService = llmService, + modifier = modifier, + onBack = { + onAgentTypeChange(AgentType.CODING) + }, + onNotification = onNotification + ) + } + AgentType.KNOWLEDGE -> { cc.unitmesh.devins.ui.compose.document.DocumentReaderPage( modifier = modifier, @@ -113,6 +129,7 @@ fun AgentInterfaceRouter( modifier = modifier ) } + AgentType.LOCAL_CHAT, AgentType.CODING -> { AgentChatInterface( diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt index b157323522..a28503bfe7 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/AgentMessageList.kt @@ -236,6 +236,51 @@ fun RenderMessageItem( metadata = timelineItem.metadata ) } + + is TimelineItem.ChatDBStepItem -> { + cc.unitmesh.devins.ui.compose.agent.chatdb.components.ChatDBStepCard( + step = timelineItem, + onApprove = { renderer.approveSqlOperation() }, + onReject = { renderer.rejectSqlOperation() } + ) + } + + is TimelineItem.InfoItem -> { + InfoMessageItem(message = timelineItem.message) + } + } +} + +/** + * Simple info message item for displaying informational messages + */ +@Composable +private fun InfoMessageItem(message: String) { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp, vertical = 4.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.Info, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + Text( + text = message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } } } diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt index ea298f844b..1ebae22c18 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/ComposeRenderer.kt @@ -1,15 +1,19 @@ package cc.unitmesh.devins.ui.compose.agent import androidx.compose.runtime.* +import cc.unitmesh.agent.database.DryRunResult import cc.unitmesh.agent.plan.AgentPlan import cc.unitmesh.agent.plan.MarkdownPlanParser import cc.unitmesh.agent.render.BaseRenderer +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType import cc.unitmesh.agent.render.RendererUtils import cc.unitmesh.agent.render.TaskInfo import cc.unitmesh.agent.render.TaskStatus import cc.unitmesh.agent.render.TimelineItem import cc.unitmesh.agent.render.TimelineItem.* import cc.unitmesh.agent.render.ToolCallInfo +import cc.unitmesh.agent.subagent.SqlOperationType import cc.unitmesh.agent.tool.ToolType import cc.unitmesh.agent.tool.impl.docql.DocQLSearchStats import cc.unitmesh.agent.tool.toToolType @@ -79,6 +83,10 @@ class ComposeRenderer : BaseRenderer() { private var _currentPlan by mutableStateOf(null) val currentPlan: AgentPlan? get() = _currentPlan + // SQL approval state + private var _pendingSqlApproval by mutableStateOf(null) + val pendingSqlApproval: SqlApprovalRequest? get() = _pendingSqlApproval + // BaseRenderer implementation override fun renderIterationHeader( @@ -525,6 +533,10 @@ class ComposeRenderer : BaseRenderer() { _isProcessing = false } + override fun renderInfo(message: String) { + _timeline.add(TimelineItem.InfoItem(message = message)) + } + override fun renderRepeatWarning( toolName: String, count: Int @@ -559,6 +571,71 @@ class ComposeRenderer : BaseRenderer() { // For now, just use error rendering since JS renderer doesn't have this method yet } + override fun renderSqlApprovalRequest( + sql: String, + operationType: SqlOperationType, + affectedTables: List, + isHighRisk: Boolean, + dryRunResult: DryRunResult?, + onApprove: () -> Unit, + onReject: () -> Unit + ) { + _pendingSqlApproval = SqlApprovalRequest( + sql = sql, + operationType = operationType, + affectedTables = affectedTables, + isHighRisk = isHighRisk, + dryRunResult = dryRunResult, + onApprove = { + _pendingSqlApproval = null + onApprove() + }, + onReject = { + _pendingSqlApproval = null + onReject() + } + ) + + // Build details map with dry run info + val details = mutableMapOf( + "sql" to sql, + "operationType" to operationType.name, + "affectedTables" to affectedTables.joinToString(", "), + "isHighRisk" to isHighRisk + ) + if (dryRunResult != null) { + details["dryRunValid"] = dryRunResult.isValid + if (dryRunResult.estimatedRows != null) { + details["estimatedRows"] = dryRunResult.estimatedRows!! + } + if (dryRunResult.warnings.isNotEmpty()) { + details["warnings"] = dryRunResult.warnings.joinToString(", ") + } + } + + // Also add to timeline for visibility + renderChatDBStep( + stepType = ChatDBStepType.AWAIT_APPROVAL, + status = ChatDBStepStatus.AWAITING_APPROVAL, + title = "Awaiting Approval: ${operationType.name}", + details = details + ) + } + + /** + * Approve the pending SQL operation + */ + fun approveSqlOperation() { + _pendingSqlApproval?.onApprove?.invoke() + } + + /** + * Reject the pending SQL operation + */ + fun rejectSqlOperation() { + _pendingSqlApproval?.onReject?.invoke() + } + // Public methods for UI interaction fun addUserMessage(content: String) { _timeline.add( @@ -782,6 +859,39 @@ class ComposeRenderer : BaseRenderer() { ) } + /** + * Render a ChatDB execution step. + * Adds or updates a step in the timeline for interactive display. + */ + override fun renderChatDBStep( + stepType: ChatDBStepType, + status: ChatDBStepStatus, + title: String, + details: Map, + error: String? + ) { + // Check if this step already exists in the timeline + val existingIndex = _timeline.indexOfLast { + it is ChatDBStepItem && it.stepType == stepType + } + + val stepItem = ChatDBStepItem( + stepType = stepType, + status = status, + title = title, + details = details, + error = error + ) + + if (existingIndex >= 0) { + // Update existing step + _timeline[existingIndex] = stepItem + } else { + // Add new step + _timeline.add(stepItem) + } + } + /** * Render an Agent-generated sketch block (chart, nanodsl, mermaid, etc.) * Adds the sketch block to the timeline for interactive rendering. @@ -884,6 +994,16 @@ class ComposeRenderer : BaseRenderer() { sketchCode = item.code ) } + + is ChatDBStepItem -> { + // ChatDB steps are not persisted (they're runtime-only for UI display) + null + } + + is TimelineItem.InfoItem -> { + // Info items are not persisted (they're runtime-only for UI display) + null + } } } @@ -1109,8 +1229,31 @@ class ComposeRenderer : BaseRenderer() { metadata = toMessageMetadata(item) ) } + + is ChatDBStepItem -> { + // ChatDB steps are not persisted as messages + null + } + + is TimelineItem.InfoItem -> { + // Info items are not persisted as messages + null + } } } } } +/** + * Data class representing a pending SQL approval request + */ +data class SqlApprovalRequest( + val sql: String, + val operationType: SqlOperationType, + val affectedTables: List, + val isHighRisk: Boolean, + val dryRunResult: DryRunResult? = null, + val onApprove: () -> Unit, + val onReject: () -> Unit +) + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt new file mode 100644 index 0000000000..aa30e71ef6 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBPage.kt @@ -0,0 +1,107 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb + +import androidx.compose.foundation.layout.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Modifier +import cc.unitmesh.devins.ui.base.ResizableSplitPane +import cc.unitmesh.devins.ui.compose.agent.chatdb.components.* +import cc.unitmesh.devins.workspace.Workspace +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.flow.collectLatest + +/** + * ChatDB Page - Main page for text-to-SQL agent + * + * Left side: Data source management panel (resizable) + * Right side: Config pane (when adding/editing) or Chat area for natural language to SQL queries + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun ChatDBPage( + workspace: Workspace? = null, + llmService: KoogLLMService?, + modifier: Modifier = Modifier, + onBack: () -> Unit, + onNotification: (String, String) -> Unit = { _, _ -> } +) { + val viewModel = remember { ChatDBViewModel(workspace) } + val state = viewModel.state + + // Collect notifications + LaunchedEffect(viewModel) { + viewModel.notificationEvent.collectLatest { (title, message) -> + onNotification(title, message) + } + } + + // Cleanup on dispose + DisposableEffect(viewModel) { + onDispose { + viewModel.dispose() + } + } + + Scaffold(modifier = modifier) { paddingValues -> + ResizableSplitPane( + modifier = Modifier + .fillMaxSize() + .padding(paddingValues), + initialSplitRatio = 0.22f, + minRatio = 0.15f, + maxRatio = 0.4f, + saveKey = "chatdb_split_ratio", + first = { + // Left panel - Data source management (multi-selection mode) + DataSourcePanel( + dataSources = state.filteredDataSources, + selectedDataSourceIds = state.selectedDataSourceIds, + connectionStatuses = state.connectionStatuses, + filterQuery = state.filterQuery, + onFilterChange = viewModel::setFilterQuery, + onToggleDataSource = { id -> + viewModel.toggleDataSource(id) + // When toggling a data source, optionally show its config in the pane + val selected = state.dataSources.find { it.id == id } + if (selected != null && state.isConfigPaneOpen) { + viewModel.openConfigPane(selected) + } + }, + onAddClick = { viewModel.openConfigPane(null) }, + onEditClick = { config -> viewModel.openConfigPane(config) }, + onDeleteClick = viewModel::deleteDataSource, + onConnectClick = viewModel::connectDataSource, + onDisconnectClick = viewModel::disconnectDataSource, + onConnectAllClick = viewModel::connectAll, + onDisconnectAllClick = viewModel::disconnectAll, + modifier = Modifier.fillMaxSize() + ) + }, + second = { + // Right panel - Config pane or Chat area + if (state.isConfigPaneOpen) { + DataSourceConfigPane( + existingConfig = state.configuringDataSource, + onCancel = viewModel::closeConfigPane, + onSave = viewModel::saveFromPane, + onSaveAndConnect = viewModel::saveAndConnectFromPane, + modifier = Modifier.fillMaxSize() + ) + } else { + ChatDBChatPane( + renderer = viewModel.renderer, + connectionStatus = state.connectionStatus, + connectedCount = state.connectedCount, + selectedCount = state.selectedCount, + schema = viewModel.getSchema(), + isGenerating = viewModel.isGenerating, + onSendMessage = viewModel::sendMessage, + onStopGeneration = viewModel::stopGeneration, + onNewSession = viewModel::newSession, + modifier = Modifier.fillMaxSize() + ) + } + } + ) + } +} diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt new file mode 100644 index 0000000000..248303275c --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/ChatDBViewModel.kt @@ -0,0 +1,513 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb + +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import cc.unitmesh.agent.chatdb.ChatDBAgent +import cc.unitmesh.agent.chatdb.ChatDBTask +import cc.unitmesh.agent.chatdb.MultiDatabaseChatDBAgent +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.config.ToolConfigFile +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.database.DatabaseConnection +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.agent.database.createDatabaseConnection +import cc.unitmesh.config.ConfigManager +import cc.unitmesh.devins.db.DataSourceRepository +import cc.unitmesh.devins.ui.compose.agent.ComposeRenderer +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.workspace.Workspace +import cc.unitmesh.llm.KoogLLMService +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.asSharedFlow +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * ViewModel for ChatDB Page + * + * Manages data sources, database connections, and chat interactions for text-to-SQL. + */ +class ChatDBViewModel( + private val workspace: Workspace? = null +) { + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + + val renderer = ComposeRenderer() + + // LLM Service + private var llmService: KoogLLMService? = null + private var currentExecutionJob: Job? = null + + // Database connections (multi-datasource support) + private val connections: MutableMap = mutableMapOf() + private val schemas: MutableMap = mutableMapOf() + + // Data source repository for persistence + private val dataSourceRepository: DataSourceRepository by lazy { + DataSourceRepository.getInstance() + } + + // UI State + var state by mutableStateOf(ChatDBState()) + private set + + var isGenerating by mutableStateOf(false) + private set + + // Notifications + private val _notificationEvent = MutableSharedFlow>() + val notificationEvent = _notificationEvent.asSharedFlow() + + init { + initializeLLMService() + loadDataSources() + } + + private fun initializeLLMService() { + scope.launch { + try { + val configWrapper = ConfigManager.load() + val modelConfig = configWrapper.getActiveModelConfig() + if (modelConfig != null && modelConfig.isValid()) { + llmService = KoogLLMService.create(modelConfig) + } + } catch (e: Exception) { + println("[ChatDB] Failed to initialize LLM service: ${e.message}") + } + } + } + + private fun loadDataSources() { + scope.launch { + try { + val dataSources = dataSourceRepository.getAll() + // Multi-datasource: select all data sources by default + val allIds = dataSources.map { it.id }.toSet() + state = state.copy( + dataSources = dataSources, + selectedDataSourceIds = allIds + ) + println("[ChatDB] Loaded ${dataSources.size} data sources, all selected by default") + } catch (e: Exception) { + println("[ChatDB] Failed to load data sources: ${e.message}") + state = state.copy(dataSources = emptyList()) + } + } + } + + @OptIn(ExperimentalUuidApi::class) + fun addDataSource(config: DataSourceConfig) { + val newConfig = config.copy( + id = if (config.id.isBlank()) Uuid.random().toString() else config.id, + createdAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds(), + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + state = state.copy( + dataSources = state.dataSources + newConfig, + isConfigDialogOpen = false, + editingDataSource = null + ) + saveDataSource(newConfig) + } + + fun updateDataSource(config: DataSourceConfig) { + val updated = config.copy( + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + state = state.copy( + dataSources = state.dataSources.map { if (it.id == config.id) updated else it }, + isConfigDialogOpen = false, + editingDataSource = null + ) + saveDataSource(updated) + } + + fun deleteDataSource(id: String) { + // Disconnect this specific data source if connected + disconnectDataSource(id) + + state = state.copy( + dataSources = state.dataSources.filter { it.id != id }, + selectedDataSourceIds = state.selectedDataSourceIds - id, + connectionStatuses = state.connectionStatuses - id + ) + deleteDataSourceFromRepository(id) + } + + private fun saveDataSources() { + scope.launch { + try { + // Save all data sources to repository + state.dataSources.forEach { config -> + dataSourceRepository.save(config) + } + println("[ChatDB] Saved ${state.dataSources.size} data sources") + } catch (e: Exception) { + println("[ChatDB] Failed to save data sources: ${e.message}") + } + } + } + + private fun saveDataSource(config: DataSourceConfig) { + scope.launch { + try { + dataSourceRepository.save(config) + println("[ChatDB] Saved data source: ${config.id}") + } catch (e: Exception) { + println("[ChatDB] Failed to save data source: ${e.message}") + } + } + } + + private fun deleteDataSourceFromRepository(id: String) { + scope.launch { + try { + dataSourceRepository.delete(id) + println("[ChatDB] Deleted data source: $id") + } catch (e: Exception) { + println("[ChatDB] Failed to delete data source: ${e.message}") + } + } + } + + /** + * Toggle selection of a data source (multi-selection mode) + */ + fun toggleDataSource(id: String) { + val newSelectedIds = if (id in state.selectedDataSourceIds) { + // Deselect: disconnect if connected + disconnectDataSource(id) + state.selectedDataSourceIds - id + } else { + // Select: add to selection + state.selectedDataSourceIds + id + } + state = state.copy(selectedDataSourceIds = newSelectedIds) + } + + /** + * Select a data source (for backward compatibility, also used for single-click selection) + */ + fun selectDataSource(id: String) { + if (id !in state.selectedDataSourceIds) { + state = state.copy(selectedDataSourceIds = state.selectedDataSourceIds + id) + } + } + + /** + * Select all data sources + */ + fun selectAllDataSources() { + val allIds = state.dataSources.map { it.id }.toSet() + state = state.copy(selectedDataSourceIds = allIds) + } + + /** + * Deselect all data sources + */ + fun deselectAllDataSources() { + disconnectAll() + state = state.copy(selectedDataSourceIds = emptySet()) + } + + fun setFilterQuery(query: String) { + state = state.copy(filterQuery = query) + } + + fun openAddDialog() { + state = state.copy(isConfigDialogOpen = true, editingDataSource = null) + } + + fun openEditDialog(config: DataSourceConfig) { + state = state.copy(isConfigDialogOpen = true, editingDataSource = config) + } + + fun closeConfigDialog() { + state = state.copy(isConfigDialogOpen = false, editingDataSource = null) + } + + // --- Config Pane methods (inline panel mode) --- + + fun openConfigPane(config: DataSourceConfig? = null) { + state = state.copy(isConfigPaneOpen = true, configuringDataSource = config) + } + + fun closeConfigPane() { + state = state.copy(isConfigPaneOpen = false, configuringDataSource = null) + } + + @OptIn(ExperimentalUuidApi::class) + fun saveFromPane(config: DataSourceConfig) { + val isNew = config.id.isBlank() + val savedConfig = config.copy( + id = if (isNew) Uuid.random().toString() else config.id, + createdAt = if (isNew) kotlinx.datetime.Clock.System.now().toEpochMilliseconds() else state.configuringDataSource?.createdAt ?: 0L, + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + + state = if (isNew) { + state.copy( + dataSources = state.dataSources + savedConfig, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id + ) + } else { + state.copy( + dataSources = state.dataSources.map { if (it.id == savedConfig.id) savedConfig else it }, + isConfigPaneOpen = false, + configuringDataSource = null + ) + } + saveDataSource(savedConfig) + } + + @OptIn(ExperimentalUuidApi::class) + fun saveAndConnectFromPane(config: DataSourceConfig) { + val isNew = config.id.isBlank() + val savedConfig = config.copy( + id = if (isNew) Uuid.random().toString() else config.id, + createdAt = if (isNew) kotlinx.datetime.Clock.System.now().toEpochMilliseconds() else state.configuringDataSource?.createdAt ?: 0L, + updatedAt = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + ) + + state = if (isNew) { + state.copy( + dataSources = state.dataSources + savedConfig, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id + ) + } else { + state.copy( + dataSources = state.dataSources.map { if (it.id == savedConfig.id) savedConfig else it }, + isConfigPaneOpen = false, + configuringDataSource = null, + selectedDataSourceIds = state.selectedDataSourceIds + savedConfig.id + ) + } + saveDataSource(savedConfig) + + // Trigger connection for this specific data source + connectDataSource(savedConfig.id) + } + + /** + * Connect all selected data sources + */ + fun connectAll() { + state.selectedDataSources.forEach { dataSource -> + connectDataSource(dataSource.id) + } + } + + /** + * Connect a specific data source + */ + fun connectDataSource(id: String) { + val dataSource = state.dataSources.find { it.id == id } ?: return + + scope.launch { + // Update status to connecting + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Connecting), + connectionStatus = ConnectionStatus.Connecting + ) + + try { + val connection = createDatabaseConnection(dataSource.toDatabaseConfig()) + if (connection.isConnected()) { + connections[id] = connection + schemas[id] = connection.getSchema() + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Connected), + connectionStatus = if (state.hasAnyConnection || true) ConnectionStatus.Connected else state.connectionStatus + ) + _notificationEvent.emit("Connected" to "Successfully connected to ${dataSource.name}") + } else { + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Error("Failed to connect")) + ) + } + } catch (e: Exception) { + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Error(e.message ?: "Unknown error")) + ) + _notificationEvent.emit("Connection Failed" to "${dataSource.name}: ${e.message ?: "Unknown error"}") + } + } + } + + /** + * Disconnect a specific data source + */ + fun disconnectDataSource(id: String) { + scope.launch { + try { + connections[id]?.close() + } catch (e: Exception) { + println("[ChatDB] Error closing connection for $id: ${e.message}") + } + connections.remove(id) + schemas.remove(id) + state = state.copy( + connectionStatuses = state.connectionStatuses + (id to ConnectionStatus.Disconnected), + connectionStatus = if (connections.isEmpty()) ConnectionStatus.Disconnected else ConnectionStatus.Connected + ) + } + } + + /** + * Disconnect all data sources + */ + fun disconnectAll() { + scope.launch { + connections.forEach { (id, connection) -> + try { + connection.close() + } catch (e: Exception) { + println("[ChatDB] Error closing connection for $id: ${e.message}") + } + } + connections.clear() + schemas.clear() + state = state.copy( + connectionStatuses = emptyMap(), + connectionStatus = ConnectionStatus.Disconnected + ) + } + } + + /** + * Legacy connect method - connects all selected data sources + */ + fun connect() { + connectAll() + } + + /** + * Legacy disconnect method - disconnects all data sources + */ + fun disconnect() { + disconnectAll() + } + + fun sendMessage(text: String) { + if (isGenerating || text.isBlank()) return + + currentExecutionJob = scope.launch { + isGenerating = true + renderer.addUserMessage(text) + + try { + val service = llmService + if (service == null) { + renderer.renderError("LLM service not initialized. Please configure your model settings.") + isGenerating = false + return@launch + } + + // Get connected data sources + val connectedDataSources = state.selectedDataSources.filter { ds -> + state.getConnectionStatus(ds.id) is ConnectionStatus.Connected + } + + if (connectedDataSources.isEmpty()) { + renderer.renderError("No database connected. Please connect to at least one data source first.") + isGenerating = false + return@launch + } + + val projectPath = workspace?.rootPath ?: "." + val mcpConfigService = McpToolConfigService(ToolConfigFile()) + + // Build database configs map for multi-database agent + val databaseConfigs: Map = connectedDataSources.associate { ds -> + ds.id to ds.toDatabaseConfig() + } + + // Use MultiDatabaseChatDBAgent for unified schema linking and query execution + val agent = MultiDatabaseChatDBAgent( + projectPath = projectPath, + llmService = service, + databaseConfigs = databaseConfigs, + maxIterations = 10, + renderer = renderer, + mcpToolConfigService = mcpConfigService, + enableLLMStreaming = true + ) + + val task = ChatDBTask( + query = text, + maxRows = 100, + generateVisualization = false + ) + + try { + // Execute the multi-database agent + // It will merge schemas, let LLM decide which database(s) to query, + // and execute SQL on the appropriate database(s) + agent.execute(task) { progress -> +// println("[ChatDB] Progress: $progress") + } + } finally { + agent.close() + } + + } catch (e: CancellationException) { + renderer.forceStop() + renderer.renderError("Generation cancelled") + } catch (e: Exception) { + renderer.renderError("Error: ${e.message}") + } finally { + isGenerating = false + currentExecutionJob = null + } + } + } + + fun stopGeneration() { + currentExecutionJob?.cancel() + isGenerating = false + } + + /** + * Create a new session - clears the current chat timeline + */ + fun newSession() { + if (isGenerating) { + stopGeneration() + } + renderer.clearMessages() + } + + /** + * Get combined schema from all connected data sources + */ + fun getSchema(): DatabaseSchema? { + if (schemas.isEmpty()) return null + + // Return the first schema for now, or combine them + // For multi-datasource, we might want to show all schemas + return schemas.values.firstOrNull() + } + + /** + * Get all schemas from connected data sources + */ + fun getAllSchemas(): Map = schemas.toMap() + + /** + * Get schema for a specific data source + */ + fun getSchemaForDataSource(id: String): DatabaseSchema? = schemas[id] + + fun dispose() { + stopGeneration() + disconnectAll() + scope.cancel() + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt new file mode 100644 index 0000000000..c22777203d --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBChatPane.kt @@ -0,0 +1,391 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.BorderStroke +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.BasicTextField +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.SolidColor +import androidx.compose.ui.input.key.* +import androidx.compose.ui.platform.LocalFocusManager +import androidx.compose.ui.text.TextStyle +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.TextFieldValue +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import cc.unitmesh.agent.Platform +import cc.unitmesh.agent.database.DatabaseSchema +import cc.unitmesh.devins.ui.compose.agent.AgentMessageList +import cc.unitmesh.devins.ui.compose.agent.ComposeRenderer +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.ConnectionStatus +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Chat pane for ChatDB - right side chat area for text-to-SQL + * + * Supports multi-datasource mode: shows connection status for multiple databases. + */ +@Composable +fun ChatDBChatPane( + renderer: ComposeRenderer, + connectionStatus: ConnectionStatus, + connectedCount: Int = 0, + selectedCount: Int = 0, + schema: DatabaseSchema?, + isGenerating: Boolean, + onSendMessage: (String) -> Unit, + onStopGeneration: () -> Unit, + onNewSession: () -> Unit = {}, + modifier: Modifier = Modifier +) { + val hasAnyConnection = connectedCount > 0 + + Column(modifier = modifier.fillMaxSize()) { + // Connection status banner (multi-datasource aware) + ConnectionStatusBanner( + connectionStatus = connectionStatus, + connectedCount = connectedCount, + selectedCount = selectedCount, + schema = schema, + onNewSession = onNewSession, + hasMessages = renderer.timeline.isNotEmpty() + ) + + // Message list + Box(modifier = Modifier.weight(1f)) { + AgentMessageList( + renderer = renderer, + modifier = Modifier.fillMaxSize() + ) + + // Welcome message when no messages and not streaming + if (renderer.timeline.isEmpty() && renderer.currentStreamingOutput.isEmpty() && !renderer.isProcessing) { + WelcomeMessage( + isConnected = hasAnyConnection, + connectedCount = connectedCount, + schema = schema, + onQuickQuery = onSendMessage, + modifier = Modifier.fillMaxSize() + ) + } + } + + HorizontalDivider() + + // Input area + ChatInputArea( + isGenerating = isGenerating, + isConnected = hasAnyConnection, + connectedCount = connectedCount, + onSendMessage = onSendMessage, + onStopGeneration = onStopGeneration + ) + } +} + +@Composable +private fun ConnectionStatusBanner( + connectionStatus: ConnectionStatus, + connectedCount: Int, + selectedCount: Int, + schema: DatabaseSchema?, + onNewSession: () -> Unit, + hasMessages: Boolean +) { + val hasAnyConnection = connectedCount > 0 + + when { + hasAnyConnection -> { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CheckCircle, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.primary + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = if (connectedCount == 1) "Connected" else "$connectedCount databases connected", + style = MaterialTheme.typography.labelLarge, + color = MaterialTheme.colorScheme.primary + ) + val statusText = buildString { + if (schema != null) { + append("${schema.tables.size} tables available") + } + if (connectedCount < selectedCount) { + if (isNotEmpty()) append(" โ€ข ") + append("${selectedCount - connectedCount} not connected") + } + } + if (statusText.isNotEmpty()) { + Text( + text = statusText, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + // New session button (only show if there are messages) + if (hasMessages) { + FilledTonalIconButton( + onClick = onNewSession, + modifier = Modifier.size(36.dp) + ) { + Icon( + AutoDevComposeIcons.Add, + contentDescription = "New Session", + modifier = Modifier.size(20.dp) + ) + } + } + } + } + } + selectedCount > 0 -> { + // Selected but not connected + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "$selectedCount data source${if (selectedCount > 1) "s" else ""} selected. Click 'Connect All' to connect.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + else -> { + // No data sources selected + Card( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "No data sources selected. Select data sources from the left panel.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } +} + +@Composable +private fun WelcomeMessage( + isConnected: Boolean, + connectedCount: Int = 0, + schema: DatabaseSchema?, + onQuickQuery: (String) -> Unit, + modifier: Modifier = Modifier +) { + Column( + modifier = modifier + .fillMaxSize() + .padding(24.dp), + verticalArrangement = Arrangement.Center, + horizontalAlignment = Alignment.CenterHorizontally + ) { + Icon( + AutoDevComposeIcons.Database, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.primary.copy(alpha = 0.6f) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "ChatDB - Text to SQL", + style = MaterialTheme.typography.headlineSmall, + color = MaterialTheme.colorScheme.onSurface + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = when { + connectedCount > 1 -> "Query across $connectedCount databases in natural language" + isConnected -> "Ask questions about your data in natural language" + else -> "Connect to a database to start querying" + }, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + if (isConnected && schema != null) { + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = "Try asking:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + val suggestions = listOf( + "Show me all tables in the database", + "What are the columns in the ${schema.tables.firstOrNull()?.name ?: "users"} table?", + "Count the total number of records" + ) + + suggestions.forEach { suggestion -> + SuggestionChip( + onClick = { onQuickQuery(suggestion) }, + label = { Text(suggestion, style = MaterialTheme.typography.bodySmall) }, + modifier = Modifier.padding(vertical = 4.dp) + ) + } + } + } +} + +@Composable +private fun ChatInputArea( + isGenerating: Boolean, + isConnected: Boolean, + connectedCount: Int = 0, + onSendMessage: (String) -> Unit, + onStopGeneration: () -> Unit +) { + var inputText by remember { mutableStateOf(TextFieldValue("")) } + val focusManager = LocalFocusManager.current + + fun sendMessage() { + if (inputText.text.isNotBlank() && !isGenerating && isConnected) { + onSendMessage(inputText.text) + inputText = TextFieldValue("") + } + } + + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surface + ) { + Row( + modifier = Modifier + .padding(12.dp) + .fillMaxWidth(), + verticalAlignment = Alignment.Bottom, + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + Surface( + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(12.dp), + border = BorderStroke(1.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + BasicTextField( + value = inputText, + onValueChange = { inputText = it }, + modifier = Modifier + .fillMaxWidth() + .padding(12.dp) + .onKeyEvent { event -> + if (event.type == KeyEventType.KeyDown && event.key == Key.Enter) { + if (!event.isShiftPressed) { + sendMessage() + true + } else false + } else false + }, + textStyle = TextStyle( + color = MaterialTheme.colorScheme.onSurface, + fontSize = 14.sp + ), + cursorBrush = SolidColor(MaterialTheme.colorScheme.primary), + enabled = isConnected, + decorationBox = { innerTextField -> + Box { + if (inputText.text.isEmpty()) { + val placeholderText = when { + connectedCount > 1 -> "Ask a question across $connectedCount databases..." + isConnected -> "Ask a question about your data..." + else -> "Connect to a database first" + } + Text( + text = placeholderText, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f) + ) + } + innerTextField() + } + } + ) + } + + if (isGenerating) { + FilledIconButton( + onClick = onStopGeneration, + colors = IconButtonDefaults.filledIconButtonColors( + containerColor = MaterialTheme.colorScheme.error + ) + ) { + Icon(AutoDevComposeIcons.Stop, contentDescription = "Stop") + } + } else { + FilledIconButton( + onClick = { sendMessage() }, + enabled = inputText.text.isNotBlank() && isConnected + ) { + Icon(AutoDevComposeIcons.Send, contentDescription = "Send") + } + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt new file mode 100644 index 0000000000..791b1454af --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/ChatDBStepCard.kt @@ -0,0 +1,1180 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.fadeIn +import androidx.compose.animation.fadeOut +import androidx.compose.animation.shrinkVertically +import androidx.compose.foundation.BorderStroke +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.horizontalScroll +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.KeyboardArrowDown +import androidx.compose.material.icons.filled.KeyboardArrowRight +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.text.font.FontFamily +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import cc.unitmesh.agent.render.ChatDBStepStatus +import cc.unitmesh.agent.render.ChatDBStepType +import cc.unitmesh.agent.render.TimelineItem + +/** + * Composable for rendering a ChatDB execution step card. + * Displays step status, title, and expandable details. + */ +@Composable +fun ChatDBStepCard( + step: TimelineItem.ChatDBStepItem, + modifier: Modifier = Modifier, + onApprove: (() -> Unit)? = null, + onReject: (() -> Unit)? = null +) { + var isExpanded by remember { mutableStateOf(step.status == ChatDBStepStatus.ERROR) } + + Card( + modifier = modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(8.dp), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(12.dp) + ) { + // Header row with icon, title, and status + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Expand/collapse icon + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown else Icons.Default.KeyboardArrowRight, + contentDescription = if (isExpanded) "Collapse" else "Expand", + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.width(8.dp)) + + // Step icon + Text( + text = step.stepType.icon, + fontSize = 16.sp, + modifier = Modifier.padding(end = 8.dp) + ) + + // Step title + Text( + text = step.title, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + } + + // Status indicator + StepStatusBadge(status = step.status) + } + + // Expandable details section + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 12.dp) + ) { + // Error message if present + step.error?.let { error -> + Text( + text = "Error: $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 8.dp) + ) + } + + // Details + if (step.details.isNotEmpty()) { + StepDetails( + details = step.details, + stepType = step.stepType, + step = step, + onApprove = onApprove, + onReject = onReject + ) + } + } + } + } + } +} + +@Composable +private fun StepStatusBadge(status: ChatDBStepStatus) { + val (color, text) = when (status) { + ChatDBStepStatus.PENDING -> MaterialTheme.colorScheme.outline to "Pending" + ChatDBStepStatus.IN_PROGRESS -> MaterialTheme.colorScheme.primary to "In Progress" + ChatDBStepStatus.SUCCESS -> MaterialTheme.colorScheme.tertiary to "Success" + ChatDBStepStatus.WARNING -> MaterialTheme.colorScheme.secondary to "Warning" + ChatDBStepStatus.ERROR -> MaterialTheme.colorScheme.error to "Error" + ChatDBStepStatus.AWAITING_APPROVAL -> MaterialTheme.colorScheme.tertiary to "Awaiting Approval" + ChatDBStepStatus.APPROVED -> MaterialTheme.colorScheme.primary to "Approved" + ChatDBStepStatus.REJECTED -> MaterialTheme.colorScheme.error to "Rejected" + } + + Surface( + shape = RoundedCornerShape(12.dp), + color = color.copy(alpha = 0.2f), + modifier = Modifier.padding(start = 8.dp) + ) { + Text( + text = text, + style = MaterialTheme.typography.labelSmall, + color = color, + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp) + ) + } +} + +@Composable +private fun StepDetails( + details: Map, + stepType: ChatDBStepType, + step: TimelineItem.ChatDBStepItem, + onApprove: (() -> Unit)?, + onReject: (() -> Unit)? +) { + Column( + modifier = Modifier + .fillMaxWidth() + .clip(RoundedCornerShape(4.dp)) + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.5f)) + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + when (stepType) { + ChatDBStepType.FETCH_SCHEMA -> { + // Single database mode + details["databaseName"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } + details["totalTables"]?.let { + DetailRow("Total Tables", it.toString()) + } + + // Multi-database mode - show databases list + @Suppress("UNCHECKED_CAST") + val databases = details["databases"] as? List> + if (databases != null && databases.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Connected Databases", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + databases.forEach { dbInfo -> + DatabaseInfoCard(dbInfo) + } + } + + // Show table schema cards (single database mode) + @Suppress("UNCHECKED_CAST") + val tableSchemas = details["tableSchemas"] as? List> + if (tableSchemas != null && tableSchemas.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Tables", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + tableSchemas.forEach { tableInfo -> + TableSchemaCard(tableInfo) + } + } else if (databases == null) { + // Fallback to simple table list + details["tables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Tables", tables.joinToString(", ")) + } + } + } + } + + ChatDBStepType.SCHEMA_LINKING -> { + // Multi-database mode - show analyzed databases + @Suppress("UNCHECKED_CAST") + val databasesAnalyzed = details["databasesAnalyzed"] as? List + if (databasesAnalyzed != null && databasesAnalyzed.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Databases Analyzed:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + databasesAnalyzed.forEach { dbName -> + KeywordChip(dbName) + } + } + } + + // Show schema context preview + details["schemaContext"]?.let { context -> + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Schema Context", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) { + Text( + text = context.toString(), + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 10.sp + ), + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(8.dp), + maxLines = 10 + ) + } + } + + // Show keywords (single database mode) + details["keywords"]?.let { keywords -> + if (keywords is List<*> && keywords.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Keywords:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + keywords.forEach { keyword -> + KeywordChip(keyword.toString()) + } + } + } + } + + // Show relevant table schemas + @Suppress("UNCHECKED_CAST") + val relevantTableSchemas = details["relevantTableSchemas"] as? List> + if (relevantTableSchemas != null && relevantTableSchemas.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Relevant Tables", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + relevantTableSchemas.forEach { tableInfo -> + TableSchemaCard(tableInfo, highlightRelevant = true) + } + } else if (databasesAnalyzed == null) { + // Fallback + details["relevantTables"]?.let { tables -> + if (tables is List<*>) { + DetailRow("Relevant Tables", tables.joinToString(", ")) + } + } + } + } + + ChatDBStepType.GENERATE_SQL, ChatDBStepType.REVISE_SQL -> { + // Multi-database mode - show target databases + @Suppress("UNCHECKED_CAST") + val targetDatabases = details["targetDatabases"] as? List + if (targetDatabases != null && targetDatabases.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Target Databases:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + targetDatabases.forEach { dbName -> + KeywordChip(dbName) + } + } + Spacer(modifier = Modifier.height(8.dp)) + } + + // Multi-database mode - show SQL blocks + @Suppress("UNCHECKED_CAST") + val sqlBlocks = details["sqlBlocks"] as? List> + if (sqlBlocks != null && sqlBlocks.isNotEmpty()) { + sqlBlocks.forEach { block -> + val database = block["database"]?.toString() ?: "Unknown" + val sql = block["sql"]?.toString() ?: "" + Text( + text = "Database: $database", + style = MaterialTheme.typography.labelSmall, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.secondary, + modifier = Modifier.padding(bottom = 4.dp) + ) + CodeBlock(code = sql, language = "sql") + Spacer(modifier = Modifier.height(8.dp)) + } + } else { + // Single database mode + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + } + } + + ChatDBStepType.VALIDATE_SQL -> { + details["errorType"]?.let { + DetailRow("Error Type", it.toString()) + } + details["errors"]?.let { errors -> + if (errors is List<*>) { + Column { + Text( + text = "Errors:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(start = 8.dp) + ) + } + } + } + } + } + + ChatDBStepType.EXECUTE_SQL -> { + // Multi-database mode - show database name + details["database"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } + + // Show SQL that was executed + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show result summary + details["rowCount"]?.let { + DetailRow("Rows Returned", it.toString()) + } + + // Show data preview table + @Suppress("UNCHECKED_CAST") + val columns = details["columns"] as? List + @Suppress("UNCHECKED_CAST") + val previewRows = details["previewRows"] as? List> + + if (columns != null && previewRows != null && previewRows.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Data Preview", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + DataPreviewTable(columns = columns, rows = previewRows) + + val totalRows = (details["rowCount"] as? Int) ?: previewRows.size + if (totalRows > previewRows.size) { + Text( + text = "Showing ${previewRows.size} of $totalRows rows", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(top = 4.dp) + ) + } + } else if (columns != null) { + DetailRow("Columns", columns.joinToString(", ")) + } + } + + ChatDBStepType.FINAL_RESULT -> { + // Multi-database mode - show databases queried + @Suppress("UNCHECKED_CAST") + val databases = details["databases"] as? List + if (databases != null && databases.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Databases Queried:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + databases.forEach { dbName -> + KeywordChip(dbName) + } + } + Spacer(modifier = Modifier.height(8.dp)) + } + + // Show final SQL (single database mode) + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show result summary + details["totalRows"]?.let { + DetailRow("Total Rows", it.toString()) + } + details["rowCount"]?.let { + if (details["totalRows"] == null) { + DetailRow("Total Rows", it.toString()) + } + } + details["revisionAttempts"]?.let { attempts -> + if ((attempts as? Int ?: 0) > 0) { + DetailRow("Revision Attempts", attempts.toString()) + } + } + + // Show errors if any + @Suppress("UNCHECKED_CAST") + val errors = details["errors"] as? List + if (errors != null && errors.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Errors", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error, + modifier = Modifier.padding(start = 8.dp) + ) + } + } + + // Show data preview + @Suppress("UNCHECKED_CAST") + val columns = details["columns"] as? List + @Suppress("UNCHECKED_CAST") + val previewRows = details["previewRows"] as? List> + + if (columns != null && previewRows != null && previewRows.isNotEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Query Results", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + DataPreviewTable(columns = columns, rows = previewRows) + } + } + + ChatDBStepType.DRY_RUN -> { + // Show SQL being validated + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show dry run result + val isValid = details["isValid"] as? Boolean + val message = details["message"]?.toString() + @Suppress("UNCHECKED_CAST") + val errors = details["errors"] as? List + @Suppress("UNCHECKED_CAST") + val warnings = details["warnings"] as? List + val estimatedRows = details["estimatedRows"] + + // Validation status + if (isValid != null) { + Spacer(modifier = Modifier.height(8.dp)) + Surface( + shape = RoundedCornerShape(4.dp), + color = if (isValid) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f) + } else { + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + } + ) { + Row( + modifier = Modifier.padding(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = if (isValid) "โœ“" else "โœ—", + fontSize = 16.sp, + color = if (isValid) { + MaterialTheme.colorScheme.primary + } else { + MaterialTheme.colorScheme.error + } + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = if (isValid) "Validation Passed" else "Validation Failed", + style = MaterialTheme.typography.bodySmall, + fontWeight = FontWeight.Bold, + color = if (isValid) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onErrorContainer + } + ) + } + } + } + + // Show message + if (!message.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + // Show errors + if (!errors.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Errors:", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.error + ) + errors.forEach { error -> + Text( + text = "โ€ข $error", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error + ) + } + } + + // Show warnings + if (!warnings.isNullOrEmpty()) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Warnings:", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.tertiary + ) + warnings.forEach { warning -> + Text( + text = "โ€ข $warning", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.tertiary + ) + } + } + + // Show estimated rows + if (estimatedRows != null) { + Spacer(modifier = Modifier.height(4.dp)) + DetailRow("Estimated Rows", estimatedRows.toString()) + } + } + + ChatDBStepType.AWAIT_APPROVAL -> { + // Show SQL that requires approval + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show operation type + details["operationType"]?.let { opType -> + DetailRow("Operation Type", opType.toString()) + } + + // Show affected tables + @Suppress("UNCHECKED_CAST") + val affectedTables = details["affectedTables"] as? List + if (affectedTables != null && affectedTables.isNotEmpty()) { + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + modifier = Modifier.horizontalScroll(rememberScrollState()) + ) { + Text( + text = "Affected Tables:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + affectedTables.forEach { tableName -> + KeywordChip(tableName) + } + } + } + + // Show high risk warning + val isHighRisk = details["isHighRisk"] as? Boolean ?: false + if (isHighRisk) { + Spacer(modifier = Modifier.height(8.dp)) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + ) { + Row( + modifier = Modifier.padding(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text(text = "โš ๏ธ", fontSize = 16.sp) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = "HIGH RISK OPERATION - This operation may cause data loss", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onErrorContainer, + fontWeight = FontWeight.Bold + ) + } + } + } + + // Show Approve/Reject buttons only when status is AWAITING_APPROVAL + if (step.status == ChatDBStepStatus.AWAITING_APPROVAL && (onApprove != null || onReject != null)) { + Spacer(modifier = Modifier.height(12.dp)) + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.End, + verticalAlignment = Alignment.CenterVertically + ) { + // Reject button + if (onReject != null) { + OutlinedButton( + onClick = onReject, + colors = ButtonDefaults.outlinedButtonColors( + contentColor = MaterialTheme.colorScheme.error + ), + border = BorderStroke(1.dp, MaterialTheme.colorScheme.error) + ) { + Text("Reject") + } + Spacer(modifier = Modifier.width(8.dp)) + } + + // Approve button + if (onApprove != null) { + Button( + onClick = onApprove, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.primary + ) + ) { + Text("Approve") + } + } + } + } + } + + ChatDBStepType.EXECUTE_WRITE -> { + // Show database name + details["database"]?.let { dbName -> + DetailRow("Database", dbName.toString()) + } + + // Show operation type + details["operationType"]?.let { opType -> + DetailRow("Operation Type", opType.toString()) + } + + // Show SQL that was executed + details["sql"]?.let { sql -> + CodeBlock(code = sql.toString(), language = "sql") + } + + // Show affected rows + details["affectedRows"]?.let { rows -> + DetailRow("Affected Rows", rows.toString()) + } + + // Show message if present + details["message"]?.let { message -> + if (message.toString().isNotEmpty()) { + DetailRow("Message", message.toString()) + } + } + } + + else -> { + // Generic detail rendering + details.forEach { (key, value) -> + DetailRow(key, value.toString()) + } + } + } + } +} + +@Composable +private fun DetailRow(label: String, value: String) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween + ) { + Text( + text = "$label:", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(end = 8.dp) + ) + Text( + text = value, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurface + ) + } +} + +@Composable +private fun CodeBlock(code: String, language: String) { + Column(modifier = Modifier.fillMaxWidth()) { + Text( + text = language.uppercase(), + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.primary, + modifier = Modifier.padding(bottom = 4.dp) + ) + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ) { + Text( + text = code, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + color = MaterialTheme.colorScheme.onSurface, + modifier = Modifier.padding(12.dp) + ) + } + } +} + +/** + * Card displaying database information for multi-database mode + */ +@Composable +private fun DatabaseInfoCard(dbInfo: Map) { + var isExpanded by remember { mutableStateOf(false) } + val name = dbInfo["name"]?.toString() ?: "Unknown" + val displayName = dbInfo["displayName"]?.toString() ?: name + val tableCount = dbInfo["tableCount"]?.toString() ?: "0" + + @Suppress("UNCHECKED_CAST") + val tables = dbInfo["tables"] as? List ?: emptyList() + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(6.dp), + color = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(8.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown + else Icons.Default.KeyboardArrowRight, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(4.dp)) + Text(text = "๐Ÿ—„๏ธ", fontSize = 12.sp) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = displayName, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + if (name != displayName) { + Text( + text = " ($name)", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + Text( + text = "$tableCount tables", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + if (tables.isNotEmpty()) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp, start = 24.dp) + ) { + Text( + text = tables.joinToString(", "), + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } + } +} + +/** + * Card displaying table schema information + */ +@Composable +private fun TableSchemaCard( + tableInfo: Map, + highlightRelevant: Boolean = false +) { + var isExpanded by remember { mutableStateOf(highlightRelevant) } + val tableName = tableInfo["name"]?.toString() ?: "Unknown" + val comment = tableInfo["comment"]?.toString() + + @Suppress("UNCHECKED_CAST") + val columns = tableInfo["columns"] as? List> ?: emptyList() + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 4.dp), + shape = RoundedCornerShape(6.dp), + color = if (highlightRelevant) + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + else + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .clickable { isExpanded = !isExpanded } + .padding(8.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = if (isExpanded) Icons.Default.KeyboardArrowDown + else Icons.Default.KeyboardArrowRight, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = "๐Ÿ“‹", + fontSize = 12.sp + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = tableName, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.onSurface + ) + } + Text( + text = "${columns.size} columns", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + comment?.let { + Text( + text = it, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(start = 24.dp, top = 2.dp) + ) + } + + AnimatedVisibility( + visible = isExpanded, + enter = expandVertically() + fadeIn(), + exit = shrinkVertically() + fadeOut() + ) { + if (columns.isNotEmpty()) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(top = 8.dp, start = 8.dp) + ) { + columns.forEach { col -> + ColumnInfoRow(col) + } + } + } + } + } + } +} + +/** + * Row displaying column information + */ +@Composable +private fun ColumnInfoRow(columnInfo: Map) { + val name = columnInfo["name"]?.toString() ?: "" + val type = columnInfo["type"]?.toString() ?: "" + val isPrimaryKey = columnInfo["isPrimaryKey"] as? Boolean ?: false + val isForeignKey = columnInfo["isForeignKey"] as? Boolean ?: false + val nullable = columnInfo["nullable"] as? Boolean ?: true + + Row( + modifier = Modifier + .fillMaxWidth() + .padding(vertical = 2.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Column name with key indicators + if (isPrimaryKey) { + Text(text = "๐Ÿ”‘", fontSize = 10.sp) + Spacer(modifier = Modifier.width(2.dp)) + } else if (isForeignKey) { + Text(text = "๐Ÿ”—", fontSize = 10.sp) + Spacer(modifier = Modifier.width(2.dp)) + } + + Text( + text = name, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace + ), + fontWeight = if (isPrimaryKey) FontWeight.Bold else FontWeight.Normal, + color = if (isPrimaryKey) MaterialTheme.colorScheme.primary + else MaterialTheme.colorScheme.onSurface + ) + } + + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + verticalAlignment = Alignment.CenterVertically + ) { + // Type badge + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.5f) + ) { + Text( + text = type, + style = MaterialTheme.typography.labelSmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 10.sp + ), + color = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier.padding(horizontal = 4.dp, vertical = 1.dp) + ) + } + + // Nullable indicator + if (!nullable) { + Surface( + shape = RoundedCornerShape(4.dp), + color = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.5f) + ) { + Text( + text = "NOT NULL", + style = MaterialTheme.typography.labelSmall.copy(fontSize = 8.sp), + color = MaterialTheme.colorScheme.onErrorContainer, + modifier = Modifier.padding(horizontal = 3.dp, vertical = 1.dp) + ) + } + } + } + } +} + +/** + * Keyword chip for schema linking + */ +@Composable +private fun KeywordChip(keyword: String) { + Surface( + shape = RoundedCornerShape(12.dp), + color = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.7f) + ) { + Text( + text = keyword, + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp) + ) + } +} + +/** + * Data preview table for query results + */ +@Composable +private fun DataPreviewTable( + columns: List, + rows: List>, + maxDisplayRows: Int = 10 +) { + val displayRows = rows.take(maxDisplayRows) + val borderColor = MaterialTheme.colorScheme.outlineVariant + + Column( + modifier = Modifier + .fillMaxWidth() + .horizontalScroll(rememberScrollState()) + ) { + Surface( + shape = RoundedCornerShape(6.dp), + color = MaterialTheme.colorScheme.surface, + modifier = Modifier.border( + width = 1.dp, + color = borderColor, + shape = RoundedCornerShape(6.dp) + ) + ) { + Column { + // Header row + Row( + modifier = Modifier + .background(MaterialTheme.colorScheme.surfaceVariant) + .padding(1.dp) + ) { + columns.forEach { column -> + Box( + modifier = Modifier + .widthIn(min = 80.dp, max = 200.dp) + .border( + width = 0.5.dp, + color = borderColor + ) + .padding(horizontal = 8.dp, vertical = 6.dp) + ) { + Text( + text = column, + style = MaterialTheme.typography.labelSmall.copy( + fontFamily = FontFamily.Monospace + ), + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + } + + // Data rows + displayRows.forEachIndexed { rowIndex, row -> + Row( + modifier = Modifier + .background( + if (rowIndex % 2 == 0) + MaterialTheme.colorScheme.surface + else + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) + .padding(1.dp) + ) { + row.forEachIndexed { colIndex, cell -> + Box( + modifier = Modifier + .widthIn(min = 80.dp, max = 200.dp) + .border( + width = 0.5.dp, + color = borderColor.copy(alpha = 0.5f) + ) + .padding(horizontal = 8.dp, vertical = 4.dp) + ) { + Text( + text = cell.ifEmpty { "NULL" }, + style = MaterialTheme.typography.bodySmall.copy( + fontFamily = FontFamily.Monospace, + fontSize = 11.sp + ), + color = if (cell.isEmpty()) + MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f) + else + MaterialTheme.colorScheme.onSurface, + maxLines = 2, + overflow = TextOverflow.Ellipsis + ) + } + } + } + } + } + } + } +} diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt new file mode 100644 index 0000000000..d3a4699d5f --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourceConfigPane.kt @@ -0,0 +1,297 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.shrinkVertically +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Inline configuration pane for adding/editing data source + * Displayed on the right side instead of a dialog + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun DataSourceConfigPane( + existingConfig: DataSourceConfig?, + onCancel: () -> Unit, + onSave: (DataSourceConfig) -> Unit, + onSaveAndConnect: (DataSourceConfig) -> Unit, + modifier: Modifier = Modifier +) { + var name by remember(existingConfig) { mutableStateOf(existingConfig?.name ?: "") } + var dialect by remember(existingConfig) { mutableStateOf(existingConfig?.dialect ?: DatabaseDialect.MYSQL) } + var host by remember(existingConfig) { mutableStateOf(existingConfig?.host ?: "localhost") } + var port by remember(existingConfig) { mutableStateOf(existingConfig?.port?.toString() ?: "3306") } + var database by remember(existingConfig) { mutableStateOf(existingConfig?.database ?: "") } + var username by remember(existingConfig) { mutableStateOf(existingConfig?.username ?: "") } + var password by remember(existingConfig) { mutableStateOf(existingConfig?.password ?: "") } + var description by remember(existingConfig) { mutableStateOf(existingConfig?.description ?: "") } + var showPassword by remember { mutableStateOf(false) } + var dialectExpanded by remember { mutableStateOf(false) } + var advancedExpanded by remember { mutableStateOf(existingConfig?.description?.isNotBlank() == true) } + + val isEditing = existingConfig != null + val scrollState = rememberScrollState() + + val isValid = name.isNotBlank() && database.isNotBlank() && + (dialect == DatabaseDialect.SQLITE || host.isNotBlank()) + + fun buildConfig(): DataSourceConfig = DataSourceConfig( + id = existingConfig?.id ?: "", + name = name.trim(), + dialect = dialect, + host = host.trim(), + port = port.toIntOrNull() ?: dialect.defaultPort, + database = database.trim(), + username = username.trim(), + password = password, + description = description.trim() + ) + + Column( + modifier = modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface) + ) { + // Header + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 12.dp, vertical = 8.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = if (isEditing) "Edit Data Source" else "Add Data Source", + style = MaterialTheme.typography.titleSmall + ) + IconButton(onClick = onCancel, modifier = Modifier.size(32.dp)) { + Icon(AutoDevComposeIcons.Close, contentDescription = "Close", modifier = Modifier.size(18.dp)) + } + } + } + + HorizontalDivider() + + // Form content - more compact + Column( + modifier = Modifier + .weight(1f) + .verticalScroll(scrollState) + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Name and Database Type in one row + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name *") }, + modifier = Modifier.weight(1f), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + ExposedDropdownMenuBox( + expanded = dialectExpanded, + onExpandedChange = { dialectExpanded = it }, + modifier = Modifier.weight(1f) + ) { + OutlinedTextField( + value = dialect.displayName, + onValueChange = {}, + readOnly = true, + label = { Text("Type") }, + trailingIcon = { ExposedDropdownMenuDefaults.TrailingIcon(expanded = dialectExpanded) }, + modifier = Modifier.fillMaxWidth().menuAnchor(), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + ExposedDropdownMenu( + expanded = dialectExpanded, + onDismissRequest = { dialectExpanded = false } + ) { + DatabaseDialect.entries.forEach { option -> + DropdownMenuItem( + text = { Text(option.displayName, style = MaterialTheme.typography.bodySmall) }, + onClick = { + dialect = option + port = option.defaultPort.toString() + dialectExpanded = false + } + ) + } + } + } + } + + // Host and Port (not for SQLite) + if (dialect != DatabaseDialect.SQLITE) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = host, + onValueChange = { host = it }, + label = { Text("Host *") }, + modifier = Modifier.weight(2f), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + OutlinedTextField( + value = port, + onValueChange = { port = it.filter { c -> c.isDigit() } }, + label = { Text("Port") }, + modifier = Modifier.weight(1f), + singleLine = true, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number), + textStyle = MaterialTheme.typography.bodySmall + ) + } + } + + // Database name + OutlinedTextField( + value = database, + onValueChange = { database = it }, + label = { Text(if (dialect == DatabaseDialect.SQLITE) "File Path *" else "Database *") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + + // Username and Password in one row (not for SQLite) + if (dialect != DatabaseDialect.SQLITE) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = username, + onValueChange = { username = it }, + label = { Text("Username") }, + modifier = Modifier.weight(1f), + singleLine = true, + textStyle = MaterialTheme.typography.bodySmall + ) + OutlinedTextField( + value = password, + onValueChange = { password = it }, + label = { Text("Password") }, + modifier = Modifier.weight(1f), + singleLine = true, + visualTransformation = if (showPassword) VisualTransformation.None else PasswordVisualTransformation(), + trailingIcon = { + IconButton(onClick = { showPassword = !showPassword }, modifier = Modifier.size(24.dp)) { + Icon( + if (showPassword) AutoDevComposeIcons.VisibilityOff else AutoDevComposeIcons.Visibility, + contentDescription = if (showPassword) "Hide" else "Show", + modifier = Modifier.size(16.dp) + ) + } + }, + textStyle = MaterialTheme.typography.bodySmall + ) + } + } + + Spacer(modifier = Modifier.height(4.dp)) + + // Advanced Settings - collapsible + Row( + modifier = Modifier + .fillMaxWidth() + .clickable { advancedExpanded = !advancedExpanded } + .padding(vertical = 4.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + if (advancedExpanded) AutoDevComposeIcons.ExpandLess else AutoDevComposeIcons.ExpandMore, + contentDescription = if (advancedExpanded) "Collapse" else "Expand", + modifier = Modifier.size(18.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "Advanced Settings", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + AnimatedVisibility( + visible = advancedExpanded, + enter = expandVertically(), + exit = shrinkVertically() + ) { + Column( + modifier = Modifier.padding(start = 4.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description") }, + modifier = Modifier.fillMaxWidth(), + minLines = 2, + maxLines = 3, + textStyle = MaterialTheme.typography.bodySmall + ) + } + } + } + + HorizontalDivider() + + // Action buttons - more compact + Row( + modifier = Modifier + .fillMaxWidth() + .background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f)) + .padding(horizontal = 12.dp, vertical = 8.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp, Alignment.End), + verticalAlignment = Alignment.CenterVertically + ) { + TextButton(onClick = onCancel, modifier = Modifier.height(36.dp)) { + Text("Cancel", style = MaterialTheme.typography.labelMedium) + } + OutlinedButton( + onClick = { onSave(buildConfig()) }, + enabled = isValid, + modifier = Modifier.height(36.dp) + ) { + Text("Save", style = MaterialTheme.typography.labelMedium) + } + Button( + onClick = { onSaveAndConnect(buildConfig()) }, + enabled = isValid, + modifier = Modifier.height(36.dp) + ) { + Text("Save & Connect", style = MaterialTheme.typography.labelMedium) + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt new file mode 100644 index 0000000000..30d9219446 --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/components/DataSourcePanel.kt @@ -0,0 +1,407 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.components + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.* +import cc.unitmesh.devins.ui.compose.icons.AutoDevComposeIcons + +/** + * Data Source Panel - Left side panel for managing database connections + * + * Supports multi-datasource selection: all data sources are selected by default, + * users can toggle individual data sources on/off using checkboxes. + */ +@Composable +fun DataSourcePanel( + dataSources: List, + selectedDataSourceIds: Set, + connectionStatuses: Map, + filterQuery: String, + onFilterChange: (String) -> Unit, + onToggleDataSource: (String) -> Unit, + onAddClick: () -> Unit, + onEditClick: (DataSourceConfig) -> Unit, + onDeleteClick: (String) -> Unit, + onConnectClick: (String) -> Unit, + onDisconnectClick: (String) -> Unit, + onConnectAllClick: () -> Unit, + onDisconnectAllClick: () -> Unit, + modifier: Modifier = Modifier +) { + val selectedCount = selectedDataSourceIds.size + val connectedCount = connectionStatuses.values.count { it is ConnectionStatus.Connected } + + Column( + modifier = modifier + .fillMaxHeight() + .background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f)) + ) { + // Header with Add button and selection info + DataSourceHeader( + selectedCount = selectedCount, + totalCount = dataSources.size, + onAddClick = onAddClick + ) + + HorizontalDivider() + + // Search/Filter + SearchField( + query = filterQuery, + onQueryChange = onFilterChange, + modifier = Modifier.padding(8.dp) + ) + + // Data source list + LazyColumn( + modifier = Modifier.weight(1f), + contentPadding = PaddingValues(vertical = 4.dp) + ) { + items(dataSources, key = { it.id }) { dataSource -> + DataSourceItem( + dataSource = dataSource, + isSelected = dataSource.id in selectedDataSourceIds, + connectionStatus = connectionStatuses[dataSource.id] ?: ConnectionStatus.Disconnected, + onToggle = { onToggleDataSource(dataSource.id) }, + onEditClick = { onEditClick(dataSource) }, + onDeleteClick = { onDeleteClick(dataSource.id) }, + onConnectClick = { onConnectClick(dataSource.id) }, + onDisconnectClick = { onDisconnectClick(dataSource.id) } + ) + } + } + + // Connection controls for all selected data sources + if (selectedDataSourceIds.isNotEmpty()) { + HorizontalDivider() + MultiConnectionControls( + selectedCount = selectedCount, + connectedCount = connectedCount, + onConnectAll = onConnectAllClick, + onDisconnectAll = onDisconnectAllClick + ) + } + } +} + +@Composable +private fun DataSourceHeader( + selectedCount: Int, + totalCount: Int, + onAddClick: () -> Unit +) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Column { + Text( + text = "Data Sources", + style = MaterialTheme.typography.titleMedium + ) + if (totalCount > 0) { + Text( + text = "$selectedCount of $totalCount selected", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + IconButton( + onClick = onAddClick, + modifier = Modifier.size(32.dp) + ) { + Icon( + AutoDevComposeIcons.Add, + contentDescription = "Add data source", + modifier = Modifier.size(20.dp) + ) + } + } +} + +@Composable +private fun SearchField( + query: String, + onQueryChange: (String) -> Unit, + modifier: Modifier = Modifier +) { + OutlinedTextField( + value = query, + onValueChange = onQueryChange, + modifier = modifier.fillMaxWidth(), + placeholder = { Text("Search...", style = MaterialTheme.typography.bodySmall) }, + leadingIcon = { + Icon( + AutoDevComposeIcons.Search, + contentDescription = null, + modifier = Modifier.size(18.dp) + ) + }, + trailingIcon = { + if (query.isNotEmpty()) { + IconButton(onClick = { onQueryChange("") }, modifier = Modifier.size(24.dp)) { + Icon(AutoDevComposeIcons.Close, contentDescription = "Clear", modifier = Modifier.size(16.dp)) + } + } + }, + singleLine = true, + shape = RoundedCornerShape(8.dp), + textStyle = MaterialTheme.typography.bodySmall, + colors = OutlinedTextFieldDefaults.colors( + focusedContainerColor = MaterialTheme.colorScheme.surface, + unfocusedContainerColor = MaterialTheme.colorScheme.surface + ) + ) +} + +@Composable +private fun DataSourceItem( + dataSource: DataSourceConfig, + isSelected: Boolean, + connectionStatus: ConnectionStatus, + onToggle: () -> Unit, + onEditClick: () -> Unit, + onDeleteClick: () -> Unit, + onConnectClick: () -> Unit, + onDisconnectClick: () -> Unit +) { + var showMenu by remember { mutableStateOf(false) } + val isConnected = connectionStatus is ConnectionStatus.Connected + val isConnecting = connectionStatus is ConnectionStatus.Connecting + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 8.dp, vertical = 2.dp), + shape = RoundedCornerShape(8.dp), + color = if (isSelected) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + } else { + MaterialTheme.colorScheme.surface + } + ) { + Row( + modifier = Modifier.padding(horizontal = 8.dp, vertical = 8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + // Checkbox for selection + Checkbox( + checked = isSelected, + onCheckedChange = { onToggle() }, + modifier = Modifier.size(24.dp) + ) + + Spacer(modifier = Modifier.width(8.dp)) + + // Status indicator + Box( + modifier = Modifier + .size(8.dp) + .clip(CircleShape) + .background( + when (connectionStatus) { + is ConnectionStatus.Connected -> MaterialTheme.colorScheme.primary + is ConnectionStatus.Connecting -> MaterialTheme.colorScheme.tertiary + is ConnectionStatus.Error -> MaterialTheme.colorScheme.error + else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.5f) + } + ) + ) + + Spacer(modifier = Modifier.width(8.dp)) + + Column(modifier = Modifier.weight(1f)) { + Text( + text = dataSource.name, + style = MaterialTheme.typography.bodyMedium, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Text( + text = dataSource.getDisplayUrl(), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + modifier = Modifier.weight(1f, fill = false) + ) + // Connection status text + when (connectionStatus) { + is ConnectionStatus.Connected -> { + Text( + text = "Connected", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.primary + ) + } + is ConnectionStatus.Connecting -> { + Text( + text = "Connecting...", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.tertiary + ) + } + is ConnectionStatus.Error -> { + Text( + text = "Error", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.error + ) + } + else -> {} + } + } + } + + // Quick connect/disconnect button + if (isSelected) { + if (isConnecting) { + CircularProgressIndicator( + modifier = Modifier.size(20.dp), + strokeWidth = 2.dp + ) + } else if (isConnected) { + IconButton( + onClick = onDisconnectClick, + modifier = Modifier.size(28.dp) + ) { + Icon( + AutoDevComposeIcons.CloudOff, + contentDescription = "Disconnect", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.error + ) + } + } else { + IconButton( + onClick = onConnectClick, + modifier = Modifier.size(28.dp) + ) { + Icon( + AutoDevComposeIcons.Cloud, + contentDescription = "Connect", + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + } + } + } + + Box { + IconButton( + onClick = { showMenu = true }, + modifier = Modifier.size(24.dp) + ) { + Icon( + AutoDevComposeIcons.MoreVert, + contentDescription = "More options", + modifier = Modifier.size(16.dp) + ) + } + + DropdownMenu( + expanded = showMenu, + onDismissRequest = { showMenu = false } + ) { + DropdownMenuItem( + text = { Text("Edit") }, + onClick = { + showMenu = false + onEditClick() + }, + leadingIcon = { + Icon(AutoDevComposeIcons.Edit, contentDescription = null, modifier = Modifier.size(18.dp)) + } + ) + DropdownMenuItem( + text = { Text("Delete", color = MaterialTheme.colorScheme.error) }, + onClick = { + showMenu = false + onDeleteClick() + }, + leadingIcon = { + Icon( + AutoDevComposeIcons.Delete, + contentDescription = null, + modifier = Modifier.size(18.dp), + tint = MaterialTheme.colorScheme.error + ) + } + ) + } + } + } + } +} + +/** + * Connection controls for multi-datasource mode + */ +@Composable +private fun MultiConnectionControls( + selectedCount: Int, + connectedCount: Int, + onConnectAll: () -> Unit, + onDisconnectAll: () -> Unit +) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Status text + Text( + text = "$connectedCount of $selectedCount connected", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Connect All button + Button( + onClick = onConnectAll, + enabled = connectedCount < selectedCount, + modifier = Modifier.weight(1f) + ) { + Text("Connect All") + } + + // Disconnect All button + OutlinedButton( + onClick = onDisconnectAll, + enabled = connectedCount > 0, + modifier = Modifier.weight(1f), + colors = ButtonDefaults.outlinedButtonColors( + contentColor = MaterialTheme.colorScheme.error + ) + ) { + Text("Disconnect All") + } + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt new file mode 100644 index 0000000000..38dce1480e --- /dev/null +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/agent/chatdb/model/DataSourceModels.kt @@ -0,0 +1,207 @@ +package cc.unitmesh.devins.ui.compose.agent.chatdb.model + +import kotlinx.serialization.Serializable + +/** + * Supported database dialects + */ +enum class DatabaseDialect(val displayName: String, val defaultPort: Int) { + MYSQL("MySQL", 3306), + MARIADB("MariaDB", 3306), + POSTGRESQL("PostgreSQL", 5432), + SQLITE("SQLite", 0), + ORACLE("Oracle", 1521), + SQLSERVER("SQL Server", 1433); + + companion object { + fun fromString(value: String): DatabaseDialect { + return entries.find { + it.name.equals(value, ignoreCase = true) || + it.displayName.equals(value, ignoreCase = true) + } ?: MYSQL + } + } +} + +/** + * Data source configuration for database connections + */ +@Serializable +data class DataSourceConfig( + val id: String, + val name: String, + val dialect: DatabaseDialect = DatabaseDialect.MYSQL, + val host: String = "localhost", + val port: Int = 3306, + val database: String = "", + val username: String = "", + val password: String = "", + val description: String = "", + val isDefault: Boolean = false, + val createdAt: Long = 0L, + val updatedAt: Long = 0L +) { + /** + * Get connection URL for display (without password) + */ + fun getDisplayUrl(): String { + return when (dialect) { + DatabaseDialect.SQLITE -> "sqlite://$database" + else -> "${dialect.name.lowercase()}://$host:$port/$database" + } + } + + /** + * Validate the configuration + */ + fun validate(): ValidationResult { + val errors = mutableListOf() + + if (name.isBlank()) errors.add("Name is required") + if (dialect != DatabaseDialect.SQLITE) { + if (host.isBlank()) errors.add("Host is required") + if (port <= 0 || port > 65535) errors.add("Port must be between 1 and 65535") + } + if (database.isBlank()) errors.add("Database name is required") + + return if (errors.isEmpty()) { + ValidationResult.Valid + } else { + ValidationResult.Invalid(errors) + } + } + + /** + * Convert to cc.unitmesh.agent.database.DatabaseConfig + */ + fun toDatabaseConfig(): cc.unitmesh.agent.database.DatabaseConfig { + return cc.unitmesh.agent.database.DatabaseConfig( + host = host, + port = port, + databaseName = database, + username = username, + password = password, + dialect = dialect.name + ) + } +} + +/** + * Validation result + */ +sealed class ValidationResult { + data object Valid : ValidationResult() + data class Invalid(val errors: List) : ValidationResult() +} + +/** + * Connection status for a data source + */ +sealed class ConnectionStatus { + data object Disconnected : ConnectionStatus() + data object Connecting : ConnectionStatus() + data object Connected : ConnectionStatus() + data class Error(val message: String) : ConnectionStatus() +} + +/** + * Chat message in ChatDB context + */ +@Serializable +data class ChatDBMessage( + val id: String, + val role: MessageRole, + val content: String, + val timestamp: Long, + val sql: String? = null, + val queryResult: QueryResultDisplay? = null +) + +/** + * Message role + */ +@Serializable +enum class MessageRole { + USER, + ASSISTANT, + SYSTEM +} + +/** + * Query result for display + */ +@Serializable +data class QueryResultDisplay( + val columns: List, + val rows: List>, + val rowCount: Int, + val executionTimeMs: Long = 0 +) + +/** + * Connection status for a specific data source (used in multi-datasource mode) + */ +data class DataSourceConnectionState( + val dataSourceId: String, + val status: ConnectionStatus = ConnectionStatus.Disconnected +) + +/** + * UI state for ChatDB page + * + * Supports multi-datasource selection: all configured data sources are selected by default, + * users can toggle individual data sources on/off. + */ +data class ChatDBState( + val dataSources: List = emptyList(), + /** Set of selected data source IDs (multi-selection mode) */ + val selectedDataSourceIds: Set = emptySet(), + /** Connection status per data source */ + val connectionStatuses: Map = emptyMap(), + /** Overall connection status (for backward compatibility) */ + val connectionStatus: ConnectionStatus = ConnectionStatus.Disconnected, + val filterQuery: String = "", + val isLoading: Boolean = false, + val error: String? = null, + val isConfigDialogOpen: Boolean = false, + val editingDataSource: DataSourceConfig? = null, + /** Whether the config pane is shown (inline panel mode) */ + val isConfigPaneOpen: Boolean = false, + /** The data source being configured in the pane */ + val configuringDataSource: DataSourceConfig? = null +) { + /** Get all selected data sources */ + val selectedDataSources: List + get() = dataSources.filter { it.id in selectedDataSourceIds } + + /** Check if a data source is selected */ + fun isSelected(id: String): Boolean = id in selectedDataSourceIds + + /** Get connection status for a specific data source */ + fun getConnectionStatus(id: String): ConnectionStatus = + connectionStatuses[id] ?: ConnectionStatus.Disconnected + + /** Check if any data source is connected */ + val hasAnyConnection: Boolean + get() = connectionStatuses.values.any { it is ConnectionStatus.Connected } + + /** Get count of connected data sources */ + val connectedCount: Int + get() = connectionStatuses.values.count { it is ConnectionStatus.Connected } + + /** Get count of selected data sources */ + val selectedCount: Int + get() = selectedDataSourceIds.size + + val filteredDataSources: List + get() = if (filterQuery.isBlank()) { + dataSources + } else { + dataSources.filter { ds -> + ds.name.contains(filterQuery, ignoreCase = true) || + ds.database.contains(filterQuery, ignoreCase = true) || + ds.host.contains(filterQuery, ignoreCase = true) + } + } +} + diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt index 1ffb24722d..7ebcb345e1 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuDesktop.kt @@ -286,6 +286,7 @@ private fun AgentTypeTab( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, tint = contentColor, diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt index 33b1a3d4f5..b9aa2700c6 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/chat/TopBarMenuMobile.kt @@ -224,6 +224,7 @@ fun TopBarMenuMobile( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, modifier = Modifier.size(20.dp) @@ -258,6 +259,7 @@ fun TopBarMenuMobile( AgentType.KNOWLEDGE -> AutoDevComposeIcons.Article AgentType.CODING -> AutoDevComposeIcons.Code AgentType.LOCAL_CHAT -> AutoDevComposeIcons.Chat + AgentType.CHAT_DB -> AutoDevComposeIcons.Database }, contentDescription = null, modifier = Modifier.size(20.dp) diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt index 12267bc6d3..3328c35d89 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/icons/AutoDevComposeIcons.kt @@ -113,6 +113,10 @@ object AutoDevComposeIcons { val Functions: ImageVector get() = Icons.Default.Functions val KeyboardArrowRight: ImageVector get() = Icons.Default.KeyboardArrowRight + // Database Icons + val Database: ImageVector get() = Icons.Default.Storage + val Schema: ImageVector get() = Icons.Default.TableChart + /** * Custom icons converted from SVG resources * These icons are converted from ai.svg and mcp.svg to Compose ImageVector format diff --git a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt index f512226610..b38e5b7876 100644 --- a/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt +++ b/mpp-ui/src/commonMain/kotlin/cc/unitmesh/devins/ui/compose/state/DesktopUiState.kt @@ -2,13 +2,19 @@ package cc.unitmesh.devins.ui.compose.state import androidx.compose.runtime.* import cc.unitmesh.agent.AgentType +import cc.unitmesh.config.AutoDevConfigWrapper import cc.unitmesh.devins.ui.state.UIStateManager +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch /** * Desktop UI State ViewModel * ็ฎก็†ๆกŒ้ข็ซฏ UI ็š„ๆ‰€ๆœ‰็Šถๆ€๏ผŒๅŒๆญฅๅ…จๅฑ€ UIStateManager */ class DesktopUiState { + private val scope = CoroutineScope(Dispatchers.Default) + // Agent Type var currentAgentType by mutableStateOf(AgentType.CODING) @@ -40,6 +46,15 @@ class DesktopUiState { // Actions fun updateAgentType(type: AgentType) { currentAgentType = type + // Save to config file for persistence + scope.launch { + try { + val typeString = type.getDisplayName() + AutoDevConfigWrapper.saveAgentTypePreference(typeString) + } catch (e: Exception) { + println("โš ๏ธ Failed to save agent type preference: ${e.message}") + } + } } fun toggleSessionSidebar() { diff --git a/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq b/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq new file mode 100644 index 0000000000..2f6f86ccc7 --- /dev/null +++ b/mpp-ui/src/commonMain/sqldelight/cc/unitmesh/devins/db/DataSource.sq @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS DataSource ( + id TEXT NOT NULL PRIMARY KEY, + name TEXT NOT NULL, + dialect TEXT NOT NULL, + host TEXT NOT NULL, + port INTEGER NOT NULL, + databaseName TEXT NOT NULL, + username TEXT NOT NULL, + password TEXT NOT NULL, + description TEXT NOT NULL, + isDefault INTEGER NOT NULL DEFAULT 0, + createdAt INTEGER NOT NULL, + updatedAt INTEGER NOT NULL +); + +selectById: +SELECT * FROM DataSource WHERE id = ?; + +selectAll: +SELECT * FROM DataSource ORDER BY updatedAt DESC; + +selectDefault: +SELECT * FROM DataSource WHERE isDefault = 1 LIMIT 1; + +insertOrReplace: +INSERT OR REPLACE INTO DataSource(id, name, dialect, host, port, databaseName, username, password, description, isDefault, createdAt, updatedAt) +VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + +deleteById: +DELETE FROM DataSource WHERE id = ?; + +deleteAll: +DELETE FROM DataSource; + +clearDefault: +UPDATE DataSource SET isDefault = 0 WHERE isDefault = 1; + +setDefault: +UPDATE DataSource SET isDefault = 1 WHERE id = ?; + diff --git a/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt b/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt new file mode 100644 index 0000000000..afd4df8fed --- /dev/null +++ b/mpp-ui/src/iosMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.ios.kt @@ -0,0 +1,83 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - iOS implementation + */ +actual class DataSourceRepository(private val database: DevInsDatabase) { + private val queries = database.dataSourceQueries + + actual fun getAll(): List { + return queries.selectAll().executeAsList().map { it.toDataSourceConfig() } + } + + actual fun getById(id: String): DataSourceConfig? { + return queries.selectById(id).executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun getDefault(): DataSourceConfig? { + return queries.selectDefault().executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun save(config: DataSourceConfig) { + queries.insertOrReplace( + id = config.id, + name = config.name, + dialect = config.dialect.name, + host = config.host, + port = config.port.toLong(), + databaseName = config.database, + username = config.username, + password = config.password, + description = config.description, + isDefault = if (config.isDefault) 1L else 0L, + createdAt = config.createdAt, + updatedAt = config.updatedAt + ) + } + + actual fun delete(id: String) { + queries.deleteById(id) + } + + actual fun deleteAll() { + queries.deleteAll() + } + + actual fun setDefault(id: String) { + queries.clearDefault() + queries.setDefault(id) + } + + private fun DataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port.toInt(), + database = this.databaseName, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault == 1L, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: run { + val driverFactory = DatabaseDriverFactory() + val database = createDatabase(driverFactory) + DataSourceRepository(database).also { instance = it } + } + } + } +} + diff --git a/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt b/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt new file mode 100644 index 0000000000..5f4ca56cbe --- /dev/null +++ b/mpp-ui/src/jsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.js.kt @@ -0,0 +1,47 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig + +/** + * DataSource Repository - JS implementation + * Currently provides stub implementation, can be extended with localStorage or IndexedDB + */ +actual class DataSourceRepository { + actual fun getAll(): List { + console.warn("DataSourceRepository not implemented for JS platform") + return emptyList() + } + + actual fun getById(id: String): DataSourceConfig? { + return null + } + + actual fun getDefault(): DataSourceConfig? { + return null + } + + actual fun save(config: DataSourceConfig) { + // No-op + } + + actual fun delete(id: String) { + // No-op + } + + actual fun deleteAll() { + // No-op + } + + actual fun setDefault(id: String) { + // No-op + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: DataSourceRepository().also { instance = it } + } + } +} + diff --git a/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt new file mode 100644 index 0000000000..b517ca6cd1 --- /dev/null +++ b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.jvm.kt @@ -0,0 +1,85 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect + +/** + * DataSource Repository - JVM implementation + */ +actual class DataSourceRepository(private val database: DevInsDatabase) { + private val queries = database.dataSourceQueries + + actual fun getAll(): List { + return queries.selectAll().executeAsList().map { it.toDataSourceConfig() } + } + + actual fun getById(id: String): DataSourceConfig? { + return queries.selectById(id).executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun getDefault(): DataSourceConfig? { + return queries.selectDefault().executeAsOneOrNull()?.toDataSourceConfig() + } + + actual fun save(config: DataSourceConfig) { + queries.insertOrReplace( + id = config.id, + name = config.name, + dialect = config.dialect.name, + host = config.host, + port = config.port.toLong(), + databaseName = config.database, + username = config.username, + password = config.password, + description = config.description, + isDefault = if (config.isDefault) 1L else 0L, + createdAt = config.createdAt, + updatedAt = config.updatedAt + ) + } + + actual fun delete(id: String) { + queries.deleteById(id) + } + + actual fun deleteAll() { + queries.deleteAll() + } + + actual fun setDefault(id: String) { + queries.clearDefault() + queries.setDefault(id) + } + + private fun DataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port.toInt(), + database = this.databaseName, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault == 1L, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: synchronized(this) { + instance ?: run { + val driverFactory = DatabaseDriverFactory() + val database = createDatabase(driverFactory) + DataSourceRepository(database).also { instance = it } + } + } + } + } +} + diff --git a/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt new file mode 100644 index 0000000000..d16927e65f --- /dev/null +++ b/mpp-ui/src/jvmMain/kotlin/cc/unitmesh/server/cli/ChatDBCli.kt @@ -0,0 +1,187 @@ +package cc.unitmesh.server.cli + +import cc.unitmesh.agent.chatdb.ChatDBAgent +import cc.unitmesh.agent.chatdb.ChatDBTask +import cc.unitmesh.agent.config.McpToolConfigService +import cc.unitmesh.agent.config.ToolConfigFile +import cc.unitmesh.agent.database.DatabaseConfig +import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem +import cc.unitmesh.llm.KoogLLMService +import cc.unitmesh.llm.LLMProviderType +import cc.unitmesh.llm.ModelConfig +import com.charleskorn.kaml.Yaml +import com.charleskorn.kaml.YamlConfiguration +import kotlinx.coroutines.runBlocking +import java.io.File + +/** + * JVM CLI for testing ChatDBAgent - Text2SQL Agent + * + * Usage: + * ```bash + * ./gradlew :mpp-ui:runChatDBCli \ + * -PdbHost=localhost \ + * -PdbPort=3306 \ + * -PdbName=testdb \ + * -PdbUser=root \ + * -PdbPassword=prisma \ + * -PdbQuery="Show me the top 10 customers by order amount" + * ``` + */ +object ChatDBCli { + + @JvmStatic + fun main(args: Array) { + println("=".repeat(80)) + println("AutoDev ChatDB Agent CLI (Text2SQL)") + println("=".repeat(80)) + + // Parse database connection arguments + val dbHost = System.getProperty("dbHost") ?: args.getOrNull(0) ?: "localhost" + val dbPort = System.getProperty("dbPort")?.toIntOrNull() ?: args.getOrNull(1)?.toIntOrNull() ?: 3306 + val dbName = System.getProperty("dbName") ?: args.getOrNull(2) ?: run { + System.err.println("Usage: -PdbName= -PdbQuery= [-PdbHost=localhost] [-PdbPort=3306] [-PdbUser=root] [-PdbPassword=]") + return + } + val dbUser = System.getProperty("dbUser") ?: args.getOrNull(3) ?: "root" + val dbPassword = System.getProperty("dbPassword") ?: args.getOrNull(4) ?: "" + val dbDialect = System.getProperty("dbDialect") ?: args.getOrNull(5) ?: "MariaDB" + + val query = System.getProperty("dbQuery") ?: args.getOrNull(6) ?: run { + System.err.println("Usage: -PdbName= -PdbQuery= [-PdbHost=localhost] [-PdbPort=3306] [-PdbUser=root] [-PdbPassword=]") + return + } + + val generateVisualization = System.getProperty("generateVisualization")?.toBoolean() ?: true + val maxRows = System.getProperty("maxRows")?.toIntOrNull() ?: 100 + + println("๐Ÿ—„๏ธ Database: $dbDialect://$dbHost:$dbPort/$dbName") + println("๐Ÿ‘ค User: $dbUser") + println("๐Ÿ“ Query: $query") + println("๐Ÿ“Š Generate Visualization: $generateVisualization") + println("๐Ÿ“ Max Rows: $maxRows") + println() + + runBlocking { + var agent: ChatDBAgent? = null + try { + val startTime = System.currentTimeMillis() + + // Load LLM configuration from ~/.autodev/config.yaml + val configFile = File(System.getProperty("user.home"), ".autodev/config.yaml") + if (!configFile.exists()) { + System.err.println("โŒ Configuration file not found: ${configFile.absolutePath}") + System.err.println(" Please create ~/.autodev/config.yaml with your LLM configuration") + return@runBlocking + } + + val yamlContent = configFile.readText() + val yaml = Yaml(configuration = YamlConfiguration(strictMode = false)) + val config = yaml.decodeFromString(AutoDevConfig.serializer(), yamlContent) + + val activeName = config.active + val activeConfig = config.configs.find { it.name == activeName } + + if (activeConfig == null) { + System.err.println("โŒ Active configuration '$activeName' not found in config.yaml") + System.err.println(" Available configs: ${config.configs.map { it.name }.joinToString(", ")}") + return@runBlocking + } + + println("๐Ÿ“ Using LLM config: ${activeConfig.name} (${activeConfig.provider}/${activeConfig.model})") + + // Convert provider string to LLMProviderType + val providerType = when (activeConfig.provider.lowercase()) { + "openai" -> LLMProviderType.OPENAI + "anthropic" -> LLMProviderType.ANTHROPIC + "google" -> LLMProviderType.GOOGLE + "deepseek" -> LLMProviderType.DEEPSEEK + "ollama" -> LLMProviderType.OLLAMA + "openrouter" -> LLMProviderType.OPENROUTER + "glm" -> LLMProviderType.GLM + "qwen" -> LLMProviderType.QWEN + "kimi" -> LLMProviderType.KIMI + else -> LLMProviderType.CUSTOM_OPENAI_BASE + } + + val llmService = KoogLLMService( + ModelConfig( + provider = providerType, + modelName = activeConfig.model, + apiKey = activeConfig.apiKey, + temperature = activeConfig.temperature ?: 0.7, + maxTokens = activeConfig.maxTokens ?: 4096, + baseUrl = activeConfig.baseUrl ?: "" + ) + ) + + // Create database configuration + val databaseConfig = DatabaseConfig( + host = dbHost, + port = dbPort, + databaseName = dbName, + username = dbUser, + password = dbPassword, + dialect = dbDialect + ) + + val renderer = CodingCliRenderer() + val mcpConfigService = McpToolConfigService(ToolConfigFile()) + val projectPath = System.getProperty("user.dir") + + println("๐Ÿง  Creating ChatDBAgent...") + agent = ChatDBAgent( + projectPath = projectPath, + llmService = llmService, + databaseConfig = databaseConfig, + maxIterations = 10, + renderer = renderer, + fileSystem = DefaultToolFileSystem(projectPath), + mcpToolConfigService = mcpConfigService, + enableLLMStreaming = true + ) + + println("โœ… Agent created") + println() + println("๐Ÿš€ Executing query...") + println() + + val task = ChatDBTask( + query = query, + maxRows = maxRows, + generateVisualization = generateVisualization + ) + + val result = agent.execute(task) { progress -> + println(" $progress") + } + + val totalTime = System.currentTimeMillis() - startTime + + println() + println("=".repeat(80)) + println("๐Ÿ“Š Result:") + println("=".repeat(80)) + println(result.content) + println() + + if (result.success) { + println("โœ… Query completed successfully") + } else { + println("โŒ Query failed") + } + println("โฑ๏ธ Total time: ${totalTime}ms") + println("๐Ÿ”„ Revision attempts: ${result.metadata["revisionAttempts"] ?: "0"}") + println("๐Ÿ“ˆ Has visualization: ${result.metadata["hasVisualization"] ?: "false"}") + + } catch (e: Exception) { + System.err.println("โŒ Error: ${e.message}") + e.printStackTrace() + } finally { + // Close database connection + agent?.close() + } + } + } +} + diff --git a/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt b/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt new file mode 100644 index 0000000000..2f007ebd38 --- /dev/null +++ b/mpp-ui/src/wasmJsMain/kotlin/cc/unitmesh/devins/db/DataSourceRepository.kt @@ -0,0 +1,188 @@ +package cc.unitmesh.devins.db + +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DataSourceConfig +import cc.unitmesh.devins.ui.compose.agent.chatdb.model.DatabaseDialect +import cc.unitmesh.devins.ui.platform.BrowserStorage +import cc.unitmesh.devins.ui.platform.console +import kotlinx.serialization.Serializable +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json + +/** + * DataSource Repository for WASM platform + * Uses browser localStorage to store data source configurations + */ +actual class DataSourceRepository { + private val json = Json { + prettyPrint = true + ignoreUnknownKeys = true + } + + @Serializable + private data class StoredDataSource( + val id: String, + val name: String, + val dialect: String, + val host: String, + val port: Int, + val database: String, + val username: String, + val password: String, + val description: String, + val isDefault: Boolean, + val createdAt: Long, + val updatedAt: Long + ) + + @Serializable + private data class DataSourceStorage( + val dataSources: List + ) + + actual fun getAll(): List { + return try { + val storage = loadStorage() + storage.dataSources.map { it.toDataSourceConfig() } + } catch (e: Exception) { + console.error("WASM: Error loading data sources: ${e.message}") + emptyList() + } + } + + actual fun getById(id: String): DataSourceConfig? { + return try { + val storage = loadStorage() + storage.dataSources.firstOrNull { it.id == id }?.toDataSourceConfig() + } catch (e: Exception) { + console.error("WASM: Error getting data source by id: ${e.message}") + null + } + } + + actual fun getDefault(): DataSourceConfig? { + return try { + val storage = loadStorage() + storage.dataSources.firstOrNull { it.isDefault }?.toDataSourceConfig() + } catch (e: Exception) { + console.error("WASM: Error getting default data source: ${e.message}") + null + } + } + + actual fun save(config: DataSourceConfig) { + try { + val storage = loadStorage() + val stored = config.toStoredDataSource() + val existing = storage.dataSources.indexOfFirst { it.id == config.id } + val updatedList = if (existing >= 0) { + storage.dataSources.toMutableList().apply { set(existing, stored) } + } else { + storage.dataSources + stored + } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Data source saved: ${config.id}") + } catch (e: Exception) { + console.error("WASM: Error saving data source: ${e.message}") + } + } + + actual fun delete(id: String) { + try { + val storage = loadStorage() + val updatedList = storage.dataSources.filter { it.id != id } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Data source deleted: $id") + } catch (e: Exception) { + console.error("WASM: Error deleting data source: ${e.message}") + } + } + + actual fun deleteAll() { + try { + saveStorage(DataSourceStorage(emptyList())) + console.log("WASM: All data sources deleted") + } catch (e: Exception) { + console.error("WASM: Error deleting all data sources: ${e.message}") + } + } + + actual fun setDefault(id: String) { + try { + val storage = loadStorage() + val now = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + val updatedList = storage.dataSources.map { + it.copy( + isDefault = it.id == id, + updatedAt = if (it.id == id) now else it.updatedAt + ) + } + saveStorage(DataSourceStorage(updatedList)) + console.log("WASM: Default data source set to: $id") + } catch (e: Exception) { + console.error("WASM: Error setting default data source: ${e.message}") + } + } + + private fun loadStorage(): DataSourceStorage { + val content = BrowserStorage.getItem(STORAGE_KEY) + return if (content != null) { + try { + json.decodeFromString(content) + } catch (e: Exception) { + console.warn("WASM: Failed to parse data source storage: ${e.message}") + DataSourceStorage(emptyList()) + } + } else { + DataSourceStorage(emptyList()) + } + } + + private fun saveStorage(storage: DataSourceStorage) { + val content = json.encodeToString(storage) + BrowserStorage.setItem(STORAGE_KEY, content) + } + + private fun StoredDataSource.toDataSourceConfig(): DataSourceConfig { + return DataSourceConfig( + id = this.id, + name = this.name, + dialect = DatabaseDialect.fromString(this.dialect), + host = this.host, + port = this.port, + database = this.database, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + private fun DataSourceConfig.toStoredDataSource(): StoredDataSource { + return StoredDataSource( + id = this.id, + name = this.name, + dialect = this.dialect.name, + host = this.host, + port = this.port, + database = this.database, + username = this.username, + password = this.password, + description = this.description, + isDefault = this.isDefault, + createdAt = this.createdAt, + updatedAt = this.updatedAt + ) + } + + actual companion object { + private const val STORAGE_KEY = "autodev-datasources" + private var instance: DataSourceRepository? = null + + actual fun getInstance(): DataSourceRepository { + return instance ?: DataSourceRepository().also { instance = it } + } + } +} +