From 44a51789d99d2a93b6d09de814386c967fac7b2e Mon Sep 17 00:00:00 2001 From: Rhiannon Udall Date: Thu, 23 Jan 2025 11:09:20 -0800 Subject: [PATCH] Add wrap_with_model_yaml decorator This adds the decorator `wrap_with_model_yaml`, which takes functions whose signature includes arguments typed as a pydantic model. It adds to this function corresponding string arguments, presumed to point towards yaml files which represent instances of the corresponding model. These are then parsed as that model, and updated with any non-default arguments passed by the model a la the original signature. This amalgamated model is then passed to the original function. This is meant to fill the same niche as the configargparse package, which provides an analogous extension to argparse. --- src/pydantic_typer/main.py | 62 ++++++++++++++++++++++++++++++++++++- src/pydantic_typer/utils.py | 22 +++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/pydantic_typer/main.py b/src/pydantic_typer/main.py index 316d6eb..c403300 100644 --- a/src/pydantic_typer/main.py +++ b/src/pydantic_typer/main.py @@ -21,7 +21,13 @@ ) from typing_extensions import Annotated -from pydantic_typer.utils import copy_type, deep_update, inspect_signature +from pydantic._internal._model_construction import ModelMetaclass +from pydantic_yaml import parse_yaml_file_as +from typing import get_type_hints +from inspect import signature, Parameter +import makefun + +from pydantic_typer.utils import copy_type, deep_update, inspect_signature, _clear_empty_dictionaries PYDANTIC_FIELD_SEPARATOR = "." @@ -285,6 +291,60 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()} return wrapper +def wrap_with_model_yaml(function : Callable) -> Callable: + """From an input function with some model inputs, create + a function with the same model inputs and corresponding arguments + to load in yaml files which represent that model. + This allows behavior like configargparse, such that mixed use + of configuration files and command line arguments is enabled. + + Args: + function : callable + The base function, which should have some arguments + typed as pydantic model, as for + other pydantic_typer calls. + + Returns: + Callable + The input function, with added options to give the path + to a yaml which can configure some or all of the model arguments. + The signature of this function will be: + f(config_{model_1}, config_{model_2}...,{model_1}, {model_2},...) + When called, it will automatically load the contents of the yaml, + then update with any model arguments passed normally. + It will then proceed to execution of the function with this + combined model. + """ + function_signature = signature(function) + type_hints = get_type_hints(function) + model_arguments = [] + config_parameters = [] + for parameter_name, type_hint in type_hints.items(): + if not isinstance(type_hint, ModelMetaclass): + continue + config_name = f"config_{parameter_name}" + model_arguments.append((parameter_name, config_name)) + config_parameters.append( + Parameter( + config_name, + kind=Parameter.POSITIONAL_OR_KEYWORD, + annotation=str + ) + ) + new_signature = makefun.add_signature_parameters(function_signature, first=config_parameters) + + @makefun.with_signature(func_signature=new_signature) + def config_added_function(**kwargs): + for model_arg, config_name in model_arguments: + model_type = type_hints[model_arg] + working_model = parse_yaml_file_as(model_type, kwargs[config_name]) + argument_updates = _clear_empty_dictionaries(kwargs[model_arg].model_dump(exclude_defaults=True)) + working_model = working_model.model_copy(update=argument_updates) + kwargs[model_arg] = working_model + kwargs.pop(config_name) + return function(**kwargs) + + return config_added_function class Typer(TyperBase): @copy_type(TyperBase.command) diff --git a/src/pydantic_typer/utils.py b/src/pydantic_typer/utils.py index 8a07b80..5565c85 100644 --- a/src/pydantic_typer/utils.py +++ b/src/pydantic_typer/utils.py @@ -44,6 +44,28 @@ def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: # pragma: signature = raw_signature.replace(parameters=resolved_params) return signature +def _clear_empty_dictionaries(dictionary : dict) -> dict: + """A utility to recursively remove empty dictionaries + so they will not overwrite anything when used for updates. + + Args: + dictionary : dict + The dictionary in question, which may have any heterogeneous contents + + Returns: + The input dictionary with any empty dictionaries removed recursively + """ + new_dict = {} + for key,val in dictionary.items(): + if isinstance(val, dict): + new_val = _clear_empty_dictionaries(val) + if not new_val: + continue + else: + new_dict[key] = new_val + else: + new_dict[key] = val + return new_dict _T = TypeVar("_T")