diff --git a/.gitignore b/.gitignore
index dca4913..3da298b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,4 @@ sql_parse/clauses/__pycache__
sql_parse/clauses
sql_parse/test.py
sql_command/__pycache__/DB.cpython-310.pyc
+pyrightconfig.json
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..717e9bd
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..58e8fc9
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/sql_minibuilder.iml b/.idea/sql_minibuilder.iml
new file mode 100644
index 0000000..8e5446a
--- /dev/null
+++ b/.idea/sql_minibuilder.iml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index c6e1efe..47d347c 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-~~目前只需要改commands/DB.py就行了~~
+~~目前只需要改command/DB.py就行了~~
commands中的执行功能都差不多写完了,tokenize已完成,AST写了一点
@@ -43,27 +43,38 @@ WHERE id = 1 AND this < 2.3;
完成情况:
-- [x] 对SELECT的实现
+- [x] 对SELECT的解析
-- [X] 对FROM的实现
+- [X] 对FROM的解析
-- [x] 对WHERE的基础实现
+- [x] 对WHERE的基础解析
-- [ ] 对WHERE中AND、OR的正确顺序判断
+- [x] 对WHERE中AND、OR的正确顺序判断
-- [ ] 对CREATE的实现
+- [x] 对SET的解析
-- [ ] 对主键`PRIMARY`的实现
+- [x] 对UPDATE的解析
-- [ ] 对非空`NOT NULL`的实现
+- [x] 对DELETE的解析
-- [x] 对UPDATE的实现
+##### 星期五添加
-- [x] 对DELETE的实现
+- [x] 对SELECT时`WILDCARD`的解析(就是选中所有列,详见输入输出5)
+- [x] 对CREATE时主键`PRIMARY`的解析
+- [x] 对CREATE时非空`NOT NULL`的解析
-**目前已部分完成,只到能够解析查询命令(SELECT)、更新命令(UPDATE)与删除命令(DELETE)的地方**
+- [ ] 对CREATE各种约束的解析(摆,不想做,因为外键有点麻烦)
+
+- [x] 对SET赋值式右边的Binary Expression的解析(抱歉的是SET时的expression结构也换了,换成assignment了,详见输出2)
+
+- [x] 对CREATE的解析
+
+- [x] 对INSERT INTO的解析
+
+
+**目前已部分完成,只到能够解析查询命令(SELECT)、更新命令(UPDATE)、删除命令(DELETE)与基础创建命令(CREATE)(不包含NOT NULL)的地方**
输入1:
```
@@ -133,36 +144,27 @@ WHERE level:AST_KEYWORDS.CLAUSE
输入2:
```
UPDATE table1
-SET alexa = 50000, country='USA', salary = 14.5
+SET alexa = 50000, country='USA', salary = salary * 14.5
WHERE id = 1 AND this < 2.3 OR name>1;
```
**实际输出2**:
```
-UPDATE level:AST_KEYWORDS.CLAUSE
-
+UPDATE level:AST_KEYWORDS.CLAUSE
-- table1
SET level:AST_KEYWORDS.CLAUSE
-
--SET level:AST_KEYWORDS.EXPRESSION
-
----- {'left': 'alexa', 'op': '=', 'right': 50000}
+---- {'assignment': 'alexa', 'expression': {'left': 50000}}
--SET level:AST_KEYWORDS.EXPRESSION
-
----- {'left': 'country', 'op': '=', 'right': 'USA'}
+---- {'assignment': 'country', 'expression': {'left': 'USA'}}
--SET level:AST_KEYWORDS.EXPRESSION
-
----- {'left': 'salary', 'op': '=', 'right': 14.5}
+---- {'assignment': 'salary', 'expression': {'left': 'salary', 'op': '*', 'right': 14.5}}
WHERE level:AST_KEYWORDS.CLAUSE
-
--WHERE level:AST_KEYWORDS.EXPRESSION
-
---- {'left': 'id', 'op': '=', 'right': 1}
--AND level:AST_KEYWORDS.EXPRESSION
-
---- {'left': 'this', 'op': '<', 'right': 2.3}
--OR level:AST_KEYWORDS.EXPRESSION
-
---- {'left': 'name', 'op': '>', 'right': 1}
```
@@ -189,6 +191,106 @@ WHERE level:AST_KEYWORDS.
---- {'left': 'name', 'op': '>', 'right': 1}
```
+输入4:
+```
+CREATE TABLE Persons
+(
+ PersonID SERIAL int,
+ LastName PRIMARY varchar(255),
+ FirstName char(255) NOT NULL,
+ Address float,
+ City varchar(255)
+);
+```
+
+
+**实际输出4**:
+```
+CREATE level:AST_KEYWORDS.CLAUSE
+-- Persons
+COLUMNS level:AST_KEYWORDS.CLAUSE
+--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION
+---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'PERSONID', 'type': 'int'}
+--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION
+---- {'PRIMARY': True, 'NOT NULL': False, 'name': 'LASTNAME', 'type': 'varchar', 'length': 255}
+--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION
+---- {'PRIMARY': False, 'NOT NULL': True, 'name': 'FIRSTNAME', 'type': 'char', 'length': 255}
+--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION
+---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'ADDRESS', 'type': 'float'}
+--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION
+---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'CITY', 'type': 'varchar', 'length': 255}
+```
+
+输入5:(与输入1相比,SELECT所有列):
+```
+SELECT *
+FROM table1, table2
+WHERE id = 1 AND "this" < 2.3;
+```
+
+**实际输出5**:
+```
+SELECT level:AST_KEYWORDS.CLAUSE
+-- Token.Wildcard
+FROM level:AST_KEYWORDS.CLAUSE
+-- table1
+-- table2
+WHERE level:AST_KEYWORDS.CLAUSE
+--WHERE level:AST_KEYWORDS.EXPRESSION
+---- {'left': 'id', 'op': '=', 'right': 1}
+--AND level:AST_KEYWORDS.EXPRESSION
+---- {'left': 'this', 'op': '<', 'right': 2.3}
+```
+
+输入6:
+```
+INSERT INTO table1 (id, name, this)
+VALUES (1, 'alex', 2.3);
+```
+
+**实际输出6**:
+```
+INSERT level:AST_KEYWORDS.CLAUSE
+-- table1
+COLUMNS level:AST_KEYWORDS.CLAUSE
+-- id
+-- name
+-- this
+VALUES level:AST_KEYWORDS.CLAUSE
+-- 1
+-- alex
+-- 2.3
+```
+
+输入7:(与输入6的区别在于,INSERT INTO未指定表名,同时是多列)
+```
+INSERT INTO table1
+VALUES(
+ (1, 'alex', 2.3),
+ (2, 'bob', 2.4),
+ (5, 'jjj', 2.5)
+);
+```
+
+**实际输出7**
+```
+INSERT level:AST_KEYWORDS.CLAUSE
+-- table1
+COLUMNS level:AST_KEYWORDS.CLAUSE
+VALUES level:AST_KEYWORDS.CLAUSE
+-- 1
+-- alex
+-- 2.3
+VALUES level:AST_KEYWORDS.CLAUSE
+-- 2
+-- bob
+-- 2.4
+VALUES level:AST_KEYWORDS.CLAUSE
+-- 5
+-- jjj
+-- 2.5
+```
+
**用语解释**:
@@ -201,19 +303,39 @@ WHERE level:AST_KEYWORDS.
-#### 1.3 在AST基础上,实现对`sql_commands`的调用,完成语句
+#### 1.3 在AST基础上,实现对`sql_command`的调用,完成语句
+
+文件:`sql_control/main.py`(示例用,之后可能会改)
+
+- [x] 完成示例文件,演示如何结合AST与实现的`sql_command`,通过FROM语句获得databse中的表
+
+- [x] 完成AST与执行代码关于WHERE的结合
+
+- [x] 完成AST与执行代码关于SELECT的结合
+
+- [x] 完成AST与执行代码关于DELETE的结合
+
+- [x] 完成AST与执行代码关于UPDATE的结合
+
+##### 星期五添加的
+
+- [ ] 一个小问题:所以什么时候在硬盘写入文件,读入文件?
+
+- [x] 完成AST与执行代码关于SELECT的WILDCARD设置(也就是选中所有列,详见输入5,输出5)
+
+- [x] 完成AST与执行代码关于SET的右侧赋值式结合
+
+- [x] 完成AST与执行代码关于CREATE的基础结合,限制每一列的类型
-文件:`sql_commands/main.py`(示例用,之后可能会改)
+- [x] 完成AST与执行代码关于CREATE的主键设置,NOT NULL设置
-- [x] 完成示例文件,演示如何结合AST与实现的`sql_commands`,通过FROM语句获得databse中的表
+- [ ] 完成AST与执行代码关于CREATE中类型有SERIAL时,在INSERT时主键的自动更新并插入
-- [ ] 完成AST与执行代码关于WHERE的结合
-- [ ] 完成AST与执行代码关于SELECT的结合
-- [ ] 完成AST与执行代码关于DELETE的结合
+- [ ] 完成AST与执行代码关于INSERT的基础结合
-- [ ] 完成AST与执行代码关于UPDATE的结合
+- [ ] 完成AST与执行代码关于INSERT在表名未指定列时的正确处理(详见输入输出7)
- [ ] 更多……
diff --git a/sql_command/DB.py b/sql_command/DB.py
index 0cc1b42..4b000b0 100644
--- a/sql_command/DB.py
+++ b/sql_command/DB.py
@@ -2,16 +2,23 @@
class DB:
def __init__(self):
- self.dbtypes=dict[dict['tablename':str,'tabledata':pd.DataFrame,'datatypes':dict]]
+ self.dbtypes=dict[dict['tablename':str,'tabledata':pd.DataFrame,'datatypes':dict,'not_null_flag':dict,'primary_key':str,'attri_len':dict]]
self.database = {}
#访问某个表用 database['表名']
#访问某个表中的数据用 database['表名']['tabledata'] 这是一个pd.DataFrame型的数据
#访问某个表中某个属性的数据类型用 database['表名']['datatypes']['属性名']
+ #查询某个表中某个属性是否是not null :database['表名']['not_null_flag']['属性名']
+ #访问某个表中某个(这里仅考虑char varchar类型)属性设定的长度 :database['表名']['attri_len']['属性名']
+ #primary: database['表名']['primary_key']['属性名']
def create(self,
table: str,
attributes: list[str],
- types: list[str]) -> bool:
+ types: list[str],
+ not_null: list[bool],
+ primary_key: list[str],
+ char_attri: list[str],
+ char_attri_len: list[int]) -> bool:
#检查表是否已经存在
if table in self.database:
print(f"Table '{table}' already exists.")
@@ -21,7 +28,10 @@ def create(self,
newtable = {
'tablename': table,
'tabledata': pd.DataFrame(columns=attributes),
- 'datatypes': {attr: data_type for attr, data_type in zip(attributes, types)}
+ 'datatypes': {attr: data_type for attr, data_type in zip(attributes, types)},
+ 'not_null_flag': {attr: not_null for attr, not_null in zip(attributes, not_null)},
+ 'primary_key': primary_key,
+ 'attri_len': {ch: chlen for ch, chlen in zip(char_attri,char_attri_len)}
}
#数据库添加新表
self.database[table] = newtable
diff --git a/sql_control/main.py b/sql_control/main.py
index 6fea041..c4c699d 100644
--- a/sql_control/main.py
+++ b/sql_control/main.py
@@ -1,5 +1,13 @@
from sql_parse.AST_builder import AST
from sql_command.DB import DB
+import os
+import pandas as pd
+
+from sql_parse.tokens import Token
+
+# data_folder = '../data'
+# file_names = os.listdir(data_folder)
+
class blabla:
def __init__(self):
@@ -7,8 +15,14 @@ def __init__(self):
示例用
"""
self.db = DB()
- self.db.create("test",["索引","姓名"], [int,str])
-
+ # 外部数据读取 test
+ # for file_name in file_names:
+ # if file_name.endswith('.csv'):
+ # var_name = os.path.splitext(file_name)[0]
+ # file_path = data_folder + '/' + file_name
+ # data_table = pd.read_csv(file_path)
+ # self.dict[var_name] = data_table
+
def ast_clear(self):
"""
清除上一个语句留下的AST
@@ -22,13 +36,13 @@ def extract(self,text:str):
获取不同clause(子句)的内容
"""
# 根据text,生成AST(这里AST只实现了SELECT查询)
- self.ast = AST(sql).content
+ self.ast = AST(text).content
# 将AST中的不同clause保存过来
self.clause = {}
for clause in self.ast.content:
self.clause[clause.value] = clause
- def execute(self,text:str):
+ def execute(self, text: str):
"""
执行ast中的语句
"""
@@ -36,27 +50,304 @@ def execute(self,text:str):
self.ast_clear()
# 获取当前语句的AST
self.extract(text)
- # 示例用,演示怎么从AST解析后的树中获取table
- tables = self.get_tables()
- print(tables)
-
- def get_tables(self):
- #从FROM clause中获得content
- table_names = self.clause["FROM"].content
- ret = []
- for table_name_i in table_names:
- # 调用sql_commands中实现的方法,访问得到每个table的DataFrame
- table = self.db.database[table_name_i]['tabledata']
- ret.append(table)
- return ret
-
+
+ function = self.funct()
+
+ # 查询
+ if function == 'SELECT':
+ tables = self.get_tables(function)
+ cols = self.get_col()
+ for table in tables:
+ cur_table = self.db.database[table]
+ rows = self.get_row(table) if "WHERE" in self.clause.keys() else None
+ result = self.db.select(cur_table['tabledata'], cols, rows) if Token.Wildcard not in cols else self.db.select(cur_table['tabledata'], cur_table['tabledata'].columns, rows)
+ print(result)
+
+ # 更新
+ elif function == 'UPDATE':
+ tables = self.get_tables(function)
+ for table in tables:
+ cur_table = self.db.database[table]
+ rows = self.get_row(table) if "WHERE" in self.clause.keys() else None
+ for up in self.clause["SET"].content:
+ v = up.content[0]['expression']
+ if len(v) == 1:
+ if isinstance(v['left'], int) or isinstance(v['left'], float) or (isinstance(v['left'], str) and f"'{v['left']}'" in text):
+ value = v['left']
+ else:
+ value = self.db.select(cur_table['tabledata'], v['left'], rows).values.tolist()
+ else:
+ value = self.get_val(table, up.content[0]['expression'], rows)
+ self.db.update(cur_table['tabledata'], update_rows=rows, attributes=[up.content[0]['assignment']], values=value)
+
+ # 添加
+ elif function == 'INSERT':
+ tables = self.get_tables(function)
+ for table in tables:
+ cols = self.clause["COLUMNS"].content
+ values_clauses = self.ast.content[2:]
+ cur_table = self.db.database[table]
+ values_for_insert = []
+ for values_clause_i in values_clauses:
+ values = values_clause_i.content
+ self.check_types(cur_table['datatypes'], cols, values)
+ self.check_null(cur_table['not_null_flag'], cols, values)
+ self.check_primary(cur_table['tabledata'],cur_table['primary_key'],cols,values)
+ values_for_insert.append(values)
+ _, cur_table['tabledata'] = self.db.insert(cur_table['tabledata'], cols, values_for_insert)
+
+ # 删除
+ elif function == 'DELETE':
+ tables = self.get_tables(function)
+ for table in tables:
+ cur_table = self.db.database[table]
+ rows = self.get_row(table) if "WHERE" in self.clause.keys() else None
+ self.db.delete(cur_table['tabledata'], del_rows=rows)
+
+ # 创建新表
+ elif function == 'CREATE':
+ self.excute_create_statement()
+
+ # 删除表
+ elif function == 'DROP':
+ print("DROP")
+
+ else:
+ print("ERROR FUNCTION")
+
+ if function != "SELECT":
+ self.save_tables()
+
+ def check_types(self, requires, cols, values):
+ for i in range(len(cols)):
+ if requires[cols[i]].upper() in ["VARCHAR", "CHAR"]:
+ if not isinstance(values[i], str):
+ raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.")
+ if requires[cols[i]].upper() in ["INT"]:
+ if not isinstance(values[i], int):
+ raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.")
+ if requires[cols[i]].upper() in ["FLOAT"]:
+ if not isinstance(values[i],float):
+ pass
+ raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.")
+
+ def check_null(self, requires, cols, values):
+ for k in requires:
+ v = requires[k]
+ if v is True:
+ if k not in cols:
+ raise Exception(f"{k} is not null, but null is given.")
+ else:
+ if values[cols.index(k)] is None:
+ raise Exception(f"{k} is not null, but null is given.")
+
+ def check_primary(self, table: pd.DataFrame, primary_keys, cols, values):
+ for k in primary_keys:
+ if k in cols:
+ if values[cols.index(k)] in table[k].values:
+ raise Exception(f"{k} is primary key, but {values[cols.index(k)]} is already in table.")
+ else:
+ raise Exception(f"{k} is primary key, but not given in insertion values.")
+
+
+ def save_tables(self):
+ # 待实现
+ pass
+
+ # 判断功能
+ def funct(self):
+ return self.ast.content[0].value
+
+ # SELECT功能的筛选列
+ def get_col(self):
+ num_col = len(self.clause["SELECT"].content)
+ cols = []
+ for i in range(0, num_col):
+ col = self.clause["SELECT"].content[i]
+ cols.append(col)
+ return cols
+
+ # WHERE语句的筛选行和对AND、OR运算结果
+ def get_row(self, table):
+ cur_table = self.db.database[table]
+ num_condition = len(self.clause["WHERE"].content)
+ rows_ = []
+ for i in range(0, num_condition):
+ condition = self.clause["WHERE"].content[i].content[0]
+ row = self.db.where(cur_table['tabledata'], condition['left'], condition['right'], condition['op'])
+ rows_.append(set(row))
+ operators = [expression_i.value for expression_i in self.clause["WHERE"].content[1:]]
+ rows = rows_[0]
+ i = 1
+ while i < len(rows_):
+ if operators[i - 1] == "AND":
+ rows = rows.intersection(rows_[i])
+ elif operators[i - 1] == "OR":
+ temp_result = rows_[i]
+ j = i + 1
+ while j < len(rows_) and operators[j - 1] == "AND":
+ temp_result = temp_result.intersection(rows_[j])
+ j += 1
+ rows = rows.union(temp_result)
+ i = j - 1
+ i += 1
+ return list(rows)
+
+ # SET 右边为表达式的情况
+ def get_val(self, table, expression, rows):
+ cur_table = self.db.database[table]
+ result = self.db.select(cur_table['tabledata'], expression['left'], rows)
+ if expression['op'] == '+':
+ result += expression['right']
+ elif expression['op'] == '-':
+ result -= expression['right']
+ elif expression['op'] == '*':
+ result *= expression['right']
+ elif expression['op'] == '/':
+ result /= expression['right']
+ result = result.values.tolist()
+ return result
+
+ def get_tables(self, function):
+ if function == "SELECT" or function == "DELETE":
+ table_names = self.clause["FROM"].content
+ elif function == "UPDATE":
+ table_names = self.clause["UPDATE"].content
+ elif function == "INSERT":
+ table_names = self.clause["INSERT"].content
+ return table_names
+
+ def excute_create_statement(self):
+ table = self.clause["CREATE"].content[0]
+ attribute_ls = []
+ types_ls = []
+ not_null_ls = []
+ primary_key_ls = []
+ char_attri_ls = []
+ char_attri_len_ls = []
+
+ for col_def in self.clause["COLUMNS"].content:
+ t = col_def.content[0]
+ attribute_ls.append(t["name"])
+ types_ls.append(t["type"])
+ if t["type"] in ["char","varchar"]:
+ char_attri_ls.append(t["name"])
+ char_attri_len_ls.append(t["length"])
+ # attention: if it is a primary key, then it should always be not null
+ not_null_ls.append(t["NOT NULL"] or t["PRIMARY"])
+ if t["PRIMARY"]:
+ primary_key_ls.append(t["name"])
+ if self.db.create(table=table,
+ attributes=attribute_ls,
+ types=types_ls,
+ not_null=not_null_ls,
+ primary_key=primary_key_ls,
+ char_attri=char_attri_ls,
+ char_attri_len=char_attri_len_ls):
+ print("创建成功!")
+ # print(self.db.database[table])
+ else:
+ print("创建失败!")
if __name__ == "__main__":
- sql = """
- SELECT id, name, this
- FROM test
- WHERE id = 1 AND this < 2.3;
- """
a = blabla()
- a.execute(sql)
\ No newline at end of file
+
+ # 创建
+ sql1 = """
+ CREATE TABLE Persons
+ (
+ PersonID PRIMARY int,
+ LastName varchar(255) NOT NULL,
+ FirstName char(255),
+ Address float,
+ City varchar(255)
+ );
+ """
+ a.execute(sql1)
+ print(a.db.database["Persons"]['tabledata'])
+ print("="*20)
+ print("="*20)
+
+ # 插入
+ sql2 = """
+ INSERT INTO Persons (PersonID,LastName, Address, City)
+ VALUES (
+ (3,'my', 2.3, "this"),
+ (4,'she', 5.6, "7"),
+ (5,'thistsdas',1.1,"9")
+ );
+ """
+ a.execute(sql2)
+ print(a.db.database["Persons"]['tabledata'])
+ print("="*20)
+ print("="*20)
+
+ # 更新
+ sql3 = """
+ UPDATE Persons
+ SET Address = Address * 1.1
+ WHERE PersonID >= 2
+ """
+ a.execute(sql3)
+ print(a.db.database["Persons"]['tabledata'])
+ print("="*20)
+ print("="*20)
+
+ # 查询
+ sql4 = """
+ SELECT *
+ FROM Persons
+ WHERE PersonID >= 4
+ """
+ a.execute(sql4)
+ print("=" * 20)
+ print("=" * 20)
+
+ # print(a.db.database["Persons"]['datatypes'])
+ # print(a.db.database["Persons"]['not_null_flag'])
+ # print(a.db.database["Persons"]['primary_key'])
+ # a = blabla()
+ # sql1 = """
+ # SELECT id, name
+ # FROM test2
+ # WHERE id >= 2 AND this < 3.1 OR this2 = 13 AND id = 1;
+ # """
+ # a.execute(sql1)
+
+ # sql2 = """
+ # UPDATE test2
+ # SET this2 = 100, name = 'Li Si', this = this * 1.1
+ # WHERE id >= 2 AND this2 < 13;
+ # """
+ # a.execute(sql2)
+
+ # sql3 = """
+ # SELECT id, name, this, this2
+ # FROM test2
+ # """
+ # a.execute(sql3)
+
+ # sql4 = """
+ # DELETE FROM test2
+ # WHERE id = 1
+ # """
+ # a.execute(sql4)
+ # a.execute(sql3)
+
+ # sql5 = """
+ # SELECT *
+ # FROM test
+ # WHERE gender = Female;
+ # """
+ # a.execute(sql5)
+
+ # sql6 = """
+ # UPDATE test
+ # SET salary = salary * 1.2
+ # WHERE id = 1 OR id = 2 OR id = 5;
+ # """
+ # a.execute(sql6)
+
+ # sql7 = """SELECT * FROM test"""
+ # a.execute(sql7)
\ No newline at end of file
diff --git a/sql_minibuilder-main.zip b/sql_minibuilder-main.zip
new file mode 100644
index 0000000..c1749de
Binary files /dev/null and b/sql_minibuilder-main.zip differ
diff --git a/sql_parse/AST_builder.py b/sql_parse/AST_builder.py
index 77a4eb5..34c87d6 100644
--- a/sql_parse/AST_builder.py
+++ b/sql_parse/AST_builder.py
@@ -1,12 +1,12 @@
from sql_parse.tokenizer import tokenizer
from sql_parse import tokens
from sql_parse.ast_def import AST_KEYWORDS
-from sql_parse.ast_def import _statement,_clause,_expression
+from sql_parse.ast_def import _statement,_clause,_expression,_coldef
class AST:
def __init__(self, text = None):
- self.get_tokens(text)
+ self.get_tokens(text=text)
new_statement = _statement()
new_statement.attribute = AST_KEYWORDS.STATEMENT
self.content , _ = self.build_AST(start_idx=0, cur_node=new_statement)
@@ -51,7 +51,138 @@ def create_node(self, level):
return _clause()
if(level == AST_KEYWORDS.EXPRESSION):
return _expression()
+ if(level == AST_KEYWORDS.COLUMN_DEFINITION):
+ return _coldef()
+ def build_AST_CREATE(self, start_idx = 0, statement_node = None):
+ stream = self.tokens
+ total_idx = len(stream)
+
+ # 第一个必定会存在的clause: value="CREATE",content包含table的名称
+ create_clause_node = self.create_node(AST_KEYWORDS.CLAUSE)
+ create_clause_node.value = "CREATE"
+ statement_node.content.append(create_clause_node)
+
+ # 第二个必定会存在的clause:value="COLUMNS",content包含每一列的定义
+ columns_clause_node = self.create_node(AST_KEYWORDS.CLAUSE)
+ columns_clause_node.value = "COLUMNS"
+ statement_node.content.append(columns_clause_node)
+
+ # 在找到CREATE关键词后,我们希望找的是:CREATE TABLE 表名
+ cur_node = create_clause_node
+ idx = start_idx + 1
+
+ pair_level = 0
+
+ while idx < total_idx:
+ cls, value = stream[idx]
+
+ # 如果当前token并不特殊,非关键字,那么就是当前node需要接受的内容
+ if (cls not in tokens.Keyword) and (cls not in tokens.Punctuation):
+ cur_node.deal(cls, value)
+
+ # 如果当前token为标点
+ elif cls in tokens.Punctuation:
+ # ()的处理
+ if value in ["(", ")"]:
+ if value == "(":
+ pair_level = pair_level + 1
+ if pair_level == 1 and value == "(":
+ # 遇到这里,说明读到了CREATE TABLE 表名 ( 的情况,接下来该是新的clause了
+ if cur_node.value == "CREATE":
+ cur_coldef_node = self.create_node(AST_KEYWORDS.COLUMN_DEFINITION)
+ cur_coldef_node.value = "COLUMN_DEFINITION"
+ columns_clause_node.content.append(cur_coldef_node)
+ cur_node = cur_coldef_node
+ idx = idx + 1
+ continue
+ if pair_level == 1 and value == ")":
+ # 遇到这里,说明是在希望读取下一个列的时候,发现已经没有更多列了
+ # 那可真是个天大的喜事,说明读取完成了
+ return statement_node, idx
+ if value == ")":
+ pair_level = pair_level - 1
+ # 逗号的处理
+ if value in [","]:
+ # 一个列的结束,另一个列的开始
+ cur_coldef_node = self.create_node(AST_KEYWORDS.COLUMN_DEFINITION)
+ cur_coldef_node.value = "COLUMN_DEFINITION"
+ columns_clause_node.content.append(cur_coldef_node)
+ cur_node = cur_coldef_node
+ idx = idx + 1
+ continue
+
+ # 如果当前token特殊,为关键字……(但其实唯一要看的关键字就是TABLE和PRIMARY(目前的话))
+ else:
+ val = value.upper()
+ if val == "TABLE": pass
+ if val == "PRIMARY":
+ cur_node.content[0]["PRIMARY"] = True
+ if val == "NOT NULL":
+ cur_node.content[0]["NOT NULL"] = True
+ idx = idx + 1
+ return statement_node, idx
+
+ def build_AST_INSERT(self, start_idx = 0, statement_node = None):
+ stream = self.tokens
+ total_idx = len(stream)
+
+ # 第一个必定会存在的clause: value="INSERT",content包含table的名称
+ insert_clause_node = self.create_node(AST_KEYWORDS.CLAUSE)
+ insert_clause_node.value = "INSERT"
+ statement_node.content.append(insert_clause_node)
+
+ # 第二个必定会存在的clause:value="COLUMNS",content包含每一列的名称
+ columns_clause_node = self.create_node(AST_KEYWORDS.CLAUSE)
+ columns_clause_node.value = "COLUMNS"
+ statement_node.content.append(columns_clause_node)
+
+
+ # 在找到INSERT关键词后,我们希望找的是:INSERT INTO 表名
+ requireValue = "INTO"
+ cur_node = insert_clause_node
+ idx = start_idx + 1
+
+ pair_level = 0
+
+ while idx < total_idx:
+ cls, value = stream[idx]
+
+ # 如果当前token并不特殊,非关键字,那么就是当前node需要接受的内容
+ if (cls not in tokens.Keyword) and (cls not in tokens.Punctuation):
+ cur_node.deal(cls, value)
+
+ # 如果当前token为标点
+ elif cls in tokens.Punctuation:
+ # ()的处理
+ if value == "(":
+ pair_level = pair_level + 1
+ if value in ["(", ")"]:
+ if cur_node.value == "INSERT" and value == "(":
+ # 遇到这里,说明读到了INSERT INTO 表名 ( 的情况,接下来该是第二个clause了
+ cur_node = columns_clause_node
+ idx = idx + 1
+ continue
+ elif (cur_node.value == "COLUMNS" and pair_level == 1 and value == "(" and self.next_token_value(idx+1) != "(") or (pair_level == 2 and value == "("):
+ # 第三个必定会存在的clause:value="VALUES",content包含每一列的值
+ values_clause_node = self.create_node(AST_KEYWORDS.CLAUSE)
+ values_clause_node.value = "VALUES"
+ statement_node.content.append(values_clause_node)
+ cur_node = values_clause_node
+ if value == ")":
+ pair_level = pair_level - 1
+
+ # 如果当前token特殊,为关键字
+ else:
+ val = value.upper()
+ if val == "INTO": pass
+ # 读到这里,说明columns的名字已经读完了,在读VALUES了,进入下一个clause
+ if val == "VALUES":
+ pass
+ idx = idx + 1
+ return statement_node, idx
+
+
def build_AST(self, start_idx = 0, cur_node = None):
stream = self.tokens
idx = start_idx
@@ -85,6 +216,16 @@ def build_AST(self, start_idx = 0, cur_node = None):
# 如果当前token特殊,为关键字
else:
+ # 特殊语法检查
+ # CREATE的语法区别和其他的差别太大了,放弃兼容性,直接单独处理
+ if value.upper() == "CREATE":
+ node_create, _ = self.build_AST_CREATE(idx, cur_node)
+ return node_create, None
+ # INSERT同理
+ if value.upper() == "INSERT":
+ node_insert, _ = self.build_AST_INSERT(idx, cur_node)
+ return node_insert, None
+
# 判断当前token的层级
par_cls_level = cur_node.attribute
cur_cls_level = self.get_level(value=value.upper(),par_value=cur_node.value)
@@ -115,7 +256,19 @@ def build_AST(self, start_idx = 0, cur_node = None):
raise Exception(f"Current node level is {cur_node.attribute}, but the word level is {cur_cls_level}")
idx = idx + 1
return cur_node, idx
-
+
+ def next_token_value(self, start_idx):
+ stream = self.tokens
+ idx = start_idx
+ total_idx = len(stream)
+ while idx < total_idx:
+ cls, value = stream[idx]
+ if value == " ":
+ idx = idx+1
+ continue
+ return value
+ return None
+
def pprint(self):
"""
用还算好看的方式打印自己的content
@@ -127,7 +280,7 @@ def pprint_impl(self,depth = 0, cur_list = None,_pre = ''):
pprint的单层实现
"""
for idx, cur_list_i in enumerate(cur_list):
- if isinstance(cur_list_i, _clause) or isinstance(cur_list_i, _expression):
+ if isinstance(cur_list_i, _clause) or isinstance(cur_list_i, _expression) or isinstance(cur_list_i,_coldef):
level = "level:" + str(cur_list_i.attribute)
item = _pre + str(cur_list_i.value)
print(f"{item:<60} {level:<50}")
@@ -139,29 +292,55 @@ def pprint_impl(self,depth = 0, cur_list = None,_pre = ''):
if __name__ == "__main__":
+ # 基础SELECT
sql1 = """
SELECT id, name, this
FROM table1, table2
WHERE id = 1 AND "this" < 2.3;
"""
+ # 基础UPDATE
sql2 = """
UPDATE table1
- SET alexa = 50000, country='USA', salary = 14.5
+ SET alexa = 50000, country='USA', salary = salary * 14.5
WHERE id = 1 AND this < 2.3 OR name>1;
"""
+ # 基础DELETE
sql3 = """
DELETE table1
WHERE id = 1 AND this < 2.3 OR name>1;
"""
+ # 基础CREATE
+ sql4 = """
+ CREATE TABLE Persons (
+ PersonID SERIAL int,
+ LastName varchar(255),
+ FirstName char(255) NOT NULL,
+ Address float,
+ City varchar(255),
+ );
+ """
+ # SELECT的特殊情况:Wildcard
+ sql5 = """
+ SELECT *
+ FROM table1, table2
+ WHERE id = 1 AND "this" < 2.3;
+ """
- a = AST(sql3)
- # TODO: 由于自己的实现是从左往右读TOKEN,而没有提前读等操作,因而不可能先读
- # AND再读WHERE。自己的一个暂时的解决方法是将AND和WHERE一样看作一个
- # expression,这样能保证一个CLAUSE中只有一个表达式(例如a=3),读到AND时执行
- # WHERE查询,再将AND查询的结果,与WHERE查询的结果取交集。如果后面还有AND,就
- # 再与左边的取交集……这样对于多个AND没有问题,但是如果有OR,那么优先级就被打
- # 乱了,必须要先完成OR两边的再对两边的结果取并集
- # 现在来看这部分有点困难,先不要动为好
+ # 基础INSERT
+ sql6 = """
+ INSERT INTO table1 (id, name, this)
+ VALUES (1, 'alex', 2.3);
+ """
+ # INSERT的特殊情况,不指定列名
+ sql7 = """
+ INSERT INTO table1(id,name)
+ VALUES(
+ (1, 'alex'),
+ (2, 'bob'),
+ (5, 'jjj')
+ );
+ """
+ a = AST(sql7)
a.pprint()
show = a.content
- print("\n\n",show)
\ No newline at end of file
+ print("\n\n",show)
diff --git a/sql_parse/ast_def.py b/sql_parse/ast_def.py
index 647e2f5..9e47d69 100644
--- a/sql_parse/ast_def.py
+++ b/sql_parse/ast_def.py
@@ -5,6 +5,7 @@ class AST_KEYWORDS(enum.IntEnum):
STATEMENT = 0
CLAUSE = 1
EXPRESSION = 2
+ COLUMN_DEFINITION = 10
class _statement:
def __init__(self):
@@ -20,21 +21,26 @@ class _clause:
def __init__(self):
self.attribute = AST_KEYWORDS.CLAUSE
self.content = []
+ self.value = None
def deal(self, cls, value):
"""
- 对于一个非关键词的token,根据自己clause的类型,将其加入到自己的内容中
+ 对于一个非关键词的token,根据自己的类型,将其加入到自己的内容中
"""
# 当前columns的类型决定其会记录columns
- if self.value in ["SELECT"]:
+ if self.value in ["SELECT", "CREATE", "COLUMNS"]:
if cls in tokens.Name:
self.content.append(value)
+ elif cls in tokens.Wildcard:
+ self.content.append(tokens.Wildcard)
# 当前clause的类型决定其会记录tables
- if self.value in ["FROM", "UPDATE"]:
+ if self.value in ["FROM", "UPDATE", "INSERT"]:
if cls in tokens.Name:
self.content.append(value)
- # 当前clause的类型决定其会记录一个表达式
- # TODO: 完成一些例如CREATE,PRIMARY之类的处理
+ # 当前clause的类型决定其会记录values,但并不是表达式
+ if self.value in ["VALUES"]:
+ if cls in tokens.Name or cls in tokens.Literal:
+ self.content.append(Numerize(cls,value))
# WHERE很特殊,之后WHERE会被重复利用一次,以完成多个表达式的连接
# 因此理论上WHERE clause不可能接受到非关键词token(忽略),WHERE expression才会接受到
@@ -46,12 +52,13 @@ class _expression:
def __init__(self):
self.attribute = AST_KEYWORDS.EXPRESSION
self.content = []
+ self.value = None
def deal(self, cls, value):
"""
对于一个非关键词的token,根据自己expression的类型,将其加入到自己的内容中
"""
- if self.value in ["WHERE", "AND", "OR", "SET"]:
+ if self.value in ["WHERE", "AND", "OR"]:
if cls in tokens.Name or cls in tokens.Literal:
if self.content == []: # 左边
sub_expr_left = {"left": Numerize(cls,value), "op": None, "right": None}
@@ -60,6 +67,41 @@ def deal(self, cls, value):
self.content[0]["right"] = Numerize(cls,value)
if cls in tokens.Operator: # 中间的Operator
self.content[0]["op"] = value
+ if self.value in ["SET"]:
+ if cls in tokens.Name or cls in tokens.Literal:
+ if self.content == []: # 左边
+ sub_expr_left = {"assignment": Numerize(cls,value), "expression": None}
+ self.content.append(sub_expr_left)
+ else: # 右边
+ if self.content[0]["expression"] == None:
+ self.content[0]["expression"] = dict()
+ self.content[0]["expression"]["left"] = Numerize(cls,value)
+ else:
+ self.content[0]["expression"]["right"] = Numerize(cls,value)
+ if cls in tokens.Operator or cls in tokens.Wildcard: # 中间的Operator,Wildcard是指*
+ if self.content[0]["expression"] == None:
+ pass
+ else:
+ self.content[0]["expression"]["op"] = value
+
+class _coldef:
+ def __init__(self):
+ self.attribute = AST_KEYWORDS.COLUMN_DEFINITION
+ self.content = [{"PRIMARY":False,"NOT NULL":False}]
+
+ def deal(self, cls, value):
+ """
+ 对于一个非关键词的token,根据自己的类型,将其加入到自己的内容中
+ """
+ # 说明这是一个column的类型提示
+ if cls in tokens.Name.Builtin:
+ self.content[0]["type"] = value
+
+ elif cls in tokens.Name:
+ self.content[0]["name"] = value
+
+ elif cls in tokens.Literal:
+ self.content[0]["length"] = Numerize(cls,value)
def Numerize(cls, text: str):
diff --git a/sql_parse/keywords.py b/sql_parse/keywords.py
index 5d3a405..d851f33 100644
--- a/sql_parse/keywords.py
+++ b/sql_parse/keywords.py
@@ -711,4 +711,6 @@
'MIN': tokens.Keyword,
'MAX': tokens.Keyword,
'DISTINCT': tokens.Keyword,
-}
\ No newline at end of file
+}
+
+
diff --git a/sql_parse/tokenizer.py b/sql_parse/tokenizer.py
index d694953..d0bde05 100644
--- a/sql_parse/tokenizer.py
+++ b/sql_parse/tokenizer.py
@@ -26,6 +26,13 @@ def default_initialization(self):
self._keywords.append(keywords.KEYWORDS_COMMON)
self._keywords.append(keywords.KEYWORDS)
+ # reverse to find the built-in keywords
+ self.KEYWORDS_BUILTIN_LIST = []
+ for k, v in keywords.KEYWORDS.items():
+ if v != tokens.Name.Builtin:
+ continue
+ self.KEYWORDS_BUILTIN_LIST.append(k)
+
def is_keyword(self, value):
"""
判断当前的内容是一个NAME还是关键字
@@ -40,6 +47,20 @@ def is_keyword(self, value):
else:
return tokens.Name, value
+ def builtin_check(self, action, value):
+ """
+ 一个特殊的check,保证当前的token是一个内置的类型,而非一个自定义的类型
+ 这里多余的处理是因为最开始的正则不太对
+ """
+ if action != tokens.Token.Name:
+ return action, value
+ val = value.upper()
+ if val in self.KEYWORDS_BUILTIN_LIST:
+ return tokens.Token.Name.Builtin, value
+ else:
+ return action, value
+
+
def consume(self,iterator, n):
"""
将迭代器向后推动n个位置,即跳过n个元素,这里是为了在匹配到字符为token后,跳过已匹配部分
@@ -64,7 +85,7 @@ def tokenize(self,text: str):
if not m: # 从这个字符往后看,并不能匹配出任何的关键字或者token
continue
elif isinstance(action, tokens._TokenType): # 如果这是一个普通Token
- yield action, m.group()
+ yield self.builtin_check(action,m.group())
elif action is keywords.PROCESS_AS_KEYWORD: # 如果这是一个关键字
yield self.is_keyword(m.group())