From 8278f5421892a6494282148a644782d87b059a83 Mon Sep 17 00:00:00 2001 From: snoyer Date: Sun, 11 Jan 2026 11:40:25 +0400 Subject: [PATCH] improve arg spec guessing --- src/argh/assembling.py | 63 +++++++++++++++++++------------------- tests/test_typing_hints.py | 23 +++++++++++--- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/src/argh/assembling.py b/src/argh/assembling.py index fdfc4a0..aab8fc7 100644 --- a/src/argh/assembling.py +++ b/src/argh/assembling.py @@ -28,6 +28,7 @@ List, Literal, Optional, + Sequence, Tuple, Union, get_args, @@ -739,7 +740,8 @@ class ArgumentNameMappingError(AssemblingError): ... class TypingHintArgSpecGuesser: - BASIC_TYPES = (str, int, float, bool) + SEQUENCE_BUT_BASIC = str + """type(s) that are subclasses of Sequence but we consider basic""" @classmethod def typing_hint_to_arg_spec_params( @@ -748,16 +750,12 @@ def typing_hint_to_arg_spec_params( origin = get_origin(type_def) args = get_args(type_def) - # `str` - if type_def in cls.BASIC_TYPES: - return { - "type": type_def - # "type": _parse_basic_type(type_def) - } - - # `list` - if type_def in (list, List): - return {"nargs": ZERO_OR_MORE} + # `list`, `list[...]`, Sequence[...], ... + if cls._is_sequence(type_def): + retval = {"nargs": ZERO_OR_MORE} + if args and not cls._is_sequence(args[0]): + retval["type"] = args[0] + return retval # `Literal["a", "b"]` if origin == Literal: @@ -767,35 +765,36 @@ def typing_hint_to_arg_spec_params( if any(origin is t for t in UNION_TYPES): retval = {} first_subtype = args[0] - if first_subtype in cls.BASIC_TYPES: + if first_subtype is not type(None) and not cls._is_sequence(first_subtype): retval["type"] = first_subtype - if first_subtype in (list, List): + if cls._is_sequence(first_subtype): retval["nargs"] = ZERO_OR_MORE - if first_subtype != List and get_origin(first_subtype) == list: - retval["nargs"] = ZERO_OR_MORE - item_type = cls._extract_item_type_from_list_type(first_subtype) - if item_type: - retval["type"] = item_type + first_subtype_args = get_args(first_subtype) + if first_subtype_args and not cls._is_sequence(first_subtype_args[0]): + retval["type"] = first_subtype_args[0] if type(None) in args: retval["required"] = False return retval - # `list[str]` - if origin == list: - retval = {} - retval["nargs"] = ZERO_OR_MORE - if args[0] in cls.BASIC_TYPES: - retval["type"] = args[0] - return retval - - return {} + # basic types + return {"type": type_def} @classmethod - def _extract_item_type_from_list_type(cls, type_def) -> Optional[type]: - args = get_args(type_def) - if args[0] in cls.BASIC_TYPES: - return args[0] - return None + def _is_sequence(cls, type_def): + def check(x): + if _safe_issubclass(x, Sequence): + return not _safe_issubclass(x, cls.SEQUENCE_BUT_BASIC) + else: + return False + + return check(type_def) or check(get_origin(type_def)) + + +def _safe_issubclass(cls, class_or_tuple: Any): + try: + return issubclass(cls, class_or_tuple) + except TypeError: # `cls` is not a class + return False diff --git a/tests/test_typing_hints.py b/tests/test_typing_hints.py index 6114edc..3e0f5c9 100644 --- a/tests/test_typing_hints.py +++ b/tests/test_typing_hints.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List, Literal, Optional, Union import pytest @@ -5,7 +6,11 @@ from argh.assembling import TypingHintArgSpecGuesser -@pytest.mark.parametrize("arg_type", TypingHintArgSpecGuesser.BASIC_TYPES) +class CustomSimpleType: + def __init__(self, string: str) -> None: ... + + +@pytest.mark.parametrize("arg_type", (str, int, float, bool, Path, CustomSimpleType)) def test_simple_types(arg_type): guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params @@ -46,6 +51,9 @@ def test_list(): assert guess(List[list]) == {"nargs": "*"} assert guess(List[tuple]) == {"nargs": "*"} + assert guess(List[Path]) == {"nargs": "*", "type": Path} + assert guess(List[CustomSimpleType]) == {"nargs": "*", "type": CustomSimpleType} + def test_literal(): guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params @@ -55,8 +63,13 @@ def test_literal(): assert guess(Literal[1]) == {"choices": (1,), "type": int} -@pytest.mark.parametrize("arg_type", (dict, tuple)) -def test_unusable_types(arg_type): +@pytest.mark.parametrize( + "arg_type, expected", + [ + (dict, {"type": dict}), + (tuple, {"nargs": "*"}), + ], +) +def test_unusable_types(arg_type, expected): guess = TypingHintArgSpecGuesser.typing_hint_to_arg_spec_params - - assert guess(arg_type) == {} + assert guess(arg_type) == expected