diff --git a/awsclilinter/MANIFEST.in b/awsclilinter/MANIFEST.in index 5de61f385a65..a7aeb71e7d91 100644 --- a/awsclilinter/MANIFEST.in +++ b/awsclilinter/MANIFEST.in @@ -1,4 +1,3 @@ include README.md -recursive-include awsclilinter *.py -recursive-exclude tests * +include LICENSE.txt recursive-exclude examples * diff --git a/awsclilinter/README.md b/awsclilinter/README.md index a9636356d408..3c71e60f6f96 100644 --- a/awsclilinter/README.md +++ b/awsclilinter/README.md @@ -67,7 +67,8 @@ In interactive mode, you can: - Press `y` to accept the current change - Press `n` to skip the current change - Press `u` to accept all remaining changes -- Press `s` to save and exit (applies all accepted changes so far) +- Press `s` to save and exit +- Press `q` to quit without saving ## Development diff --git a/awsclilinter/awsclilinter/cli.py b/awsclilinter/awsclilinter/cli.py index 5bc7d26ac981..96b4a6f5edb1 100644 --- a/awsclilinter/awsclilinter/cli.py +++ b/awsclilinter/awsclilinter/cli.py @@ -2,11 +2,15 @@ import difflib import sys from pathlib import Path -from typing import List +from typing import Dict, List, Optional, Tuple -from awsclilinter.linter import ScriptLinter -from awsclilinter.rules import LintFinding +from ast_grep_py.ast_grep_py import SgRoot + +from awsclilinter import linter +from awsclilinter.linter import parse +from awsclilinter.rules import LintFinding, LintRule from awsclilinter.rules.base64_rule import Base64BinaryFormatRule +from awsclilinter.rules.pagination_rule import PaginationRule # ANSI color codes RED = "\033[31m" @@ -18,13 +22,17 @@ CONTEXT_SIZE = 3 -def get_user_choice(prompt: str) -> str: +def prompt_user_choice_interactive_mode() -> str: """Get user input for interactive mode.""" while True: - choice = input(prompt).lower().strip() - if choice in ["y", "n", "u", "s"]: + choice = ( + input("\nApply this fix? [y]es, [n]o, [u]pdate all, [s]ave and exit, [q]uit: ") + .lower() + .strip() + ) + if choice in ["y", "n", "u", "s", "q"]: return choice - print("Invalid choice. Please enter y, n, u, or s.") + print("Invalid choice. Please enter y, n, u, s, or q.") def display_finding(finding: LintFinding, index: int, total: int, script_content: str): @@ -75,22 +83,89 @@ def display_finding(finding: LintFinding, index: int, total: int, script_content print(line, end="") -def interactive_mode(findings: List[LintFinding], script_content: str) -> List[LintFinding]: - """Run interactive mode and return accepted findings.""" - accepted = [] - for i, finding in enumerate(findings, 1): - display_finding(finding, i, len(findings), script_content) - choice = get_user_choice("\nApply this fix? [y]es, [n]o, [u]pdate all, [s]ave and exit: ") +def apply_all_fixes( + findings_with_rules: List[Tuple[LintFinding, LintRule]], + ast: SgRoot, +) -> str: + """Apply all fixes using rule-by-rule processing. + + Since multiple rules can target the same command, we must process one rule + at a time and re-parse the updated script between rules to get fresh Edit objects. + + Args: + findings_with_rules: List of findings and their rules. + ast: Current script represented as an AST. + + Returns: + Modified script content + """ + current_ast = ast + + # Group findings by rule + findings_by_rule: Dict[str, List[LintFinding]] = {} + for finding, rule in findings_with_rules: + if rule.name not in findings_by_rule: + findings_by_rule[rule.name] = [] + findings_by_rule[rule.name].append(finding) + + # Process one rule at a time, re-parsing between rules + for rule in findings_by_rule: + updated_script = linter.apply_fixes(current_ast, findings_by_rule[rule]) + current_ast = parse(updated_script) + return current_ast.root().text() + - if choice == "y": - accepted.append(finding) - elif choice == "u": - accepted.extend(findings[i - 1 :]) - break - elif choice == "s": - break +def interactive_mode_for_rule( + findings: List[LintFinding], + ast: SgRoot, + finding_offset: int, + total_findings: int, +) -> Tuple[SgRoot, bool, Optional[str]]: + """Run interactive mode for a single rule's findings. - return accepted + Args: + findings: List of findings for this rule. + ast: Current script content, represented as an AST. + finding_offset: Offset for display numbering. + total_findings: Total number of findings across all rules. + + Returns: + Tuple of (ast, changes_made, last_choice) + ast is the resulting AST from this interactive mode execution. + changes_made whether the AST was updated based on user choice. + last_choice is the last choice entered by the user. + """ + accepted_findings: List[LintFinding] = [] + last_choice: Optional[str] = None + + for i, finding in enumerate(findings): + display_finding(finding, finding_offset + i + 1, total_findings, ast.root().text()) + last_choice = prompt_user_choice_interactive_mode() + + if last_choice == "y": + accepted_findings.append(finding) + elif last_choice == "n": + pass # Skip this finding + elif last_choice == "u": + # Accept this and all remaining findings for this rule. + accepted_findings.extend(findings[i:]) + if accepted_findings: + ast = parse(linter.apply_fixes(ast, accepted_findings)) + return ast, True, last_choice + elif last_choice == "s": + # Apply accepted findings and stop processing + if accepted_findings: + ast = parse(linter.apply_fixes(ast, accepted_findings)) + return ast, len(accepted_findings) > 0, last_choice + elif last_choice == "q": + print("Quit without saving.") + sys.exit(0) + + if accepted_findings: + ast = parse(linter.apply_fixes(ast, accepted_findings)) + return ast, True, last_choice + + return ast, False, last_choice def main(): @@ -127,31 +202,80 @@ def main(): script_content = script_path.read_text() - rules = [Base64BinaryFormatRule()] - linter = ScriptLinter(rules) - findings = linter.lint(script_content) - - if not findings: - print("No issues found.") - return + rules = [Base64BinaryFormatRule(), PaginationRule()] if args.interactive: - findings = interactive_mode(findings, script_content) - if not findings: + # Do an initial parse-and-lint with all the rules simultaneously to compute the total + # number of findings for displaying progress in interactive mode. + current_ast = parse(script_content) + findings_with_rules = linter.lint(current_ast, rules) + + if not findings_with_rules: + print("No issues found.") + return + + # Process one rule at a time, re-parsing between rules + current_script = script_content + any_changes = False + finding_offset = 0 + + # Calculate total findings for display + total_findings = len(findings_with_rules) + + for rule_index, rule in enumerate(rules): + # Lint for this specific rule with current script state + rule_findings = linter.lint_for_rule(current_ast, rule) + + if not rule_findings: + continue + + current_ast, changes_made, last_choice = interactive_mode_for_rule( + rule_findings, current_ast, finding_offset, total_findings + ) + + if changes_made: + current_script = current_ast.root().text() + any_changes = True + + finding_offset += len(rule_findings) + + if last_choice == "s": + break + + # If user chose 'u', auto-apply all remaining rules + if last_choice == "u": + for remaining_rule in rules[rule_index + 1 :]: + remaining_findings = linter.lint_for_rule(current_ast, remaining_rule) + if remaining_findings: + current_script = linter.apply_fixes(current_ast, remaining_findings) + any_changes = True + break + + if not any_changes: print("No changes accepted.") return - if args.fix or args.output or args.interactive: - # Interactive mode is functionally equivalent to --fix, except the user - # can select a subset of the changes to apply. - fixed_content = linter.apply_fixes(script_content, findings) output_path = Path(args.output) if args.output else script_path - output_path.write_text(fixed_content) + output_path.write_text(current_script) + print(f"Fixed script written to: {output_path}") + elif args.fix or args.output: + current_ast = parse(script_content) + findings_with_rules = linter.lint(current_ast, rules) + updated_script = apply_all_fixes(findings_with_rules, current_ast) + output_path = Path(args.output) if args.output else script_path + output_path.write_text(updated_script) print(f"Fixed script written to: {output_path}") else: - print(f"\nFound {len(findings)} issue(s):\n") - for i, finding in enumerate(findings, 1): - display_finding(finding, i, len(findings), script_content) + current_ast = parse(script_content) + findings_with_rules = linter.lint(current_ast, rules) + + if not findings_with_rules: + print("No issues found.") + return + + print(f"\nFound {len(findings_with_rules)} issue(s):\n") + for i, (finding, _) in enumerate(findings_with_rules, 1): + display_finding(finding, i, len(findings_with_rules), script_content) print("\n\nRun with --fix to apply changes or --interactive to review each change.") diff --git a/awsclilinter/awsclilinter/linter.py b/awsclilinter/awsclilinter/linter.py index 25b7c311e26b..8f28928465dc 100644 --- a/awsclilinter/awsclilinter/linter.py +++ b/awsclilinter/awsclilinter/linter.py @@ -1,26 +1,53 @@ -from typing import List +from typing import List, Tuple -from ast_grep_py import SgRoot +from ast_grep_py.ast_grep_py import SgRoot from awsclilinter.rules import LintFinding, LintRule -class ScriptLinter: - """Linter for bash scripts to detect AWS CLI v1 to v2 migration issues.""" +def parse(script_content: str) -> SgRoot: + """Parse the bash script content into an AST.""" + return SgRoot(script_content, "bash") - def __init__(self, rules: List[LintRule]): - self.rules = rules - def lint(self, script_content: str) -> List[LintFinding]: - """Lint the script and return all findings.""" - root = SgRoot(script_content, "bash") - findings = [] - for rule in self.rules: - findings.extend(rule.check(root)) - return sorted(findings, key=lambda f: (f.line_start, f.line_end)) +def lint(ast: SgRoot, rules: List[LintRule]) -> List[Tuple[LintFinding, LintRule]]: + """Lint the AST and return all findings with their associated rules.""" + findings_with_rules = [] + for rule in rules: + findings = rule.check(ast) + for finding in findings: + findings_with_rules.append((finding, rule)) + return sorted(findings_with_rules, key=lambda fr: (fr[0].edit.start_pos, fr[0].edit.end_pos)) - def apply_fixes(self, script_content: str, findings: List[LintFinding]) -> str: - """Apply fixes to the script content.""" - root = SgRoot(script_content, "bash") - node = root.root() - return node.commit_edits([f.edit for f in findings]) + +def lint_for_rule(ast: SgRoot, rule: LintRule) -> List[LintFinding]: + """Lint the script for a single rule. + + Args: + ast: The AST to lint for the rule. + rule: The rule to check. + + Returns: + List of findings for this rule, sorted by position (ascending) + """ + findings = rule.check(ast) + return sorted(findings, key=lambda f: (f.edit.start_pos, f.edit.end_pos)) + + +def apply_fixes(ast: SgRoot, findings: List[LintFinding]) -> str: + """Apply to the AST for a single rule. + + Args: + ast: The AST representation of the script to apply fixes to. + findings: List of findings from a single rule to apply. + + Returns: + Modified script content + """ + root = ast.root() + if not findings: + return root.text() + + # Collect all edits - they should be non-overlapping within a single rule + edits = [f.edit for f in findings] + return root.commit_edits(edits) diff --git a/awsclilinter/awsclilinter/rules/base.py b/awsclilinter/awsclilinter/rules/base.py index 5900a2c09ff8..4d512784a6b2 100644 --- a/awsclilinter/awsclilinter/rules/base.py +++ b/awsclilinter/awsclilinter/rules/base.py @@ -34,5 +34,9 @@ def description(self) -> str: @abstractmethod def check(self, root: SgRoot) -> List[LintFinding]: - """Check the AST root for violations and return findings.""" + """Check the AST root for violations and return findings. + + Args: + root: The AST root to check + """ pass diff --git a/awsclilinter/awsclilinter/rules/base64_rule.py b/awsclilinter/awsclilinter/rules/base64_rule.py index 6da9ad61fcf8..980d275243e9 100644 --- a/awsclilinter/awsclilinter/rules/base64_rule.py +++ b/awsclilinter/awsclilinter/rules/base64_rule.py @@ -6,9 +6,8 @@ class Base64BinaryFormatRule(LintRule): - """Detects AWS CLI commands with file:// that need --cli-binary-format. This is a best-effort - attempt at statically detecting the breaking change with how AWS CLI v2 treats binary - parameters.""" + """Detects any AWS CLI command that does not specify the --cli-binary-format. This mitigates + the breaking change with how AWS CLI v2 treats binary parameters.""" @property def name(self) -> str: @@ -23,13 +22,12 @@ def description(self) -> str: ) def check(self, root: SgRoot) -> List[LintFinding]: - """Check for AWS CLI commands with file:// missing --cli-binary-format.""" + """Check for AWS CLI commands missing --cli-binary-format.""" node = root.root() base64_broken_nodes = node.find_all( all=[ {"kind": "command"}, {"pattern": "aws $SERVICE $OPERATION $$$ARGS"}, - {"has": {"kind": "word", "regex": r"\Afile://"}}, {"not": {"has": {"kind": "word", "pattern": "--cli-binary-format"}}}, ] ) diff --git a/awsclilinter/awsclilinter/rules/pagination_rule.py b/awsclilinter/awsclilinter/rules/pagination_rule.py new file mode 100644 index 000000000000..f7416f0c4943 --- /dev/null +++ b/awsclilinter/awsclilinter/rules/pagination_rule.py @@ -0,0 +1,50 @@ +from typing import List + +from ast_grep_py.ast_grep_py import SgRoot + +from awsclilinter.rules import LintFinding, LintRule + + +class PaginationRule(LintRule): + """Detects AWS CLI commands missing --no-cli-paginate flag.""" + + @property + def name(self) -> str: + return "no-cli-paginate" + + @property + def description(self) -> str: + return ( + "AWS CLI v2 uses pagination by default for commands that return large result sets. " + "Add --no-cli-paginate to disable pagination and match v1 behavior." + ) + + def check(self, root: SgRoot) -> List[LintFinding]: + """Check for AWS CLI commands missing --no-cli-paginate.""" + node = root.root() + nodes = node.find_all( + all=[ + {"kind": "command"}, + {"pattern": "aws $SERVICE $OPERATION $$$ARGS"}, + {"not": {"has": {"kind": "word", "pattern": "--no-cli-paginate"}}}, + ] + ) + + findings = [] + for stmt in nodes: + original = stmt.text() + suggested = original + " --no-cli-paginate" + edit = stmt.replace(suggested) + + findings.append( + LintFinding( + line_start=stmt.range().start.line, + line_end=stmt.range().end.line, + edit=edit, + original_text=original, + rule_name=self.name, + description=self.description, + ) + ) + + return findings diff --git a/awsclilinter/pyproject.toml b/awsclilinter/pyproject.toml index 3b070b8cfe53..bd891e3fe965 100644 --- a/awsclilinter/pyproject.toml +++ b/awsclilinter/pyproject.toml @@ -59,3 +59,6 @@ testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] + +[tool.setuptools] +packages = ["awsclilinter"] \ No newline at end of file diff --git a/awsclilinter/tests/test_cli.py b/awsclilinter/tests/test_cli.py index 8e507a382008..3e43d066a61a 100644 --- a/awsclilinter/tests/test_cli.py +++ b/awsclilinter/tests/test_cli.py @@ -66,7 +66,9 @@ def test_fix_mode(self, tmp_path): with patch("sys.argv", ["upgrade-aws-cli", "--script", str(script_file), "--fix"]): main() fixed_content = script_file.read_text() + # 1 command, 2 rules = 2 flags added assert "--cli-binary-format" in fixed_content + assert "--no-cli-paginate" in fixed_content def test_output_mode(self, tmp_path): """Test output mode creates new file.""" @@ -82,7 +84,10 @@ def test_output_mode(self, tmp_path): ): main() assert output_file.exists() - assert "--cli-binary-format" in output_file.read_text() + content = output_file.read_text() + # 1 command, 2 rules = 2 flags added + assert "--cli-binary-format" in content + assert "--no-cli-paginate" in content def test_interactive_mode_accept_all(self, tmp_path): """Test interactive mode with 'y' to accept all changes.""" @@ -105,10 +110,13 @@ def test_interactive_mode_accept_all(self, tmp_path): str(output_file), ], ): - with patch("builtins.input", side_effect=["y", "y"]): + with patch("builtins.input", side_effect=["y", "y", "y", "y"]): main() fixed_content = output_file.read_text() + print(fixed_content) + # 2 commands, 2 rules = 4 findings, so 2 of each flag assert fixed_content.count("--cli-binary-format") == 2 + assert fixed_content.count("--no-cli-paginate") == 2 def test_interactive_mode_reject_all(self, tmp_path, capsys): """Test interactive mode with 'n' to reject all changes.""" @@ -146,7 +154,9 @@ def test_interactive_mode_update_all(self, tmp_path): with patch("builtins.input", return_value="u"): main() fixed_content = output_file.read_text() + # 2 commands, 2 rules = 4 findings, so 2 of each flag assert fixed_content.count("--cli-binary-format") == 2 + assert fixed_content.count("--no-cli-paginate") == 2 def test_interactive_mode_save_and_exit(self, tmp_path): """Test interactive mode with 's' to save and exit.""" @@ -173,4 +183,32 @@ def test_interactive_mode_save_and_exit(self, tmp_path): main() fixed_content = output_file.read_text() # Only first change should be applied since we pressed 's' on the second + # First finding is binary-params-base64 for cmd1 + assert "--cli-binary-format" in fixed_content assert fixed_content.count("--cli-binary-format") == 1 + + def test_interactive_mode_quit(self, tmp_path): + """Test interactive mode with 'q' to quit without saving.""" + script_file = tmp_path / "test.sh" + output_file = tmp_path / "output.sh" + script_file.write_text( + "aws secretsmanager put-secret-value --secret-id secret1213 --secret-binary file://data.json" + ) + + with patch( + "sys.argv", + [ + "upgrade-aws-cli", + "--script", + str(script_file), + "--interactive", + "--output", + str(output_file), + ], + ): + with patch("builtins.input", return_value="q"): + with pytest.raises(SystemExit) as exc_info: + main() + assert exc_info.value.code == 0 + # Output file should not exist since we quit without saving + assert not output_file.exists() diff --git a/awsclilinter/tests/test_linter.py b/awsclilinter/tests/test_linter.py index 1f3afb274259..86fb0fed0202 100644 --- a/awsclilinter/tests/test_linter.py +++ b/awsclilinter/tests/test_linter.py @@ -1,26 +1,28 @@ -from awsclilinter.linter import ScriptLinter +from awsclilinter import linter from awsclilinter.rules.base64_rule import Base64BinaryFormatRule -class TestScriptLinter: - """Test cases for ScriptLinter.""" +class TestLinter: + """Test cases for linter functions.""" def test_lint_finds_issues(self): """Test that linter finds issues in script.""" script = "aws secretsmanager put-secret-value --secret-id secret1213 --secret-binary file://data.json" - linter = ScriptLinter([Base64BinaryFormatRule()]) - findings = linter.lint(script) + ast = linter.parse(script) + findings_with_rules = linter.lint(ast, [Base64BinaryFormatRule()]) - assert len(findings) == 1 - assert findings[0].rule_name == "binary-params-base64" - assert "file://" in findings[0].original_text + assert len(findings_with_rules) == 1 + finding, rule = findings_with_rules[0] + assert finding.rule_name == "binary-params-base64" + assert "aws" in finding.original_text def test_apply_fixes(self): """Test that fixes are applied correctly.""" script = "aws secretsmanager put-secret-value --secret-id secret1213 --secret-binary file://data.json" - linter = ScriptLinter([Base64BinaryFormatRule()]) - findings = linter.lint(script) - fixed = linter.apply_fixes(script, findings) + ast = linter.parse(script) + findings_with_rules = linter.lint(ast, [Base64BinaryFormatRule()]) + findings = [f for f, _ in findings_with_rules] + fixed = linter.apply_fixes(ast, findings) assert "--cli-binary-format raw-in-base64-out" in fixed assert "file://data.json" in fixed @@ -33,7 +35,8 @@ def test_multiple_issues(self): " aws kinesis put-record --stream-name samplestream " "--data file://data --partition-key samplepartitionkey" ) - linter = ScriptLinter([Base64BinaryFormatRule()]) - findings = linter.lint(script) + ast = linter.parse(script) + findings_with_rules = linter.lint(ast, [Base64BinaryFormatRule()]) - assert len(findings) == 2 + # 2 commands, 1 rule = 2 findings + assert len(findings_with_rules) == 2 diff --git a/awsclilinter/tests/test_rules.py b/awsclilinter/tests/test_rules.py index dc886b4b6fb0..74c6ad66afd3 100644 --- a/awsclilinter/tests/test_rules.py +++ b/awsclilinter/tests/test_rules.py @@ -32,16 +32,3 @@ def test_no_detection_with_flag(self): findings = rule.check(root) assert len(findings) == 0 - - def test_no_detection_without_file_protocol(self): - """Test no detection when file:// is not used. Even though the breaking change may - still occur without the use of file://, only the case where file:// is used can be detected - statically.""" - script = ( - "aws secretsmanager put-secret-value --secret-id secret1213 --secret-string secret123" - ) - root = SgRoot(script, "bash") - rule = Base64BinaryFormatRule() - findings = rule.check(root) - - assert len(findings) == 0