diff --git a/pydantic_settings/sources/providers/cli.py b/pydantic_settings/sources/providers/cli.py index bd22e31..87d0d58 100644 --- a/pydantic_settings/sources/providers/cli.py +++ b/pydantic_settings/sources/providers/cli.py @@ -956,7 +956,7 @@ def _add_parser_args( self._convert_bool_flag(arg.kwargs, field_info, model_default) - if arg.is_parser_submodel: + if arg.is_parser_submodel and not getattr(field_info.annotation, '__pydantic_root_model__', False): self._add_parser_submodels( parser, model, @@ -1107,6 +1107,7 @@ def _add_parser_submodels( model_group_kwargs['description'] = CLI_SUPPRESS if not self.cli_avoid_json: added_args.append(arg_names[0]) + kwargs['required'] = False kwargs['nargs'] = '?' kwargs['const'] = '{}' kwargs['help'] = ( @@ -1205,8 +1206,12 @@ def _metavar_format_recurse(self, obj: Any) -> str: ) elif obj is type(None): return self.cli_parse_none_str - elif is_model_class(obj): - return 'JSON' + elif is_model_class(obj) or is_pydantic_dataclass(obj): + return ( + self._metavar_format_recurse(_get_model_fields(obj)['root'].annotation) + if getattr(obj, '__pydantic_root_model__', False) + else 'JSON' + ) elif isinstance(obj, type): return obj.__qualname__ else: diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index d250322..1fe98f9 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -19,6 +19,7 @@ DirectoryPath, Discriminator, Field, + RootModel, Tag, ValidationError, field_validator, @@ -1778,13 +1779,19 @@ class Settings(BaseSettings): def test_cli_enforce_required(env): + class MyRootModel(RootModel[str]): + root: str + class Settings(BaseSettings, cli_exit_on_error=False): my_required_field: str + my_root_model_required_field: MyRootModel env.set('MY_REQUIRED_FIELD', 'hello from environment') + env.set('MY_ROOT_MODEL_REQUIRED_FIELD', 'hi from environment') assert Settings(_cli_parse_args=[], _cli_enforce_required=False).model_dump() == { - 'my_required_field': 'hello from environment' + 'my_required_field': 'hello from environment', + 'my_root_model_required_field': 'hi from environment', } with pytest.raises( @@ -1792,6 +1799,11 @@ class Settings(BaseSettings, cli_exit_on_error=False): ): Settings(_cli_parse_args=[], _cli_enforce_required=True).model_dump() + with pytest.raises( + SettingsError, match='error parsing CLI: the following arguments are required: --my_root_model_required_field' + ): + Settings(_cli_parse_args=['--my_required_field', 'hello from cli'], _cli_enforce_required=True).model_dump() + def test_cli_exit_on_error(capsys, monkeypatch): class Settings(BaseSettings, cli_parse_args=True): ...