Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ariadne_codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _format_code(code: str, *, remove_unused_imports: bool = True) -> str:
input=code,
capture_output=True,
text=True,
encoding="utf-8",
check=False,
timeout=_SUBPROCESS_TIMEOUT,
)
Expand Down
73 changes: 73 additions & 0 deletions tests/client_generators/test_custom_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import ast

from graphql import build_schema

from ariadne_codegen.client_generators.custom_arguments import ArgumentGenerator
from ariadne_codegen.codegen import generate_import_from

from ..utils import compare_ast


def _expected_enum_import(name: str) -> ast.ImportFrom:
"""Matches generate_import_from for GraphQLEnumType in _parse_graphql_type_name."""
return generate_import_from(names=[name], from_="enums", level=1)


def test_generate_arguments_records_enums_submodule_import_for_enum_field_argument():
schema = build_schema(
"""
schema { query: Query }
enum Color { RED GREEN }
type Query {
paint(color: Color!): String
}
"""
)
field = schema.query_type.fields["paint"]
generator = ArgumentGenerator(custom_scalars={}, convert_to_snake_case=False)

generator.generate_arguments(field.args)

expected = _expected_enum_import("Color")
assert len(generator.imports) == 1
assert compare_ast(generator.imports[0], expected)


def test_generate_arguments_records_enums_submodule_import_for_list_of_enum_arguments():
schema = build_schema(
"""
schema { query: Query }
enum SortOrder { ASC DESC }
type Query {
items(order: [SortOrder!]!): String
}
"""
)
field = schema.query_type.fields["items"]
generator = ArgumentGenerator(custom_scalars={}, convert_to_snake_case=False)

generator.generate_arguments(field.args)

expected = _expected_enum_import("SortOrder")
assert len(generator.imports) == 1
assert compare_ast(generator.imports[0], expected)


def test_generate_arguments_records_enums_submodule_import_for_optional_enum_argument():
schema = build_schema(
"""
schema { query: Query }
enum Priority { LOW HIGH }
type Query {
task(priority: Priority): String
}
"""
)
field = schema.query_type.fields["task"]
generator = ArgumentGenerator(custom_scalars={}, convert_to_snake_case=False)

generator.generate_arguments(field.args)

expected = _expected_enum_import("Priority")
assert len(generator.imports) == 1
assert compare_ast(generator.imports[0], expected)
25 changes: 25 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ast
import subprocess
from textwrap import dedent

import pytest

from ariadne_codegen.utils import (
_format_code,
add_extra_to_base_model,
ast_to_str,
convert_to_multiline_string,
Expand Down Expand Up @@ -78,6 +80,29 @@ class TestClass(Efg):
assert not_used_imported_class not in generated_code


def test_ast_to_str_non_ascii_unicode_round_trip_issue_422():
"""Regression for mirumee/ariadne-codegen#422 (Windows cp1252 / ruff stdin)."""
description = "商店 line: émoji 🛍️ — characters outside cp1252"
module = ast.parse(f'"""{description}"""')
generated = ast_to_str(module, remove_unused_imports=False)
assert description in generated


def test_format_code_ruff_format_uses_utf8_encoding_issue_422(mocker):
"""Ensure ruff format stdin/stdout use UTF-8 (mirumee/ariadne-codegen#422)."""
spy = mocker.patch("ariadne_codegen.utils.subprocess.run", wraps=subprocess.run)

_format_code("x = 1\n")

format_calls = [
call
for call in spy.call_args_list
if tuple(call[0][0][2:4]) == ("ruff", "format")
]
assert format_calls, "expected a ruff format subprocess.run"
assert format_calls[-1][1].get("encoding") == "utf-8"


@pytest.mark.parametrize(
"name, expected_result",
[
Expand Down
Loading