diff --git a/src/pydantic_typer/main.py b/src/pydantic_typer/main.py index 316d6eb..c403300 100644 --- a/src/pydantic_typer/main.py +++ b/src/pydantic_typer/main.py @@ -21,7 +21,13 @@ ) from typing_extensions import Annotated -from pydantic_typer.utils import copy_type, deep_update, inspect_signature +from pydantic._internal._model_construction import ModelMetaclass +from pydantic_yaml import parse_yaml_file_as +from typing import get_type_hints +from inspect import signature, Parameter +import makefun + +from pydantic_typer.utils import copy_type, deep_update, inspect_signature, _clear_empty_dictionaries PYDANTIC_FIELD_SEPARATOR = "." @@ -285,6 +291,60 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()} return wrapper +def wrap_with_model_yaml(function : Callable) -> Callable: + """From an input function with some model inputs, create + a function with the same model inputs and corresponding arguments + to load in yaml files which represent that model. + This allows behavior like configargparse, such that mixed use + of configuration files and command line arguments is enabled. + + Args: + function : callable + The base function, which should have some arguments + typed as pydantic model, as for + other pydantic_typer calls. + + Returns: + Callable + The input function, with added options to give the path + to a yaml which can configure some or all of the model arguments. + The signature of this function will be: + f(config_{model_1}, config_{model_2}...,{model_1}, {model_2},...) + When called, it will automatically load the contents of the yaml, + then update with any model arguments passed normally. + It will then proceed to execution of the function with this + combined model. + """ + function_signature = signature(function) + type_hints = get_type_hints(function) + model_arguments = [] + config_parameters = [] + for parameter_name, type_hint in type_hints.items(): + if not isinstance(type_hint, ModelMetaclass): + continue + config_name = f"config_{parameter_name}" + model_arguments.append((parameter_name, config_name)) + config_parameters.append( + Parameter( + config_name, + kind=Parameter.POSITIONAL_OR_KEYWORD, + annotation=str + ) + ) + new_signature = makefun.add_signature_parameters(function_signature, first=config_parameters) + + @makefun.with_signature(func_signature=new_signature) + def config_added_function(**kwargs): + for model_arg, config_name in model_arguments: + model_type = type_hints[model_arg] + working_model = parse_yaml_file_as(model_type, kwargs[config_name]) + argument_updates = _clear_empty_dictionaries(kwargs[model_arg].model_dump(exclude_defaults=True)) + working_model = working_model.model_copy(update=argument_updates) + kwargs[model_arg] = working_model + kwargs.pop(config_name) + return function(**kwargs) + + return config_added_function class Typer(TyperBase): @copy_type(TyperBase.command) diff --git a/src/pydantic_typer/utils.py b/src/pydantic_typer/utils.py index 8a07b80..5565c85 100644 --- a/src/pydantic_typer/utils.py +++ b/src/pydantic_typer/utils.py @@ -44,6 +44,28 @@ def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: # pragma: signature = raw_signature.replace(parameters=resolved_params) return signature +def _clear_empty_dictionaries(dictionary : dict) -> dict: + """A utility to recursively remove empty dictionaries + so they will not overwrite anything when used for updates. + + Args: + dictionary : dict + The dictionary in question, which may have any heterogeneous contents + + Returns: + The input dictionary with any empty dictionaries removed recursively + """ + new_dict = {} + for key,val in dictionary.items(): + if isinstance(val, dict): + new_val = _clear_empty_dictionaries(val) + if not new_val: + continue + else: + new_dict[key] = new_val + else: + new_dict[key] = val + return new_dict _T = TypeVar("_T")