11import pathlib
22import sys
3- from typing import List , Tuple , cast
3+ from typing import Any , Callable , Dict , List , Optional , Tuple , TypeVar , cast
44
55import click
66import sqlfluff
77import yaml
88from sqlfluff .core import SQLFluffUserError
99
1010from ..cli_constants import BUILD_DIR
11- from ..cli_utils import echo_error , echo_suberror , echo_subinfo , echo_warning
11+ from ..cli_utils import echo_error , echo_info , echo_suberror , echo_subinfo , echo_warning
1212from ..config_generation import read_dictionary_from_config_directory
1313from ..errors import SQLLintError
1414from ..func_utils import flat_map
1515
1616if sys .version_info >= (3 , 8 ):
17- from typing import TypedDict # pylint: disable=no-name-in-module
17+ from typing import Protocol , TypedDict # pylint: disable=no-name-in-module
1818else :
19- from typing_extensions import TypedDict
19+ from typing_extensions import Protocol , TypedDict
20+
21+
22+ T = TypeVar ("T" , covariant = True )
23+ S = TypeVar ("S" )
2024
2125
2226class LintResult (TypedDict ):
@@ -26,6 +30,19 @@ class LintResult(TypedDict):
2630 description : str
2731
2832
33+ # Utility class representing sqlfluff `lint` and `fix` types
34+ class SqlFluffCallable (Protocol [T ]):
35+ def __call__ (
36+ self ,
37+ sql : str ,
38+ dialect : str = "ansi" ,
39+ rules : Optional [List [str ]] = None ,
40+ exclude_rules : Optional [List [str ]] = None ,
41+ config_path : Optional [str ] = None ,
42+ ) -> T :
43+ ...
44+
45+
2946def _get_dialect_or_default () -> str :
3047 """Read ``dbt.yml`` config file and return its ``target_type`` or just the ``ansi``."""
3148 env , dbt_filename = "base" , "dbt.yml"
@@ -54,8 +71,13 @@ def _get_source_tests_paths() -> List[pathlib.Path]:
5471 return list (map (lambda dir_name : pathlib .Path .cwd ().joinpath (dir_name ), dir_names ))
5572
5673
57- def _lint_sql_files (dialect : str ) -> List [Tuple [pathlib .Path , List [LintResult ]]]:
58- lint_results = []
74+ def _process_sql_files (
75+ dialect : str ,
76+ include_rules : Optional [List [str ]],
77+ exclude_rules : Optional [List [str ]],
78+ sqlfluff_fn : SqlFluffCallable [S ],
79+ result_callback : Callable [[pathlib .Path , S ], None ],
80+ ) -> None :
5981 sql_file_paths : List [pathlib .Path ] = flat_map (
6082 lambda dir_path : dir_path .rglob ("*.sql" ), _get_source_tests_paths ()
6183 )
@@ -64,21 +86,45 @@ def _lint_sql_files(dialect: str) -> List[Tuple[pathlib.Path, List[LintResult]]]
6486 sql_file_str = sql_file .read ()
6587
6688 try :
67- lint_result = sqlfluff .lint (sql_file_str , dialect = dialect )
89+ result = sqlfluff_fn (
90+ sql_file_str , dialect = dialect , rules = include_rules , exclude_rules = exclude_rules
91+ )
6892 except SQLFluffUserError : # dialect does not exist, try default instead
6993 echo_warning (
7094 f"Dialect { dialect } did not get recognized. "
7195 'Linting using default one ("ansi") instead.'
7296 )
73- lint_result = sqlfluff .lint (sql_file_str )
97+ result = sqlfluff_fn (sql_file_str )
98+
99+ result_callback (sql_file_path , result )
100+
74101
102+ def _fix_sql_files (
103+ dialect : str , include_rules : Optional [List [str ]], exclude_rules : Optional [List [str ]]
104+ ) -> None :
105+ def result_callback (sql_file_path : pathlib .Path , fix_result : str ) -> None :
106+ with open (sql_file_path , "w" ) as sql_file :
107+ sql_file .write (fix_result )
108+
109+ _process_sql_files (dialect , include_rules , exclude_rules , sqlfluff .fix , result_callback )
110+
111+
112+ def _lint_sql_files (
113+ dialect : str , include_rules : Optional [List [str ]], exclude_rules : Optional [List [str ]]
114+ ) -> List [Tuple [pathlib .Path , List [LintResult ]]]:
115+ lint_results = []
116+
117+ def result_callback (sql_file_path : pathlib .Path , lint_result : List [Dict [str , Any ]]) -> None :
118+ nonlocal lint_results
75119 if len (lint_result ) > 0 :
76120 lint_results .append (
77121 (
78122 sql_file_path .relative_to (pathlib .Path .cwd ()),
79123 cast (List [LintResult ], lint_result ),
80124 )
81125 )
126+
127+ _process_sql_files (dialect , include_rules , exclude_rules , sqlfluff .lint , result_callback )
82128 return lint_results
83129
84130
@@ -102,17 +148,58 @@ def _print_lint_results(lint_results: List[Tuple[pathlib.Path, List[LintResult]]
102148 click .echo ("" )
103149
104150
105- def lint () -> None :
151+ def lint (fix : bool , include_rules : Optional [ List [ str ]], exclude_rules : Optional [ List [ str ]] ) -> None :
106152 """
107153 Lint and format SQL.
154+
155+ :param fix: Whether to lint and fix linting errors, or just lint.
156+ :type fix: bool
157+ :param include_rules: A subset of rules to lint with.
158+ :type include_rules: Optional[List[str]]
159+ :param exclude_rules: A subset of rules not to lint with.
160+ :type exclude_rules: Optional[List[str]]
108161 """
162+ echo_info ("Linting SQLs:" )
109163 dialect = _get_dialect_or_default ()
110- lint_results = _lint_sql_files (dialect )
164+ if fix :
165+ echo_subinfo ("Attempting to fix SQLs. Not every error can be automatically fixed." )
166+ _fix_sql_files (dialect , include_rules , exclude_rules )
167+ echo_subinfo ("Linting SQLs." )
168+ lint_results = _lint_sql_files (dialect , include_rules , exclude_rules )
111169 _print_lint_results (lint_results )
112170 if len (lint_results ) > 0 :
113171 raise SQLLintError (list (map (lambda tup : str (tup [0 ]), lint_results )))
114172
115173
116- @click .command (name = "lint" , help = "Lint and format SQL" )
117- def lint_command () -> None :
118- lint ()
174+ @click .command (
175+ name = "lint" ,
176+ short_help = "Lint and format SQL" ,
177+ help = "Lint and format SQL using SQLFluff.\n \n "
178+ "For more information on rules and the workings of SQLFluff, "
179+ "refer to https://docs.sqlfluff.com/" ,
180+ )
181+ @click .option (
182+ "--no-fix" ,
183+ is_flag = True ,
184+ default = False ,
185+ type = bool ,
186+ help = "Whether to lint and fix linting errors, or just lint." ,
187+ )
188+ @click .option (
189+ "--rules" ,
190+ required = False ,
191+ type = str ,
192+ help = "A subset of rules to lint with, as string of rules separated by a comma." ,
193+ )
194+ @click .option (
195+ "--exclude-rules" ,
196+ required = False ,
197+ type = str ,
198+ help = "A subset of rules not to lint with, as string of rules separated by a comma." ,
199+ )
200+ def lint_command (no_fix : bool , rules : Optional [str ], exclude_rules : Optional [str ]) -> None :
201+ lint (
202+ not no_fix ,
203+ rules .split ("," ) if rules else None ,
204+ exclude_rules .split ("," ) if exclude_rules else None ,
205+ )
0 commit comments