Skip to content

Commit 779e03a

Browse files
authored
Merge pull request #509 from phodal/feature/chatdb-agent
feat(chatdb): add ChatDB agent for text-to-SQL interactions
2 parents f7ea233 + 55a5713 commit 779e03a

File tree

80 files changed

+11060
-487
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+11060
-487
lines changed

mpp-core/build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ kotlin {
212212

213213
// JSQLParser for SQL validation and parsing
214214
implementation("com.github.jsqlparser:jsqlparser:4.9")
215+
216+
// MyNLP for Chinese NLP tokenization
217+
implementation("com.mayabot.mynlp:mynlp-all:4.0.0")
215218
}
216219
}
217220

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package cc.unitmesh.agent.chatdb
2+
3+
/**
4+
* Android implementation of NlpTokenizer.
5+
* Uses the fallback regex-based tokenization since MyNLP is JVM-only
6+
* and may have compatibility issues on Android.
7+
*
8+
* TODO: Consider using Android's BreakIterator or a lightweight NLP library for better tokenization.
9+
*/
10+
actual object NlpTokenizer {
11+
/**
12+
* Extract keywords from natural language query using simple tokenization.
13+
* Supports both English and Chinese text.
14+
*
15+
* @param query The natural language query to tokenize
16+
* @param stopWords Set of words to filter out from results
17+
* @return List of extracted keywords
18+
*/
19+
actual fun extractKeywords(query: String, stopWords: Set<String>): List<String> {
20+
return FallbackNlpTokenizer.extractKeywords(query, stopWords)
21+
}
22+
}
23+
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package cc.unitmesh.agent.subagent
2+
3+
import net.sf.jsqlparser.parser.CCJSqlParserUtil
4+
import net.sf.jsqlparser.statement.Statement
5+
import net.sf.jsqlparser.util.TablesNamesFinder
6+
7+
/**
8+
* Android implementation of SqlValidator using JSqlParser.
9+
*
10+
* This validator uses JSqlParser to validate SQL syntax.
11+
* It can detect:
12+
* - Syntax errors
13+
* - Malformed SQL statements
14+
* - Unsupported SQL constructs
15+
* - Table names not in whitelist (schema validation)
16+
*/
17+
actual class SqlValidator actual constructor() : SqlValidatorInterface {
18+
19+
actual override fun validate(sql: String): SqlValidationResult {
20+
return try {
21+
val statement: Statement = CCJSqlParserUtil.parse(sql)
22+
SqlValidationResult(
23+
isValid = true,
24+
errors = emptyList(),
25+
warnings = collectWarnings(statement)
26+
)
27+
} catch (e: Exception) {
28+
SqlValidationResult(
29+
isValid = false,
30+
errors = listOf(extractErrorMessage(e)),
31+
warnings = emptyList()
32+
)
33+
}
34+
}
35+
36+
actual override fun validateWithTableWhitelist(sql: String, allowedTables: Set<String>): SqlValidationResult {
37+
return try {
38+
val statement: Statement = CCJSqlParserUtil.parse(sql)
39+
40+
// Extract table names from the SQL
41+
val tablesNamesFinder = TablesNamesFinder()
42+
val usedTables = tablesNamesFinder.getTableList(statement)
43+
44+
// Check if all used tables are in the whitelist (case-insensitive)
45+
val allowedTablesLower = allowedTables.map { it.lowercase() }.toSet()
46+
val invalidTables = usedTables.filter { tableName ->
47+
tableName.lowercase() !in allowedTablesLower
48+
}
49+
50+
if (invalidTables.isNotEmpty()) {
51+
SqlValidationResult(
52+
isValid = false,
53+
errors = listOf(
54+
"Invalid table(s) used: ${invalidTables.joinToString(", ")}. " +
55+
"Available tables: ${allowedTables.joinToString(", ")}"
56+
),
57+
warnings = collectWarnings(statement)
58+
)
59+
} else {
60+
SqlValidationResult(
61+
isValid = true,
62+
errors = emptyList(),
63+
warnings = collectWarnings(statement)
64+
)
65+
}
66+
} catch (e: Exception) {
67+
SqlValidationResult(
68+
isValid = false,
69+
errors = listOf(extractErrorMessage(e)),
70+
warnings = emptyList()
71+
)
72+
}
73+
}
74+
75+
actual override fun extractTableNames(sql: String): List<String> {
76+
return try {
77+
val statement: Statement = CCJSqlParserUtil.parse(sql)
78+
val tablesNamesFinder = TablesNamesFinder()
79+
tablesNamesFinder.getTableList(statement)
80+
} catch (e: Exception) {
81+
emptyList()
82+
}
83+
}
84+
85+
private fun extractErrorMessage(e: Exception): String {
86+
val message = e.message ?: "Unknown SQL parsing error"
87+
return when {
88+
message.contains("Encountered") -> {
89+
val match = Regex("Encountered \"(.+?)\" at line (\\d+), column (\\d+)").find(message)
90+
if (match != null) {
91+
val (token, line, column) = match.destructured
92+
"Syntax error at line $line, column $column: unexpected token '$token'"
93+
} else {
94+
message
95+
}
96+
}
97+
message.contains("Was expecting") -> {
98+
val match = Regex("Was expecting.*?:\\s*(.+)").find(message)
99+
if (match != null) {
100+
"Expected: ${match.groupValues[1].take(100)}"
101+
} else {
102+
message
103+
}
104+
}
105+
else -> message.take(200)
106+
}
107+
}
108+
109+
private fun collectWarnings(statement: Statement): List<String> {
110+
val warnings = mutableListOf<String>()
111+
val sql = statement.toString()
112+
113+
if (sql.contains("SELECT *")) {
114+
warnings.add("Consider specifying explicit columns instead of SELECT *")
115+
}
116+
117+
if (!sql.contains("WHERE", ignoreCase = true) &&
118+
(sql.contains("UPDATE", ignoreCase = true) || sql.contains("DELETE", ignoreCase = true))) {
119+
warnings.add("UPDATE/DELETE without WHERE clause will affect all rows")
120+
}
121+
122+
return warnings
123+
}
124+
}
125+

