diff --git a/.gitignore b/.gitignore index dca4913..3da298b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ sql_parse/clauses/__pycache__ sql_parse/clauses sql_parse/test.py sql_command/__pycache__/DB.cpython-310.pyc +pyrightconfig.json diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..717e9bd --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,19 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..58e8fc9 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/sql_minibuilder.iml b/.idea/sql_minibuilder.iml new file mode 100644 index 0000000..8e5446a --- /dev/null +++ b/.idea/sql_minibuilder.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index c6e1efe..47d347c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -~~目前只需要改commands/DB.py就行了~~ +~~目前只需要改command/DB.py就行了~~ commands中的执行功能都差不多写完了,tokenize已完成,AST写了一点 @@ -43,27 +43,38 @@ WHERE id = 1 AND this < 2.3; 完成情况: -- [x] 对SELECT的实现 +- [x] 对SELECT的解析 -- [X] 对FROM的实现 +- [X] 对FROM的解析 -- [x] 对WHERE的基础实现 +- [x] 对WHERE的基础解析 -- [ ] 对WHERE中AND、OR的正确顺序判断 +- [x] 对WHERE中AND、OR的正确顺序判断 -- [ ] 对CREATE的实现 +- [x] 对SET的解析 -- [ ] 对主键`PRIMARY`的实现 +- [x] 对UPDATE的解析 -- [ ] 对非空`NOT NULL`的实现 +- [x] 对DELETE的解析 -- [x] 对UPDATE的实现 +##### 星期五添加 -- [x] 对DELETE的实现 +- [x] 对SELECT时`WILDCARD`的解析(就是选中所有列,详见输入输出5) +- [x] 对CREATE时主键`PRIMARY`的解析 +- [x] 对CREATE时非空`NOT NULL`的解析 -**目前已部分完成,只到能够解析查询命令(SELECT)、更新命令(UPDATE)与删除命令(DELETE)的地方** +- [ ] 对CREATE各种约束的解析(摆,不想做,因为外键有点麻烦) + +- [x] 对SET赋值式右边的Binary Expression的解析(抱歉的是SET时的expression结构也换了,换成assignment了,详见输出2) + +- [x] 对CREATE的解析 + +- [x] 对INSERT INTO的解析 + + +**目前已部分完成,只到能够解析查询命令(SELECT)、更新命令(UPDATE)、删除命令(DELETE)与基础创建命令(CREATE)(不包含NOT NULL)的地方** 输入1: ``` @@ -133,36 +144,27 @@ WHERE level:AST_KEYWORDS.CLAUSE 输入2: ``` UPDATE table1 -SET alexa = 50000, country='USA', salary = 14.5 +SET alexa = 50000, country='USA', salary = salary * 14.5 WHERE id = 1 AND this < 2.3 OR name>1; ``` **实际输出2**: ``` -UPDATE level:AST_KEYWORDS.CLAUSE - +UPDATE level:AST_KEYWORDS.CLAUSE -- table1 SET level:AST_KEYWORDS.CLAUSE - --SET level:AST_KEYWORDS.EXPRESSION - ----- {'left': 'alexa', 'op': '=', 'right': 50000} +---- {'assignment': 'alexa', 'expression': {'left': 50000}} --SET level:AST_KEYWORDS.EXPRESSION - ----- {'left': 'country', 'op': '=', 'right': 'USA'} +---- {'assignment': 'country', 'expression': {'left': 'USA'}} --SET level:AST_KEYWORDS.EXPRESSION - ----- {'left': 'salary', 'op': '=', 'right': 14.5} +---- {'assignment': 'salary', 'expression': {'left': 'salary', 'op': '*', 'right': 14.5}} WHERE level:AST_KEYWORDS.CLAUSE - --WHERE level:AST_KEYWORDS.EXPRESSION - ---- {'left': 'id', 'op': '=', 'right': 1} --AND level:AST_KEYWORDS.EXPRESSION - ---- {'left': 'this', 'op': '<', 'right': 2.3} --OR level:AST_KEYWORDS.EXPRESSION - ---- {'left': 'name', 'op': '>', 'right': 1} ``` @@ -189,6 +191,106 @@ WHERE level:AST_KEYWORDS. ---- {'left': 'name', 'op': '>', 'right': 1} ``` +输入4: +``` +CREATE TABLE Persons +( + PersonID SERIAL int, + LastName PRIMARY varchar(255), + FirstName char(255) NOT NULL, + Address float, + City varchar(255) +); +``` + + +**实际输出4**: +``` +CREATE level:AST_KEYWORDS.CLAUSE +-- Persons +COLUMNS level:AST_KEYWORDS.CLAUSE +--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION +---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'PERSONID', 'type': 'int'} +--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION +---- {'PRIMARY': True, 'NOT NULL': False, 'name': 'LASTNAME', 'type': 'varchar', 'length': 255} +--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION +---- {'PRIMARY': False, 'NOT NULL': True, 'name': 'FIRSTNAME', 'type': 'char', 'length': 255} +--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION +---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'ADDRESS', 'type': 'float'} +--COLUMN_DEFINITION level:AST_KEYWORDS.COLUMN_DEFINITION +---- {'PRIMARY': False, 'NOT NULL': False, 'name': 'CITY', 'type': 'varchar', 'length': 255} +``` + +输入5:(与输入1相比,SELECT所有列): +``` +SELECT * +FROM table1, table2 +WHERE id = 1 AND "this" < 2.3; +``` + +**实际输出5**: +``` +SELECT level:AST_KEYWORDS.CLAUSE +-- Token.Wildcard +FROM level:AST_KEYWORDS.CLAUSE +-- table1 +-- table2 +WHERE level:AST_KEYWORDS.CLAUSE +--WHERE level:AST_KEYWORDS.EXPRESSION +---- {'left': 'id', 'op': '=', 'right': 1} +--AND level:AST_KEYWORDS.EXPRESSION +---- {'left': 'this', 'op': '<', 'right': 2.3} +``` + +输入6: +``` +INSERT INTO table1 (id, name, this) +VALUES (1, 'alex', 2.3); +``` + +**实际输出6**: +``` +INSERT level:AST_KEYWORDS.CLAUSE +-- table1 +COLUMNS level:AST_KEYWORDS.CLAUSE +-- id +-- name +-- this +VALUES level:AST_KEYWORDS.CLAUSE +-- 1 +-- alex +-- 2.3 +``` + +输入7:(与输入6的区别在于,INSERT INTO未指定表名,同时是多列) +``` +INSERT INTO table1 +VALUES( + (1, 'alex', 2.3), + (2, 'bob', 2.4), + (5, 'jjj', 2.5) +); +``` + +**实际输出7** +``` +INSERT level:AST_KEYWORDS.CLAUSE +-- table1 +COLUMNS level:AST_KEYWORDS.CLAUSE +VALUES level:AST_KEYWORDS.CLAUSE +-- 1 +-- alex +-- 2.3 +VALUES level:AST_KEYWORDS.CLAUSE +-- 2 +-- bob +-- 2.4 +VALUES level:AST_KEYWORDS.CLAUSE +-- 5 +-- jjj +-- 2.5 +``` + **用语解释**: @@ -201,19 +303,39 @@ WHERE level:AST_KEYWORDS. -#### 1.3 在AST基础上,实现对`sql_commands`的调用,完成语句 +#### 1.3 在AST基础上,实现对`sql_command`的调用,完成语句 + +文件:`sql_control/main.py`(示例用,之后可能会改) + +- [x] 完成示例文件,演示如何结合AST与实现的`sql_command`,通过FROM语句获得databse中的表 + +- [x] 完成AST与执行代码关于WHERE的结合 + +- [x] 完成AST与执行代码关于SELECT的结合 + +- [x] 完成AST与执行代码关于DELETE的结合 + +- [x] 完成AST与执行代码关于UPDATE的结合 + +##### 星期五添加的 + +- [ ] 一个小问题:所以什么时候在硬盘写入文件,读入文件? + +- [x] 完成AST与执行代码关于SELECT的WILDCARD设置(也就是选中所有列,详见输入5,输出5) + +- [x] 完成AST与执行代码关于SET的右侧赋值式结合 + +- [x] 完成AST与执行代码关于CREATE的基础结合,限制每一列的类型 -文件:`sql_commands/main.py`(示例用,之后可能会改) +- [x] 完成AST与执行代码关于CREATE的主键设置,NOT NULL设置 -- [x] 完成示例文件,演示如何结合AST与实现的`sql_commands`,通过FROM语句获得databse中的表 +- [ ] 完成AST与执行代码关于CREATE中类型有SERIAL时,在INSERT时主键的自动更新并插入 -- [ ] 完成AST与执行代码关于WHERE的结合 -- [ ] 完成AST与执行代码关于SELECT的结合 -- [ ] 完成AST与执行代码关于DELETE的结合 +- [ ] 完成AST与执行代码关于INSERT的基础结合 -- [ ] 完成AST与执行代码关于UPDATE的结合 +- [ ] 完成AST与执行代码关于INSERT在表名未指定列时的正确处理(详见输入输出7) - [ ] 更多…… diff --git a/sql_command/DB.py b/sql_command/DB.py index 0cc1b42..4b000b0 100644 --- a/sql_command/DB.py +++ b/sql_command/DB.py @@ -2,16 +2,23 @@ class DB: def __init__(self): - self.dbtypes=dict[dict['tablename':str,'tabledata':pd.DataFrame,'datatypes':dict]] + self.dbtypes=dict[dict['tablename':str,'tabledata':pd.DataFrame,'datatypes':dict,'not_null_flag':dict,'primary_key':str,'attri_len':dict]] self.database = {} #访问某个表用 database['表名'] #访问某个表中的数据用 database['表名']['tabledata'] 这是一个pd.DataFrame型的数据 #访问某个表中某个属性的数据类型用 database['表名']['datatypes']['属性名'] + #查询某个表中某个属性是否是not null :database['表名']['not_null_flag']['属性名'] + #访问某个表中某个(这里仅考虑char varchar类型)属性设定的长度 :database['表名']['attri_len']['属性名'] + #primary: database['表名']['primary_key']['属性名'] def create(self, table: str, attributes: list[str], - types: list[str]) -> bool: + types: list[str], + not_null: list[bool], + primary_key: list[str], + char_attri: list[str], + char_attri_len: list[int]) -> bool: #检查表是否已经存在 if table in self.database: print(f"Table '{table}' already exists.") @@ -21,7 +28,10 @@ def create(self, newtable = { 'tablename': table, 'tabledata': pd.DataFrame(columns=attributes), - 'datatypes': {attr: data_type for attr, data_type in zip(attributes, types)} + 'datatypes': {attr: data_type for attr, data_type in zip(attributes, types)}, + 'not_null_flag': {attr: not_null for attr, not_null in zip(attributes, not_null)}, + 'primary_key': primary_key, + 'attri_len': {ch: chlen for ch, chlen in zip(char_attri,char_attri_len)} } #数据库添加新表 self.database[table] = newtable diff --git a/sql_control/main.py b/sql_control/main.py index 6fea041..c4c699d 100644 --- a/sql_control/main.py +++ b/sql_control/main.py @@ -1,5 +1,13 @@ from sql_parse.AST_builder import AST from sql_command.DB import DB +import os +import pandas as pd + +from sql_parse.tokens import Token + +# data_folder = '../data' +# file_names = os.listdir(data_folder) + class blabla: def __init__(self): @@ -7,8 +15,14 @@ def __init__(self): 示例用 """ self.db = DB() - self.db.create("test",["索引","姓名"], [int,str]) - + # 外部数据读取 test + # for file_name in file_names: + # if file_name.endswith('.csv'): + # var_name = os.path.splitext(file_name)[0] + # file_path = data_folder + '/' + file_name + # data_table = pd.read_csv(file_path) + # self.dict[var_name] = data_table + def ast_clear(self): """ 清除上一个语句留下的AST @@ -22,13 +36,13 @@ def extract(self,text:str): 获取不同clause(子句)的内容 """ # 根据text,生成AST(这里AST只实现了SELECT查询) - self.ast = AST(sql).content + self.ast = AST(text).content # 将AST中的不同clause保存过来 self.clause = {} for clause in self.ast.content: self.clause[clause.value] = clause - def execute(self,text:str): + def execute(self, text: str): """ 执行ast中的语句 """ @@ -36,27 +50,304 @@ def execute(self,text:str): self.ast_clear() # 获取当前语句的AST self.extract(text) - # 示例用,演示怎么从AST解析后的树中获取table - tables = self.get_tables() - print(tables) - - def get_tables(self): - #从FROM clause中获得content - table_names = self.clause["FROM"].content - ret = [] - for table_name_i in table_names: - # 调用sql_commands中实现的方法,访问得到每个table的DataFrame - table = self.db.database[table_name_i]['tabledata'] - ret.append(table) - return ret - + + function = self.funct() + + # 查询 + if function == 'SELECT': + tables = self.get_tables(function) + cols = self.get_col() + for table in tables: + cur_table = self.db.database[table] + rows = self.get_row(table) if "WHERE" in self.clause.keys() else None + result = self.db.select(cur_table['tabledata'], cols, rows) if Token.Wildcard not in cols else self.db.select(cur_table['tabledata'], cur_table['tabledata'].columns, rows) + print(result) + + # 更新 + elif function == 'UPDATE': + tables = self.get_tables(function) + for table in tables: + cur_table = self.db.database[table] + rows = self.get_row(table) if "WHERE" in self.clause.keys() else None + for up in self.clause["SET"].content: + v = up.content[0]['expression'] + if len(v) == 1: + if isinstance(v['left'], int) or isinstance(v['left'], float) or (isinstance(v['left'], str) and f"'{v['left']}'" in text): + value = v['left'] + else: + value = self.db.select(cur_table['tabledata'], v['left'], rows).values.tolist() + else: + value = self.get_val(table, up.content[0]['expression'], rows) + self.db.update(cur_table['tabledata'], update_rows=rows, attributes=[up.content[0]['assignment']], values=value) + + # 添加 + elif function == 'INSERT': + tables = self.get_tables(function) + for table in tables: + cols = self.clause["COLUMNS"].content + values_clauses = self.ast.content[2:] + cur_table = self.db.database[table] + values_for_insert = [] + for values_clause_i in values_clauses: + values = values_clause_i.content + self.check_types(cur_table['datatypes'], cols, values) + self.check_null(cur_table['not_null_flag'], cols, values) + self.check_primary(cur_table['tabledata'],cur_table['primary_key'],cols,values) + values_for_insert.append(values) + _, cur_table['tabledata'] = self.db.insert(cur_table['tabledata'], cols, values_for_insert) + + # 删除 + elif function == 'DELETE': + tables = self.get_tables(function) + for table in tables: + cur_table = self.db.database[table] + rows = self.get_row(table) if "WHERE" in self.clause.keys() else None + self.db.delete(cur_table['tabledata'], del_rows=rows) + + # 创建新表 + elif function == 'CREATE': + self.excute_create_statement() + + # 删除表 + elif function == 'DROP': + print("DROP") + + else: + print("ERROR FUNCTION") + + if function != "SELECT": + self.save_tables() + + def check_types(self, requires, cols, values): + for i in range(len(cols)): + if requires[cols[i]].upper() in ["VARCHAR", "CHAR"]: + if not isinstance(values[i], str): + raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.") + if requires[cols[i]].upper() in ["INT"]: + if not isinstance(values[i], int): + raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.") + if requires[cols[i]].upper() in ["FLOAT"]: + if not isinstance(values[i],float): + pass + raise Exception(f"Type of {cols[i]} is {requires[cols[i]]}, but {type(values[i])} is given.") + + def check_null(self, requires, cols, values): + for k in requires: + v = requires[k] + if v is True: + if k not in cols: + raise Exception(f"{k} is not null, but null is given.") + else: + if values[cols.index(k)] is None: + raise Exception(f"{k} is not null, but null is given.") + + def check_primary(self, table: pd.DataFrame, primary_keys, cols, values): + for k in primary_keys: + if k in cols: + if values[cols.index(k)] in table[k].values: + raise Exception(f"{k} is primary key, but {values[cols.index(k)]} is already in table.") + else: + raise Exception(f"{k} is primary key, but not given in insertion values.") + + + def save_tables(self): + # 待实现 + pass + + # 判断功能 + def funct(self): + return self.ast.content[0].value + + # SELECT功能的筛选列 + def get_col(self): + num_col = len(self.clause["SELECT"].content) + cols = [] + for i in range(0, num_col): + col = self.clause["SELECT"].content[i] + cols.append(col) + return cols + + # WHERE语句的筛选行和对AND、OR运算结果 + def get_row(self, table): + cur_table = self.db.database[table] + num_condition = len(self.clause["WHERE"].content) + rows_ = [] + for i in range(0, num_condition): + condition = self.clause["WHERE"].content[i].content[0] + row = self.db.where(cur_table['tabledata'], condition['left'], condition['right'], condition['op']) + rows_.append(set(row)) + operators = [expression_i.value for expression_i in self.clause["WHERE"].content[1:]] + rows = rows_[0] + i = 1 + while i < len(rows_): + if operators[i - 1] == "AND": + rows = rows.intersection(rows_[i]) + elif operators[i - 1] == "OR": + temp_result = rows_[i] + j = i + 1 + while j < len(rows_) and operators[j - 1] == "AND": + temp_result = temp_result.intersection(rows_[j]) + j += 1 + rows = rows.union(temp_result) + i = j - 1 + i += 1 + return list(rows) + + # SET 右边为表达式的情况 + def get_val(self, table, expression, rows): + cur_table = self.db.database[table] + result = self.db.select(cur_table['tabledata'], expression['left'], rows) + if expression['op'] == '+': + result += expression['right'] + elif expression['op'] == '-': + result -= expression['right'] + elif expression['op'] == '*': + result *= expression['right'] + elif expression['op'] == '/': + result /= expression['right'] + result = result.values.tolist() + return result + + def get_tables(self, function): + if function == "SELECT" or function == "DELETE": + table_names = self.clause["FROM"].content + elif function == "UPDATE": + table_names = self.clause["UPDATE"].content + elif function == "INSERT": + table_names = self.clause["INSERT"].content + return table_names + + def excute_create_statement(self): + table = self.clause["CREATE"].content[0] + attribute_ls = [] + types_ls = [] + not_null_ls = [] + primary_key_ls = [] + char_attri_ls = [] + char_attri_len_ls = [] + + for col_def in self.clause["COLUMNS"].content: + t = col_def.content[0] + attribute_ls.append(t["name"]) + types_ls.append(t["type"]) + if t["type"] in ["char","varchar"]: + char_attri_ls.append(t["name"]) + char_attri_len_ls.append(t["length"]) + # attention: if it is a primary key, then it should always be not null + not_null_ls.append(t["NOT NULL"] or t["PRIMARY"]) + if t["PRIMARY"]: + primary_key_ls.append(t["name"]) + if self.db.create(table=table, + attributes=attribute_ls, + types=types_ls, + not_null=not_null_ls, + primary_key=primary_key_ls, + char_attri=char_attri_ls, + char_attri_len=char_attri_len_ls): + print("创建成功!") + # print(self.db.database[table]) + else: + print("创建失败!") if __name__ == "__main__": - sql = """ - SELECT id, name, this - FROM test - WHERE id = 1 AND this < 2.3; - """ a = blabla() - a.execute(sql) \ No newline at end of file + + # 创建 + sql1 = """ + CREATE TABLE Persons + ( + PersonID PRIMARY int, + LastName varchar(255) NOT NULL, + FirstName char(255), + Address float, + City varchar(255) + ); + """ + a.execute(sql1) + print(a.db.database["Persons"]['tabledata']) + print("="*20) + print("="*20) + + # 插入 + sql2 = """ + INSERT INTO Persons (PersonID,LastName, Address, City) + VALUES ( + (3,'my', 2.3, "this"), + (4,'she', 5.6, "7"), + (5,'thistsdas',1.1,"9") + ); + """ + a.execute(sql2) + print(a.db.database["Persons"]['tabledata']) + print("="*20) + print("="*20) + + # 更新 + sql3 = """ + UPDATE Persons + SET Address = Address * 1.1 + WHERE PersonID >= 2 + """ + a.execute(sql3) + print(a.db.database["Persons"]['tabledata']) + print("="*20) + print("="*20) + + # 查询 + sql4 = """ + SELECT * + FROM Persons + WHERE PersonID >= 4 + """ + a.execute(sql4) + print("=" * 20) + print("=" * 20) + + # print(a.db.database["Persons"]['datatypes']) + # print(a.db.database["Persons"]['not_null_flag']) + # print(a.db.database["Persons"]['primary_key']) + # a = blabla() + # sql1 = """ + # SELECT id, name + # FROM test2 + # WHERE id >= 2 AND this < 3.1 OR this2 = 13 AND id = 1; + # """ + # a.execute(sql1) + + # sql2 = """ + # UPDATE test2 + # SET this2 = 100, name = 'Li Si', this = this * 1.1 + # WHERE id >= 2 AND this2 < 13; + # """ + # a.execute(sql2) + + # sql3 = """ + # SELECT id, name, this, this2 + # FROM test2 + # """ + # a.execute(sql3) + + # sql4 = """ + # DELETE FROM test2 + # WHERE id = 1 + # """ + # a.execute(sql4) + # a.execute(sql3) + + # sql5 = """ + # SELECT * + # FROM test + # WHERE gender = Female; + # """ + # a.execute(sql5) + + # sql6 = """ + # UPDATE test + # SET salary = salary * 1.2 + # WHERE id = 1 OR id = 2 OR id = 5; + # """ + # a.execute(sql6) + + # sql7 = """SELECT * FROM test""" + # a.execute(sql7) \ No newline at end of file diff --git a/sql_minibuilder-main.zip b/sql_minibuilder-main.zip new file mode 100644 index 0000000..c1749de Binary files /dev/null and b/sql_minibuilder-main.zip differ diff --git a/sql_parse/AST_builder.py b/sql_parse/AST_builder.py index 77a4eb5..34c87d6 100644 --- a/sql_parse/AST_builder.py +++ b/sql_parse/AST_builder.py @@ -1,12 +1,12 @@ from sql_parse.tokenizer import tokenizer from sql_parse import tokens from sql_parse.ast_def import AST_KEYWORDS -from sql_parse.ast_def import _statement,_clause,_expression +from sql_parse.ast_def import _statement,_clause,_expression,_coldef class AST: def __init__(self, text = None): - self.get_tokens(text) + self.get_tokens(text=text) new_statement = _statement() new_statement.attribute = AST_KEYWORDS.STATEMENT self.content , _ = self.build_AST(start_idx=0, cur_node=new_statement) @@ -51,7 +51,138 @@ def create_node(self, level): return _clause() if(level == AST_KEYWORDS.EXPRESSION): return _expression() + if(level == AST_KEYWORDS.COLUMN_DEFINITION): + return _coldef() + def build_AST_CREATE(self, start_idx = 0, statement_node = None): + stream = self.tokens + total_idx = len(stream) + + # 第一个必定会存在的clause: value="CREATE",content包含table的名称 + create_clause_node = self.create_node(AST_KEYWORDS.CLAUSE) + create_clause_node.value = "CREATE" + statement_node.content.append(create_clause_node) + + # 第二个必定会存在的clause:value="COLUMNS",content包含每一列的定义 + columns_clause_node = self.create_node(AST_KEYWORDS.CLAUSE) + columns_clause_node.value = "COLUMNS" + statement_node.content.append(columns_clause_node) + + # 在找到CREATE关键词后,我们希望找的是:CREATE TABLE 表名 + cur_node = create_clause_node + idx = start_idx + 1 + + pair_level = 0 + + while idx < total_idx: + cls, value = stream[idx] + + # 如果当前token并不特殊,非关键字,那么就是当前node需要接受的内容 + if (cls not in tokens.Keyword) and (cls not in tokens.Punctuation): + cur_node.deal(cls, value) + + # 如果当前token为标点 + elif cls in tokens.Punctuation: + # ()的处理 + if value in ["(", ")"]: + if value == "(": + pair_level = pair_level + 1 + if pair_level == 1 and value == "(": + # 遇到这里,说明读到了CREATE TABLE 表名 ( 的情况,接下来该是新的clause了 + if cur_node.value == "CREATE": + cur_coldef_node = self.create_node(AST_KEYWORDS.COLUMN_DEFINITION) + cur_coldef_node.value = "COLUMN_DEFINITION" + columns_clause_node.content.append(cur_coldef_node) + cur_node = cur_coldef_node + idx = idx + 1 + continue + if pair_level == 1 and value == ")": + # 遇到这里,说明是在希望读取下一个列的时候,发现已经没有更多列了 + # 那可真是个天大的喜事,说明读取完成了 + return statement_node, idx + if value == ")": + pair_level = pair_level - 1 + # 逗号的处理 + if value in [","]: + # 一个列的结束,另一个列的开始 + cur_coldef_node = self.create_node(AST_KEYWORDS.COLUMN_DEFINITION) + cur_coldef_node.value = "COLUMN_DEFINITION" + columns_clause_node.content.append(cur_coldef_node) + cur_node = cur_coldef_node + idx = idx + 1 + continue + + # 如果当前token特殊,为关键字……(但其实唯一要看的关键字就是TABLE和PRIMARY(目前的话)) + else: + val = value.upper() + if val == "TABLE": pass + if val == "PRIMARY": + cur_node.content[0]["PRIMARY"] = True + if val == "NOT NULL": + cur_node.content[0]["NOT NULL"] = True + idx = idx + 1 + return statement_node, idx + + def build_AST_INSERT(self, start_idx = 0, statement_node = None): + stream = self.tokens + total_idx = len(stream) + + # 第一个必定会存在的clause: value="INSERT",content包含table的名称 + insert_clause_node = self.create_node(AST_KEYWORDS.CLAUSE) + insert_clause_node.value = "INSERT" + statement_node.content.append(insert_clause_node) + + # 第二个必定会存在的clause:value="COLUMNS",content包含每一列的名称 + columns_clause_node = self.create_node(AST_KEYWORDS.CLAUSE) + columns_clause_node.value = "COLUMNS" + statement_node.content.append(columns_clause_node) + + + # 在找到INSERT关键词后,我们希望找的是:INSERT INTO 表名 + requireValue = "INTO" + cur_node = insert_clause_node + idx = start_idx + 1 + + pair_level = 0 + + while idx < total_idx: + cls, value = stream[idx] + + # 如果当前token并不特殊,非关键字,那么就是当前node需要接受的内容 + if (cls not in tokens.Keyword) and (cls not in tokens.Punctuation): + cur_node.deal(cls, value) + + # 如果当前token为标点 + elif cls in tokens.Punctuation: + # ()的处理 + if value == "(": + pair_level = pair_level + 1 + if value in ["(", ")"]: + if cur_node.value == "INSERT" and value == "(": + # 遇到这里,说明读到了INSERT INTO 表名 ( 的情况,接下来该是第二个clause了 + cur_node = columns_clause_node + idx = idx + 1 + continue + elif (cur_node.value == "COLUMNS" and pair_level == 1 and value == "(" and self.next_token_value(idx+1) != "(") or (pair_level == 2 and value == "("): + # 第三个必定会存在的clause:value="VALUES",content包含每一列的值 + values_clause_node = self.create_node(AST_KEYWORDS.CLAUSE) + values_clause_node.value = "VALUES" + statement_node.content.append(values_clause_node) + cur_node = values_clause_node + if value == ")": + pair_level = pair_level - 1 + + # 如果当前token特殊,为关键字 + else: + val = value.upper() + if val == "INTO": pass + # 读到这里,说明columns的名字已经读完了,在读VALUES了,进入下一个clause + if val == "VALUES": + pass + idx = idx + 1 + return statement_node, idx + + def build_AST(self, start_idx = 0, cur_node = None): stream = self.tokens idx = start_idx @@ -85,6 +216,16 @@ def build_AST(self, start_idx = 0, cur_node = None): # 如果当前token特殊,为关键字 else: + # 特殊语法检查 + # CREATE的语法区别和其他的差别太大了,放弃兼容性,直接单独处理 + if value.upper() == "CREATE": + node_create, _ = self.build_AST_CREATE(idx, cur_node) + return node_create, None + # INSERT同理 + if value.upper() == "INSERT": + node_insert, _ = self.build_AST_INSERT(idx, cur_node) + return node_insert, None + # 判断当前token的层级 par_cls_level = cur_node.attribute cur_cls_level = self.get_level(value=value.upper(),par_value=cur_node.value) @@ -115,7 +256,19 @@ def build_AST(self, start_idx = 0, cur_node = None): raise Exception(f"Current node level is {cur_node.attribute}, but the word level is {cur_cls_level}") idx = idx + 1 return cur_node, idx - + + def next_token_value(self, start_idx): + stream = self.tokens + idx = start_idx + total_idx = len(stream) + while idx < total_idx: + cls, value = stream[idx] + if value == " ": + idx = idx+1 + continue + return value + return None + def pprint(self): """ 用还算好看的方式打印自己的content @@ -127,7 +280,7 @@ def pprint_impl(self,depth = 0, cur_list = None,_pre = ''): pprint的单层实现 """ for idx, cur_list_i in enumerate(cur_list): - if isinstance(cur_list_i, _clause) or isinstance(cur_list_i, _expression): + if isinstance(cur_list_i, _clause) or isinstance(cur_list_i, _expression) or isinstance(cur_list_i,_coldef): level = "level:" + str(cur_list_i.attribute) item = _pre + str(cur_list_i.value) print(f"{item:<60} {level:<50}") @@ -139,29 +292,55 @@ def pprint_impl(self,depth = 0, cur_list = None,_pre = ''): if __name__ == "__main__": + # 基础SELECT sql1 = """ SELECT id, name, this FROM table1, table2 WHERE id = 1 AND "this" < 2.3; """ + # 基础UPDATE sql2 = """ UPDATE table1 - SET alexa = 50000, country='USA', salary = 14.5 + SET alexa = 50000, country='USA', salary = salary * 14.5 WHERE id = 1 AND this < 2.3 OR name>1; """ + # 基础DELETE sql3 = """ DELETE table1 WHERE id = 1 AND this < 2.3 OR name>1; """ + # 基础CREATE + sql4 = """ + CREATE TABLE Persons ( + PersonID SERIAL int, + LastName varchar(255), + FirstName char(255) NOT NULL, + Address float, + City varchar(255), + ); + """ + # SELECT的特殊情况:Wildcard + sql5 = """ + SELECT * + FROM table1, table2 + WHERE id = 1 AND "this" < 2.3; + """ - a = AST(sql3) - # TODO: 由于自己的实现是从左往右读TOKEN,而没有提前读等操作,因而不可能先读 - # AND再读WHERE。自己的一个暂时的解决方法是将AND和WHERE一样看作一个 - # expression,这样能保证一个CLAUSE中只有一个表达式(例如a=3),读到AND时执行 - # WHERE查询,再将AND查询的结果,与WHERE查询的结果取交集。如果后面还有AND,就 - # 再与左边的取交集……这样对于多个AND没有问题,但是如果有OR,那么优先级就被打 - # 乱了,必须要先完成OR两边的再对两边的结果取并集 - # 现在来看这部分有点困难,先不要动为好 + # 基础INSERT + sql6 = """ + INSERT INTO table1 (id, name, this) + VALUES (1, 'alex', 2.3); + """ + # INSERT的特殊情况,不指定列名 + sql7 = """ + INSERT INTO table1(id,name) + VALUES( + (1, 'alex'), + (2, 'bob'), + (5, 'jjj') + ); + """ + a = AST(sql7) a.pprint() show = a.content - print("\n\n",show) \ No newline at end of file + print("\n\n",show) diff --git a/sql_parse/ast_def.py b/sql_parse/ast_def.py index 647e2f5..9e47d69 100644 --- a/sql_parse/ast_def.py +++ b/sql_parse/ast_def.py @@ -5,6 +5,7 @@ class AST_KEYWORDS(enum.IntEnum): STATEMENT = 0 CLAUSE = 1 EXPRESSION = 2 + COLUMN_DEFINITION = 10 class _statement: def __init__(self): @@ -20,21 +21,26 @@ class _clause: def __init__(self): self.attribute = AST_KEYWORDS.CLAUSE self.content = [] + self.value = None def deal(self, cls, value): """ - 对于一个非关键词的token,根据自己clause的类型,将其加入到自己的内容中 + 对于一个非关键词的token,根据自己的类型,将其加入到自己的内容中 """ # 当前columns的类型决定其会记录columns - if self.value in ["SELECT"]: + if self.value in ["SELECT", "CREATE", "COLUMNS"]: if cls in tokens.Name: self.content.append(value) + elif cls in tokens.Wildcard: + self.content.append(tokens.Wildcard) # 当前clause的类型决定其会记录tables - if self.value in ["FROM", "UPDATE"]: + if self.value in ["FROM", "UPDATE", "INSERT"]: if cls in tokens.Name: self.content.append(value) - # 当前clause的类型决定其会记录一个表达式 - # TODO: 完成一些例如CREATE,PRIMARY之类的处理 + # 当前clause的类型决定其会记录values,但并不是表达式 + if self.value in ["VALUES"]: + if cls in tokens.Name or cls in tokens.Literal: + self.content.append(Numerize(cls,value)) # WHERE很特殊,之后WHERE会被重复利用一次,以完成多个表达式的连接 # 因此理论上WHERE clause不可能接受到非关键词token(忽略),WHERE expression才会接受到 @@ -46,12 +52,13 @@ class _expression: def __init__(self): self.attribute = AST_KEYWORDS.EXPRESSION self.content = [] + self.value = None def deal(self, cls, value): """ 对于一个非关键词的token,根据自己expression的类型,将其加入到自己的内容中 """ - if self.value in ["WHERE", "AND", "OR", "SET"]: + if self.value in ["WHERE", "AND", "OR"]: if cls in tokens.Name or cls in tokens.Literal: if self.content == []: # 左边 sub_expr_left = {"left": Numerize(cls,value), "op": None, "right": None} @@ -60,6 +67,41 @@ def deal(self, cls, value): self.content[0]["right"] = Numerize(cls,value) if cls in tokens.Operator: # 中间的Operator self.content[0]["op"] = value + if self.value in ["SET"]: + if cls in tokens.Name or cls in tokens.Literal: + if self.content == []: # 左边 + sub_expr_left = {"assignment": Numerize(cls,value), "expression": None} + self.content.append(sub_expr_left) + else: # 右边 + if self.content[0]["expression"] == None: + self.content[0]["expression"] = dict() + self.content[0]["expression"]["left"] = Numerize(cls,value) + else: + self.content[0]["expression"]["right"] = Numerize(cls,value) + if cls in tokens.Operator or cls in tokens.Wildcard: # 中间的Operator,Wildcard是指* + if self.content[0]["expression"] == None: + pass + else: + self.content[0]["expression"]["op"] = value + +class _coldef: + def __init__(self): + self.attribute = AST_KEYWORDS.COLUMN_DEFINITION + self.content = [{"PRIMARY":False,"NOT NULL":False}] + + def deal(self, cls, value): + """ + 对于一个非关键词的token,根据自己的类型,将其加入到自己的内容中 + """ + # 说明这是一个column的类型提示 + if cls in tokens.Name.Builtin: + self.content[0]["type"] = value + + elif cls in tokens.Name: + self.content[0]["name"] = value + + elif cls in tokens.Literal: + self.content[0]["length"] = Numerize(cls,value) def Numerize(cls, text: str): diff --git a/sql_parse/keywords.py b/sql_parse/keywords.py index 5d3a405..d851f33 100644 --- a/sql_parse/keywords.py +++ b/sql_parse/keywords.py @@ -711,4 +711,6 @@ 'MIN': tokens.Keyword, 'MAX': tokens.Keyword, 'DISTINCT': tokens.Keyword, -} \ No newline at end of file +} + + diff --git a/sql_parse/tokenizer.py b/sql_parse/tokenizer.py index d694953..d0bde05 100644 --- a/sql_parse/tokenizer.py +++ b/sql_parse/tokenizer.py @@ -26,6 +26,13 @@ def default_initialization(self): self._keywords.append(keywords.KEYWORDS_COMMON) self._keywords.append(keywords.KEYWORDS) + # reverse to find the built-in keywords + self.KEYWORDS_BUILTIN_LIST = [] + for k, v in keywords.KEYWORDS.items(): + if v != tokens.Name.Builtin: + continue + self.KEYWORDS_BUILTIN_LIST.append(k) + def is_keyword(self, value): """ 判断当前的内容是一个NAME还是关键字 @@ -40,6 +47,20 @@ def is_keyword(self, value): else: return tokens.Name, value + def builtin_check(self, action, value): + """ + 一个特殊的check,保证当前的token是一个内置的类型,而非一个自定义的类型 + 这里多余的处理是因为最开始的正则不太对 + """ + if action != tokens.Token.Name: + return action, value + val = value.upper() + if val in self.KEYWORDS_BUILTIN_LIST: + return tokens.Token.Name.Builtin, value + else: + return action, value + + def consume(self,iterator, n): """ 将迭代器向后推动n个位置,即跳过n个元素,这里是为了在匹配到字符为token后,跳过已匹配部分 @@ -64,7 +85,7 @@ def tokenize(self,text: str): if not m: # 从这个字符往后看,并不能匹配出任何的关键字或者token continue elif isinstance(action, tokens._TokenType): # 如果这是一个普通Token - yield action, m.group() + yield self.builtin_check(action,m.group()) elif action is keywords.PROCESS_AS_KEYWORD: # 如果这是一个关键字 yield self.is_keyword(m.group())