Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion src/pydantic_typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
)
from typing_extensions import Annotated

from pydantic_typer.utils import copy_type, deep_update, inspect_signature
from pydantic._internal._model_construction import ModelMetaclass
from pydantic_yaml import parse_yaml_file_as
from typing import get_type_hints
from inspect import signature, Parameter
import makefun

from pydantic_typer.utils import copy_type, deep_update, inspect_signature, _clear_empty_dictionaries

PYDANTIC_FIELD_SEPARATOR = "."

Expand Down Expand Up @@ -285,6 +291,60 @@ def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()}
return wrapper

def wrap_with_model_yaml(function : Callable) -> Callable:
"""From an input function with some model inputs, create
a function with the same model inputs and corresponding arguments
to load in yaml files which represent that model.
This allows behavior like configargparse, such that mixed use
of configuration files and command line arguments is enabled.

Args:
function : callable
The base function, which should have some arguments
typed as pydantic model, as for
other pydantic_typer calls.

Returns:
Callable
The input function, with added options to give the path
to a yaml which can configure some or all of the model arguments.
The signature of this function will be:
f(config_{model_1}, config_{model_2}...,{model_1}, {model_2},...)
When called, it will automatically load the contents of the yaml,
then update with any model arguments passed normally.
It will then proceed to execution of the function with this
combined model.
"""
function_signature = signature(function)
type_hints = get_type_hints(function)
model_arguments = []
config_parameters = []
for parameter_name, type_hint in type_hints.items():
if not isinstance(type_hint, ModelMetaclass):
continue
config_name = f"config_{parameter_name}"
model_arguments.append((parameter_name, config_name))
config_parameters.append(
Parameter(
config_name,
kind=Parameter.POSITIONAL_OR_KEYWORD,
annotation=str
)
)
new_signature = makefun.add_signature_parameters(function_signature, first=config_parameters)

@makefun.with_signature(func_signature=new_signature)
def config_added_function(**kwargs):
for model_arg, config_name in model_arguments:
model_type = type_hints[model_arg]
working_model = parse_yaml_file_as(model_type, kwargs[config_name])
argument_updates = _clear_empty_dictionaries(kwargs[model_arg].model_dump(exclude_defaults=True))
working_model = working_model.model_copy(update=argument_updates)
kwargs[model_arg] = working_model
kwargs.pop(config_name)
return function(**kwargs)

return config_added_function

class Typer(TyperBase):
@copy_type(TyperBase.command)
Expand Down
22 changes: 22 additions & 0 deletions src/pydantic_typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ def inspect_signature(func: Callable[..., Any]) -> inspect.Signature: # pragma:
signature = raw_signature.replace(parameters=resolved_params)
return signature

def _clear_empty_dictionaries(dictionary : dict) -> dict:
"""A utility to recursively remove empty dictionaries
so they will not overwrite anything when used for updates.

Args:
dictionary : dict
The dictionary in question, which may have any heterogeneous contents

Returns:
The input dictionary with any empty dictionaries removed recursively
"""
new_dict = {}
for key,val in dictionary.items():
if isinstance(val, dict):
new_val = _clear_empty_dictionaries(val)
if not new_val:
continue
else:
new_dict[key] = new_val
else:
new_dict[key] = val
return new_dict

_T = TypeVar("_T")

Expand Down