mpp-core/src/commonMain/kotlin/cc/unitmesh/agent/AgentType.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ package cc.unitmesh.agent
77
* - LOCAL: Simple local chat mode without heavy tooling
88
* - CODING: Local coding agent with full tool access (file system, shell, etc.)
99
* - CODE_REVIEW: Dedicated code review agent with git integration
10+
* - KNOWLEDGE: Document reader mode for AI-native document reading
11+
* - CHAT_DB: Database chat mode for text-to-SQL interactions
1012
* - REMOTE: Remote agent connected to mpp-server
1113
*/
1214
enum class AgentType {
@@ -30,6 +32,11 @@ enum class AgentType {
3032
*/
3133
KNOWLEDGE,
3234

35+
/**
36+
* Database chat mode - text-to-SQL agent for database queries
37+
*/
38+
CHAT_DB,
39+
3340
/**
3441
* Remote agent mode - connects to remote mpp-server for distributed execution
3542
*/
@@ -40,6 +47,7 @@ enum class AgentType {
4047
CODING -> "Agentic"
4148
CODE_REVIEW -> "Review"
4249
KNOWLEDGE -> "Knowledge"
50+
CHAT_DB -> "ChatDB"
4351
REMOTE -> "Remote"
4452
}
4553

@@ -51,6 +59,7 @@ enum class AgentType {
5159
"coding" -> CODING
5260
"codereview" -> CODE_REVIEW
5361
"documentreader", "documents" -> KNOWLEDGE
62+
"chatdb", "database" -> CHAT_DB
5463
else -> LOCAL_CHAT
5564
}
5665
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package cc.unitmesh.agent.chatdb
2+
3+
import cc.unitmesh.agent.config.McpToolConfigService
4+
import cc.unitmesh.agent.core.MainAgent
5+
import cc.unitmesh.agent.database.DatabaseConfig
6+
import cc.unitmesh.agent.database.DatabaseConnection
7+
import cc.unitmesh.agent.database.createDatabaseConnection
8+
import cc.unitmesh.agent.logging.getLogger
9+
import cc.unitmesh.agent.model.AgentDefinition
10+
import cc.unitmesh.agent.model.PromptConfig
11+
import cc.unitmesh.agent.model.RunConfig
12+
import cc.unitmesh.agent.orchestrator.ToolOrchestrator
13+
import cc.unitmesh.agent.policy.DefaultPolicyEngine
14+
import cc.unitmesh.agent.render.CodingAgentRenderer
15+
import cc.unitmesh.agent.render.DefaultCodingAgentRenderer
16+
import cc.unitmesh.agent.tool.shell.DefaultShellExecutor
17+
import cc.unitmesh.agent.tool.shell.ShellExecutor
18+
import cc.unitmesh.agent.tool.ToolResult
19+
import cc.unitmesh.agent.tool.filesystem.DefaultToolFileSystem
20+
import cc.unitmesh.agent.tool.filesystem.ToolFileSystem
21+
import cc.unitmesh.agent.tool.registry.ToolRegistry
22+
import cc.unitmesh.llm.KoogLLMService
23+
import cc.unitmesh.llm.ModelConfig
24+
25+
/**
26+
* ChatDB Agent - Text2SQL Agent for natural language database queries
27+
*
28+
* This agent converts natural language queries to SQL, executes them,
29+
* and optionally generates visualizations of the results.
30+
*
31+
* Features:
32+
* - Schema Linking: Keyword-based search to find relevant tables/columns
33+
* - SQL Generation: LLM generates SQL from natural language
34+
* - Revise Agent: Self-correction loop using JSqlParser for SQL validation
35+
* - Query Execution: Execute validated SQL and return results
36+
* - Visualization: Optional PlotDSL generation for data visualization
37+
*
38+
* Based on GitHub Issue #508: https://github.com/phodal/auto-dev/issues/508
39+
*/
40+
class ChatDBAgent(
41+
private val projectPath: String,
42+
private val llmService: KoogLLMService,
43+
private val databaseConfig: DatabaseConfig,
44+
override val maxIterations: Int = 10,
45+
private val renderer: CodingAgentRenderer = DefaultCodingAgentRenderer(),
46+
private val fileSystem: ToolFileSystem? = null,
47+
private val shellExecutor: ShellExecutor? = null,
48+
private val mcpToolConfigService: McpToolConfigService,
49+
private val enableLLMStreaming: Boolean = true
50+
) : MainAgent<ChatDBTask, ToolResult.AgentResult>(
51+
AgentDefinition(
52+
name = "ChatDBAgent",
53+
displayName = "ChatDB Agent",
54+
description = "Text2SQL Agent that converts natural language to SQL queries with schema linking and self-correction",
55+
promptConfig = PromptConfig(
56+
systemPrompt = SYSTEM_PROMPT
57+
),
58+
modelConfig = ModelConfig.default(),
59+
runConfig = RunConfig(maxTurns = 10, maxTimeMinutes = 5)
60+
)
61+
) {
62+
private val logger = getLogger("ChatDBAgent")
63+
64+
private val actualFileSystem = fileSystem ?: DefaultToolFileSystem(projectPath = projectPath)
65+
66+
private val toolRegistry = ToolRegistry(
67+
fileSystem = actualFileSystem,
68+
shellExecutor = shellExecutor ?: DefaultShellExecutor(),
69+
configService = mcpToolConfigService,
70+
llmService = llmService
71+
)
72+
73+
private val policyEngine = DefaultPolicyEngine()
74+
75+
private val toolOrchestrator = ToolOrchestrator(
76+
registry = toolRegistry,
77+
policyEngine = policyEngine,
78+
renderer = renderer,
79+
mcpConfigService = mcpToolConfigService
80+
)
81+
82+
private var databaseConnection: DatabaseConnection? = null
83+
84+
private val executor: ChatDBAgentExecutor by lazy {
85+
val connection = databaseConnection ?: createDatabaseConnection(databaseConfig)
86+
databaseConnection = connection
87+
88+
ChatDBAgentExecutor(
89+
projectPath = projectPath,
90+
llmService = llmService,
91+
toolOrchestrator = toolOrchestrator,
92+
renderer = renderer,
93+
databaseConnection = connection,
94+
maxIterations = maxIterations,
95+
enableLLMStreaming = enableLLMStreaming
96+
)
97+
}
98+
99+
override fun validateInput(input: Map<String, Any>): ChatDBTask {
100+
val query = input["query"] as? String
101+
?: throw IllegalArgumentException("Missing required parameter: query")
102+
103+
return ChatDBTask(
104+
query = query,
105+
additionalContext = input["additionalContext"] as? String ?: "",
106+
maxRows = (input["maxRows"] as? Number)?.toInt() ?: 100,
107+
generateVisualization = input["generateVisualization"] as? Boolean ?: true
108+
)
109+
}
110+
111+
override suspend fun execute(
112+
input: ChatDBTask,
113+
onProgress: (String) -> Unit
114+
): ToolResult.AgentResult {
115+
logger.info { "Starting ChatDB Agent for query: ${input.query}" }
116+
117+
val systemPrompt = buildSystemPrompt()
118+
val result = executor.execute(input, systemPrompt, onProgress)
119+
120+
return ToolResult.AgentResult(
121+
success = result.success,
122+
content = result.message,
123+
metadata = mapOf(
124+
"generatedSql" to (result.generatedSql ?: ""),
125+
"rowCount" to (result.queryResult?.rowCount?.toString() ?: "0"),
126+
"revisionAttempts" to result.revisionAttempts.toString(),
127+
"hasVisualization" to (result.plotDslCode != null).toString()
128+
)
129+
)
130+
}
131+
132+
private fun buildSystemPrompt(): String {
133+
return SYSTEM_PROMPT
134+
}
135+
136+
override fun formatOutput(output: ToolResult.AgentResult): String {
137+
return output.content
138+
}
139+
140+
override fun getParameterClass(): String = "ChatDBTask"
141+
142+
/**
143+
* Close database connection when done
144+
*/
145+
suspend fun close() {
146+
databaseConnection?.close()
147+
databaseConnection = null
148+
}
149+
150+
companion object {
151+
const val SYSTEM_PROMPT = """You are an expert SQL developer. Generate SQL queries from natural language.
152+
153+
CRITICAL RULES - YOU MUST FOLLOW THESE:
154+
1. ONLY use table names provided in the schema - NEVER invent or guess table names
155+
2. ONLY use column names provided in the schema - NEVER invent or guess column names
156+
3. If a table or column doesn't exist in the schema, DO NOT use it
157+
4. Only generate SELECT queries (read-only operations)
158+
5. Always add LIMIT clause to prevent large result sets
159+
160+
OUTPUT FORMAT:
161+
- Return ONLY the SQL query wrapped in ```sql code block
162+
- Do NOT include explanations, alternatives, or reasoning
163+
- Do NOT add comments outside the code block
164+
- Keep response concise - just the SQL
165+
166+
Example response:
167+
```sql
168+
SELECT id, name FROM users WHERE status = 'active' LIMIT 100;
169+
```"""
170+
}
171+
}
172+

0 commit comments

Comments
 (0)