From a1044a400860a943f4d6bc1c3d65c22594926172 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Tue, 2 Jul 2024 16:02:52 +0200 Subject: [PATCH 1/2] improve BaseFile typing --- sqlalchemy_file/base.py | 52 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/sqlalchemy_file/base.py b/sqlalchemy_file/base.py index ab34168..ab4d6d2 100644 --- a/sqlalchemy_file/base.py +++ b/sqlalchemy_file/base.py @@ -1,5 +1,21 @@ import typing -from typing import Any +from typing import Any, overload, Literal, override + + +STR_ATTRS = Literal[ + "url", + "filename", + "content_type", + "file_id", + "upload_storage", + "uploaded_at", + "path", + "url", +] +BOOL_ATTRS = Literal["saved"] +INT_ATTRS = Literal["size"] +DICT_ATTRS = Literal["meta_data"] +STR_LIST_ATTRS = Literal["files"] class BaseFile(typing.Dict[str, Any]): @@ -11,15 +27,47 @@ class BaseFile(typing.Dict[str, Any]): """ + @overload + def __getitem__(self, key: STR_ATTRS) -> str: ... + + @overload + def __getitem__(self, key: INT_ATTRS) -> int: ... + + @overload + def __getitem__(self, key: DICT_ATTRS) -> dict[str, str]: ... + + @overload + def __getitem__(self, key: STR_LIST_ATTRS) -> list[str]: ... + + @overload + def __getitem__(self, key: BOOL_ATTRS) -> bool: ... + + @override def __getitem__(self, key: str) -> Any: return dict.__getitem__(self, key) + @overload + def __getattr__(self, name: STR_ATTRS) -> str: ... + + @overload + def __getattr__(self, name: INT_ATTRS) -> int: ... + + @overload + def __getattr__(self, name: DICT_ATTRS) -> dict[str, str]: ... + + @overload + def __getattr__(self, name: STR_LIST_ATTRS) -> list[str]: ... + + @overload + def __getattr__(self, name: BOOL_ATTRS) -> bool: ... + def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) + @override def __setitem__(self, key: str, value: Any) -> None: if getattr(self, "_frozen", False): raise TypeError("Already saved files are immutable") @@ -27,6 +75,7 @@ def __setitem__(self, key: str, value: Any) -> None: __setattr__ = __setitem__ + @override def __delattr__(self, name: str) -> None: if getattr(self, "_frozen", False): raise TypeError("Already saved files are immutable") @@ -36,6 +85,7 @@ def __delattr__(self, name: str) -> None: except KeyError: raise AttributeError(name) + @override def __delitem__(self, key: str) -> None: if object.__getattribute__(self, "_frozen"): raise TypeError("Already saved files are immutable") From de02d06ebb674b13a157fed32c69a5156d9bdc47 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jul 2024 14:03:19 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sqlalchemy_file/base.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/sqlalchemy_file/base.py b/sqlalchemy_file/base.py index ab4d6d2..a44172e 100644 --- a/sqlalchemy_file/base.py +++ b/sqlalchemy_file/base.py @@ -1,6 +1,5 @@ import typing -from typing import Any, overload, Literal, override - +from typing import Any, Literal, overload, override STR_ATTRS = Literal[ "url", @@ -28,38 +27,48 @@ class BaseFile(typing.Dict[str, Any]): """ @overload - def __getitem__(self, key: STR_ATTRS) -> str: ... + def __getitem__(self, key: STR_ATTRS) -> str: + ... @overload - def __getitem__(self, key: INT_ATTRS) -> int: ... + def __getitem__(self, key: INT_ATTRS) -> int: + ... @overload - def __getitem__(self, key: DICT_ATTRS) -> dict[str, str]: ... + def __getitem__(self, key: DICT_ATTRS) -> dict[str, str]: + ... @overload - def __getitem__(self, key: STR_LIST_ATTRS) -> list[str]: ... + def __getitem__(self, key: STR_LIST_ATTRS) -> list[str]: + ... @overload - def __getitem__(self, key: BOOL_ATTRS) -> bool: ... + def __getitem__(self, key: BOOL_ATTRS) -> bool: + ... @override def __getitem__(self, key: str) -> Any: return dict.__getitem__(self, key) @overload - def __getattr__(self, name: STR_ATTRS) -> str: ... + def __getattr__(self, name: STR_ATTRS) -> str: + ... @overload - def __getattr__(self, name: INT_ATTRS) -> int: ... + def __getattr__(self, name: INT_ATTRS) -> int: + ... @overload - def __getattr__(self, name: DICT_ATTRS) -> dict[str, str]: ... + def __getattr__(self, name: DICT_ATTRS) -> dict[str, str]: + ... @overload - def __getattr__(self, name: STR_LIST_ATTRS) -> list[str]: ... + def __getattr__(self, name: STR_LIST_ATTRS) -> list[str]: + ... @overload - def __getattr__(self, name: BOOL_ATTRS) -> bool: ... + def __getattr__(self, name: BOOL_ATTRS) -> bool: + ... def __getattr__(self, name: str) -> Any: try: