Skip to content

Commit 5bab31b

Browse files
committed
improve api
1 parent 0488e63 commit 5bab31b

File tree

4 files changed

+153
-35
lines changed

4 files changed

+153
-35
lines changed

README.md

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,34 @@
11
# [Query.Farm](https://query.farm) SQL Manipulation
22

3-
A Python library for intelligent SQL predicate manipulation using [SQLGlot](https://sqlglot.com/sqlglot.html). This library provides tools to safely remove specific predicates from `SQL WHERE` clauses and filter SQL statements based on column availability.
3+
A Python library for intelligent SQL predicate manipulation using [SQLGlot](https://sqlglot.org)
4+
5+
### Column Filtering with Complex Expressions
6+
7+
```python
8+
import sqlglot
9+
from query_farm_sql_manipulation import transforms
10+
11+
sql = '''
12+
SELECT * FROM users
13+
WHERE age > 18
14+
AND (status = 'active' OR role = 'admin')
15+
AND department IN ('engineering', 'sales')
16+
'''
17+
18+
# Parse the statement first
19+
statement = sqlglot.parse_one(sql, dialect="duckdb")
20+
21+
# Only keep predicates involving 'age' and 'role'
22+
allowed_columns = {'age', 'role'}
23+
24+
result = transforms.filter_column_references(
25+
statement=statement,
26+
selector=lambda col: col.name in allowed_columns,
27+
)
28+
29+
# Result: SELECT * FROM users WHERE age > 18 AND role = 'admin'
30+
print(result.sql())
31+
```
432

533
## Features
634

@@ -44,19 +72,22 @@ transforms.remove_expression_part(target_predicate)
4472
print(statement.sql())
4573
```
4674

47-
### Column-Based Filtering
75+
### Column-Name Based Filtering
4876

4977
```python
78+
import sqlglot
5079
from query_farm_sql_manipulation import transforms
5180

52-
# Filter SQL to only include predicates with allowed columns
81+
# Parse SQL statement first
5382
sql = 'SELECT * FROM data WHERE color = "red" AND size > 10 AND type = "car"'
83+
statement = sqlglot.parse_one(sql, dialect="duckdb")
84+
85+
# Filter to only include predicates with allowed columns
5486
allowed_columns = {"color", "type"}
5587

56-
filtered = transforms.filter_column_references_statement(
57-
sql=sql,
58-
allowed_column_names=allowed_columns,
59-
dialect="duckdb"
88+
filtered = transforms.filter_column_references(
89+
statement=statement,
90+
selector=lambda col: col.name in allowed_columns,
6091
)
6192

6293
# Result: SELECT * FROM data WHERE color = "red" AND type = "car"
@@ -67,7 +98,7 @@ print(filtered.sql())
6798

6899
### `remove_expression_part(child: sqlglot.Expression) -> None`
69100

70-
Removes the specified expression from its parent, respecting logical structure.
101+
Removes the specified SQLGlot expression from its parent, respecting logical structure.
71102

72103
**Parameters:**
73104
- `child`: The SQLGlot expression to remove
@@ -82,21 +113,40 @@ Removes the specified expression from its parent, respecting logical structure.
82113
- `NOT` expressions: Removes the entire NOT expression
83114
- `CASE` statements: Removes conditional branches
84115

85-
### `filter_column_references_statement(*, sql: str, allowed_column_names: Container[str], dialect: str = "duckdb") -> sqlglot.Expression`
116+
### `filter_column_references(*, statement: sqlglot.Expression, selector: Callable[[sqlglot.expressions.Column], bool]) -> sqlglot.Expression`
86117

87-
Filters a SQL statement to remove predicates containing columns not in the allowed set.
118+
Filters a SQL statement to remove predicates containing columns that don't match the selector criteria.
88119

89120
**Parameters:**
90-
- `sql`: The SQL statement to filter
91-
- `allowed_column_names`: Container of column names that should be preserved
92-
- `dialect`: SQL dialect for parsing (default: "duckdb")
121+
- `statement`: The SQLGlot expression to filter
122+
- `selector`: A callable that takes a Column and returns True if it should be preserved, False if it should be removed
93123

94124
**Returns:**
95-
- Filtered SQLGlot expression with non-allowed columns removed
125+
- Filtered SQLGlot expression with non-matching columns removed
96126

97127
**Raises:**
98128
- `ValueError`: If a column can't be cleanly removed due to interactions with allowed columns
99129

130+
### `where_clause_contents(statement: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression | None`
131+
132+
Extracts the contents of the WHERE clause from a SQLGlot expression.
133+
134+
**Parameters:**
135+
- `statement`: The SQLGlot expression to extract from
136+
137+
**Returns:**
138+
- The contents of the WHERE clause, or None if no WHERE clause exists
139+
140+
### `filter_predicates_with_right_side_column_references(statement: sqlglot.expressions.Expression) -> sqlglot.Expression`
141+
142+
Filters out predicates that have column references on the right side of comparisons.
143+
144+
**Parameters:**
145+
- `statement`: The SQLGlot expression to filter
146+
147+
**Returns:**
148+
- Filtered SQLGlot expression with right-side column reference predicates removed
149+
100150
## Examples
101151

102152
### Complex Logic Handling
@@ -140,9 +190,15 @@ result = transforms.filter_column_references_statement(
140190
The library will raise `ValueError` when predicates cannot be safely removed:
141191

142192
```python
193+
import sqlglot
194+
from query_farm_sql_manipulation import transforms
195+
143196
# This will raise ValueError because x = 1 is part of a larger expression
144197
sql = "SELECT * FROM data WHERE result = (x = 1)"
198+
statement = sqlglot.parse_one(sql, dialect="duckdb")
199+
145200
# Cannot remove x = 1 because it's used as a value, not a predicate
201+
# This would raise ValueError if attempted
146202
```
147203

148204
## Supported SQL Constructs

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "query-farm-sql-manipulation"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
description = "A Python library for intelligent SQL predicate manipulation using SQLGlot. This library provides tools to safely remove specific predicates from `SQL WHERE` clauses and filter SQL statements based on column availability."
55
authors = [
66
{ name = "Rusty Conover", email = "rusty@query.farm" }

src/query_farm_sql_manipulation/test_transforms.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _remove_target_predicate(sql: str, target: str, expected: str) -> None:
4444
matching_predicates = [p for p in predicates if p == target_predicate]
4545

4646
# Assert that we found at least one match
47-
assert matching_predicates, f"No matching predicate found for: {target}"
47+
assert matching_predicates, f"No matching predicate found for: {target} {expression}"
4848

4949
# Remove all matching predicates
5050
for matched_predicate in matching_predicates:
@@ -145,6 +145,31 @@ def test_remove_logical_predicates_errors(sql: str, target: str) -> None:
145145
_remove_target_predicate(sql, target, "")
146146

147147

148+
@pytest.mark.parametrize(
149+
"sql, expected", [("""x = 1""", "x = 1"), ("foo = bar and z = 4", "foo = bar AND z = 4")]
150+
)
151+
def test_where_clause_extract(sql: str, expected: str) -> None:
152+
statement = sqlglot.parse_one(f'SELECT * FROM "data" WHERE {sql}')
153+
extracted_where = transforms.where_clause_contents(statement)
154+
assert extracted_where is not None, "Expected a WHERE clause to be present"
155+
assert extracted_where.sql("duckdb") == expected
156+
157+
158+
@pytest.mark.parametrize(
159+
"sql, expected",
160+
[
161+
("""v1 >= v1 + 5 and z = 5""", "z = 5"),
162+
("""((v1 >= v1 + 5) or t1 = 5) and z = 5""", "(t1 = 5) AND z = 5"),
163+
],
164+
)
165+
def test_filter_predicates_with_right_side_column_references(sql: str, expected: str) -> None:
166+
statement = sqlglot.parse_one(f'SELECT * FROM "data" WHERE {sql}')
167+
updated = transforms.filter_predicates_with_right_side_column_references(statement)
168+
extracted_where = transforms.where_clause_contents(updated)
169+
assert extracted_where is not None, "Expected a WHERE clause to be present"
170+
assert extracted_where.sql("duckdb") == expected
171+
172+
148173
@pytest.mark.parametrize(
149174
"column_names, sql, expected_sql",
150175
[
@@ -356,8 +381,10 @@ def test_filter_column_references(
356381
) -> None:
357382
dialect = "duckdb"
358383
full_sql = f'SELECT * FROM "data" AS "data" WHERE {sql}'
359-
result = transforms.filter_column_references_statement(
360-
sql=full_sql, allowed_column_names=column_names
384+
statement = sqlglot.parse_one(full_sql, dialect="duckdb")
385+
result = transforms.filter_column_references(
386+
statement=statement,
387+
selector=lambda col: col.name in column_names,
361388
)
362389

363390
optimized = sqlglot.optimizer.optimize(result, rules=RULES_WITHOUT_NORMALIZE, dialect=dialect)

src/query_farm_sql_manipulation/transforms.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Container
1+
from collections.abc import Callable
22

33
import sqlglot
44
import sqlglot.expressions
5+
import sqlglot.optimizer.simplify
56

67

78
def remove_expression_part(child: sqlglot.Expression) -> None:
@@ -65,36 +66,70 @@ def remove_expression_part(child: sqlglot.Expression) -> None:
6566
raise ValueError(f"Cannot remove child from parent of type {type(parent)} {parent.sql()}")
6667

6768

68-
def filter_column_references_statement(
69-
*, sql: str, allowed_column_names: Container[str], dialect: str = "duckdb"
69+
def where_clause_contents(
70+
statement: sqlglot.expressions.Expression,
71+
) -> sqlglot.expressions.Expression | None:
72+
"""
73+
Extract the contents of the WHERE clause from a SQLGlot expression.
74+
Args:
75+
statement: The SQLGlot expression to extract from
76+
Returns:
77+
The contents of the WHERE clause, or None if no WHERE clause exists
78+
"""
79+
where_clause = statement.find(sqlglot.expressions.Where)
80+
if where_clause is None:
81+
return None
82+
return where_clause.this
83+
84+
85+
def filter_predicates_with_right_side_column_references(
86+
statement: sqlglot.expressions.Expression,
87+
) -> sqlglot.Expression:
88+
# Need to simplify the statement to move the column references to the
89+
# left side by default.
90+
statement = sqlglot.optimizer.simplify.simplify(statement)
91+
92+
# If there's no WHERE clause, nothing to filter
93+
where_clause = statement.find(sqlglot.expressions.Where)
94+
if where_clause is None:
95+
return statement
96+
assert where_clause is not None
97+
98+
for predicate in where_clause.find_all(sqlglot.expressions.Predicate):
99+
assert predicate is not None
100+
if predicate.right.find(sqlglot.expressions.Column):
101+
remove_expression_part(predicate)
102+
103+
return statement
104+
105+
106+
def filter_column_references(
107+
*,
108+
statement: sqlglot.expressions.Expression,
109+
selector: Callable[[sqlglot.expressions.Column], bool],
70110
) -> sqlglot.Expression:
71111
"""
72112
Filter SQL statement to remove predicates with columns not in allowed_column_names.
73113
74114
Args:
75115
sql: The SQL statement to filter
76-
allowed_column_names: Container of column names that should be preserved
77-
dialect: The SQL dialect to use for parsing (default is "duckdb")
116+
selector: A callable that determines if a column should be preserved.
117+
It should return True for columns that are allowed, and False for those to be removed.
78118
79119
Returns:
80120
Filtered SQLGlot expression with non-allowed columns removed
81121
82122
Raises:
83123
ValueError: If a column can't be cleanly removed due to interactions with allowed columns
84124
"""
85-
# Parse and optimize the statement for predictable traversal
86-
statement = sqlglot.parse_one(sql, dialect=dialect)
87-
88125
# If there's no WHERE clause, nothing to filter
89126
where_clause = statement.find(sqlglot.expressions.Where)
90127
if where_clause is None:
91128
return statement
92129

93130
# Find all column references not in allowed_column_names
94131
column_refs_to_remove = [
95-
col
96-
for col in where_clause.find_all(sqlglot.expressions.Column)
97-
if col.name not in allowed_column_names
132+
col for col in where_clause.find_all(sqlglot.expressions.Column) if not selector(col)
98133
]
99134

100135
# Process each column reference that needs to be removed
@@ -104,7 +139,7 @@ def filter_column_references_statement(
104139
closest_expression = _find_closest_removable_expression(column_ref)
105140

106141
# Check if removing this expression would affect allowed columns
107-
if _can_safely_remove_expression(closest_expression, allowed_column_names):
142+
if _can_safely_remove_expression(closest_expression, selector):
108143
remove_expression_part(closest_expression)
109144
else:
110145
raise ValueError(
@@ -132,14 +167,16 @@ def _find_closest_removable_expression(
132167

133168

134169
def _can_safely_remove_expression(
135-
expression: sqlglot.expressions.Expression, allowed_column_names: Container[str]
170+
expression: sqlglot.expressions.Expression,
171+
selector: Callable[[sqlglot.expressions.Column], bool],
136172
) -> bool:
137173
"""
138174
Check if an expression can be safely removed without affecting allowed columns.
139175
140176
Args:
141177
expression: The expression to check
142-
allowed_column_names: Container of allowed column names
178+
selector: A callable that determines if a column should be preserved.
179+
It should return True for columns that are allowed, and False for those to be removed.
143180
144181
Returns:
145182
True if the expression can be safely removed, False otherwise
@@ -161,9 +198,7 @@ def _can_safely_remove_expression(
161198

162199
# Check if this expression references any allowed columns
163200
allowed_columns_referenced = [
164-
col.name
165-
for col in expression.find_all(sqlglot.expressions.Column)
166-
if col.name in allowed_column_names
201+
col.name for col in expression.find_all(sqlglot.expressions.Column) if selector(col)
167202
]
168203

169204
# If there are no allowed columns referenced, it's safe to remove

0 commit comments

Comments
 (0)