Skip to content

Commit b971369

Browse files
fix: compat with Python 3.14
1 parent f6b66d4 commit b971369

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/llama_stack_client/_models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import os
1010
import inspect
11+
import weakref
1112
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
1213
from datetime import date, datetime
1314
from typing_extensions import (
@@ -579,6 +580,9 @@ class CachedDiscriminatorType(Protocol):
579580
__discriminator__: DiscriminatorDetails
580581

581582

583+
DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
584+
585+
582586
class DiscriminatorDetails:
583587
field_name: str
584588
"""The name of the discriminator field in the variant class, e.g.
@@ -621,8 +625,9 @@ def __init__(
621625

622626

623627
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
624-
if isinstance(union, CachedDiscriminatorType):
625-
return union.__discriminator__
628+
cached = DISCRIMINATOR_CACHE.get(union)
629+
if cached is not None:
630+
return cached
626631

627632
discriminator_field_name: str | None = None
628633

@@ -675,7 +680,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
675680
discriminator_field=discriminator_field_name,
676681
discriminator_alias=discriminator_alias,
677682
)
678-
cast(CachedDiscriminatorType, union).__discriminator__ = details
683+
DISCRIMINATOR_CACHE.setdefault(union, details)
679684
return details
680685

681686

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from llama_stack_client._utils import PropertyInfo
1717
from llama_stack_client._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
18-
from llama_stack_client._models import BaseModel, construct_type
18+
from llama_stack_client._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
1919

2020

2121
class BasicModel(BaseModel):
@@ -815,7 +815,7 @@ class B(BaseModel):
815815

816816
UnionType = cast(Any, Union[A, B])
817817

818-
assert not hasattr(UnionType, "__discriminator__")
818+
assert not DISCRIMINATOR_CACHE.get(UnionType)
819819

820820
m = construct_type(
821821
value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")])
@@ -824,7 +824,7 @@ class B(BaseModel):
824824
assert m.type == "b"
825825
assert m.data == "foo" # type: ignore[comparison-overlap]
826826

827-
discriminator = UnionType.__discriminator__
827+
discriminator = DISCRIMINATOR_CACHE.get(UnionType)
828828
assert discriminator is not None
829829

830830
m = construct_type(
@@ -836,7 +836,7 @@ class B(BaseModel):
836836

837837
# if the discriminator details object stays the same between invocations then
838838
# we hit the cache
839-
assert UnionType.__discriminator__ is discriminator
839+
assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator
840840

841841

842842
@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")

0 commit comments

Comments
 (0)