Skip to content

Commit 8fa4b85

Browse files
committed
Fix SQL linting errors
1 parent 968f61b commit 8fa4b85

File tree

2 files changed

+134
-18
lines changed

2 files changed

+134
-18
lines changed

data_pipelines_cli/cli_commands/lint.py

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import pathlib
22
import sys
3-
from typing import List, Tuple, cast
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
44

55
import click
66
import sqlfluff
77
import yaml
88
from sqlfluff.core import SQLFluffUserError
99

1010
from ..cli_constants import BUILD_DIR
11-
from ..cli_utils import echo_error, echo_suberror, echo_subinfo, echo_warning
11+
from ..cli_utils import echo_error, echo_info, echo_suberror, echo_subinfo, echo_warning
1212
from ..config_generation import read_dictionary_from_config_directory
1313
from ..errors import SQLLintError
1414
from ..func_utils import flat_map
1515

1616
if sys.version_info >= (3, 8):
17-
from typing import TypedDict # pylint: disable=no-name-in-module
17+
from typing import Protocol, TypedDict # pylint: disable=no-name-in-module
1818
else:
19-
from typing_extensions import TypedDict
19+
from typing_extensions import Protocol, TypedDict
20+
21+
22+
T = TypeVar("T", covariant=True)
23+
S = TypeVar("S")
2024

2125

2226
class LintResult(TypedDict):
@@ -26,6 +30,19 @@ class LintResult(TypedDict):
2630
description: str
2731

2832

33+
# Utility class representing sqlfluff `lint` and `fix` types
34+
class SqlFluffCallable(Protocol[T]):
35+
def __call__(
36+
self,
37+
sql: str,
38+
dialect: str = "ansi",
39+
rules: Optional[List[str]] = None,
40+
exclude_rules: Optional[List[str]] = None,
41+
config_path: Optional[str] = None,
42+
) -> T:
43+
...
44+
45+
2946
def _get_dialect_or_default() -> str:
3047
"""Read ``dbt.yml`` config file and return its ``target_type`` or just the ``ansi``."""
3148
env, dbt_filename = "base", "dbt.yml"
@@ -54,8 +71,13 @@ def _get_source_tests_paths() -> List[pathlib.Path]:
5471
return list(map(lambda dir_name: pathlib.Path.cwd().joinpath(dir_name), dir_names))
5572

5673

57-
def _lint_sql_files(dialect: str) -> List[Tuple[pathlib.Path, List[LintResult]]]:
58-
lint_results = []
74+
def _process_sql_files(
75+
dialect: str,
76+
include_rules: Optional[List[str]],
77+
exclude_rules: Optional[List[str]],
78+
sqlfluff_fn: SqlFluffCallable[S],
79+
result_callback: Callable[[pathlib.Path, S], None],
80+
) -> None:
5981
sql_file_paths: List[pathlib.Path] = flat_map(
6082
lambda dir_path: dir_path.rglob("*.sql"), _get_source_tests_paths()
6183
)
@@ -64,21 +86,45 @@ def _lint_sql_files(dialect: str) -> List[Tuple[pathlib.Path, List[LintResult]]]
6486
sql_file_str = sql_file.read()
6587

6688
try:
67-
lint_result = sqlfluff.lint(sql_file_str, dialect=dialect)
89+
result = sqlfluff_fn(
90+
sql_file_str, dialect=dialect, rules=include_rules, exclude_rules=exclude_rules
91+
)
6892
except SQLFluffUserError: # dialect does not exist, try default instead
6993
echo_warning(
7094
f"Dialect {dialect} did not get recognized. "
7195
'Linting using default one ("ansi") instead.'
7296
)
73-
lint_result = sqlfluff.lint(sql_file_str)
97+
result = sqlfluff_fn(sql_file_str)
98+
99+
result_callback(sql_file_path, result)
100+
74101

102+
def _fix_sql_files(
103+
dialect: str, include_rules: Optional[List[str]], exclude_rules: Optional[List[str]]
104+
) -> None:
105+
def result_callback(sql_file_path: pathlib.Path, fix_result: str) -> None:
106+
with open(sql_file_path, "w") as sql_file:
107+
sql_file.write(fix_result)
108+
109+
_process_sql_files(dialect, include_rules, exclude_rules, sqlfluff.fix, result_callback)
110+
111+
112+
def _lint_sql_files(
113+
dialect: str, include_rules: Optional[List[str]], exclude_rules: Optional[List[str]]
114+
) -> List[Tuple[pathlib.Path, List[LintResult]]]:
115+
lint_results = []
116+
117+
def result_callback(sql_file_path: pathlib.Path, lint_result: List[Dict[str, Any]]) -> None:
118+
nonlocal lint_results
75119
if len(lint_result) > 0:
76120
lint_results.append(
77121
(
78122
sql_file_path.relative_to(pathlib.Path.cwd()),
79123
cast(List[LintResult], lint_result),
80124
)
81125
)
126+
127+
_process_sql_files(dialect, include_rules, exclude_rules, sqlfluff.lint, result_callback)
82128
return lint_results
83129

84130

@@ -102,17 +148,58 @@ def _print_lint_results(lint_results: List[Tuple[pathlib.Path, List[LintResult]]
102148
click.echo("")
103149

104150

105-
def lint() -> None:
151+
def lint(fix: bool, include_rules: Optional[List[str]], exclude_rules: Optional[List[str]]) -> None:
106152
"""
107153
Lint and format SQL.
154+
155+
:param fix: Whether to lint and fix linting errors, or just lint.
156+
:type fix: bool
157+
:param include_rules: A subset of rules to lint with.
158+
:type include_rules: Optional[List[str]]
159+
:param exclude_rules: A subset of rules not to lint with.
160+
:type exclude_rules: Optional[List[str]]
108161
"""
162+
echo_info("Linting SQLs:")
109163
dialect = _get_dialect_or_default()
110-
lint_results = _lint_sql_files(dialect)
164+
if fix:
165+
echo_subinfo("Attempting to fix SQLs. Not every error can be automatically fixed.")
166+
_fix_sql_files(dialect, include_rules, exclude_rules)
167+
echo_subinfo("Linting SQLs.")
168+
lint_results = _lint_sql_files(dialect, include_rules, exclude_rules)
111169
_print_lint_results(lint_results)
112170
if len(lint_results) > 0:
113171
raise SQLLintError(list(map(lambda tup: str(tup[0]), lint_results)))
114172

115173

116-
@click.command(name="lint", help="Lint and format SQL")
117-
def lint_command() -> None:
118-
lint()
174+
@click.command(
175+
name="lint",
176+
short_help="Lint and format SQL",
177+
help="Lint and format SQL using SQLFluff.\n\n"
178+
"For more information on rules and the workings of SQLFluff, "
179+
"refer to https://docs.sqlfluff.com/",
180+
)
181+
@click.option(
182+
"--no-fix",
183+
is_flag=True,
184+
default=False,
185+
type=bool,
186+
help="Whether to lint and fix linting errors, or just lint.",
187+
)
188+
@click.option(
189+
"--rules",
190+
required=False,
191+
type=str,
192+
help="A subset of rules to lint with, as string of rules separated by a comma.",
193+
)
194+
@click.option(
195+
"--exclude-rules",
196+
required=False,
197+
type=str,
198+
help="A subset of rules not to lint with, as string of rules separated by a comma.",
199+
)
200+
def lint_command(no_fix: bool, rules: Optional[str], exclude_rules: Optional[str]) -> None:
201+
lint(
202+
not no_fix,
203+
rules.split(",") if rules else None,
204+
exclude_rules.split(",") if exclude_rules else None,
205+
)

tests/cli_commands/test_lint.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def _lint_without_errors(self, sql_content, *_args, **_kwargs):
6363
self.linted_sqls.append(sql_content)
6464
return []
6565

66+
@staticmethod
67+
def _fix(sql_content: str, *_args, **_kwargs) -> str:
68+
return str(len(sql_content))
69+
6670
@patch("data_pipelines_cli.cli_commands.lint.BUILD_DIR", pathlib.Path("/a/b/c/d/e/f"))
6771
def test_lint_no_sqls(self):
6872
with tempfile.TemporaryDirectory() as tmp_dir, patch(
@@ -74,46 +78,71 @@ def test_lint_no_sqls(self):
7478
)
7579

7680
runner = CliRunner()
77-
result = runner.invoke(_cli, ["lint"])
81+
result = runner.invoke(_cli, ["lint", "--no-fix"])
7882
self.assertEqual(
7983
0, result.exit_code, msg="\n".join([str(result.exception), str(result.output)])
8084
)
8185

8286
@patch("data_pipelines_cli.cli_commands.lint.BUILD_DIR", pathlib.Path("/a/b/c/d/e/f"))
8387
def test_lint_sqls_with_errors(self):
88+
fix_mock = MagicMock()
8489
with patch("pathlib.Path.cwd", lambda: pathlib.Path(self.dbt_project_tmp_dir)), patch(
8590
"sqlfluff.lint", self._lint_with_errors
86-
):
91+
), patch("data_pipelines_cli.cli_commands.lint._fix_sql_files", fix_mock):
8792
runner = CliRunner()
88-
result = runner.invoke(_cli, ["lint"])
93+
result = runner.invoke(_cli, ["lint", "--no-fix"])
8994
self.assertEqual(
9095
1, result.exit_code, msg="\n".join([str(result.exception), str(result.output)])
9196
)
9297
self.assertIsInstance(result.exception, SQLLintError)
9398
self.assertSetEqual(set(map(lambda t: t[1], self.sqls)), set(self.linted_sqls))
9499
self.assertIn("L: 1 | P: 2 | C:L213 | description\n", result.output)
95100
self.assertIn("L: 2 | P: 4 | C:L137 | some other description\n", result.output)
101+
self.assertFalse(fix_mock.called)
96102

97103
@patch("data_pipelines_cli.cli_commands.lint.BUILD_DIR", pathlib.Path("/a/b/c/d/e/f"))
98104
def test_lint_sqls_with_errors_raises(self):
99105
with patch("pathlib.Path.cwd", lambda: pathlib.Path(self.dbt_project_tmp_dir)), patch(
100106
"sqlfluff.lint", self._lint_with_errors
101107
):
102108
with self.assertRaises(SQLLintError):
103-
lint()
109+
lint(False, None, None)
104110

105111
@patch("data_pipelines_cli.cli_commands.lint.BUILD_DIR", pathlib.Path("/a/b/c/d/e/f"))
106112
def test_lint_sqls_without_errors(self):
107113
with patch("pathlib.Path.cwd", lambda: pathlib.Path(self.dbt_project_tmp_dir)), patch(
108114
"sqlfluff.lint", self._lint_without_errors
109115
):
110116
runner = CliRunner()
111-
result = runner.invoke(_cli, ["lint"])
117+
result = runner.invoke(_cli, ["lint", "--no-fix"])
112118
self.assertEqual(
113119
0, result.exit_code, msg="\n".join([str(result.exception), str(result.output)])
114120
)
115121
self.assertSetEqual(set(map(lambda t: t[1], self.sqls)), set(self.linted_sqls))
116122

123+
@patch("data_pipelines_cli.cli_commands.lint.BUILD_DIR", pathlib.Path("/a/b/c/d/e/f"))
124+
def test_fix_sqls(self):
125+
with patch("pathlib.Path.cwd", lambda: pathlib.Path(self.dbt_project_tmp_dir)), patch(
126+
"sqlfluff.lint", lambda _sql_content, *_args, **_kwargs: []
127+
), patch("sqlfluff.fix", self._fix):
128+
runner = CliRunner()
129+
result = runner.invoke(_cli, ["lint"])
130+
self.assertEqual(
131+
0, result.exit_code, msg="\n".join([str(result.exception), str(result.output)])
132+
)
133+
134+
for sql_name, sql_content in self.sqls:
135+
with open(
136+
pathlib.Path(self.dbt_project_tmp_dir).joinpath("models", sql_name), "r"
137+
) as fixed_file:
138+
fixed_content = fixed_file.read()
139+
self.assertEqual(
140+
str(len(sql_content)),
141+
fixed_content,
142+
msg="Contents of an expected value and "
143+
f"the fixed {sql_name} file are different",
144+
)
145+
117146

118147
class LintHelperFunctionsTestCase(unittest.TestCase):
119148
def test_get_dialect(self):

0 commit comments

Comments
 (0)