From 8cbc87ee2c9b6089842fb50afdccee22c8f1b2da Mon Sep 17 00:00:00 2001 From: Luuk Verweij Date: Mon, 28 Aug 2023 13:31:56 +0200 Subject: [PATCH 1/3] wip --- aaa/__init__.py | 17 + aaa/formatter.py | 844 ++++++++++++++++++ aaa/parser/models.py | 9 + aaa/runner/runner.py | 21 +- aaa/tokenizer/models.py | 3 + manage.py | 10 + .../enum_with_multiple_associated_fields.aaa | 1 + tests/fixtures/stdlib_path.py | 4 +- tests/test_formatter.py | 533 +++++++++++ 9 files changed, 1425 insertions(+), 17 deletions(-) create mode 100644 aaa/formatter.py create mode 100644 tests/test_formatter.py diff --git a/aaa/__init__.py b/aaa/__init__.py index e7e4d94f..be072f42 100644 --- a/aaa/__init__.py +++ b/aaa/__init__.py @@ -1,3 +1,4 @@ +import os import secrets from pathlib import Path from tempfile import gettempdir @@ -64,6 +65,22 @@ class AaaException(Exception): ... +class AaaEnvironmentError(AaaException): + ... + + +def get_stdlib_path() -> Path: + try: + stdlib_folder = os.environ["AAA_STDLIB_PATH"] + except KeyError as e: + raise AaaEnvironmentError( + "Environment variable AAA_STDLIB_PATH is not set.\n" + + "Cannot find standard library!" + ) from e + + return Path(stdlib_folder) / "builtins.aaa" + + AAA_DEFAULT_OUTPUT_FOLDER_ROOT = Path(gettempdir()) / "aaa/transpiled" AAA_TEST_OUTPUT_FOLDER_ROOT = Path(gettempdir()) / "aaa/transpiled/tests" diff --git a/aaa/formatter.py b/aaa/formatter.py new file mode 100644 index 00000000..9fd65f3d --- /dev/null +++ b/aaa/formatter.py @@ -0,0 +1,844 @@ +import sys +from bisect import bisect_left +from difflib import unified_diff +from enum import IntEnum, auto +from functools import cached_property +from pathlib import Path +from typing import Callable, Dict, List, Optional, Set, Tuple, Type + +from aaa import Position, get_stdlib_path +from aaa.parser.exceptions import ParserBaseException +from aaa.parser.models import ( + Assignment, + BooleanLiteral, + Branch, + Call, + CaseBlock, + DefaultBlock, + Enum, + EnumVariant, + ForeachLoop, + Function, + FunctionBody, + FunctionBodyItem, + FunctionCall, + FunctionPointerTypeLiteral, + GetFunctionPointer, + Import, + ImportItem, + IntegerLiteral, + MatchBlock, + Never, + ParsedFile, + Return, + StringLiteral, + Struct, + StructFieldQuery, + StructFieldUpdate, + TypeLiteral, + UseBlock, + WhileLoop, +) +from aaa.parser.single_file_parser import SingleFileParser +from aaa.tokenizer.exceptions import TokenizerBaseException +from aaa.tokenizer.models import Token, TokenType +from aaa.tokenizer.tokenizer import Tokenizer + +MAX_LINE_LENGTH = 88 +INDENT = 4 * " " +NONEXISTENT_POSITION = Position(Path("/dev/null"), -1, -1) + +COLOR_RED = "\033[31m" +COLOR_GREEN = "\033[32m" +COLOR_BLUE = "\033[34m" +COLOR_RESET = "\033[0m" + +DIFF_COLORS = { + " ": ("", ""), + "-": (COLOR_RED, COLOR_RESET), + "+": (COLOR_GREEN, COLOR_RESET), + "@": (COLOR_BLUE, COLOR_RESET), +} + + +class EnumFormatterResult(IntEnum): + NO_CHANGES = auto() + REFORMATTED = auto() + NEED_REFORMATTING = auto() + FORMAT_ERROR = auto() + + +class OneLinerError(Exception): + """ + Indicates a one-liner could not be created. + """ + + +def format_diff(before: str, after: str) -> str: # pragma: nocover + diff = unified_diff(before.split("\n"), after.split("\n")) + output_lines: List[str] = [] + + for diff_line in diff: + prefix, suffix = DIFF_COLORS[diff_line[0]] + output_lines.append(prefix + diff_line + suffix) + + return "\n".join(output_lines) + "\n" + + +def format_source_files( + files: Tuple[str], fix_files: bool, show_diff: bool +) -> int: # pragma: nocover + files_by_result: Dict[EnumFormatterResult, Set[Path]] = { + result: set() for result in EnumFormatterResult + } + + for file_str in files: + file = Path(file_str) + formatter = AaaFormatter(file) + result = formatter.run(fix_files, show_diff) + files_by_result[result].add(file) + + reformatted = sorted(files_by_result[EnumFormatterResult.REFORMATTED]) + errored = sorted(files_by_result[EnumFormatterResult.FORMAT_ERROR]) + need_reformatting = sorted(files_by_result[EnumFormatterResult.NEED_REFORMATTING]) + + summary = f"{len(files)} file(s) were checked.\n" + + exit_code = 0 + + if need_reformatting: + print("Incorrectly formatted:") + for file in need_reformatting: + print(file) + summary += f"{len(need_reformatting)} file(s) need to be reformatted.\n" + exit_code = 1 + + if reformatted: + print("Reformatting:") + for file in reformatted: + print(file) + summary += f"{len(reformatted)} file(s) were fixed.\n" + exit_code = 1 + + if errored: + print("Errors while reformatting:") + for file in errored: + print(file) + summary += f"{len(errored)} file(s) had errors.\n" + exit_code = 2 + + print() + print(summary, end="") + return exit_code + + +class AaaFormatter: + def __init__(self, file: Path) -> None: + self.file = file + self.verbose = False + self.comments: List[Token] = [] + self.whitespace: List[Token] = [] + self.filtered_tokens: List[Token] = [] + + @cached_property + def block_formatters(self) -> Dict[Type[FunctionBodyItem], Callable[..., str]]: + return { + Assignment: self._format_assignment, + Branch: self._format_branch, + ForeachLoop: self._format_foreach_loop, + MatchBlock: self._format_match_block, + WhileLoop: self._format_while_loop, + UseBlock: self._format_use_block, + } + + @cached_property + def non_block_formatters(self) -> Dict[Type[FunctionBodyItem], Callable[..., str]]: + return { + BooleanLiteral: self._format_boolean_literal, + Call: self._format_call, + FunctionCall: self._format_function_call, + FunctionPointerTypeLiteral: self._format_function_pointer_type_literal, + GetFunctionPointer: self._format_get_function_pointer, + IntegerLiteral: self._format_integer_literal, + Return: self._format_return, + StringLiteral: self._format_string_literal, + StructFieldQuery: self._format_struct_field_query, + StructFieldUpdate: self._format_struct_field_update, + } + + def run( + self, fix_file: bool, show_diff: bool + ) -> EnumFormatterResult: # pragma: nocover + before = self.file.read_text() + + try: + after = self.get_formatted() + except TokenizerBaseException as e: + print(str(e), file=sys.stderr) + return EnumFormatterResult.FORMAT_ERROR + except ParserBaseException as e: + print(str(e), file=sys.stderr) + return EnumFormatterResult.FORMAT_ERROR + + if show_diff: + diff = format_diff(before, after) + + if diff != "\n": + print(diff) + + if before != after and fix_file: + self.file.write_text(after) + + if before != after: + if fix_file: + return EnumFormatterResult.REFORMATTED + return EnumFormatterResult.NEED_REFORMATTING + + return EnumFormatterResult.NO_CHANGES + + def get_formatted(self, *, force_parse_as_builtins_file: bool = False) -> str: + tokenizer = Tokenizer(self.file, self.verbose) + + tokens = tokenizer.tokenize_unfiltered() + + self.comments = [token for token in tokens if token.type == TokenType.COMMENT] + + self.whitespace = [ + token for token in tokens if token.type == TokenType.WHITESPACE + ] + + self.filtered_tokens = [ + token + for token in tokens + if token.type not in [TokenType.COMMENT, TokenType.WHITESPACE] + ] + + parser = SingleFileParser(self.file, self.filtered_tokens, self.verbose) + + if self.file.resolve() == get_stdlib_path() or force_parse_as_builtins_file: + parsed_file = parser.parse_builtins_file() + else: + parsed_file = parser.parse_regular_file() + + # TODO re-add comments + + return self._format_parsed_file(parsed_file) + + def _format_parsed_file(self, parsed_file: ParsedFile) -> str: + formatted_sections: List[str] = [] + + formatted_imports = self._format_imports(parsed_file.imports) + if formatted_imports: + formatted_sections.append(formatted_imports) + + top_level_items: List[Function | Struct | Enum] = [] + top_level_items += parsed_file.functions + top_level_items += parsed_file.structs + top_level_items += parsed_file.enums + + def sort_key(item: Function | Struct | Enum) -> Position: + return item.position + + for item in sorted(top_level_items, key=sort_key): + if isinstance(item, Function): + formatted_sections.append(self._format_function(item)) + elif isinstance(item, Struct): + formatted_sections.append(self._format_struct_definition(item)) + else: + assert isinstance(item, Enum) + formatted_sections.append(self._format_enum_definition(item)) + + return "\n".join(formatted_sections) + + def _format_imports(self, imports: List[Import]) -> str: + imports_by_source: Dict[str, List[Import]] = {} + + for import_ in imports: + source = import_.source + + if source not in imports_by_source: + imports_by_source[source] = [] + + imports_by_source[source].append(import_) + + merged_imports: List[Import] = [] + + for source, imports in imports_by_source.items(): + imported_items: List[ImportItem] = [] + + for import_ in imports: + imported_items += import_.imported_items + + merged_import = Import(NONEXISTENT_POSITION, source, imported_items) + merged_imports.append(merged_import) + + def sort_key(import_: Import) -> str: + return import_.source + + formatted_imports: List[str] = [] + for import_ in sorted(merged_imports, key=sort_key): + formatted_imports.append(self._format_import(import_)) + + return "\n".join(formatted_imports) + + def _format_import(self, import_: Import) -> str: + start = f'from "{import_.source}" import' + + formatted_items: List[str] = [] + + def sort_key(item: ImportItem) -> str: + return item.original.name + + for item in sorted(import_.imported_items, key=sort_key): + if item.original == item.imported: + formatted_items.append(item.original.name) + else: + formatted_items.append(f"{item.original.name} as {item.imported.name}") + + one_line = start + " " + ", ".join(formatted_items) + "\n" + + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + multi_line = start + "\n" + for formatted_item in formatted_items: + multi_line += f"{INDENT}{formatted_item},\n" + + return multi_line + + def _format_type_or_function_pointer( + self, item: TypeLiteral | FunctionPointerTypeLiteral + ) -> str: + if isinstance(item, TypeLiteral): + return self._format_type_literal(item) + else: + assert isinstance(item, FunctionPointerTypeLiteral) + return self._format_function_pointer_type_literal(item, 0) + + def _format_type_literal(self, type_literal: TypeLiteral) -> str: + code = "" + + if type_literal.const: + code += "const " + + code += type_literal.identifier.name + + if type_literal.params: + formatted_params = [ + self._format_type_or_function_pointer(param) + for param in type_literal.params + ] + + code += "[" + ", ".join(formatted_params) + "]" + + return code + + def _format_struct_definition(self, struct: Struct) -> str: + start = f"struct {struct.identifier.name} {{" + + end = "}\n" + + fields = "" + + for field_name, field_type in struct.fields.items(): + type = self._format_type_or_function_pointer(field_type) + fields += f"{INDENT}{field_name} as {type},\n" + + if not fields: + return start + end + + return start + "\n" + fields + end + + def _format_enum_variant(self, enum_variant: EnumVariant) -> str: + prefix = f"{INDENT}{enum_variant.name.name}" + + if not enum_variant.associated_data: + return prefix + ",\n" + + associated_data: List[str] = [ + self._format_type_or_function_pointer(item) + for item in enum_variant.associated_data + ] + + # Special case: remove unnecessary brackets + if len(enum_variant.associated_data) == 1: + return prefix + " as " + associated_data[0] + ",\n" + + one_line = prefix + " as { " + ", ".join(associated_data) + " },\n" + + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + return ( + prefix + + " as {\n" + + "".join(f"{2*INDENT}{item},\n" for item in associated_data) + + f"{INDENT}}}\n" + ) + + def _format_enum_definition(self, enum: Enum) -> str: + code = f"enum {enum.identifier.name} {{\n" + + for variant in enum.variants: + code += self._format_enum_variant(variant) + + code += "}\n" + + return code + + def _format_function_declaration(self, function: Function) -> str: + prefix = "fn " + + formatted_params = "" + if function.type_params: + formatted_params = ( + "[" + + ", ".join(param.identifier.name for param in function.type_params) + + "]" + ) + + if function.struct_name: + prefix += f"{function.struct_name.name}{formatted_params}:{function.func_name.name}" + + else: + prefix += f"{function.func_name.name}{formatted_params}" + + argument_items: List[str] = [] + for argument in function.arguments: + argument_type = self._format_type_or_function_pointer(argument.type) + argument_items.append(f"{argument.identifier.name} as {argument_type}") + + if isinstance(function.return_types, Never): + return_type_items = ["never"] + else: + return_type_items = [ + self._format_type_or_function_pointer(return_type) + for return_type in function.return_types + ] + + single_line = prefix + if argument_items: + single_line += " args " + ", ".join(argument_items) + + if return_type_items: + single_line += " return " + ", ".join(return_type_items) + + if len(single_line) <= MAX_LINE_LENGTH: + return single_line + + multi_line_arguments = "" + multi_line_return_types = "" + + if argument_items: + multi_line_arguments = f"{INDENT}args " + ", ".join(argument_items) + "\n" + + if len(multi_line_arguments) > MAX_LINE_LENGTH: + multi_line_arguments = f"{INDENT}args\n" + for item in argument_items: + multi_line_arguments += f"{2*INDENT}{item},\n" + + if return_type_items: + multi_line_return_types = ( + f"{INDENT}return " + ", ".join(return_type_items) + "\n" + ) + + if len(multi_line_return_types) > MAX_LINE_LENGTH: + multi_line_return_types = f"{INDENT}return\n" + for item in return_type_items: + multi_line_return_types += f"{2*INDENT}{item},\n" + + multi_line = prefix + "\n" + multi_line_arguments + multi_line_return_types + + if multi_line.endswith("\n"): + multi_line = multi_line[:-1] + + return multi_line + + def _format_function_body_as_one_liner(self, body: FunctionBody) -> str: + return self._format_function_body_non_block_items( + body.items, 0, force_one_liner=True + ) + + def _get_non_block_items_slice_end(self, body: FunctionBody, index: int) -> int: + while True: + if index >= len(body.items): + return index + + if type(body.items[index]) in self.block_formatters: + return index + + index += 1 + + def _format_function_body(self, body: FunctionBody, indent_level: int) -> str: + index = 0 + code = "" + + while True: + try: + item = body.items[index] + except IndexError: + break + + if index > 0: + token = Token(item.position, TokenType.WHITESPACE, "") + prev_token_index = bisect_left(self.filtered_tokens, token) - 1 + prev_token = self.filtered_tokens[prev_token_index] + + # We keep at most one empty line between items. + if item.position.line - prev_token.position.line > 1: + code += "\n" + + if type(item) in self.block_formatters: + formatter = self.block_formatters[type(item)] + code += formatter(item, indent_level) + index += 1 + else: + next_block_index = self._get_non_block_items_slice_end(body, index) + + # TODO move computation of `non_block_items` to separate function + non_block_items = body.items[index:next_block_index] + formatted_items = self._format_function_body_non_block_items( + non_block_items, indent_level + ) + code += formatted_items + index = next_block_index + + return code + + def _format_function_body_non_block_items( + self, + items: List[FunctionBodyItem], + indent_level: int, + *, + force_one_liner: bool = False, + ) -> str: + formatted_items: List[str] = [] + + for item in items: + try: + formatter = self.non_block_formatters[type(item)] + except KeyError: + if not force_one_liner: # pragma: nocover + raise NotImplementedError + raise OneLinerError + + formatted_items.append(formatter(item, indent_level)) + + if force_one_liner: + return " ".join(formatted_items) + + code = "" + line = indent_level * INDENT + formatted_items[0] + + for i in range(1, len(formatted_items)): + formatted_item = formatted_items[i] + prev_position = items[i - 1].position + position = items[i].position + + # This works because WHITESPACE is the only Token that can contain newlines + newlines_count = position.line - prev_position.line + + if newlines_count == 0: + if len(line + " " + formatted_item + "\n") >= MAX_LINE_LENGTH: + code += line + "\n" + line = indent_level * INDENT + formatted_item + else: + line += " " + formatted_item + elif newlines_count == 1: + code += f"{line}\n" + line = indent_level * INDENT + formatted_item + else: + code += f"{line}\n\n" + line = indent_level * INDENT + formatted_item + + code += f"{line}\n" + return code + + def _format_function(self, function: Function) -> str: + code = self._format_function_declaration(function) + + if not function.body: + return code + "\n" + + return code + " {\n" + self._format_function_body(function.body, 1) + "}\n" + + def _format_assignment(self, item: Assignment, indent_level: int) -> str: + prefix = ( + indent_level * INDENT + + ", ".join(var.name for var in item.variables) + + " <- {" + ) + + try: + one_line = ( + prefix + + " " + + self._format_function_body_as_one_liner(item.body) + + " }\n" + ) + except OneLinerError: + pass + else: + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + return ( + prefix + + "\n" + + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + def _format_boolean_literal(self, item: BooleanLiteral, indent_level: int) -> str: + return str(item.value).lower() + + def _format_branch(self, item: Branch, indent_level: int) -> str: + formatted_condition = self._format_function_body( + item.condition, indent_level + 1 + ) + + one_line_if: Optional[str] = None + + if formatted_condition.strip().count("\n") == 0: + one_line_if = ( + indent_level * INDENT + "if " + formatted_condition.strip() + " {\n" + ) + + if one_line_if is not None and len(one_line_if) <= MAX_LINE_LENGTH: + formatted_if = one_line_if + else: + formatted_if = ( + indent_level * INDENT + + "if\n" + + self._format_function_body(item.condition, indent_level + 1) + + indent_level * INDENT + + "{\n" + ) + + code = formatted_if + self._format_function_body(item.if_body, indent_level + 1) + + if item.else_body: + code += ( + indent_level * INDENT + + "} else {\n" + + self._format_function_body(item.else_body, indent_level + 1) + ) + + code += indent_level * INDENT + "}\n" + + return code + + def _format_call(self, item: Call, indent_level: int) -> str: + return "call" + + def _format_foreach_loop(self, item: ForeachLoop, indent_level: int) -> str: + return ( + indent_level * INDENT + + "foreach {\n" + + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + def _format_function_call(self, item: FunctionCall, indent_level: int) -> str: + formatted_params = "" + if item.type_params: + formatted_params = ( + "[" + + ", ".join( + self._format_type_or_function_pointer(param) + for param in item.type_params + ) + + "]" + ) + + return f"{item.name()}{formatted_params}" + + def _format_function_pointer_type_literal( + self, + func_ptr_literal: FunctionPointerTypeLiteral, + indent_level: int, + ) -> str: + formatted_arguments = [ + self._format_type_or_function_pointer(arg) + for arg in func_ptr_literal.argument_types + ] + + if isinstance(func_ptr_literal.return_types, Never): + formatted_return_types = ["never"] + else: + formatted_return_types = [ + self._format_type_or_function_pointer(return_type) + for return_type in func_ptr_literal.return_types + ] + + return ( + "fn[" + + ", ".join(formatted_arguments) + + "][" + + ", ".join(formatted_return_types) + + "]" + ) + + def _format_get_function_pointer( + self, item: GetFunctionPointer, indent_level: int + ) -> str: + return item.function_name.as_aaa_literal() + " fn" + + def _format_integer_literal(self, item: IntegerLiteral, indent_level: int) -> str: + return str(item.value) + + def _format_match_case_block(self, item: CaseBlock, indent_level: int) -> str: + prefix = ( + indent_level * INDENT + + f"case {item.label.enum_name.name}:{item.label.variant_name.name}" + ) + + if item.label.variables: + prefix += " as " + ", ".join(var.name for var in item.label.variables) + + prefix += " {" + + try: + one_line = ( + f"{prefix} " + + self._format_function_body_as_one_liner(item.body) + + " }\n" + ) + except OneLinerError: + pass + else: + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + return ( + f"{prefix}\n" + + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + def _format_match_default_block(self, item: DefaultBlock, indent_level: int) -> str: + prefix = indent_level * INDENT + "default {" + + try: + one_line = ( + f"{prefix} " + + self._format_function_body_as_one_liner(item.body) + + " }\n" + ) + except OneLinerError: + pass + else: + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + return ( + f"{prefix}\n" + + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + def _format_match_block(self, item: MatchBlock, indent_level: int) -> str: + code = indent_level * INDENT + "match {\n" + + for block in item.blocks: + if isinstance(block, CaseBlock): + code += self._format_match_case_block(block, indent_level + 1) + else: + assert isinstance(block, DefaultBlock) + code += self._format_match_default_block(block, indent_level + 1) + + code += indent_level * INDENT + "}\n" + return code + + def _format_return(self, item: Return, indent_level: int) -> str: + return "return" + + def _format_string_literal(self, item: StringLiteral, indent_level: int) -> str: + return item.as_aaa_literal() + + def _format_struct_field_query( + self, item: StructFieldQuery, indent_level: int + ) -> str: + return item.field_name.as_aaa_literal() + " ?" + + def _format_struct_field_update( + self, item: StructFieldUpdate, indent_level: int + ) -> str: + try: + one_line = ( + item.field_name.as_aaa_literal() + + " { " + + self._format_function_body_as_one_liner(item.new_value_expr) + + " } !" + ) + except OneLinerError: + pass + else: + if len(one_line) <= MAX_LINE_LENGTH: + return one_line + + return ( + item.field_name.as_aaa_literal() + + " {\n" + + self._format_function_body(item.new_value_expr, indent_level + 1) + + indent_level * INDENT + + "} !" + ) + + def _format_use_block(self, item: UseBlock, indent_level: int) -> str: + prefix_line = ( + indent_level * INDENT + + "use " + + ", ".join(var.name for var in item.variables) + + " {\n" + ) + + if len(prefix_line) >= MAX_LINE_LENGTH: + var_indent = (indent_level + 1) * INDENT + prefix_line = ( + indent_level * INDENT + + "use\n" + + "".join(f"{var_indent}{var.name},\n" for var in item.variables) + + indent_level * INDENT + + "{\n" + ) + + return ( + prefix_line + + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + def _format_while_loop(self, item: WhileLoop, indent_level: int) -> str: + suffix = ( + self._format_function_body(item.body, indent_level + 1) + + indent_level * INDENT + + "}\n" + ) + + try: + condition_one_line = ( + indent_level * INDENT + + "while " + + self._format_function_body_as_one_liner(item.condition) + + " {\n" + ) + except OneLinerError: + pass + else: + if len(condition_one_line) <= MAX_LINE_LENGTH: + return condition_one_line + suffix + + return ( + indent_level * INDENT + + "while\n" + + self._format_function_body(item.condition, indent_level + 1) + + indent_level * INDENT + + "{\n" + + suffix + ) diff --git a/aaa/parser/models.py b/aaa/parser/models.py index 6be33325..9ae7f4d7 100644 --- a/aaa/parser/models.py +++ b/aaa/parser/models.py @@ -23,6 +23,15 @@ def __init__(self, position: Position, value: str) -> None: self.value = value super().__init__(position) + def as_aaa_literal(self) -> str: + # TODO add test for this + literal = repr(self.value) + + if literal[0] == "'": + literal = '"' + literal[1:-1].replace('"', '\\"') + '"' + + return literal + class BooleanLiteral(AaaParseModel): def __init__(self, position: Position, value: bool) -> None: diff --git a/aaa/runner/runner.py b/aaa/runner/runner.py index 024546c8..170e5028 100644 --- a/aaa/runner/runner.py +++ b/aaa/runner/runner.py @@ -5,12 +5,16 @@ from subprocess import CompletedProcess from typing import Any, Dict, List, Optional, Sequence, Tuple -from aaa import AaaException, create_output_folder, create_test_output_folder +from aaa import ( + AaaException, + create_output_folder, + create_test_output_folder, + get_stdlib_path, +) from aaa.cross_referencer.cross_referencer import CrossReferencer from aaa.parser.models import ParsedFile from aaa.parser.parser import Parser from aaa.runner.exceptions import ( - AaaEnvironmentError, AaaTranslationException, ExcecutableDidNotRun, RustCompilerError, @@ -101,17 +105,6 @@ def _print_exceptions(self, runner_exception: AaaTranslationException) -> None: print(f"Found {len(runner_exception.exceptions)} error(s).", file=sys.stderr) - def _get_stdlib_path(self) -> Path: - try: - stdlib_folder = os.environ["AAA_STDLIB_PATH"] - except KeyError as e: - raise AaaEnvironmentError( - "Environment variable AAA_STDLIB_PATH is not set.\n" - + "Cannot find standard library!" - ) from e - - return Path(stdlib_folder) / "builtins.aaa" - def run( self, *, @@ -162,7 +155,7 @@ def transpile(self, runtime_type_checks: bool) -> "Transpiled": transpiler_root = create_output_folder() try: - stdlib_path = self._get_stdlib_path() + stdlib_path = get_stdlib_path() parser = Parser( self.entrypoint, stdlib_path, self.parsed_files, self.verbose diff --git a/aaa/tokenizer/models.py b/aaa/tokenizer/models.py index 5f4d5610..d869aa64 100644 --- a/aaa/tokenizer/models.py +++ b/aaa/tokenizer/models.py @@ -50,3 +50,6 @@ def __init__(self, position: Position, type: TokenType, value: str) -> None: def __repr__(self) -> str: return repr(self.value) + + def __lt__(self, other: "Token") -> bool: + return self.position < other.position diff --git a/manage.py b/manage.py index efab6905..7ab87f80 100755 --- a/manage.py +++ b/manage.py @@ -4,6 +4,7 @@ import click +from aaa.formatter import format_source_files from aaa.runner.runner import Runner from aaa.runner.test_runner import TestRunner @@ -40,5 +41,14 @@ def test(**kwargs: Any) -> None: exit(TestRunner.test_command(**kwargs)) +@cli.command() +@click.argument("files", type=click.Path(exists=True), nargs=-1) +@click.option("--fix", is_flag=True, default=False) +@click.option("--show-diff", is_flag=True, default=False) +def format(files: Tuple[str], fix: bool, show_diff: bool) -> None: + exit_code = format_source_files(files, fix, show_diff) + exit(exit_code) + + if __name__ == "__main__": cli() diff --git a/tests/aaa/misc/src/enum_with_multiple_associated_fields.aaa b/tests/aaa/misc/src/enum_with_multiple_associated_fields.aaa index 44dbfce7..2a42c9a6 100644 --- a/tests/aaa/misc/src/enum_with_multiple_associated_fields.aaa +++ b/tests/aaa/misc/src/enum_with_multiple_associated_fields.aaa @@ -1,6 +1,7 @@ enum Event { click as { int, int }, message as str, + // TODO prevent formatting of next line message_with_brackets as { str }, quit, } diff --git a/tests/fixtures/stdlib_path.py b/tests/fixtures/stdlib_path.py index fb8b75bb..a476f13b 100644 --- a/tests/fixtures/stdlib_path.py +++ b/tests/fixtures/stdlib_path.py @@ -5,11 +5,9 @@ import pytest -from aaa.runner.runner import Runner - @pytest.fixture(autouse=True, scope="session") def setup_test_environment() -> Generator[None, None, None]: stdlib_path = Path(os.environ["AAA_STDLIB_PATH"]) / "builtins.aaa" - with patch.object(Runner, "_get_stdlib_path", return_value=stdlib_path): + with patch("aaa.get_stdlib_path", return_value=stdlib_path): yield diff --git a/tests/test_formatter.py b/tests/test_formatter.py new file mode 100644 index 00000000..26bb7968 --- /dev/null +++ b/tests/test_formatter.py @@ -0,0 +1,533 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile, gettempdir + +import pytest + +from aaa.formatter import AaaFormatter + + +def format_aaa_source(code: str) -> str: + temp_file = NamedTemporaryFile(delete=False) + file = Path(gettempdir()) / temp_file.name + file.write_text(code) + return AaaFormatter(file).get_formatted() + + +def format_aaa_builtins(code: str) -> str: + temp_file = NamedTemporaryFile(delete=False) + file = Path(gettempdir()) / temp_file.name + file.write_text(code) + return AaaFormatter(file).get_formatted(force_parse_as_builtins_file=True) + + +def test_format_empty_file() -> None: + assert "" == format_aaa_source("") + + +# TODO prevent string reformatting in this file + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ('from "a" import b\n', 'from "a" import b'), + ('from "a" import b\n', 'from "a" import b '), + ('from "a" import b\n', ' from "a" import b'), + ('from "a" import b\n', '\tfrom\n"a" import b\t'), + ('from "a" import b, c\n', 'from "a" import b from "a" import c'), + ('from "a" import b, c\n', 'from "a" import c from "a" import b'), + ( + 'from "a" import\n' + + " aaaaaaaaa as bbbbbbbbb,\n" + + " ccccccccc as ddddddddd,\n" + + " eeeeeeeee as fffffffff,\n" + + " ggggggggg,\n", + 'from "a" import aaaaaaaaa as bbbbbbbbb, ccccccccc as ddddddddd, eeeeeeeee as fffffffff, ggggggggg', + ), + ], +) +def test_format_imports(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ("struct Foo {}\n", "struct Foo { }"), + ("struct Foo {\n value as int,\n}\n", "struct Foo { value as int }"), + ( + "struct Foo {\n func as fn[int][bool],\n}\n", + "struct Foo { func as fn [ int ] [ bool ] }", + ), + ], +) +def test_format_struct(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ("enum Foo {\n value as int,\n}\n", "enum Foo { value as int }"), + ("enum Foo {\n value as int,\n}\n", "enum Foo { value as { int } }"), + ( + "enum Foo {\n value as { int, bool },\n}\n", + "enum Foo { value as { int, bool } }", + ), + ( + "enum Foo {\n" + + " value as {\n" + + " int,\n" + + " bool,\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n" + + " }\n" + + "}\n", + "enum Foo { value as { int, bool, VeryLongStructType, VeryLongStructType, VeryLongStructType, VeryLongStructType } }", + ), + ("enum Foo {\n value,\n}\n", "enum Foo { value }"), + ], +) +def test_format_enum(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ("fn foo\n", "fn foo"), + ("fn foo args bar as int\n", "fn foo args bar as int"), + ("fn foo args bar as const int\n", "fn foo args bar as const int"), + ( + "fn foo args bar as int, bar as int\n", + "fn foo args bar as int , bar as int", + ), + ("fn foo return int\n", "fn foo return int"), + ("fn foo return int, int\n", "fn foo return int , int"), + ("fn foo[T]\n", "fn foo [ T ]"), + ( + "fn Bar:foo args bar as Bar return int\n", + "fn Bar : foo args bar as Bar return int", + ), + ( + "fn foo args foo as int return never\n", + "fn foo args foo as int return never", + ), + ( + "fn foo\n" + + " args a as VeryLongStructType, b as VeryLongStructType, c as VeryLongStructType\n" + + " return int\n", + "fn foo args a as VeryLongStructType, b as VeryLongStructType, c as VeryLongStructType return int", + ), + ( + "fn foo\n" + + " args\n" + + " a as VeryLongStructType,\n" + + " b as VeryLongStructType,\n" + + " c as VeryLongStructType,\n" + + " d as VeryLongStructType,\n" + + " return int\n", + "fn foo args a as VeryLongStructType, b as VeryLongStructType, c as VeryLongStructType, d as VeryLongStructType return int", + ), + ( + "fn foo\n" + + " args a as int\n" + + " return VeryLongStructType, VeryLongStructType, VeryLongStructType, int\n", + "fn foo args a as int return VeryLongStructType, VeryLongStructType, VeryLongStructType, int", + ), + ( + "fn foo\n" + + " args a as int\n" + + " return\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n" + + " VeryLongStructType,\n", + "fn foo args a as int return VeryLongStructType, VeryLongStructType, VeryLongStructType, VeryLongStructType", + ), + ], +) +def test_format_function_declaration(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_builtins(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n 3\n}\n", + "fn foo { 3 }", + ), + ( + "fn foo {\n false\n}\n", + "fn foo { false }", + ), + ( + 'fn foo {\n "Hello world"\n}\n', + 'fn foo { "Hello world" }', + ), + ( + "fn foo {\n call\n}\n", + "fn foo { call }", + ), + ( + "fn foo {\n return\n}\n", + "fn foo { return }", + ), + ( + 'fn foo {\n "foo" fn\n}\n', + 'fn foo { "foo" fn }', + ), + ( + 'fn foo {\n "foo" ?\n}\n', + 'fn foo { "foo" ? }', + ), + ], +) +def test_format_function_simple(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n 3 false\n}\n", + "fn foo { 3 false }", + ), + ( + "fn foo {\n 3\n false\n}\n", + "fn foo { 3\nfalse }", + ), + ( + "fn foo {\n 3\n\n false\n}\n", + "fn foo { 3\n\nfalse }", + ), + ( + "fn foo {\n 3\n\n false\n}\n", + "fn foo { 3\n\n\nfalse }", + ), + ( + "fn foo {\n" + + " 3\n" + + " if true {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { 3 if true { nop } }", + ), + ( + "fn foo {\n" + + " 3\n" + + " if true {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { 3\nif true { nop } }", + ), + ( + "fn foo {\n" + + " 3\n" + + "\n" + + " if true {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { 3\n\nif true { nop } }", + ), + ], +) +def test_format_function_body_newlines(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n x <- { 3 }\n}\n", + "fn foo { x <- { 3 } }", + ), + ( + "fn foo {\n x, y <- { 3 4 }\n}\n", + "fn foo { x , y <- { 3 4 } }", + ), + ( + "fn foo {\n" + + " x, y <- {\n" + + " if true {\n" + + " nop\n" + + " }\n" + + " }\n" + + "}\n", + "fn foo { x , y <- { if true { nop } } }", + ), + ], +) +def test_format_function_with_assignment(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " if true {\n" + " nop\n" + " }\n" + "}\n", + "fn foo { if true { nop } }", + ), + ( + "fn foo {\n" + + " if\n" + + " very_long_function very_long_function very_long_function very_long_function\n" + + " very_long_function\n" + + " {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { if very_long_function very_long_function very_long_function very_long_function very_long_function { nop } }", + ), + ( + "fn foo {\n" + + " if true {\n" + + " nop\n" + + " } else {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { if true { nop } else { nop } }", + ), + ( + "fn foo {\n" + + " if\n" + + " very_long_function very_long_function very_long_function very_long_function\n" + + " very_long_function\n" + + " {\n" + + " nop\n" + + " } else {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { if very_long_function very_long_function very_long_function very_long_function very_long_function { nop } else { nop } }", + ), + ], +) +def test_format_function_with_branch(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " foreach {\n" + " nop\n" + " }\n" + "}\n", + "fn foo { foreach { nop } }", + ), + ], +) +def test_format_function_with_foreach(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " foo\n" + "}\n", + "fn foo { foo }", + ), + ( + "fn foo {\n" + " foo[T]\n" + "}\n", + "fn foo { foo [ T ] }", + ), + ( + "fn foo {\n" + " bar:foo\n" + "}\n", + "fn foo { bar : foo }", + ), + ], +) +def test_format_function_with_function_call(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " fn[][]\n" + "}\n", + "fn foo { fn [ ] [ ] }", + ), + ( + "fn foo {\n" + " fn[][never]\n" + "}\n", + "fn foo { fn [ ] [ never ] }", + ), + ( + "fn foo {\n" + " fn[][int]\n" + "}\n", + "fn foo { fn [ ] [ int ] }", + ), + ( + "fn foo {\n" + " fn[int, int][str, str]\n" + "}\n", + "fn foo { fn [ int , int ] [ str, str ] }", + ), + ], +) +def test_format_function_with_function_type_literal( + code: str, expected_format: str +) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + + " match {\n" + + " case Foo:bar { nop }\n" + + " default { nop }\n" + + " }\n" + + "}\n", + "fn foo { match { case Foo:bar { nop } default { nop } } }", + ), + ( + "fn foo {\n" + + " match {\n" + + " case Foo:bar {\n" + + " very_long_function very_long_function very_long_function\n" + + " very_long_function\n" + + " }\n" + + " default {\n" + + " very_long_function very_long_function very_long_function\n" + + " very_long_function\n" + + " }\n" + + " }\n" + + "}\n", + "fn foo { match { case Foo:bar { " + + "very_long_function very_long_function very_long_function very_long_function " + + "} default { " + + " very_long_function very_long_function very_long_function very_long_function " + + "} } }", + ), + ( + "fn foo {\n" + + " match {\n" + + " case Foo:bar as foo, foo, foo { nop }\n" + + " }\n" + + "}\n", + "fn foo { match { case Foo:bar as foo , foo, foo { nop } } }", + ), + ( + "fn foo {\n" + + " match {\n" + + " case Foo:bar {\n" + + " if true {\n" + + " nop\n" + + " }\n" + + " }\n" + + " }\n" + + "}\n", + "fn foo { match { case Foo:bar { if true { nop } } } }", + ), + ( + "fn foo {\n" + + " match {\n" + + " default {\n" + + " if true {\n" + + " nop\n" + + " }\n" + + " }\n" + + " }\n" + + "}\n", + "fn foo { match { default { if true { nop } } } }", + ), + ], +) +def test_format_function_with_match(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + ' foo "foo" { 3 } !\n' + "}\n", + 'fn foo { foo "foo" { 3 } ! }', + ), + ( + "fn foo {\n" + + " foo\n" + + ' "foo" {\n' + + " very_long_function very_long_function very_long_function very_long_function\n" + + " very_long_function\n" + + " } !\n" + + "}\n", + 'fn foo { foo "foo" { very_long_function very_long_function very_long_function very_long_function very_long_function } ! }', + ), + ( + "fn foo {\n" + + ' foo "foo" {\n' + + " if true {\n" + + " nop\n" + + " }\n" + + " } !\n" + + "}\n", + 'fn foo { foo "foo" { if true { nop } } ! }', + ), + ], +) +def test_format_function_with_struct_field_update( + code: str, expected_format: str +) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " use x {\n" + " nop\n" + " }\n" + "}\n", + "fn foo { use x { nop } }", + ), + ( + "fn foo {\n" + + " use\n" + + " very_long_var_name,\n" + + " very_long_var_name,\n" + + " very_long_var_name,\n" + + " very_long_var_name,\n" + + " {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { use very_long_var_name, very_long_var_name, very_long_var_name, very_long_var_name { nop } }", + ), + ], +) +def test_format_function_with_use_block(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) + + +@pytest.mark.parametrize( + ["expected_format", "code"], + [ + ( + "fn foo {\n" + " while true {\n" + " nop\n" + " }\n" + "}\n", + "fn foo { while true { nop } }", + ), + ( + "fn foo {\n" + + " while\n" + + " if true {\n" + + " nop\n" + + " }\n" + + " {\n" + + " nop\n" + + " }\n" + + "}\n", + "fn foo { while if true { nop } { nop } }", + ), + ], +) +def test_format_function_with_while_loop(code: str, expected_format: str) -> None: + assert expected_format == format_aaa_source(code) From d243426ecd30dbce668c346691a48340d0bff66f Mon Sep 17 00:00:00 2001 From: Luuk Verweij Date: Sun, 3 Sep 2023 21:26:59 +0200 Subject: [PATCH 2/3] setup and document reformat Aaa source files on save --- README.md | 23 +- aaa/cross_referencer/cross_referencer.py | 6 +- aaa/formatter.py | 3 +- aaa/parser/models.py | 9 + aaa/parser/single_file_parser.py | 30 +- aaa/tokenizer/tokenizer.py | 5 +- examples/one_to_ten.aaa | 1 - manage.py | 12 +- tests/test_formatter.py | 1 + tests/test_single_file_parser.py | 392 +++++++++++++++++++++++ 10 files changed, 466 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 3d6ba679..ee0a03b0 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,12 @@ Stack-based language, like forth. ### Contents -The following tools for the Aaa language can be found is this repo +The following tools for the Aaa language can be found is this repository: * A [tokenizer](./aaa/tokenizer/) and [parser](./aaa/parser/) for Aaa. * A [type checker](./aaa/type_checker/) * A [transpiler to Rust](./aaa/transpiler/) * A [VS Code extension](./aaa-vscode-extension/README.md) for the Aaa language. +* A [formatter](./aaa/formatter.py) for Aaa source files. * A lot of tests, written both in python and Aaa ### Examples @@ -85,6 +86,26 @@ Now you can start running code in Aaa or develop the language! To enable syntax highlighting for VS Code, enable the [Aaa language extension](./aaa-vscode-extension/README.md) +### Setup format on save +For VS Code follow these steps: +* Install the [Run on Save](https://marketplace.visualstudio.com/items?itemName=pucelle.run-on-save) extension. +* Merge the JSON configuration below into your `*.code-workspace` JSON file. + +```json +{ + "settings": { + "runOnSave.commands": [ + { + "match": ".*\\.aaa$", + "command": "cd ${workspaceFolder} && pdm run ./manage.py format --stdlib-path ./stdlib ${file} --fix", + "runIn": "backend", + "runningStatusMessage": "Formatting ${fileBasename}", + "finishStatusMessage": "${fileBasename} was formatted" + }, + ] + } +} +``` ### Aaa and porth After watching part of the [Youtube series](https://www.youtube.com/playlist?list=PLpM-Dvs8t0VbMZA7wW9aR3EtBqe2kinu4) on [porth](https://gitlab.com/tsoding/porth), I wanted to make my own stack-based language. Aaa and porth have some similarities, but obviously are not compatible with each other. No code was copied over from the porth repo. diff --git a/aaa/cross_referencer/cross_referencer.py b/aaa/cross_referencer/cross_referencer.py index 3b007a09..d8d8dded 100644 --- a/aaa/cross_referencer/cross_referencer.py +++ b/aaa/cross_referencer/cross_referencer.py @@ -744,7 +744,9 @@ def run(self) -> FunctionBody: def _resolve_function_body(self, parsed_body: parser.FunctionBody) -> FunctionBody: return FunctionBody( items=[ - self._resolve_function_body_item(item) for item in parsed_body.items + self._resolve_function_body_item(item) + for item in parsed_body.items + if type(item) != parser.Comment ], parsed=parsed_body, ) @@ -856,7 +858,7 @@ def _resolve_function_body_item( parser.WhileLoop: self._resolve_while_loop, } - assert set(resolve_functions.keys()) == set(parser.FunctionBodyItem.__args__) # type: ignore + assert set(resolve_functions.keys()) == set(parser.FunctionBodyItem.__args__) - {parser.Comment} # type: ignore return resolve_functions[type(parsed_item)](parsed_item) def _resolve_function_pointer_literal( diff --git a/aaa/formatter.py b/aaa/formatter.py index 9fd65f3d..5550d135 100644 --- a/aaa/formatter.py +++ b/aaa/formatter.py @@ -494,9 +494,8 @@ def _format_function_body(self, body: FunctionBody, indent_level: int) -> str: index += 1 else: next_block_index = self._get_non_block_items_slice_end(body, index) - - # TODO move computation of `non_block_items` to separate function non_block_items = body.items[index:next_block_index] + formatted_items = self._format_function_body_non_block_items( non_block_items, indent_level ) diff --git a/aaa/parser/models.py b/aaa/parser/models.py index 9ae7f4d7..3dae21da 100644 --- a/aaa/parser/models.py +++ b/aaa/parser/models.py @@ -195,11 +195,13 @@ def __init__( imports: List[Import], structs: List[Struct], enums: List[Enum], + comments: List[Comment], ) -> None: self.functions = functions self.imports = imports self.structs = structs self.enums = enums + self.comments = comments super().__init__(position) def dependencies(self) -> List[Path]: @@ -361,6 +363,12 @@ def __init__( super().__init__(position) +class Comment(AaaParseModel): + def __init__(self, position: Position, value: str) -> None: + self.value = value + super().__init__(position) + + class Never(AaaParseModel): ... @@ -370,6 +378,7 @@ class Never(AaaParseModel): | BooleanLiteral | Branch | Call + | Comment | ForeachLoop | FunctionCall | FunctionPointerTypeLiteral diff --git a/aaa/parser/single_file_parser.py b/aaa/parser/single_file_parser.py index 4fb383a5..dbae308b 100644 --- a/aaa/parser/single_file_parser.py +++ b/aaa/parser/single_file_parser.py @@ -16,6 +16,7 @@ Call, CaseBlock, CaseLabel, + Comment, DefaultBlock, Enum, EnumVariant, @@ -422,6 +423,15 @@ def _parse_function_declaration(self, offset: int) -> Tuple[Function, int]: self._print_parse_tree_node("FunctionDeclaration", start_offset, offset) return function, offset + def _parse_comment(self, offset: int) -> Tuple[Comment, int]: + start_offset = offset + token, offset = self._parse_token(offset, [TokenType.COMMENT]) + + comment = Comment(token.position, token.value) + self._print_parse_tree_node("Comment", start_offset, offset) + + return comment, offset + def _parse_builtins_file_root(self, offset: int) -> Tuple[ParsedFile, int]: start_offset = offset @@ -434,13 +444,16 @@ def _parse_builtins_file_root(self, offset: int) -> Tuple[ParsedFile, int]: functions: List[Function] = [] structs: List[Struct] = [] + comments: List[Comment] = [] while True: - try: + token = self._peek_token(offset) + + if not token: + break + + if token.type == TokenType.FUNCTION: function, offset = self._parse_function_declaration(offset) - except ParserBaseException: - pass - else: functions.append(function) continue @@ -460,6 +473,7 @@ def _parse_builtins_file_root(self, offset: int) -> Tuple[ParsedFile, int]: imports=[], structs=structs, enums=[], + comments=comments, ) self._print_parse_tree_node("ParsedFile", start_offset, offset) @@ -781,11 +795,14 @@ def _parse_function_body_item(self, offset: int) -> Tuple[FunctionBodyItem, int] item, offset = self._parse_call(offset) elif token.type == TokenType.FUNCTION: item, offset = self._parse_function_pointer_type_literal(offset) + elif token.type == TokenType.COMMENT: + item, offset = self._parse_comment(offset) else: raise ParserException( token, [ + TokenType.COMMENT, TokenType.FALSE, TokenType.FOREACH, TokenType.IDENTIFIER, @@ -866,6 +883,7 @@ def _parse_regular_file_root(self, offset: int) -> Tuple[ParsedFile, int]: structs: List[Struct] = [] imports: List[Import] = [] enums: List[Enum] = [] + comments: List[Comment] = [] while True: token = self._peek_token(offset) @@ -885,6 +903,9 @@ def _parse_regular_file_root(self, offset: int) -> Tuple[ParsedFile, int]: elif token.type == TokenType.ENUM: enum, offset = self._parse_enum_definition(offset) enums.append(enum) + elif token.type == TokenType.COMMENT: + comment, offset = self._parse_comment(offset) + comments.append(comment) else: break @@ -894,6 +915,7 @@ def _parse_regular_file_root(self, offset: int) -> Tuple[ParsedFile, int]: imports=imports, structs=structs, enums=enums, + comments=comments, ) self._print_parse_tree_node("ParsedFileRoot", start_offset, offset) diff --git a/aaa/tokenizer/tokenizer.py b/aaa/tokenizer/tokenizer.py index ef8899a0..1ce707bb 100644 --- a/aaa/tokenizer/tokenizer.py +++ b/aaa/tokenizer/tokenizer.py @@ -135,10 +135,7 @@ def run(self) -> List[Token]: filtered: List[Token] = [] for token in tokens: - if token.type not in [ - TokenType.WHITESPACE, - TokenType.COMMENT, - ]: + if token.type != TokenType.WHITESPACE: filtered.append(token) self._print_tokens(filtered) diff --git a/examples/one_to_ten.aaa b/examples/one_to_ten.aaa index 53732f51..6fb608e3 100755 --- a/examples/one_to_ten.aaa +++ b/examples/one_to_ten.aaa @@ -1,4 +1,3 @@ - struct Range { next_value as int, end as int, diff --git a/manage.py b/manage.py index 7ab87f80..2b6cfe19 100755 --- a/manage.py +++ b/manage.py @@ -1,6 +1,8 @@ #!/usr/bin/env -S python3 -u -from typing import Any +import os +from pathlib import Path +from typing import Any, Optional, Tuple import click @@ -45,7 +47,13 @@ def test(**kwargs: Any) -> None: @click.argument("files", type=click.Path(exists=True), nargs=-1) @click.option("--fix", is_flag=True, default=False) @click.option("--show-diff", is_flag=True, default=False) -def format(files: Tuple[str], fix: bool, show_diff: bool) -> None: +@click.option("--stdlib-path", type=click.Path(exists=True)) +def format( + files: Tuple[str], fix: bool, show_diff: bool, stdlib_path: Optional[str] +) -> None: + if stdlib_path: + os.environ["AAA_STDLIB_PATH"] = str(Path(stdlib_path).resolve()) + exit_code = format_source_files(files, fix, show_diff) exit(exit_code) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 26bb7968..e4046c8c 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -21,6 +21,7 @@ def format_aaa_builtins(code: str) -> str: def test_format_empty_file() -> None: + # Empty files don't get a trailing whitespace assert "" == format_aaa_source("") diff --git a/tests/test_single_file_parser.py b/tests/test_single_file_parser.py index c4e090dd..bec32fe3 100644 --- a/tests/test_single_file_parser.py +++ b/tests/test_single_file_parser.py @@ -1180,3 +1180,395 @@ def test_parse_function_pointer_type_literal( else: with pytest.raises(expected_exception): parser._parse_assignment(0) + + +@pytest.mark.parametrize( + ["code"], + [ + ( + "// foo\n", # fmt: skip + ), + ( + "// foo\n" # fmt: skip + "fn main { nop }\n", + ), + ( + "fn main { nop }\n" # fmt: skip + "// foo\n", + ), + ], +) +def test_parse_comment_in_regular_file(code: str) -> None: + parser = parse_code(code) + parser.parse_regular_file() + + +@pytest.mark.parametrize( + ["code"], + [ + ( + "fn main {\n" # fmt: skip + " // foo\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " // foo\n" + " nop\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " nop\n" + " // foo\n" + "}\n", + ), + ], +) +def test_parse_comment_in_function_body(code: str) -> None: + parser = parse_code(code) + parser.parse_regular_file() + + +@pytest.mark.parametrize( + ["code"], + [ + ( + "fn main {\n" # fmt: skip + " // foo\n" + " while 1 2 { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " while\n" + " // foo\n" + " 1 2 { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " while 1 2\n" + " // foo\n" + " { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " while 1 2 {\n" + " // foo\n" + " nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " while 1 2 { nop\n" + " // foo\n" + " }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " while 1 2 { nop }\n" + " // foo\n" + "}\n", + ), + ], +) +def test_parse_comment_in_while_loop(code: str) -> None: + parser = parse_code(code) + parser.parse_regular_file() + + +@pytest.mark.parametrize( + ["code"], + [ + ( + "fn main {\n" # fmt: skip + " // foo\n" + " if true { nop } else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if \n" + " // foo\n" + " true { nop } else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true \n" + " // foo\n" + " { nop } else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true {\n" + " // foo\n" + " nop } else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop\n" + " // foo\n" + " } else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop }\n" + " // foo\n" + " else { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop } else\n" + " // foo\n" + " { nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop } else {\n" + " // foo\n" + " nop }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop } else { nop\n" + " // foo\n" + " }\n" + "}\n", + ), + ( + "fn main {\n" # fmt: skip + " if true { nop } else { nop }\n" + " // foo\n" + "}\n", + ), + ], +) +def test_parse_comment_in_branch(code: str) -> None: + parser = parse_code(code) + parser.parse_regular_file() + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_struct_field_query(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_struct_field_update(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_get_function_pointer(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_return(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_call(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_argument(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_function(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_import_item(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_import(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_struct(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_parsed_file(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_type_literal(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_function_pointer_type_literal(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_function_name(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_function_call(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_case_label(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_foreach_loop(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_use_block(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_assignment(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_case_block(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_default_block(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_match_block(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_enum_variant(code: str) -> None: + ... + + +@pytest.mark.skip +@pytest.mark.parametrize( + ["code"], + [], +) +def test_parse_comment_in_enum(code: str) -> None: + ... From c9c6d0fee14de8e6da476caca7daf62bb5ba092e Mon Sep 17 00:00:00 2001 From: Luuk Verweij Date: Sun, 8 Oct 2023 13:18:06 +0200 Subject: [PATCH 3/3] fix breaking tests --- aaa/parser/single_file_parser.py | 56 ++++++++++++++++++++++++++---- examples/selfhosting/tokenizer.aaa | 1 - tests/test_tokenizer.py | 2 +- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/aaa/parser/single_file_parser.py b/aaa/parser/single_file_parser.py index dbae308b..3d5aefae 100644 --- a/aaa/parser/single_file_parser.py +++ b/aaa/parser/single_file_parser.py @@ -432,6 +432,19 @@ def _parse_comment(self, offset: int) -> Tuple[Comment, int]: return comment, offset + def _parse_comments(self, offset: int) -> Tuple[List[Comment], int]: + comments: List[Comment] = [] + + while True: + try: + comment, offset = self._parse_comment(offset) + except ParserBaseException: + break + else: + comments.append(comment) + + return comments, offset + def _parse_builtins_file_root(self, offset: int) -> Tuple[ParsedFile, int]: start_offset = offset @@ -457,6 +470,11 @@ def _parse_builtins_file_root(self, offset: int) -> Tuple[ParsedFile, int]: functions.append(function) continue + if token.type == TokenType.COMMENT: + comment, offset = self._parse_comment(offset) + comments.append(comment) + continue + try: struct, offset = self._parse_struct_declaration(offset) except ParserBaseException: @@ -683,11 +701,14 @@ def _parse_branch(self, offset: int) -> Tuple[Branch, int]: if_body, offset = self._parse_function_body(offset) _, offset = self._parse_token(offset, [TokenType.BLOCK_END]) + _, offset = self._parse_comments(offset) # TODO + token = self._peek_token(offset) else_body: Optional[FunctionBody] = None if token and token.type == TokenType.ELSE: _, offset = self._parse_token(offset, [TokenType.ELSE]) + _, offset = self._parse_comments(offset) # TODO _, offset = self._parse_token(offset, [TokenType.BLOCK_START]) else_body, offset = self._parse_function_body(offset) _, offset = self._parse_token(offset, [TokenType.BLOCK_END]) @@ -1029,6 +1050,9 @@ def _parse_match_block(self, offset: int) -> Tuple[MatchBlock, int]: block, offset = self._parse_case_block(offset) elif token.type == TokenType.DEFAULT: block, offset = self._parse_default_block(offset) + elif token.type == TokenType.COMMENT: + _, offset = self._parse_comment(offset) # TODO + continue else: raise ParserException(token, [TokenType.CASE, TokenType.DEFAULT]) blocks.append(block) @@ -1103,17 +1127,35 @@ def _parse_enum_variants(self, offset: int) -> Tuple[List[EnumVariant], int]: enum_variants.append(enum_variant) while True: - try: + token = self._peek_token_or_fail(offset) + + if token.type == TokenType.COMMA: _, offset = self._parse_token(offset, [TokenType.COMMA]) - except ParserBaseException: - break - try: - enum_variant, offset = self._parse_enum_variant(offset) - except ParserBaseException: + elif token.type == TokenType.IDENTIFIER: + try: + enum_variant, offset = self._parse_enum_variant(offset) + except ParserBaseException: + break + else: + enum_variants.append(enum_variant) + + elif token.type == TokenType.COMMENT: + _, offset = self._parse_comment(offset) # TODO + + elif token.type == TokenType.BLOCK_END: break + else: - enum_variants.append(enum_variant) + raise ParserException( + token, + [ + TokenType.BLOCK_END, + TokenType.COMMA, + TokenType.COMMENT, + TokenType.IDENTIFIER, + ], + ) self._print_parse_tree_node("EnumVariants", start_offset, offset) return enum_variants, offset diff --git a/examples/selfhosting/tokenizer.aaa b/examples/selfhosting/tokenizer.aaa index b02c4da6..2bf4f8b1 100644 --- a/examples/selfhosting/tokenizer.aaa +++ b/examples/selfhosting/tokenizer.aaa @@ -555,7 +555,6 @@ fn print_tokens args tokens as vec[Token] { token "type_" ? match { case TokenType:whitespace { nop } - case TokenType:comment { nop } default { token "position" ? use position { diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index f003ccb2..d6e6a341 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -60,7 +60,7 @@ def test_tokenizer_token_types(code: str, expected_token_type: TokenType) -> Non tokens = tokenizer.run() - if expected_token_type in [TokenType.WHITESPACE, TokenType.COMMENT]: + if expected_token_type == TokenType.WHITESPACE: assert 0 == len(tokens) else: