Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions skyvern/client/core/pydantic_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast

import pydantic
from skyvern.client.core.serialization import convert_and_respect_annotation_metadata

IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")

Expand Down Expand Up @@ -78,10 +79,9 @@ def model_construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None,

@classmethod
def construct(cls: Type["Model"], _fields_set: Optional[Set[str]] = None, **values: Any) -> "Model":
dealiased_object = convert_and_respect_annotation_metadata(object_=values, annotation=cls, direction="read")
if IS_PYDANTIC_V2:
return super().model_construct(_fields_set, **dealiased_object) # type: ignore[misc]
return super().construct(_fields_set, **dealiased_object)
return super().model_construct(_fields_set, **values) # type: ignore[misc]
return super().construct(_fields_set, **values)

def json(self, **kwargs: Any) -> str:
kwargs_with_defaults = {
Expand Down
122 changes: 65 additions & 57 deletions skyvern/client/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def convert_and_respect_annotation_metadata(
inner_type = annotation

clean_type = _remove_annotations(inner_type)
# Pydantic models
origin = typing_extensions.get_origin(clean_type)

# Pydantic models
if (
inspect.isclass(clean_type)
Expand All @@ -70,79 +73,84 @@ def convert_and_respect_annotation_metadata(
if typing_extensions.is_typeddict(clean_type) and isinstance(object_, typing.Mapping):
return _convert_mapping(object_, clean_type, direction)

if (
typing_extensions.get_origin(clean_type) == typing.Dict
or typing_extensions.get_origin(clean_type) == dict
or clean_type == typing.Dict
) and isinstance(object_, typing.Dict):
key_type = typing_extensions.get_args(clean_type)[0]
value_type = typing_extensions.get_args(clean_type)[1]

return {
key: convert_and_respect_annotation_metadata(
object_=value,
annotation=annotation,
inner_type=value_type,
direction=direction,
)
for key, value in object_.items()
}

# If you're iterating on a string, do not bother to coerce it to a sequence.
if not isinstance(object_, str):
if (
typing_extensions.get_origin(clean_type) == typing.Set
or typing_extensions.get_origin(clean_type) == set
or clean_type == typing.Set
) and isinstance(object_, typing.Set):
inner_type = typing_extensions.get_args(clean_type)[0]
# Dicts
if origin in {dict, typing.Dict} or clean_type in {dict, typing.Dict}:
if isinstance(object_, typing.Dict):
args = typing_extensions.get_args(clean_type)
key_type = args[0] if args else typing.Any
value_type = args[1] if args else typing.Any
return {
convert_and_respect_annotation_metadata(
object_=item,
key: convert_and_respect_annotation_metadata(
object_=value,
annotation=annotation,
inner_type=inner_type,
inner_type=value_type,
direction=direction,
)
for item in object_
for key, value in object_.items()
}
elif (
(
typing_extensions.get_origin(clean_type) == typing.List
or typing_extensions.get_origin(clean_type) == list
or clean_type == typing.List
)
and isinstance(object_, typing.List)
) or (
(
typing_extensions.get_origin(clean_type) == typing.Sequence
or typing_extensions.get_origin(clean_type) == collections.abc.Sequence
or clean_type == typing.Sequence
)
and isinstance(object_, typing.Sequence)
):
inner_type = typing_extensions.get_args(clean_type)[0]
return [
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type,
direction=direction,
)
for item in object_
]

if typing_extensions.get_origin(clean_type) == typing.Union:
# Sets and sequences
if not isinstance(object_, str):
# Sets
if origin in {set, typing.Set} or clean_type in {set, typing.Set}:
if isinstance(object_, typing.Set):
args = typing_extensions.get_args(clean_type)
inner_type_set = args[0] if args else typing.Any
return {
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type_set,
direction=direction,
)
for item in object_
}
# Lists
elif origin in {list, typing.List} or clean_type in {list, typing.List}:
if isinstance(object_, typing.List):
args = typing_extensions.get_args(clean_type)
inner_type_list = args[0] if args else typing.Any
return [
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type_list,
direction=direction,
)
for item in object_
]
# Sequences
elif origin in {collections.abc.Sequence, typing.Sequence} or clean_type in {collections.abc.Sequence, typing.Sequence}:
if isinstance(object_, typing.Sequence):
args = typing_extensions.get_args(clean_type)
inner_type_seq = args[0] if args else typing.Any
return [
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type_seq,
direction=direction,
)
for item in object_
]

# Unions
if origin is typing.Union:
# Try to convert keys against all member types in the union
# We should be able to ~relatively~ safely try to convert keys against all
# member types in the union, the edge case here is if one member aliases a field
# of the same name to a different name from another member
# Or if another member aliases a field of the same name that another member does not.
for member in typing_extensions.get_args(clean_type):
object_ = convert_and_respect_annotation_metadata(
new_object = convert_and_respect_annotation_metadata(
object_=object_,
annotation=annotation,
inner_type=member,
direction=direction,
)
# Only return early if conversion occurs
if new_object != object_:
return new_object
return object_

annotated_type = _get_annotation(annotation)
Expand Down