diff --git a/.env b/.env new file mode 100644 index 00000000..9a54d9a1 --- /dev/null +++ b/.env @@ -0,0 +1,4 @@ +AUTH0_CLIENT_ID=J7Gt1oCReu3zJssOrG0osSxHFDTfMG63 +AUTH0_CLIENT_SECRET=4aC4oLZ41Oc7LIzfqYwo8suBItabyn7UspIfkBKKhz1srFkQCYUdPODTLBb_prW9 +AUTH0_DOMAIN=dev-h42ipc15kqwo8cw5.us.auth0.com +APP_SECRET_KEY=bd62c75d593e9e02fc505633ce4511d4bfaf22408a02f697ec7537cf258b9b8f \ No newline at end of file diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/INSTALLER b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/LICENSE b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/LICENSE new file mode 100644 index 00000000..42441994 --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2017, Hsiaoming Yang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/METADATA b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/METADATA new file mode 100644 index 00000000..b2beee6f --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/METADATA @@ -0,0 +1,104 @@ +Metadata-Version: 2.1 +Name: Authlib +Version: 1.3.0 +Summary: The ultimate Python library in building OAuth and OpenID Connect servers and clients. +Author-email: Hsiaoming Yang +License: BSD-3-Clause +Project-URL: Documentation, https://docs.authlib.org/ +Project-URL: Purchase, https://authlib.org/plans +Project-URL: Issues, https://github.com/lepture/authlib/issues +Project-URL: Source, https://github.com/lepture/authlib +Project-URL: Donate, https://github.com/sponsors/lepture +Project-URL: Blog, https://blog.authlib.org/ +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Topic :: Security +Classifier: Topic :: Security :: Cryptography +Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content +Classifier: Topic :: Internet :: WWW/HTTP :: WSGI :: Application +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: cryptography + +Authlib +======= + +The ultimate Python library in building OAuth and OpenID Connect servers. +JWS, JWK, JWA, JWT are included. + +Useful Links +------------ + +1. Homepage: https://authlib.org/ +2. Documentation: https://docs.authlib.org/ +3. Purchase Commercial License: https://authlib.org/plans +4. Blog: https://blog.authlib.org/ +5. More Repositories: https://github.com/authlib +6. Twitter: https://twitter.com/authlib +7. Donate: https://www.patreon.com/lepture + +Specifications +-------------- + +- RFC5849: The OAuth 1.0 Protocol +- RFC6749: The OAuth 2.0 Authorization Framework +- RFC6750: The OAuth 2.0 Authorization Framework: Bearer Token Usage +- RFC7009: OAuth 2.0 Token Revocation +- RFC7515: JSON Web Signature +- RFC7516: JSON Web Encryption +- RFC7517: JSON Web Key +- RFC7518: JSON Web Algorithms +- RFC7519: JSON Web Token +- RFC7521: Assertion Framework for OAuth 2.0 Client Authentication and Authorization Grants +- RFC7523: JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants +- RFC7591: OAuth 2.0 Dynamic Client Registration Protocol +- RFC7636: Proof Key for Code Exchange by OAuth Public Clients +- RFC7638: JSON Web Key (JWK) Thumbprint +- RFC7662: OAuth 2.0 Token Introspection +- RFC8037: CFRG Elliptic Curve Diffie-Hellman (ECDH) and Signatures in JSON Object Signing and Encryption (JOSE) +- RFC8414: OAuth 2.0 Authorization Server Metadata +- RFC8628: OAuth 2.0 Device Authorization Grant +- OpenID Connect 1.0 +- OpenID Connect Discovery 1.0 +- draft-madden-jose-ecdh-1pu-04: Public Key Authenticated Encryption for JOSE: ECDH-1PU + +Implementations +--------------- + +- Requests OAuth 1 Session +- Requests OAuth 2 Session +- Requests Assertion Session +- HTTPX OAuth 1 Session +- HTTPX OAuth 2 Session +- HTTPX Assertion Session +- Flask OAuth 1/2 Client +- Django OAuth 1/2 Client +- Starlette OAuth 1/2 Client +- Flask OAuth 1.0 Server +- Flask OAuth 2.0 Server +- Flask OpenID Connect 1.0 +- Django OAuth 1.0 Server +- Django OAuth 2.0 Server +- Django OpenID Connect 1.0 + +License +------- + +Authlib is licensed under BSD. Please see LICENSE for licensing details. + +If this license does not fit your company, consider to purchase a commercial +license. Find more information on `Authlib Plans`_. + +.. _`Authlib Plans`: https://authlib.org/plans diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/RECORD b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/RECORD new file mode 100644 index 00000000..a975ebaa --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/RECORD @@ -0,0 +1,387 @@ +Authlib-1.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +Authlib-1.3.0.dist-info/LICENSE,sha256=jhtIUY3pxs0Ay0jH_luAI_2Q1VUsoS6-c2Kg3zDdvkU,1514 +Authlib-1.3.0.dist-info/METADATA,sha256=_uPrSFp-N_CnFcdce_ly9u374GhUS8V56C_2tNPI5IU,3756 +Authlib-1.3.0.dist-info/RECORD,, +Authlib-1.3.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +Authlib-1.3.0.dist-info/WHEEL,sha256=-G_t0oGuE7UD0DrSpVZnq1hHMBV9DD2XkS5v7XpmTnk,110 +Authlib-1.3.0.dist-info/top_level.txt,sha256=Rj3mJn0jhRuCs6x7ysI6hYE2PePbuxey6y6jswadAEY,8 +authlib/__init__.py,sha256=CoObQJQX-YGSJy-HWbJPtK6XbpfKDBc21DJwjhLnIcM,476 +authlib/__pycache__/__init__.cpython-311.pyc,, +authlib/__pycache__/consts.cpython-311.pyc,, +authlib/__pycache__/deprecate.cpython-311.pyc,, +authlib/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +authlib/common/__pycache__/__init__.cpython-311.pyc,, +authlib/common/__pycache__/encoding.cpython-311.pyc,, +authlib/common/__pycache__/errors.cpython-311.pyc,, +authlib/common/__pycache__/security.cpython-311.pyc,, +authlib/common/__pycache__/urls.cpython-311.pyc,, +authlib/common/encoding.py,sha256=sdiaZwuXZI-ruNPGAhJ0oIuZTcrzdawneS_PoJAnHYk,1546 +authlib/common/errors.py,sha256=z8kGl0qRBnimrMYqVgi1aqLsqSng8YaMtcqCy6MHff8,1684 +authlib/common/security.py,sha256=2xcxtJWVE26kosNJTWtnN3skeSzm3Jjtpm4wxoTCBYs,493 +authlib/common/urls.py,sha256=gUpc_VB9emhmCE0EunlxDiQHHZhegGYdLVPT-qoEkco,4501 +authlib/consts.py,sha256=WZWJbuAh8iBsbBm-KFWNYuT3bVDUtzMpNKydkisu3qw,300 +authlib/deprecate.py,sha256=dIjr5VmDMK3bua0cOzJh0Q2RAlAtMhW6iM6ENIynIQ8,481 +authlib/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +authlib/integrations/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/base_client/__init__.py,sha256=xCaZt-rH5n4g0tBVgrM4KKMzWUQ6NHUcLIHwlIuKqMM,653 +authlib/integrations/base_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/async_app.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/async_openid.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/errors.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/framework_integration.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/registry.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/sync_app.cpython-311.pyc,, +authlib/integrations/base_client/__pycache__/sync_openid.cpython-311.pyc,, +authlib/integrations/base_client/async_app.py,sha256=3MbucTGkyEBz8W7SIvJSwYxdGR_8wquxNKee8BHt4i8,5847 +authlib/integrations/base_client/async_openid.py,sha256=-OZl3g_8EYJNlLxCukcPAvOPd-o-OWP9b-updMs6Z-c,2803 +authlib/integrations/base_client/errors.py,sha256=fwXW7ldF-TeCIHeANGWYqv5hhaFpXzLLsGRUBwgcy4c,632 +authlib/integrations/base_client/framework_integration.py,sha256=12rBh8a-cj2r0mJkFKhREH49gv2D2v5aF1UZSUAD430,1871 +authlib/integrations/base_client/registry.py,sha256=FAjZBN0n_e-3MbC2ZYCUSUxeHCrKSdoAE21taWxSAeA,4273 +authlib/integrations/base_client/sync_app.py,sha256=eZajg-ruTMgQAIeaANMqOlYvzA8qqyzH1S7vvL3jx6w,12408 +authlib/integrations/base_client/sync_openid.py,sha256=zZZ9vp0M_w8xpZU3tRahN715Q26tfwC-JCV5WtbMUcM,2721 +authlib/integrations/django_client/__init__.py,sha256=ff_Kol2-pT-7E0zav2A6tkgqMT1YYLpmVuNaKh-Fz5g,458 +authlib/integrations/django_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/django_client/__pycache__/apps.cpython-311.pyc,, +authlib/integrations/django_client/__pycache__/integration.cpython-311.pyc,, +authlib/integrations/django_client/apps.py,sha256=AR7-2aa1xMJj6UX_dl8PreAVw_H_rxABPV9paqioOuw,3334 +authlib/integrations/django_client/integration.py,sha256=TfgtYs2X_IwismSdX8TI73EcpoYCfjS9OovOdEP8XLA,650 +authlib/integrations/django_oauth1/__init__.py,sha256=yp66WLC43YdsICGgVDbu6AfIRyPR17M43umIUcxjH10,221 +authlib/integrations/django_oauth1/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/django_oauth1/__pycache__/authorization_server.cpython-311.pyc,, +authlib/integrations/django_oauth1/__pycache__/nonce.cpython-311.pyc,, +authlib/integrations/django_oauth1/__pycache__/resource_protector.cpython-311.pyc,, +authlib/integrations/django_oauth1/authorization_server.py,sha256=UKkJuE4nwTPNrC8emL5ENqxwP1t3-Qu7kZn6c59zdng,4536 +authlib/integrations/django_oauth1/nonce.py,sha256=m6j4FWsSeQ1S-LJEgF4BF0TPGRZQuseOumkrbuF6KhY,396 +authlib/integrations/django_oauth1/resource_protector.py,sha256=TL1kRvuuF91kZBBe4FZzo7nY0iOKpGWIdjKvdlOcXAY,2343 +authlib/integrations/django_oauth2/__init__.py,sha256=HGqxRud5D9EGZIthXmySYhB7d-90qzZlXZDrd4mPHnQ,278 +authlib/integrations/django_oauth2/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/django_oauth2/__pycache__/authorization_server.cpython-311.pyc,, +authlib/integrations/django_oauth2/__pycache__/endpoints.cpython-311.pyc,, +authlib/integrations/django_oauth2/__pycache__/requests.cpython-311.pyc,, +authlib/integrations/django_oauth2/__pycache__/resource_protector.cpython-311.pyc,, +authlib/integrations/django_oauth2/__pycache__/signals.cpython-311.pyc,, +authlib/integrations/django_oauth2/authorization_server.py,sha256=rtlQlDyeNME3Hdv4Tnu5tYBZ2zJ9GUn-kHEvNCF5Jh8,4388 +authlib/integrations/django_oauth2/endpoints.py,sha256=lKkDmQklHNTCXK_L-6-_PrHa6XvnFBT7a2BAciYkv7o,1853 +authlib/integrations/django_oauth2/requests.py,sha256=_KpI8ecABAZ026y3OcjyNfYNK1MyiDhVH5cFIO5rrSs,1023 +authlib/integrations/django_oauth2/resource_protector.py,sha256=O3Snq0LO4divy6sZwDjh4UoMoDVn3qhLvGILaqNUWnM,2595 +authlib/integrations/django_oauth2/signals.py,sha256=8SlnOsi1IuBPmrCi7dOLXK70N_m9y6B3msjMCtBMnSk,236 +authlib/integrations/flask_client/__init__.py,sha256=DCSIvVck7aBh-zDiGILsYJp_pJSl1qHzR3p-pDGfhNk,1677 +authlib/integrations/flask_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/flask_client/__pycache__/apps.cpython-311.pyc,, +authlib/integrations/flask_client/__pycache__/integration.cpython-311.pyc,, +authlib/integrations/flask_client/apps.py,sha256=4LWY81JpXgg8C5TliVOsObegPqgzZdt9Jyr9rJx2x0w,3607 +authlib/integrations/flask_client/integration.py,sha256=T_O0-YZbezOvvyJUJZKhW_lyHTbSkEZvdQkAGwHLyjY,805 +authlib/integrations/flask_oauth1/__init__.py,sha256=PGDVdNJ9oGs5bYJr7oGdpQX0tE2nQJNcHJ8t-SxEAEY,260 +authlib/integrations/flask_oauth1/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/flask_oauth1/__pycache__/authorization_server.cpython-311.pyc,, +authlib/integrations/flask_oauth1/__pycache__/cache.cpython-311.pyc,, +authlib/integrations/flask_oauth1/__pycache__/resource_protector.cpython-311.pyc,, +authlib/integrations/flask_oauth1/authorization_server.py,sha256=qa0C73Qu7lEOM24-dYQ6wOqmXoHiSdlwxcmtqTMYNug,6299 +authlib/integrations/flask_oauth1/cache.py,sha256=x1bOuGhHKrCUO0X2XrR80f898ca8sM2edQfT6nuDFwY,2996 +authlib/integrations/flask_oauth1/resource_protector.py,sha256=feMGEpwRI5DfgJjJhyumI3XdoaAEjhZHzId_-sFDkBE,3805 +authlib/integrations/flask_oauth2/__init__.py,sha256=8xa6R7Otk9DSHzKFcyGeDv-H7OLUbFwtF73ybpe9jNY,243 +authlib/integrations/flask_oauth2/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/flask_oauth2/__pycache__/authorization_server.cpython-311.pyc,, +authlib/integrations/flask_oauth2/__pycache__/errors.cpython-311.pyc,, +authlib/integrations/flask_oauth2/__pycache__/requests.cpython-311.pyc,, +authlib/integrations/flask_oauth2/__pycache__/resource_protector.cpython-311.pyc,, +authlib/integrations/flask_oauth2/__pycache__/signals.cpython-311.pyc,, +authlib/integrations/flask_oauth2/authorization_server.py,sha256=Lz4b_77aGIJ7Rmh3LxWQzWZeyWhtKpFpjBzy18JRibI,5859 +authlib/integrations/flask_oauth2/errors.py,sha256=d4YT-I_atPUQjV8ohrHCossYpTvErOscXwBhCyLzidA,1090 +authlib/integrations/flask_oauth2/requests.py,sha256=Z78A_rCTmmJW6FJ-PfWQTe5sdYHW3hYZg2s4SYgDkIk,765 +authlib/integrations/flask_oauth2/resource_protector.py,sha256=Kpgw5UAy1dPXv3T2f5o2KbPXO6naRmsc25KL3VuMkC4,3843 +authlib/integrations/flask_oauth2/signals.py,sha256=CKao8F778CkUzl7mKjF96smr7WKJ0nfxBT0onVFq10Y,341 +authlib/integrations/httpx_client/__init__.py,sha256=zuO_FIAdLEJ9Ch25kTa-Nsyi2gqkPuAsXdHAeYc0nBI,804 +authlib/integrations/httpx_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/httpx_client/__pycache__/assertion_client.cpython-311.pyc,, +authlib/integrations/httpx_client/__pycache__/oauth1_client.cpython-311.pyc,, +authlib/integrations/httpx_client/__pycache__/oauth2_client.cpython-311.pyc,, +authlib/integrations/httpx_client/__pycache__/utils.cpython-311.pyc,, +authlib/integrations/httpx_client/assertion_client.py,sha256=1T2ZEM3cboubUu2Zme-6KKTY1Y06BCIQMhHSQ_LC6ZY,3179 +authlib/integrations/httpx_client/oauth1_client.py,sha256=rkb-wwh3Oqryj3QZop9bo3q0td2OjFptf6rDX5VUUBY,4084 +authlib/integrations/httpx_client/oauth2_client.py,sha256=d7VLXPo_6Gboko6jVn9ooFRVFJWNKm9bLHKcsP_lzlY,8439 +authlib/integrations/httpx_client/utils.py,sha256=XF8d4xz4d7tDoaaq3LQ3GZonED1_EtW8zH_FoTcUPL4,888 +authlib/integrations/requests_client/__init__.py,sha256=Nco-Q1_wOswQ9qCgBgh0rVWPYO9ddEKnlnnJ9rGZE_U,652 +authlib/integrations/requests_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/requests_client/__pycache__/assertion_session.cpython-311.pyc,, +authlib/integrations/requests_client/__pycache__/oauth1_session.cpython-311.pyc,, +authlib/integrations/requests_client/__pycache__/oauth2_session.cpython-311.pyc,, +authlib/integrations/requests_client/__pycache__/utils.cpython-311.pyc,, +authlib/integrations/requests_client/assertion_session.py,sha256=8TAEs9s_SuB1-LymwOgfkxUOgEWCWX0Lp7eagbw3xZo,1832 +authlib/integrations/requests_client/oauth1_session.py,sha256=aXC3vFLy4ytV8G3yL-on6dJitNfdKmkttcItTHTejfU,2209 +authlib/integrations/requests_client/oauth2_session.py,sha256=4USZnWGgw8v-67Os0y17UeNi0n2JjONvRoFvtbMYAv4,4498 +authlib/integrations/requests_client/utils.py,sha256=4ohGF-9JUR9Ayw63glEednsZSbscrSD4Rx0BPkPrSeA,274 +authlib/integrations/sqla_oauth2/__init__.py,sha256=d8g3ipiiPtyk1JnmTDzeAKCdm-ozjNoERDWLEkpBgL8,548 +authlib/integrations/sqla_oauth2/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/sqla_oauth2/__pycache__/client_mixin.cpython-311.pyc,, +authlib/integrations/sqla_oauth2/__pycache__/functions.cpython-311.pyc,, +authlib/integrations/sqla_oauth2/__pycache__/tokens_mixins.cpython-311.pyc,, +authlib/integrations/sqla_oauth2/client_mixin.py,sha256=qxoJLpnH3ncdvQPYnk8HoqllOoiUZAl_md3RteELQpQ,4116 +authlib/integrations/sqla_oauth2/functions.py,sha256=0smHFpiQ6srSLR-ynLCmbFIq7iPvbBjoQUDibtD_4e8,3195 +authlib/integrations/sqla_oauth2/tokens_mixins.py,sha256=9cBNIUxKLsaP8DTkOKlkZQ1m5guxeOnpJ9t-K8Kuy0c,2008 +authlib/integrations/starlette_client/__init__.py,sha256=3bOTtevT4LmMurKF6dz0HusPbiVrPjBLVRuf6P0wHXk,666 +authlib/integrations/starlette_client/__pycache__/__init__.cpython-311.pyc,, +authlib/integrations/starlette_client/__pycache__/apps.cpython-311.pyc,, +authlib/integrations/starlette_client/__pycache__/integration.cpython-311.pyc,, +authlib/integrations/starlette_client/apps.py,sha256=BHO3DRhLv8x14DAF6CCJKaTVI6Lr9WKh4vIm4eTJPjY,3556 +authlib/integrations/starlette_client/integration.py,sha256=b2ZUKcX_L9P3DA7tdeSP3mRCmiPfzwUwVzp9cepsBA0,1964 +authlib/jose/__init__.py,sha256=qpIwbdODthy50cit_FeHECvwRGd7pIABiYEA4aiFV48,1399 +authlib/jose/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/__pycache__/errors.cpython-311.pyc,, +authlib/jose/__pycache__/jwk.cpython-311.pyc,, +authlib/jose/__pycache__/util.cpython-311.pyc,, +authlib/jose/drafts/__init__.py,sha256=A--_H-kGg_s_Yf4m_P3Whb7HRQbKynDGtQYZlnOtkFc,518 +authlib/jose/drafts/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/drafts/__pycache__/_jwe_algorithms.cpython-311.pyc,, +authlib/jose/drafts/__pycache__/_jwe_enc_cryptodome.cpython-311.pyc,, +authlib/jose/drafts/__pycache__/_jwe_enc_cryptography.cpython-311.pyc,, +authlib/jose/drafts/_jwe_algorithms.py,sha256=zbDaWazb4HdYtQmstf0MofDDaMexVwqkmnfIPfc6luY,6917 +authlib/jose/drafts/_jwe_enc_cryptodome.py,sha256=a4Vb0AUZWZWlCoTKqOfBaLZ-o1D1bgS0hKuBKJT7Kqo,1860 +authlib/jose/drafts/_jwe_enc_cryptography.py,sha256=N2Bm9zp7MMSe_kbPwjyUAyD3MMOTj5BmpEUc1k_IqUw,1743 +authlib/jose/errors.py,sha256=j-vg5TV7uiIS-xBgq-0pwxuerCDV31n5VKerCsddJR4,2977 +authlib/jose/jwk.py,sha256=VZiMATxt4UPyU6spf7wKE0rwql4udvl3jZqW85Endwo,490 +authlib/jose/rfc7515/__init__.py,sha256=0NhWGkry69LiJH6cAkwIUNmQUpuGh463TNwIpoQjtE4,360 +authlib/jose/rfc7515/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc7515/__pycache__/jws.cpython-311.pyc,, +authlib/jose/rfc7515/__pycache__/models.cpython-311.pyc,, +authlib/jose/rfc7515/jws.py,sha256=0JEEGLkI6vsERYdViRGRrZBQD9hJUS9j8OPoOxD0-2c,11270 +authlib/jose/rfc7515/models.py,sha256=B3HygZbTamiXc-rhnncMwCHL1LlblU69TCkyV2gezTo,2445 +authlib/jose/rfc7516/__init__.py,sha256=SCAxvSIWD0NF2_Gcq1BrIkKx3Bqr_6WYhFpDFt3AQ6A,465 +authlib/jose/rfc7516/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc7516/__pycache__/jwe.cpython-311.pyc,, +authlib/jose/rfc7516/__pycache__/models.cpython-311.pyc,, +authlib/jose/rfc7516/jwe.py,sha256=j26ji9QSRTuWbH3E3vz10YtDzIoteqZ5bY5VDA7Vp7U,29706 +authlib/jose/rfc7516/models.py,sha256=RdtLB8_KzxZ35DcgkXFCy25rkXEFp9tg22m5oqeBv-A,4341 +authlib/jose/rfc7517/__init__.py,sha256=LfowmYyTdC0t5wB6ptJs45Qkj-6nRmHmhRONAh-UFxY,424 +authlib/jose/rfc7517/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc7517/__pycache__/_cryptography_key.cpython-311.pyc,, +authlib/jose/rfc7517/__pycache__/asymmetric_key.cpython-311.pyc,, +authlib/jose/rfc7517/__pycache__/base_key.cpython-311.pyc,, +authlib/jose/rfc7517/__pycache__/jwk.cpython-311.pyc,, +authlib/jose/rfc7517/__pycache__/key_set.cpython-311.pyc,, +authlib/jose/rfc7517/_cryptography_key.py,sha256=1-EQ1YD7ZR7Gp7FBntrWEN-76Su_vlunuzC77VPJ_sg,1257 +authlib/jose/rfc7517/asymmetric_key.py,sha256=cP5ka_7Ez5gUvVsb54WqVsuuSLDxXGwMsqGqXAvTQlw,6229 +authlib/jose/rfc7517/base_key.py,sha256=udNnaEw7_pHCfUPNc2SyfgVwZDm32DWH7p3J3Ssgk80,3261 +authlib/jose/rfc7517/jwk.py,sha256=xUPIgYiAuyo7kDHFuVfaKeyY3MY9tJURfCIMZMo70o0,2042 +authlib/jose/rfc7517/key_set.py,sha256=CAStzaaCCq5IcPwc_6I3lqYWJdhw1E_JuMqouAF0IW8,883 +authlib/jose/rfc7518/__init__.py,sha256=69w8-62wS4DJNaCCHBUYiueiORtRl6Mlek36CNoTKOQ,879 +authlib/jose/rfc7518/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/ec_key.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/jwe_algs.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/jwe_encs.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/jwe_zips.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/jws_algs.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/oct_key.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/rsa_key.cpython-311.pyc,, +authlib/jose/rfc7518/__pycache__/util.cpython-311.pyc,, +authlib/jose/rfc7518/ec_key.py,sha256=_cot3xwrpuER_tj8eiHkx-L2EJHnYSy4iWieD0RvYKo,3511 +authlib/jose/rfc7518/jwe_algs.py,sha256=yOgfh7WYo7MCHENw4EiucP9ep5nbCd1-zKmyUSbn6Ko,11311 +authlib/jose/rfc7518/jwe_encs.py,sha256=0Mjl9wD3VYpc8ApgjWF_T3n7m2S-qtNH-cnyI-P8pAw,5047 +authlib/jose/rfc7518/jwe_zips.py,sha256=9Qron3QW-1GHqoZHGa6EjLF1VAMMlB2gvHlr2anWqtY,561 +authlib/jose/rfc7518/jws_algs.py,sha256=dfWu1QwmMtlJDB4J1JFnmKxYbCzXNmicn8djnKKNqMM,6493 +authlib/jose/rfc7518/oct_key.py,sha256=gegg3PLgdDgC9J87h4oJ5hNf06j_j-nf4C6cTI-N6yU,2375 +authlib/jose/rfc7518/rsa_key.py,sha256=hxCJAs-Ljl_RRv6D5ezAhQCpjwYhp2OU3zXtCqfO_6k,4192 +authlib/jose/rfc7518/util.py,sha256=LpOgX10QHuqss6x015QPgApTlnYJ_cQyd0hzeYALnaE,265 +authlib/jose/rfc7519/__init__.py,sha256=vJKdsUGkdKlGBKctdh5CLM2jc20903rIfAlsHjbr-hA,309 +authlib/jose/rfc7519/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc7519/__pycache__/claims.cpython-311.pyc,, +authlib/jose/rfc7519/__pycache__/jwt.cpython-311.pyc,, +authlib/jose/rfc7519/claims.py,sha256=CCm4EJiv_BKxaLFhc-34sH5cIQ6Y960I1yRaB3d9_vM,8709 +authlib/jose/rfc7519/jwt.py,sha256=myFXsrFfQLV6xxsOAVv--tQF4FuwhLtvCajZtRcIKSs,5950 +authlib/jose/rfc8037/__init__.py,sha256=MV25hs0RY6HU0pE5UKB2hmRI0Avv-V8BaPtNxIshh0A,119 +authlib/jose/rfc8037/__pycache__/__init__.cpython-311.pyc,, +authlib/jose/rfc8037/__pycache__/jws_eddsa.cpython-311.pyc,, +authlib/jose/rfc8037/__pycache__/okp_key.cpython-311.pyc,, +authlib/jose/rfc8037/jws_eddsa.py,sha256=dpg6ZKOdpqGcp582mP5yqsRLlnC81y_Mnl3xL0_tgLQ,716 +authlib/jose/rfc8037/okp_key.py,sha256=SlpI-u0J0l1K0J5rh1L9g7yiUqEn6YTpmoJc7tpRmhw,3560 +authlib/jose/util.py,sha256=eZGduiUbhgqYYxzjT-b1OaKZFTTICZE16Sq1UpyBK48,1065 +authlib/oauth1/__init__.py,sha256=8c5_O3G8lWOUqd3NgWR8CGpCnKEubMA4jeVNHocRQvQ,735 +authlib/oauth1/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth1/__pycache__/client.cpython-311.pyc,, +authlib/oauth1/__pycache__/errors.cpython-311.pyc,, +authlib/oauth1/client.py,sha256=Vyo3cfAzU1f1WavJ-WosZy3B_eV0RShW_19sx1kc5I8,6524 +authlib/oauth1/errors.py,sha256=pg0NaUgENjfTN_ba50_yQB9aSNe5Mte5MDlikFuypBY,46 +authlib/oauth1/rfc5849/__init__.py,sha256=S4rdtQiDd83QsR1hBqdsvuI-VgmgoG7rFuJH2fLP_K4,1036 +authlib/oauth1/rfc5849/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/authorization_server.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/base_server.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/client_auth.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/errors.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/models.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/parameters.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/resource_protector.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/rsa.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/signature.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/util.cpython-311.pyc,, +authlib/oauth1/rfc5849/__pycache__/wrapper.cpython-311.pyc,, +authlib/oauth1/rfc5849/authorization_server.py,sha256=nHnxaNgOfTZ0FwnFXe0--A81-rgFTBoT72OSDmjzZl8,13869 +authlib/oauth1/rfc5849/base_server.py,sha256=-9Il95qaOhoFEQYJPkSMqQailAfQRYBtbdt1b0dZagc,3849 +authlib/oauth1/rfc5849/client_auth.py,sha256=SnCp8R50CAJxyHMtkuYN5ZvbHJcLfnldctVaZtz9hnY,6920 +authlib/oauth1/rfc5849/errors.py,sha256=z7LM1IMKzzTb3vnjRziRJk1tH0dQZm1axHnw8DXKXtA,2303 +authlib/oauth1/rfc5849/models.py,sha256=qdNruenkQVjrOjRkPku4EtU7ak3giIpxwirrRJd8Hi0,3418 +authlib/oauth1/rfc5849/parameters.py,sha256=bHgF_EwyJqWi1rvqYsbDKWf9I7z5RJh1OUpNvARHYEg,3455 +authlib/oauth1/rfc5849/resource_protector.py,sha256=v4EsTBnqYAQMarwukpYmgiyT3J9khnp8Apj6n_RRabI,1258 +authlib/oauth1/rfc5849/rsa.py,sha256=z2cx8e-p2pdvzx_WQm9mcMzkGfsBJniWoy6SLn6XZa4,896 +authlib/oauth1/rfc5849/signature.py,sha256=z1xPJmbA19zXJsfOdHSUOdIjAeMYUauSJm3Q1COvnpY,14123 +authlib/oauth1/rfc5849/util.py,sha256=89kP-xwQHop8SaBjCEkppxzYO7GQFbmi3ekVyI-zFvE,136 +authlib/oauth1/rfc5849/wrapper.py,sha256=FCyqlpSnCmNT5U0H_cM2yqkdTALoRLFnOu409xUBOQ0,3945 +authlib/oauth2/__init__.py,sha256=SlhZAaE8Tudl1W5KE57rfitY-9lLSvHJa02blVg9kJo,423 +authlib/oauth2/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/__pycache__/auth.cpython-311.pyc,, +authlib/oauth2/__pycache__/base.cpython-311.pyc,, +authlib/oauth2/__pycache__/client.cpython-311.pyc,, +authlib/oauth2/auth.py,sha256=n-QNzkXpy6uh6vDwpqDMhvMiBjE0n2Fv7PX8pVO6R6Q,3484 +authlib/oauth2/base.py,sha256=cadG08-t_9mhgwzdo73bLeAa1Qvbw3qj4j-g1tbl6uo,958 +authlib/oauth2/client.py,sha256=2betIXTZq9WZGgrW8x5j87Z5lPZDuNUtKh3wEuO3z0w,17626 +authlib/oauth2/rfc6749/__init__.py,sha256=2WVrM30q6NM-DbUjEX2wUwqJwrR0nBr-9HFvjP80b5A,2323 +authlib/oauth2/rfc6749/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/authenticate_client.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/authorization_server.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/errors.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/models.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/parameters.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/requests.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/resource_protector.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/token_endpoint.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/util.cpython-311.pyc,, +authlib/oauth2/rfc6749/__pycache__/wrappers.cpython-311.pyc,, +authlib/oauth2/rfc6749/authenticate_client.py,sha256=8ZmUuxuodBQeRe2Z4vMXi_OdPjTJbbnSnY7V8ODNNZ4,3748 +authlib/oauth2/rfc6749/authorization_server.py,sha256=u4D4_hN9G8Zw3q_qKIUhOoeoQVrE5E3YGpaHyWZhEAk,11898 +authlib/oauth2/rfc6749/errors.py,sha256=J93QcmMsS_eBgsBa_P4wepkAw87C7ZDW8-JGvc3X6c8,7371 +authlib/oauth2/rfc6749/grants/__init__.py,sha256=jJZOtFgyBztuIGQxcuK1qtvnR3hWRDU-G59RuyWmqKI,1314 +authlib/oauth2/rfc6749/grants/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/authorization_code.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/base.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/client_credentials.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/implicit.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/refresh_token.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/__pycache__/resource_owner_password_credentials.cpython-311.pyc,, +authlib/oauth2/rfc6749/grants/authorization_code.py,sha256=ktdVV-2_Mp5tl7TTSj2Jxx4xpV53Mmouzb0APZmmaS8,15367 +authlib/oauth2/rfc6749/grants/base.py,sha256=IQpjcs1CtiUOkJIkGb1T-oWIlKC8NfHb4toGwNxGXsg,5132 +authlib/oauth2/rfc6749/grants/client_credentials.py,sha256=OeG_V5aR-oHxkpj5ejOo9pqfa7wu1rKRCzUZ6F9ukJk,3892 +authlib/oauth2/rfc6749/grants/implicit.py,sha256=qNBYKp8GG3sNLZSfC1VL6xRiloTXDl3c1qPXsGf3XF8,9297 +authlib/oauth2/rfc6749/grants/refresh_token.py,sha256=fBAcFH35PhdW15rkQwpK_tH88JdOCRH7XFX8vcfLRZY,6432 +authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py,sha256=bmOAvYmjEhOlV77xsKQy7JndZNspfaeYkQXUGVao0Ak,5755 +authlib/oauth2/rfc6749/models.py,sha256=lpqMhIwZbnEJazleH4cM3G3jQh14Zs5Sx_csuPy1Z64,7502 +authlib/oauth2/rfc6749/parameters.py,sha256=qjvtW7-APiQ91MVm4cu0LzHNTNFC_3HLZ44uxhlZzog,8308 +authlib/oauth2/rfc6749/requests.py,sha256=LWUOIkySStPI0J_nDm2gHMa4QdUa7Vmeqwre1XIYzgs,2148 +authlib/oauth2/rfc6749/resource_protector.py,sha256=G62crN3HobWrU9DnSNm3iEKXls2cutvpkSEdXfqP_Pg,5309 +authlib/oauth2/rfc6749/token_endpoint.py,sha256=IUx79XxDS12s3QHQukzCwqXS3s1K9l3OrkvlbApnLS0,1103 +authlib/oauth2/rfc6749/util.py,sha256=xzebTUJciyJ9qy1jrYqiXfNBq6GfqCkb9Ok0XMoPWFc,1122 +authlib/oauth2/rfc6749/wrappers.py,sha256=USYF6s4Enpgk883Ye3-vg7_0-2iX7fgeW0-ZRG2kDQE,688 +authlib/oauth2/rfc6750/__init__.py,sha256=NmL2KnczR-sddyzk3lZAnFT1Aj5nJ21B77W05gMLoVU,635 +authlib/oauth2/rfc6750/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc6750/__pycache__/errors.cpython-311.pyc,, +authlib/oauth2/rfc6750/__pycache__/parameters.cpython-311.pyc,, +authlib/oauth2/rfc6750/__pycache__/token.cpython-311.pyc,, +authlib/oauth2/rfc6750/__pycache__/validator.cpython-311.pyc,, +authlib/oauth2/rfc6750/errors.py,sha256=8jmPrtinFNL3qkxuTlKWscsWmVjdUhYPNWYOZXes90g,2827 +authlib/oauth2/rfc6750/parameters.py,sha256=-9xyt87e5uHFcjvibmYZN-CIEvxy105Hzt4VaHJ6ICI,1204 +authlib/oauth2/rfc6750/token.py,sha256=5ciVR8G3AR0xqCDPTHP858UViaE_qhA5rKkjk0vcA0E,3350 +authlib/oauth2/rfc6750/validator.py,sha256=Nb1Tg_uz5ZBWRZqzbS0Ag3_zq0GlznA0YkTGSP7TBc8,1377 +authlib/oauth2/rfc7009/__init__.py,sha256=kXO-Miq9N7fFStVJp6uqRgbleitHc_DJII5QyDCvQ10,353 +authlib/oauth2/rfc7009/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7009/__pycache__/parameters.cpython-311.pyc,, +authlib/oauth2/rfc7009/__pycache__/revocation.cpython-311.pyc,, +authlib/oauth2/rfc7009/parameters.py,sha256=klTaHudte5Oncfw6M5Mr6LL9Y-cVnN29Th1ix-uVYAU,854 +authlib/oauth2/rfc7009/revocation.py,sha256=VTA6JLVKlyedN6SUlHO-_v0J0MGU8r7jVxPLgnFpy40,4061 +authlib/oauth2/rfc7521/__init__.py,sha256=sI-EfGOAZTLM-LxfBfYEXU_74cd3ib63g8fkGu39PJA,67 +authlib/oauth2/rfc7521/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7521/__pycache__/client.cpython-311.pyc,, +authlib/oauth2/rfc7521/client.py,sha256=MAqHEfl6f11AmZYxwT2GNqV1xA64hUO4-ksxujSR8dA,2683 +authlib/oauth2/rfc7523/__init__.py,sha256=0vllgNHhuHqWawuXy7vKtBT7wjHrn3YRA6ftmAkKOkA,852 +authlib/oauth2/rfc7523/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/assertion.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/auth.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/client.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/jwt_bearer.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/token.cpython-311.pyc,, +authlib/oauth2/rfc7523/__pycache__/validator.cpython-311.pyc,, +authlib/oauth2/rfc7523/assertion.py,sha256=ZFklvOm9ulLdUCoWn245igaK4LjXsFIJn5GcWJRl7iU,2024 +authlib/oauth2/rfc7523/auth.py,sha256=owAbFXkftxI79uPjchtfsFmvhUAquyA02McmaovVrNw,3346 +authlib/oauth2/rfc7523/client.py,sha256=dfleCVFY3VXJk6OKXVZBXKSVUY4817jqtD6qWA6nO9I,4388 +authlib/oauth2/rfc7523/jwt_bearer.py,sha256=e-G12oauDZiz8QTWrHXIXGuLcOpXgqZTrw5y3jjlGoQ,6532 +authlib/oauth2/rfc7523/token.py,sha256=Ovy75iNBC6BhMx44sZ6OjsWXxxIoZUrsYiCNqeGR1Fs,3321 +authlib/oauth2/rfc7523/validator.py,sha256=VXx11aa1nF3cp-fdsbQSkx0oXchiArhCtIGdnK-A9Ks,1602 +authlib/oauth2/rfc7591/__init__.py,sha256=Wb4PkmuGy8makvgceRVV1472cGL1udCS2bD9SGbqJyg,667 +authlib/oauth2/rfc7591/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7591/__pycache__/claims.cpython-311.pyc,, +authlib/oauth2/rfc7591/__pycache__/endpoint.cpython-311.pyc,, +authlib/oauth2/rfc7591/__pycache__/errors.cpython-311.pyc,, +authlib/oauth2/rfc7591/claims.py,sha256=4eMNWjY0Osh4IvdxAZE-bHXf0gqfk0VoERxtkSL07AQ,9809 +authlib/oauth2/rfc7591/endpoint.py,sha256=3EYBDrBhp8G7iK12hBZBns_95yDmLeXzEiP5I1JRW6g,8104 +authlib/oauth2/rfc7591/errors.py,sha256=nx_1-wyT1kznHXh1R91uwa1gbRJ5YKUxWY1SiFVRTWg,1098 +authlib/oauth2/rfc7592/__init__.py,sha256=4du9AjzDNRXbmtkTWBPbpXz5ivugseKIuXtp3oAsk4w,315 +authlib/oauth2/rfc7592/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7592/__pycache__/endpoint.cpython-311.pyc,, +authlib/oauth2/rfc7592/endpoint.py,sha256=Ec_QncWPPtK-1hM6KVlBM_94ufb_ci6ZfC8ydQxC8yE,9759 +authlib/oauth2/rfc7636/__init__.py,sha256=U-568f2yopL7jOGbL-6js5AMm9mSu74PgK21AkmfAN8,340 +authlib/oauth2/rfc7636/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7636/__pycache__/challenge.cpython-311.pyc,, +authlib/oauth2/rfc7636/challenge.py,sha256=MugGeYxa9Vwz9_D5EtQD3FekN29nmmtQ2wj6zRaKyFk,5173 +authlib/oauth2/rfc7662/__init__.py,sha256=buYHq9DwjRGLIENJkDJ0-9VNyo_8Sui--5C3WyFc6O4,423 +authlib/oauth2/rfc7662/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc7662/__pycache__/introspection.cpython-311.pyc,, +authlib/oauth2/rfc7662/__pycache__/models.cpython-311.pyc,, +authlib/oauth2/rfc7662/__pycache__/token_validator.cpython-311.pyc,, +authlib/oauth2/rfc7662/introspection.py,sha256=BZ5ET90mL22R3wp5OobUnAXwzrbTFJkqcyZ8uuixPqQ,5244 +authlib/oauth2/rfc7662/models.py,sha256=vlmy-_UJmzy4rFjOjLsKpybpodoMsxt4FRi-URDoCKQ,868 +authlib/oauth2/rfc7662/token_validator.py,sha256=x7s4funp-7edBTWs64hsTkxZLxf5O8yPKkDdKPwMMWA,1339 +authlib/oauth2/rfc8414/__init__.py,sha256=ilEoqCBQ-lqXM7RBVlqAzbHbW39d_5oTKbKzJU4bfIY,361 +authlib/oauth2/rfc8414/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc8414/__pycache__/models.cpython-311.pyc,, +authlib/oauth2/rfc8414/__pycache__/well_known.cpython-311.pyc,, +authlib/oauth2/rfc8414/models.py,sha256=-ajtcMT5OleAEwnUNzvRGVdpDEDJ2nfhXn9WuLPNwpE,17450 +authlib/oauth2/rfc8414/well_known.py,sha256=-2eObCyYbBJHBokeBfuVHLqL2gvaLOK-aOVcy4M-5i8,727 +authlib/oauth2/rfc8628/__init__.py,sha256=TqyFJ_dAH5XjFdHFzxuiYrfpvX8rR0UiCyDoEPXPhhc,678 +authlib/oauth2/rfc8628/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc8628/__pycache__/device_code.cpython-311.pyc,, +authlib/oauth2/rfc8628/__pycache__/endpoint.cpython-311.pyc,, +authlib/oauth2/rfc8628/__pycache__/errors.cpython-311.pyc,, +authlib/oauth2/rfc8628/__pycache__/models.cpython-311.pyc,, +authlib/oauth2/rfc8628/device_code.py,sha256=drtONNm2ZNzqNN5Ts93GEE68RaTVqhJFJPUnb16STT0,7715 +authlib/oauth2/rfc8628/endpoint.py,sha256=bA4U9n7I6rt1FblM5eVTx3jhfbJ0-7ptCD9kUV_wLuo,7107 +authlib/oauth2/rfc8628/errors.py,sha256=c761U-GWU58bmfyEPw1VveaOkmOgXElBVlWaE-tXMTg,919 +authlib/oauth2/rfc8628/models.py,sha256=o5734crJs0HZ8d_Iyqoh5O3i_D4_Pf2nml-HU62Re80,827 +authlib/oauth2/rfc8693/__init__.py,sha256=Omkb_UdGC1dsrqP8SUD1f562SnEyOs_pd5G03pouMrE,182 +authlib/oauth2/rfc8693/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc9068/__init__.py,sha256=XWpuGChgg0Rvyy7ZZ3c26q5LiTcQQbZboJknSCpKg4s,332 +authlib/oauth2/rfc9068/__pycache__/__init__.cpython-311.pyc,, +authlib/oauth2/rfc9068/__pycache__/claims.cpython-311.pyc,, +authlib/oauth2/rfc9068/__pycache__/introspection.cpython-311.pyc,, +authlib/oauth2/rfc9068/__pycache__/revocation.cpython-311.pyc,, +authlib/oauth2/rfc9068/__pycache__/token.cpython-311.pyc,, +authlib/oauth2/rfc9068/__pycache__/token_validator.cpython-311.pyc,, +authlib/oauth2/rfc9068/claims.py,sha256=omSrkPIBtOwPGmgIl8ERc8IAk9bzAewoQw9fkvDZauE,1866 +authlib/oauth2/rfc9068/introspection.py,sha256=rJ7QfzYM0ynQFkzYJ-icEByeaBzX1AcgxgLyQGwwOt8,4251 +authlib/oauth2/rfc9068/revocation.py,sha256=vqwfdE5sWEL2bJUrjoYVLllkFaAOt_BuSxU9UxQFnPs,2521 +authlib/oauth2/rfc9068/token.py,sha256=LZgwzqRtjiaZv-EKwUSl07IM2v7XfvgTioTmfwgqFoE,8639 +authlib/oauth2/rfc9068/token_validator.py,sha256=I8ZWuEGMByR0KV4KxUnr-iLHTITBvCf9lfzzpgh_Dbk,6890 +authlib/oidc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +authlib/oidc/__pycache__/__init__.cpython-311.pyc,, +authlib/oidc/core/__init__.py,sha256=NX_zhew3om9I0qVozcXyHshqqCqbk2V1qhiHO8Q4Ztk,650 +authlib/oidc/core/__pycache__/__init__.cpython-311.pyc,, +authlib/oidc/core/__pycache__/claims.cpython-311.pyc,, +authlib/oidc/core/__pycache__/errors.cpython-311.pyc,, +authlib/oidc/core/__pycache__/models.cpython-311.pyc,, +authlib/oidc/core/__pycache__/util.cpython-311.pyc,, +authlib/oidc/core/claims.py,sha256=AvHWg9GPMjQCLDS8uKtOxlXi_prz-fyWGJAF34vD5w0,10187 +authlib/oidc/core/errors.py,sha256=2nZdVIlTweNwXapWa7QodgCR0_5neGlCbGB2U4xUT6o,2883 +authlib/oidc/core/grants/__init__.py,sha256=E6KAaqJMyRYWD2635gdcfE-oN_RRaIAL_imHATP12Uo,226 +authlib/oidc/core/grants/__pycache__/__init__.cpython-311.pyc,, +authlib/oidc/core/grants/__pycache__/code.cpython-311.pyc,, +authlib/oidc/core/grants/__pycache__/hybrid.cpython-311.pyc,, +authlib/oidc/core/grants/__pycache__/implicit.cpython-311.pyc,, +authlib/oidc/core/grants/__pycache__/util.cpython-311.pyc,, +authlib/oidc/core/grants/code.py,sha256=Jl7Asl_PzMPHHaileADIkSn8cODbYCLSC14F4zkUy4U,4848 +authlib/oidc/core/grants/hybrid.py,sha256=t20y1cX83vX-fgHL0v8JOtrjcs1pbe4zhMRnYarHKFM,3362 +authlib/oidc/core/grants/implicit.py,sha256=wzpMiiC0z31dvSrZVrZThBL5WsTWowvlu2MN5JI97dM,5288 +authlib/oidc/core/grants/util.py,sha256=0Zmyd_ofSrzAyJYf7dxtuQ7wcx2shkha4aDIIIeo-F8,4122 +authlib/oidc/core/models.py,sha256=eYGC-NcIuu1m50h5o1bJFwUjwGF9l-AytTic8B07zmU,413 +authlib/oidc/core/util.py,sha256=LggTWoq-qHvF5JW6MsMxBa76rxVlxCARIlQXkxkZD_o,382 +authlib/oidc/discovery/__init__.py,sha256=sI40mT-HXoRPqxAE5UfPg-r3xa_wrQZ8jIuN5n5nm70,321 +authlib/oidc/discovery/__pycache__/__init__.cpython-311.pyc,, +authlib/oidc/discovery/__pycache__/models.cpython-311.pyc,, +authlib/oidc/discovery/__pycache__/well_known.cpython-311.pyc,, +authlib/oidc/discovery/models.py,sha256=dNpNiDyss3T-6FGkoWe_t0_D2xTHSppd4wMY_BsoQNg,12575 +authlib/oidc/discovery/well_known.py,sha256=3VmfTsVReChVVxBuN1eoL7fIRHL0-tN0aXH2XVKeiDM,574 diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/REQUESTED b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/WHEEL b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/WHEEL new file mode 100644 index 00000000..4724c457 --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: true +Tag: py2-none-any +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/top_level.txt b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/top_level.txt new file mode 100644 index 00000000..b91e7e46 --- /dev/null +++ b/.venv/Lib/site-packages/Authlib-1.3.0.dist-info/top_level.txt @@ -0,0 +1 @@ +authlib diff --git a/.venv/Lib/site-packages/__pycache__/six.cpython-311.pyc b/.venv/Lib/site-packages/__pycache__/six.cpython-311.pyc new file mode 100644 index 00000000..fa547f19 Binary files /dev/null and b/.venv/Lib/site-packages/__pycache__/six.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/_cffi_backend.cp311-win_amd64.pyd b/.venv/Lib/site-packages/_cffi_backend.cp311-win_amd64.pyd new file mode 100644 index 00000000..6d30f70e Binary files /dev/null and b/.venv/Lib/site-packages/_cffi_backend.cp311-win_amd64.pyd differ diff --git a/.venv/Lib/site-packages/authlib/__init__.py b/.venv/Lib/site-packages/authlib/__init__.py new file mode 100644 index 00000000..2a2e5adc --- /dev/null +++ b/.venv/Lib/site-packages/authlib/__init__.py @@ -0,0 +1,17 @@ +""" + authlib + ~~~~~~~ + + The ultimate Python library in building OAuth 1.0, OAuth 2.0 and OpenID + Connect clients and providers. It covers from low level specification + implementation to high level framework integrations. + + :copyright: (c) 2017 by Hsiaoming Yang. + :license: BSD, see LICENSE for more details. +""" +from .consts import version, homepage, author + +__version__ = version +__homepage__ = homepage +__author__ = author +__license__ = 'BSD-3-Clause' diff --git a/.venv/Lib/site-packages/authlib/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6a82191d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/__pycache__/consts.cpython-311.pyc b/.venv/Lib/site-packages/authlib/__pycache__/consts.cpython-311.pyc new file mode 100644 index 00000000..988cc316 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/__pycache__/consts.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/__pycache__/deprecate.cpython-311.pyc b/.venv/Lib/site-packages/authlib/__pycache__/deprecate.cpython-311.pyc new file mode 100644 index 00000000..eb83ab32 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/__pycache__/deprecate.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/__init__.py b/.venv/Lib/site-packages/authlib/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/authlib/common/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..3ae18f90 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/common/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/__pycache__/encoding.cpython-311.pyc b/.venv/Lib/site-packages/authlib/common/__pycache__/encoding.cpython-311.pyc new file mode 100644 index 00000000..869a0697 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/common/__pycache__/encoding.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/common/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..d7b849f3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/common/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/__pycache__/security.cpython-311.pyc b/.venv/Lib/site-packages/authlib/common/__pycache__/security.cpython-311.pyc new file mode 100644 index 00000000..4487dbcf Binary files /dev/null and b/.venv/Lib/site-packages/authlib/common/__pycache__/security.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/__pycache__/urls.cpython-311.pyc b/.venv/Lib/site-packages/authlib/common/__pycache__/urls.cpython-311.pyc new file mode 100644 index 00000000..2555f8e6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/common/__pycache__/urls.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/common/encoding.py b/.venv/Lib/site-packages/authlib/common/encoding.py new file mode 100644 index 00000000..f450ca47 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/common/encoding.py @@ -0,0 +1,66 @@ +import json +import base64 +import struct + + +def to_bytes(x, charset='utf-8', errors='strict'): + if x is None: + return None + if isinstance(x, bytes): + return x + if isinstance(x, str): + return x.encode(charset, errors) + if isinstance(x, (int, float)): + return str(x).encode(charset, errors) + return bytes(x) + + +def to_unicode(x, charset='utf-8', errors='strict'): + if x is None or isinstance(x, str): + return x + if isinstance(x, bytes): + return x.decode(charset, errors) + return str(x) + + +def to_native(x, encoding='ascii'): + if isinstance(x, str): + return x + return x.decode(encoding) + + +def json_loads(s): + return json.loads(s) + + +def json_dumps(data, ensure_ascii=False): + return json.dumps(data, ensure_ascii=ensure_ascii, separators=(',', ':')) + + +def urlsafe_b64decode(s): + s += b'=' * (-len(s) % 4) + return base64.urlsafe_b64decode(s) + + +def urlsafe_b64encode(s): + return base64.urlsafe_b64encode(s).rstrip(b'=') + + +def base64_to_int(s): + data = urlsafe_b64decode(to_bytes(s, charset='ascii')) + buf = struct.unpack('%sB' % len(data), data) + return int(''.join(["%02x" % byte for byte in buf]), 16) + + +def int_to_base64(num): + if num < 0: + raise ValueError('Must be a positive integer') + + s = num.to_bytes((num.bit_length() + 7) // 8, 'big', signed=False) + return to_unicode(urlsafe_b64encode(s)) + + +def json_b64encode(text): + if isinstance(text, dict): + text = json_dumps(text) + return urlsafe_b64encode(to_bytes(text)) diff --git a/.venv/Lib/site-packages/authlib/common/errors.py b/.venv/Lib/site-packages/authlib/common/errors.py new file mode 100644 index 00000000..56515bab --- /dev/null +++ b/.venv/Lib/site-packages/authlib/common/errors.py @@ -0,0 +1,63 @@ +from authlib.consts import default_json_headers + + +class AuthlibBaseError(Exception): + """Base Exception for all errors in Authlib.""" + + #: short-string error code + error = None + #: long-string to describe this error + description = '' + #: web page that describes this error + uri = None + + def __init__(self, error=None, description=None, uri=None): + if error is not None: + self.error = error + if description is not None: + self.description = description + if uri is not None: + self.uri = uri + + message = f'{self.error}: {self.description}' + super().__init__(message) + + def __repr__(self): + return f'<{self.__class__.__name__} "{self.error}">' + + +class AuthlibHTTPError(AuthlibBaseError): + #: HTTP status code + status_code = 400 + + def __init__(self, error=None, description=None, uri=None, + status_code=None): + super().__init__(error, description, uri) + if status_code is not None: + self.status_code = status_code + + def get_error_description(self): + return self.description + + def get_body(self): + error = [('error', self.error)] + + if self.description: + error.append(('error_description', self.description)) + + if self.uri: + error.append(('error_uri', self.uri)) + return error + + def get_headers(self): + return default_json_headers[:] + + def __call__(self, uri=None): + self.uri = uri + body = dict(self.get_body()) + headers = self.get_headers() + return self.status_code, body, headers + + +class ContinueIteration(AuthlibBaseError): + pass diff --git a/.venv/Lib/site-packages/authlib/common/security.py b/.venv/Lib/site-packages/authlib/common/security.py new file mode 100644 index 00000000..b05ea144 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/common/security.py @@ -0,0 +1,19 @@ +import os +import string +import random + +UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits + + +def generate_token(length=30, chars=UNICODE_ASCII_CHARACTER_SET): + rand = random.SystemRandom() + return ''.join(rand.choice(chars) for _ in range(length)) + + +def is_secure_transport(uri): + """Check if the uri is over ssl.""" + if os.getenv('AUTHLIB_INSECURE_TRANSPORT'): + return True + + uri = uri.lower() + return uri.startswith(('https://', 'http://localhost:')) diff --git a/.venv/Lib/site-packages/authlib/common/urls.py b/.venv/Lib/site-packages/authlib/common/urls.py new file mode 100644 index 00000000..1d1847fa --- /dev/null +++ b/.venv/Lib/site-packages/authlib/common/urls.py @@ -0,0 +1,146 @@ +""" + authlib.util.urls + ~~~~~~~~~~~~~~~~~ + + Wrapper functions for URL encoding and decoding. +""" + +import re +from urllib.parse import quote as _quote +from urllib.parse import unquote as _unquote +from urllib.parse import urlencode as _urlencode +import urllib.parse as urlparse + +from .encoding import to_unicode, to_bytes + +always_safe = ( + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + 'abcdefghijklmnopqrstuvwxyz' + '0123456789_.-' +) +urlencoded = set(always_safe) | set('=&;:%+~,*@!()/?') +INVALID_HEX_PATTERN = re.compile(r'%[^0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]') + + +def url_encode(params): + encoded = [] + for k, v in params: + encoded.append((to_bytes(k), to_bytes(v))) + return to_unicode(_urlencode(encoded)) + + +def url_decode(query): + """Decode a query string in x-www-form-urlencoded format into a sequence + of two-element tuples. + + Unlike urlparse.parse_qsl(..., strict_parsing=True) urldecode will enforce + correct formatting of the query string by validation. If validation fails + a ValueError will be raised. urllib.parse_qsl will only raise errors if + any of name-value pairs omits the equals sign. + """ + # Check if query contains invalid characters + if query and not set(query) <= urlencoded: + error = ("Error trying to decode a non urlencoded string. " + "Found invalid characters: %s " + "in the string: '%s'. " + "Please ensure the request/response body is " + "x-www-form-urlencoded.") + raise ValueError(error % (set(query) - urlencoded, query)) + + # Check for correctly hex encoded values using a regular expression + # All encoded values begin with % followed by two hex characters + # correct = %00, %A0, %0A, %FF + # invalid = %G0, %5H, %PO + if INVALID_HEX_PATTERN.search(query): + raise ValueError('Invalid hex encoding in query string.') + + # We encode to utf-8 prior to parsing because parse_qsl behaves + # differently on unicode input in python 2 and 3. + # Python 2.7 + # >>> urlparse.parse_qsl(u'%E5%95%A6%E5%95%A6') + # u'\xe5\x95\xa6\xe5\x95\xa6' + # Python 2.7, non unicode input gives the same + # >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6') + # '\xe5\x95\xa6\xe5\x95\xa6' + # but now we can decode it to unicode + # >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6').decode('utf-8') + # u'\u5566\u5566' + # Python 3.3 however + # >>> urllib.parse.parse_qsl(u'%E5%95%A6%E5%95%A6') + # u'\u5566\u5566' + + # We want to allow queries such as "c2" whereas urlparse.parse_qsl + # with the strict_parsing flag will not. + params = urlparse.parse_qsl(query, keep_blank_values=True) + + # unicode all the things + decoded = [] + for k, v in params: + decoded.append((to_unicode(k), to_unicode(v))) + return decoded + + +def add_params_to_qs(query, params): + """Extend a query with a list of two-tuples.""" + if isinstance(params, dict): + params = params.items() + + qs = urlparse.parse_qsl(query, keep_blank_values=True) + qs.extend(params) + return url_encode(qs) + + +def add_params_to_uri(uri, params, fragment=False): + """Add a list of two-tuples to the uri query components.""" + sch, net, path, par, query, fra = urlparse.urlparse(uri) + if fragment: + fra = add_params_to_qs(fra, params) + else: + query = add_params_to_qs(query, params) + return urlparse.urlunparse((sch, net, path, par, query, fra)) + + +def quote(s, safe=b'/'): + return to_unicode(_quote(to_bytes(s), safe)) + + +def unquote(s): + return to_unicode(_unquote(s)) + + +def quote_url(s): + return quote(s, b'~@#$&()*!+=:;,.?/\'') + + +def extract_params(raw): + """Extract parameters and return them as a list of 2-tuples. + + Will successfully extract parameters from urlencoded query strings, + dicts, or lists of 2-tuples. Empty strings/dicts/lists will return an + empty list of parameters. Any other input will result in a return + value of None. + """ + if isinstance(raw, (list, tuple)): + try: + raw = dict(raw) + except (TypeError, ValueError): + return None + + if isinstance(raw, dict): + params = [] + for k, v in raw.items(): + params.append((to_unicode(k), to_unicode(v))) + return params + + if not raw: + return None + + try: + return url_decode(raw) + except ValueError: + return None + + +def is_valid_url(url): + parsed = urlparse.urlparse(url) + return parsed.scheme and parsed.hostname diff --git a/.venv/Lib/site-packages/authlib/consts.py b/.venv/Lib/site-packages/authlib/consts.py new file mode 100644 index 00000000..e310e793 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/consts.py @@ -0,0 +1,11 @@ +name = 'Authlib' +version = '1.3.0' +author = 'Hsiaoming Yang ' +homepage = 'https://authlib.org/' +default_user_agent = f'{name}/{version} (+{homepage})' + +default_json_headers = [ + ('Content-Type', 'application/json'), + ('Cache-Control', 'no-store'), + ('Pragma', 'no-cache'), +] diff --git a/.venv/Lib/site-packages/authlib/deprecate.py b/.venv/Lib/site-packages/authlib/deprecate.py new file mode 100644 index 00000000..7d581d69 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/deprecate.py @@ -0,0 +1,16 @@ +import warnings + + +class AuthlibDeprecationWarning(DeprecationWarning): + pass + + +warnings.simplefilter('always', AuthlibDeprecationWarning) + + +def deprecate(message, version=None, link_uid=None, link_file=None): + if version: + message += f'\nIt will be compatible before version {version}.' + if link_uid and link_file: + message += f'\nRead more ' + warnings.warn(AuthlibDeprecationWarning(message), stacklevel=2) diff --git a/.venv/Lib/site-packages/authlib/integrations/__init__.py b/.venv/Lib/site-packages/authlib/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/authlib/integrations/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..bbf0c0dd Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/base_client/__init__.py new file mode 100644 index 00000000..077301f2 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/__init__.py @@ -0,0 +1,18 @@ +from .registry import BaseOAuth +from .sync_app import BaseApp, OAuth1Mixin, OAuth2Mixin +from .sync_openid import OpenIDMixin +from .framework_integration import FrameworkIntegration +from .errors import ( + OAuthError, MissingRequestTokenError, MissingTokenError, + TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError, + MismatchingStateError, +) + +__all__ = [ + 'BaseOAuth', + 'BaseApp', 'OAuth1Mixin', 'OAuth2Mixin', + 'OpenIDMixin', 'FrameworkIntegration', + 'OAuthError', 'MissingRequestTokenError', 'MissingTokenError', + 'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError', + 'MismatchingStateError', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..348bb9e9 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_app.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_app.cpython-311.pyc new file mode 100644 index 00000000..ef500c9b Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_app.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_openid.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_openid.cpython-311.pyc new file mode 100644 index 00000000..186d8f72 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/async_openid.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..5c3ef8b1 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/framework_integration.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/framework_integration.cpython-311.pyc new file mode 100644 index 00000000..89b78971 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/framework_integration.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/registry.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/registry.cpython-311.pyc new file mode 100644 index 00000000..7d2c8ead Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/registry.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_app.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_app.cpython-311.pyc new file mode 100644 index 00000000..5634bec3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_app.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_openid.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_openid.cpython-311.pyc new file mode 100644 index 00000000..21d71087 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/base_client/__pycache__/sync_openid.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/async_app.py b/.venv/Lib/site-packages/authlib/integrations/base_client/async_app.py new file mode 100644 index 00000000..640896e7 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/async_app.py @@ -0,0 +1,144 @@ +import time +import logging +from authlib.common.urls import urlparse +from .errors import ( + MissingRequestTokenError, + MissingTokenError, +) +from .sync_app import OAuth1Base, OAuth2Base + +log = logging.getLogger(__name__) + +__all__ = ['AsyncOAuth1Mixin', 'AsyncOAuth2Mixin'] + + +class AsyncOAuth1Mixin(OAuth1Base): + async def request(self, method, url, token=None, **kwargs): + async with self._get_oauth_client() as session: + return await _http_request(self, session, method, url, token, kwargs) + + async def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + if not self.authorize_url: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + async with self._get_oauth_client() as client: + client.redirect_uri = redirect_uri + params = {} + if self.request_token_params: + params.update(self.request_token_params) + request_token = await client.fetch_request_token(self.request_token_url, **params) + log.debug(f'Fetch request token: {request_token!r}') + url = client.create_authorization_url(self.authorize_url, **kwargs) + state = request_token['oauth_token'] + return {'url': url, 'request_token': request_token, 'state': state} + + async def fetch_access_token(self, request_token=None, **kwargs): + """Fetch access token in one step. + + :param request_token: A previous request token for OAuth 1. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + async with self._get_oauth_client() as client: + if request_token is None: + raise MissingRequestTokenError() + # merge request token with verifier + token = {} + token.update(request_token) + token.update(kwargs) + client.token = token + params = self.access_token_params or {} + token = await client.fetch_access_token(self.access_token_url, **params) + return token + + +class AsyncOAuth2Mixin(OAuth2Base): + async def _on_update_token(self, token, refresh_token=None, access_token=None): + if self._update_token: + await self._update_token( + token, + refresh_token=refresh_token, + access_token=access_token, + ) + + async def load_server_metadata(self): + if self._server_metadata_url and '_loaded_at' not in self.server_metadata: + async with self.client_cls(**self.client_kwargs) as client: + resp = await client.request('GET', self._server_metadata_url, withhold_token=True) + resp.raise_for_status() + metadata = resp.json() + metadata['_loaded_at'] = time.time() + self.server_metadata.update(metadata) + return self.server_metadata + + async def request(self, method, url, token=None, **kwargs): + metadata = await self.load_server_metadata() + async with self._get_oauth_client(**metadata) as session: + return await _http_request(self, session, method, url, token, kwargs) + + async def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + metadata = await self.load_server_metadata() + authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint') + if not authorization_endpoint: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + async with self._get_oauth_client(**metadata) as client: + client.redirect_uri = redirect_uri + return self._create_oauth2_authorization_url( + client, authorization_endpoint, **kwargs) + + async def fetch_access_token(self, redirect_uri=None, **kwargs): + """Fetch access token in the final step. + + :param redirect_uri: Callback or Redirect URI that is used in + previous :meth:`authorize_redirect`. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + metadata = await self.load_server_metadata() + token_endpoint = self.access_token_url or metadata.get('token_endpoint') + async with self._get_oauth_client(**metadata) as client: + if redirect_uri is not None: + client.redirect_uri = redirect_uri + params = {} + if self.access_token_params: + params.update(self.access_token_params) + params.update(kwargs) + token = await client.fetch_token(token_endpoint, **params) + return token + + +async def _http_request(ctx, session, method, url, token, kwargs): + request = kwargs.pop('request', None) + withhold_token = kwargs.get('withhold_token') + if ctx.api_base_url and not url.startswith(('https://', 'http://')): + url = urlparse.urljoin(ctx.api_base_url, url) + + if withhold_token: + return await session.request(method, url, **kwargs) + + if token is None and ctx._fetch_token and request: + token = await ctx._fetch_token(request) + if token is None: + raise MissingTokenError() + + session.token = token + return await session.request(method, url, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/async_openid.py b/.venv/Lib/site-packages/authlib/integrations/base_client/async_openid.py new file mode 100644 index 00000000..68100f2f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/async_openid.py @@ -0,0 +1,79 @@ +from authlib.jose import JsonWebToken, JsonWebKey +from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken + +__all__ = ['AsyncOpenIDMixin'] + + +class AsyncOpenIDMixin: + async def fetch_jwk_set(self, force=False): + metadata = await self.load_server_metadata() + jwk_set = metadata.get('jwks') + if jwk_set and not force: + return jwk_set + + uri = metadata.get('jwks_uri') + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + async with self.client_cls(**self.client_kwargs) as client: + resp = await client.request('GET', uri, withhold_token=True) + resp.raise_for_status() + jwk_set = resp.json() + + self.server_metadata['jwks'] = jwk_set + return jwk_set + + async def userinfo(self, **kwargs): + """Fetch user info from ``userinfo_endpoint``.""" + metadata = await self.load_server_metadata() + resp = await self.get(metadata['userinfo_endpoint'], **kwargs) + resp.raise_for_status() + data = resp.json() + return UserInfo(data) + + async def parse_id_token(self, token, nonce, claims_options=None): + """Return an instance of UserInfo from token's ``id_token``.""" + claims_params = dict( + nonce=nonce, + client_id=self.client_id, + ) + if 'access_token' in token: + claims_params['access_token'] = token['access_token'] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + metadata = await self.load_server_metadata() + if claims_options is None and 'issuer' in metadata: + claims_options = {'iss': {'values': [metadata['issuer']]}} + + alg_values = metadata.get('id_token_signing_alg_values_supported') + if not alg_values: + alg_values = ['RS256'] + + jwt = JsonWebToken(alg_values) + + jwk_set = await self.fetch_jwk_set() + try: + claims = jwt.decode( + token['id_token'], + key=JsonWebKey.import_key_set(jwk_set), + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + except ValueError: + jwk_set = await self.fetch_jwk_set(force=True) + claims = jwt.decode( + token['id_token'], + key=JsonWebKey.import_key_set(jwk_set), + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + + # https://github.com/lepture/authlib/issues/259 + if claims.get('nonce_supported') is False: + claims.params['nonce'] = None + claims.validate(leeway=120) + return UserInfo(claims) diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/errors.py b/.venv/Lib/site-packages/authlib/integrations/base_client/errors.py new file mode 100644 index 00000000..bb4dd2b1 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/errors.py @@ -0,0 +1,30 @@ +from authlib.common.errors import AuthlibBaseError + + +class OAuthError(AuthlibBaseError): + error = 'oauth_error' + + +class MissingRequestTokenError(OAuthError): + error = 'missing_request_token' + + +class MissingTokenError(OAuthError): + error = 'missing_token' + + +class TokenExpiredError(OAuthError): + error = 'token_expired' + + +class InvalidTokenError(OAuthError): + error = 'token_invalid' + + +class UnsupportedTokenTypeError(OAuthError): + error = 'unsupported_token_type' + + +class MismatchingStateError(OAuthError): + error = 'mismatching_state' + description = 'CSRF Warning! State not equal in request and response.' diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/framework_integration.py b/.venv/Lib/site-packages/authlib/integrations/base_client/framework_integration.py new file mode 100644 index 00000000..9243e8f0 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/framework_integration.py @@ -0,0 +1,64 @@ +import json +import time + + +class FrameworkIntegration: + expires_in = 3600 + + def __init__(self, name, cache=None): + self.name = name + self.cache = cache + + def _get_cache_data(self, key): + value = self.cache.get(key) + if not value: + return None + try: + return json.loads(value) + except (TypeError, ValueError): + return None + + def _clear_session_state(self, session): + now = time.time() + for key in dict(session): + if '_authlib_' in key: + # TODO: remove in future + session.pop(key) + elif key.startswith('_state_'): + value = session[key] + exp = value.get('exp') + if not exp or exp < now: + session.pop(key) + + def get_state_data(self, session, state): + key = f'_state_{self.name}_{state}' + if self.cache: + value = self._get_cache_data(key) + else: + value = session.get(key) + if value: + return value.get('data') + return None + + def set_state_data(self, session, state, data): + key = f'_state_{self.name}_{state}' + if self.cache: + self.cache.set(key, json.dumps({'data': data}), self.expires_in) + else: + now = time.time() + session[key] = {'data': data, 'exp': now + self.expires_in} + + def clear_state_data(self, session, state): + key = f'_state_{self.name}_{state}' + if self.cache: + self.cache.delete(key) + else: + session.pop(key, None) + self._clear_session_state(session) + + def update_token(self, token, refresh_token=None, access_token=None): + raise NotImplementedError() + + @staticmethod + def load_config(oauth, name, params): + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/registry.py b/.venv/Lib/site-packages/authlib/integrations/base_client/registry.py new file mode 100644 index 00000000..68d1be5d --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/registry.py @@ -0,0 +1,131 @@ +import functools +from .framework_integration import FrameworkIntegration + +__all__ = ['BaseOAuth'] + + +OAUTH_CLIENT_PARAMS = ( + 'client_id', 'client_secret', + 'request_token_url', 'request_token_params', + 'access_token_url', 'access_token_params', + 'refresh_token_url', 'refresh_token_params', + 'authorize_url', 'authorize_params', + 'api_base_url', 'client_kwargs', + 'server_metadata_url', +) + + +class BaseOAuth: + """Registry for oauth clients. + + Create an instance for registry:: + + oauth = OAuth() + """ + oauth1_client_cls = None + oauth2_client_cls = None + framework_integration_cls = FrameworkIntegration + + def __init__(self, cache=None, fetch_token=None, update_token=None): + self._registry = {} + self._clients = {} + self.cache = cache + self.fetch_token = fetch_token + self.update_token = update_token + + def create_client(self, name): + """Create or get the given named OAuth client. For instance, the + OAuth registry has ``.register`` a twitter client, developers may + access the client with:: + + client = oauth.create_client('twitter') + + :param: name: Name of the remote application + :return: OAuth remote app + """ + if name in self._clients: + return self._clients[name] + + if name not in self._registry: + return None + + overwrite, config = self._registry[name] + client_cls = config.pop('client_cls', None) + + if client_cls and client_cls.OAUTH_APP_CONFIG: + kwargs = client_cls.OAUTH_APP_CONFIG + kwargs.update(config) + else: + kwargs = config + + kwargs = self.generate_client_kwargs(name, overwrite, **kwargs) + framework = self.framework_integration_cls(name, self.cache) + if client_cls: + client = client_cls(framework, name, **kwargs) + elif kwargs.get('request_token_url'): + client = self.oauth1_client_cls(framework, name, **kwargs) + else: + client = self.oauth2_client_cls(framework, name, **kwargs) + + self._clients[name] = client + return client + + def register(self, name, overwrite=False, **kwargs): + """Registers a new remote application. + + :param name: Name of the remote application. + :param overwrite: Overwrite existing config with framework settings. + :param kwargs: Parameters for :class:`RemoteApp`. + + Find parameters for the given remote app class. When a remote app is + registered, it can be accessed with *named* attribute:: + + oauth.register('twitter', client_id='', ...) + oauth.twitter.get('timeline') + """ + self._registry[name] = (overwrite, kwargs) + return self.create_client(name) + + def generate_client_kwargs(self, name, overwrite, **kwargs): + fetch_token = kwargs.pop('fetch_token', None) + update_token = kwargs.pop('update_token', None) + + config = self.load_config(name, OAUTH_CLIENT_PARAMS) + if config: + kwargs = _config_client(config, kwargs, overwrite) + + if not fetch_token and self.fetch_token: + fetch_token = functools.partial(self.fetch_token, name) + + kwargs['fetch_token'] = fetch_token + + if not kwargs.get('request_token_url'): + if not update_token and self.update_token: + update_token = functools.partial(self.update_token, name) + + kwargs['update_token'] = update_token + return kwargs + + def load_config(self, name, params): + return self.framework_integration_cls.load_config(self, name, params) + + def __getattr__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError: + if key in self._registry: + return self.create_client(key) + raise AttributeError('No such client: %s' % key) + + +def _config_client(config, kwargs, overwrite): + for k in OAUTH_CLIENT_PARAMS: + v = config.get(k, None) + if k not in kwargs: + kwargs[k] = v + elif overwrite and v: + if isinstance(kwargs[k], dict): + kwargs[k].update(v) + else: + kwargs[k] = v + return kwargs diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/sync_app.py b/.venv/Lib/site-packages/authlib/integrations/base_client/sync_app.py new file mode 100644 index 00000000..50fa27a7 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/sync_app.py @@ -0,0 +1,343 @@ +import time +import logging +from authlib.common.urls import urlparse +from authlib.consts import default_user_agent +from authlib.common.security import generate_token +from .errors import ( + MismatchingStateError, + MissingRequestTokenError, + MissingTokenError, +) + +log = logging.getLogger(__name__) + + +class BaseApp: + client_cls = None + OAUTH_APP_CONFIG = None + + def request(self, method, url, token=None, **kwargs): + raise NotImplementedError() + + def get(self, url, **kwargs): + """Invoke GET http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.get('users/lepture') + """ + return self.request('GET', url, **kwargs) + + def post(self, url, **kwargs): + """Invoke POST http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.post('timeline', json={'text': 'Hi'}) + """ + return self.request('POST', url, **kwargs) + + def patch(self, url, **kwargs): + """Invoke PATCH http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.patch('profile', json={'name': 'Hsiaoming Yang'}) + """ + return self.request('PATCH', url, **kwargs) + + def put(self, url, **kwargs): + """Invoke PUT http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.put('profile', json={'name': 'Hsiaoming Yang'}) + """ + return self.request('PUT', url, **kwargs) + + def delete(self, url, **kwargs): + """Invoke DELETE http request. + + If ``api_base_url`` configured, shortcut is available:: + + client.delete('posts/123') + """ + return self.request('DELETE', url, **kwargs) + + +class _RequestMixin: + def _get_requested_token(self, request): + if self._fetch_token and request: + return self._fetch_token(request) + + def _send_token_request(self, session, method, url, token, kwargs): + request = kwargs.pop('request', None) + withhold_token = kwargs.get('withhold_token') + if self.api_base_url and not url.startswith(('https://', 'http://')): + url = urlparse.urljoin(self.api_base_url, url) + + if withhold_token: + return session.request(method, url, **kwargs) + + if token is None: + token = self._get_requested_token(request) + + if token is None: + raise MissingTokenError() + + session.token = token + return session.request(method, url, **kwargs) + + +class OAuth1Base: + client_cls = None + + def __init__( + self, framework, name=None, fetch_token=None, + client_id=None, client_secret=None, + request_token_url=None, request_token_params=None, + access_token_url=None, access_token_params=None, + authorize_url=None, authorize_params=None, + api_base_url=None, client_kwargs=None, user_agent=None, **kwargs): + self.framework = framework + self.name = name + self.client_id = client_id + self.client_secret = client_secret + self.request_token_url = request_token_url + self.request_token_params = request_token_params + self.access_token_url = access_token_url + self.access_token_params = access_token_params + self.authorize_url = authorize_url + self.authorize_params = authorize_params + self.api_base_url = api_base_url + self.client_kwargs = client_kwargs or {} + + self._fetch_token = fetch_token + self._user_agent = user_agent or default_user_agent + self._kwargs = kwargs + + def _get_oauth_client(self): + session = self.client_cls(self.client_id, self.client_secret, **self.client_kwargs) + session.headers['User-Agent'] = self._user_agent + return session + + +class OAuth1Mixin(_RequestMixin, OAuth1Base): + def request(self, method, url, token=None, **kwargs): + with self._get_oauth_client() as session: + return self._send_token_request(session, method, url, token, kwargs) + + def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + if not self.authorize_url: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + with self._get_oauth_client() as client: + client.redirect_uri = redirect_uri + params = self.request_token_params or {} + request_token = client.fetch_request_token(self.request_token_url, **params) + log.debug(f'Fetch request token: {request_token!r}') + url = client.create_authorization_url(self.authorize_url, **kwargs) + state = request_token['oauth_token'] + return {'url': url, 'request_token': request_token, 'state': state} + + def fetch_access_token(self, request_token=None, **kwargs): + """Fetch access token in one step. + + :param request_token: A previous request token for OAuth 1. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + with self._get_oauth_client() as client: + if request_token is None: + raise MissingRequestTokenError() + # merge request token with verifier + token = {} + token.update(request_token) + token.update(kwargs) + client.token = token + params = self.access_token_params or {} + token = client.fetch_access_token(self.access_token_url, **params) + return token + + +class OAuth2Base: + client_cls = None + + def __init__( + self, framework, name=None, fetch_token=None, update_token=None, + client_id=None, client_secret=None, + access_token_url=None, access_token_params=None, + authorize_url=None, authorize_params=None, + api_base_url=None, client_kwargs=None, server_metadata_url=None, + compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs): + self.framework = framework + self.name = name + self.client_id = client_id + self.client_secret = client_secret + self.access_token_url = access_token_url + self.access_token_params = access_token_params + self.authorize_url = authorize_url + self.authorize_params = authorize_params + self.api_base_url = api_base_url + self.client_kwargs = client_kwargs or {} + + self.compliance_fix = compliance_fix + self.client_auth_methods = client_auth_methods + self._fetch_token = fetch_token + self._update_token = update_token + self._user_agent = user_agent or default_user_agent + + self._server_metadata_url = server_metadata_url + self.server_metadata = kwargs + + def _on_update_token(self, token, refresh_token=None, access_token=None): + raise NotImplementedError() + + def _get_oauth_client(self, **metadata): + client_kwargs = {} + client_kwargs.update(self.client_kwargs) + client_kwargs.update(metadata) + + if self.authorize_url: + client_kwargs['authorization_endpoint'] = self.authorize_url + if self.access_token_url: + client_kwargs['token_endpoint'] = self.access_token_url + + session = self.client_cls( + client_id=self.client_id, + client_secret=self.client_secret, + update_token=self._on_update_token, + **client_kwargs + ) + if self.client_auth_methods: + for f in self.client_auth_methods: + session.register_client_auth_method(f) + + if self.compliance_fix: + self.compliance_fix(session) + + session.headers['User-Agent'] = self._user_agent + return session + + @staticmethod + def _format_state_params(state_data, params): + if state_data is None: + raise MismatchingStateError() + + code_verifier = state_data.get('code_verifier') + if code_verifier: + params['code_verifier'] = code_verifier + + redirect_uri = state_data.get('redirect_uri') + if redirect_uri: + params['redirect_uri'] = redirect_uri + return params + + @staticmethod + def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs): + rv = {} + if client.code_challenge_method: + code_verifier = kwargs.get('code_verifier') + if not code_verifier: + code_verifier = generate_token(48) + kwargs['code_verifier'] = code_verifier + rv['code_verifier'] = code_verifier + log.debug(f'Using code_verifier: {code_verifier!r}') + + scope = kwargs.get('scope', client.scope) + if scope and 'openid' in scope.split(): + # this is an OpenID Connect service + nonce = kwargs.get('nonce') + if not nonce: + nonce = generate_token(20) + kwargs['nonce'] = nonce + rv['nonce'] = nonce + + url, state = client.create_authorization_url( + authorization_endpoint, **kwargs) + rv['url'] = url + rv['state'] = state + return rv + + +class OAuth2Mixin(_RequestMixin, OAuth2Base): + def _on_update_token(self, token, refresh_token=None, access_token=None): + if callable(self._update_token): + self._update_token( + token, + refresh_token=refresh_token, + access_token=access_token, + ) + self.framework.update_token( + token, + refresh_token=refresh_token, + access_token=access_token, + ) + + def request(self, method, url, token=None, **kwargs): + metadata = self.load_server_metadata() + with self._get_oauth_client(**metadata) as session: + return self._send_token_request(session, method, url, token, kwargs) + + def load_server_metadata(self): + if self._server_metadata_url and '_loaded_at' not in self.server_metadata: + with self.client_cls(**self.client_kwargs) as session: + resp = session.request('GET', self._server_metadata_url, withhold_token=True) + resp.raise_for_status() + metadata = resp.json() + + metadata['_loaded_at'] = time.time() + self.server_metadata.update(metadata) + return self.server_metadata + + def create_authorization_url(self, redirect_uri=None, **kwargs): + """Generate the authorization url and state for HTTP redirect. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: dict + """ + metadata = self.load_server_metadata() + authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint') + + if not authorization_endpoint: + raise RuntimeError('Missing "authorize_url" value') + + if self.authorize_params: + kwargs.update(self.authorize_params) + + + with self._get_oauth_client(**metadata) as client: + if redirect_uri is not None: + client.redirect_uri = redirect_uri + return self._create_oauth2_authorization_url( + client, authorization_endpoint, **kwargs) + + def fetch_access_token(self, redirect_uri=None, **kwargs): + """Fetch access token in the final step. + + :param redirect_uri: Callback or Redirect URI that is used in + previous :meth:`authorize_redirect`. + :param kwargs: Extra parameters to fetch access token. + :return: A token dict. + """ + metadata = self.load_server_metadata() + token_endpoint = self.access_token_url or metadata.get('token_endpoint') + with self._get_oauth_client(**metadata) as client: + if redirect_uri is not None: + client.redirect_uri = redirect_uri + params = {} + if self.access_token_params: + params.update(self.access_token_params) + params.update(kwargs) + token = client.fetch_token(token_endpoint, **params) + return token diff --git a/.venv/Lib/site-packages/authlib/integrations/base_client/sync_openid.py b/.venv/Lib/site-packages/authlib/integrations/base_client/sync_openid.py new file mode 100644 index 00000000..ac51907a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/base_client/sync_openid.py @@ -0,0 +1,77 @@ +from authlib.jose import jwt, JsonWebToken, JsonWebKey +from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken + + +class OpenIDMixin: + def fetch_jwk_set(self, force=False): + metadata = self.load_server_metadata() + jwk_set = metadata.get('jwks') + if jwk_set and not force: + return jwk_set + + uri = metadata.get('jwks_uri') + if not uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + with self.client_cls(**self.client_kwargs) as session: + resp = session.request('GET', uri, withhold_token=True) + resp.raise_for_status() + jwk_set = resp.json() + + self.server_metadata['jwks'] = jwk_set + return jwk_set + + def userinfo(self, **kwargs): + """Fetch user info from ``userinfo_endpoint``.""" + metadata = self.load_server_metadata() + resp = self.get(metadata['userinfo_endpoint'], **kwargs) + resp.raise_for_status() + data = resp.json() + return UserInfo(data) + + def parse_id_token(self, token, nonce, claims_options=None, leeway=120): + """Return an instance of UserInfo from token's ``id_token``.""" + if 'id_token' not in token: + return None + + def load_key(header, _): + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set()) + try: + return jwk_set.find_by_kid(header.get('kid')) + except ValueError: + # re-try with new jwk set + jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True)) + return jwk_set.find_by_kid(header.get('kid')) + + claims_params = dict( + nonce=nonce, + client_id=self.client_id, + ) + if 'access_token' in token: + claims_params['access_token'] = token['access_token'] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + metadata = self.load_server_metadata() + if claims_options is None and 'issuer' in metadata: + claims_options = {'iss': {'values': [metadata['issuer']]}} + + alg_values = metadata.get('id_token_signing_alg_values_supported') + if alg_values: + _jwt = JsonWebToken(alg_values) + else: + _jwt = jwt + + claims = _jwt.decode( + token['id_token'], key=load_key, + claims_cls=claims_cls, + claims_options=claims_options, + claims_params=claims_params, + ) + # https://github.com/lepture/authlib/issues/259 + if claims.get('nonce_supported') is False: + claims.params['nonce'] = None + + claims.validate(leeway=leeway) + return UserInfo(claims) diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/django_client/__init__.py new file mode 100644 index 00000000..5839c945 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_client/__init__.py @@ -0,0 +1,19 @@ +# flake8: noqa + +from .integration import DjangoIntegration, token_update +from .apps import DjangoOAuth1App, DjangoOAuth2App +from ..base_client import BaseOAuth, OAuthError + + +class OAuth(BaseOAuth): + oauth1_client_cls = DjangoOAuth1App + oauth2_client_cls = DjangoOAuth2App + framework_integration_cls = DjangoIntegration + + +__all__ = [ + 'OAuth', + 'DjangoOAuth1App', 'DjangoOAuth2App', + 'DjangoIntegration', + 'token_update', 'OAuthError', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..9286f7a0 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/apps.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/apps.cpython-311.pyc new file mode 100644 index 00000000..a2c2352c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/apps.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/integration.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/integration.cpython-311.pyc new file mode 100644 index 00000000..4578fc7e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_client/__pycache__/integration.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/apps.py b/.venv/Lib/site-packages/authlib/integrations/django_client/apps.py new file mode 100644 index 00000000..07bdf719 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_client/apps.py @@ -0,0 +1,87 @@ +from django.http import HttpResponseRedirect +from ..requests_client import OAuth1Session, OAuth2Session +from ..base_client import ( + BaseApp, OAuthError, + OAuth1Mixin, OAuth2Mixin, OpenIDMixin, +) + + +class DjangoAppMixin: + def save_authorize_data(self, request, **kwargs): + state = kwargs.pop('state', None) + if state: + self.framework.set_state_data(request.session, state, kwargs) + else: + raise RuntimeError('Missing state value') + + def authorize_redirect(self, request, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param request: HTTP request instance from Django view. + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + rv = self.create_authorization_url(redirect_uri, **kwargs) + self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) + return HttpResponseRedirect(rv['url']) + + +class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp): + client_cls = OAuth1Session + + def authorize_access_token(self, request, **kwargs): + """Fetch access token in one step. + + :param request: HTTP request instance from Django view. + :return: A token dict. + """ + params = request.GET.dict() + state = params.get('oauth_token') + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = self.framework.get_state_data(request.session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params['request_token'] = data['request_token'] + params.update(kwargs) + self.framework.clear_state_data(request.session, state) + return self.fetch_access_token(**params) + + +class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp): + client_cls = OAuth2Session + + def authorize_access_token(self, request, **kwargs): + """Fetch access token in one step. + + :param request: HTTP request instance from Django view. + :return: A token dict. + """ + if request.method == 'GET': + error = request.GET.get('error') + if error: + description = request.GET.get('error_description') + raise OAuthError(error=error, description=description) + params = { + 'code': request.GET.get('code'), + 'state': request.GET.get('state'), + } + else: + params = { + 'code': request.POST.get('code'), + 'state': request.POST.get('state'), + } + + claims_options = kwargs.pop('claims_options', None) + state_data = self.framework.get_state_data(request.session, params.get('state')) + self.framework.clear_state_data(request.session, params.get('state')) + params = self._format_state_params(state_data, params) + token = self.fetch_access_token(**params, **kwargs) + + if 'id_token' in token and 'nonce' in state_data: + userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options) + token['userinfo'] = userinfo + return token diff --git a/.venv/Lib/site-packages/authlib/integrations/django_client/integration.py b/.venv/Lib/site-packages/authlib/integrations/django_client/integration.py new file mode 100644 index 00000000..2ff03dea --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_client/integration.py @@ -0,0 +1,22 @@ +from django.conf import settings +from django.dispatch import Signal +from ..base_client import FrameworkIntegration + +token_update = Signal() + + +class DjangoIntegration(FrameworkIntegration): + def update_token(self, token, refresh_token=None, access_token=None): + token_update.send( + sender=self.__class__, + name=self.name, + token=token, + refresh_token=refresh_token, + access_token=access_token, + ) + + @staticmethod + def load_config(oauth, name, params): + config = getattr(settings, 'AUTHLIB_OAUTH_CLIENTS', None) + if config: + return config.get(name) diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__init__.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__init__.py new file mode 100644 index 00000000..39f0e130 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__init__.py @@ -0,0 +1,9 @@ +# flake8: noqa + +from .authorization_server import ( + BaseServer, CacheAuthorizationServer +) +from .resource_protector import ResourceProtector + + +__all__ = ['BaseServer', 'CacheAuthorizationServer', 'ResourceProtector'] diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..32c63685 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..d9e391c6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/nonce.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/nonce.cpython-311.pyc new file mode 100644 index 00000000..6a865a21 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/nonce.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..71f52d91 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/authorization_server.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/authorization_server.py new file mode 100644 index 00000000..70c2b6bc --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/authorization_server.py @@ -0,0 +1,125 @@ +import logging +from authlib.oauth1 import ( + OAuth1Request, + AuthorizationServer as _AuthorizationServer, +) +from authlib.oauth1 import TemporaryCredential +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from django.core.cache import cache +from django.conf import settings +from django.http import HttpResponse +from .nonce import exists_nonce_in_cache + +log = logging.getLogger(__name__) + + +class BaseServer(_AuthorizationServer): + def __init__(self, client_model, token_model, token_generator=None): + self.client_model = client_model + self.token_model = token_model + + if token_generator is None: + def token_generator(): + return { + 'oauth_token': generate_token(42), + 'oauth_token_secret': generate_token(48) + } + + self.token_generator = token_generator + self._config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {}) + self._nonce_expires_in = self._config.get('nonce_expires_in', 86400) + methods = self._config.get('signature_methods') + if methods: + self.SUPPORTED_SIGNATURE_METHODS = methods + + def get_client_by_id(self, client_id): + try: + return self.client_model.objects.get(client_id=client_id) + except self.client_model.DoesNotExist: + return None + + def exists_nonce(self, nonce, request): + return exists_nonce_in_cache(nonce, request, self._nonce_expires_in) + + def create_token_credential(self, request): + temporary_credential = request.credential + token = self.token_generator() + item = self.token_model( + oauth_token=token['oauth_token'], + oauth_token_secret=token['oauth_token_secret'], + user_id=temporary_credential.get_user_id(), + client_id=temporary_credential.get_client_id() + ) + item.save() + return item + + def check_authorization_request(self, request): + req = self.create_oauth1_request(request) + self.validate_authorization_request(req) + return req + + def create_oauth1_request(self, request): + if request.method == 'POST': + body = request.POST.dict() + else: + body = None + url = request.build_absolute_uri() + return OAuth1Request(request.method, url, body, request.headers) + + def handle_response(self, status_code, payload, headers): + resp = HttpResponse(url_encode(payload), status=status_code) + for k, v in headers: + resp[k] = v + return resp + + +class CacheAuthorizationServer(BaseServer): + def __init__(self, client_model, token_model, token_generator=None): + super().__init__( + client_model, token_model, token_generator) + self._temporary_expires_in = self._config.get( + 'temporary_credential_expires_in', 86400) + self._temporary_credential_key_prefix = self._config.get( + 'temporary_credential_key_prefix', 'temporary_credential:') + + def create_temporary_credential(self, request): + key_prefix = self._temporary_credential_key_prefix + token = self.token_generator() + + client_id = request.client_id + redirect_uri = request.redirect_uri + key = key_prefix + token['oauth_token'] + token['client_id'] = client_id + if redirect_uri: + token['oauth_callback'] = redirect_uri + + cache.set(key, token, timeout=self._temporary_expires_in) + return TemporaryCredential(token) + + def get_temporary_credential(self, request): + if not request.token: + return None + + key_prefix = self._temporary_credential_key_prefix + key = key_prefix + request.token + value = cache.get(key) + if value: + return TemporaryCredential(value) + + def delete_temporary_credential(self, request): + if request.token: + key_prefix = self._temporary_credential_key_prefix + key = key_prefix + request.token + cache.delete(key) + + def create_authorization_verifier(self, request): + key_prefix = self._temporary_credential_key_prefix + verifier = generate_token(36) + credential = request.credential + user = request.user + key = key_prefix + credential.get_oauth_token() + credential['oauth_verifier'] = verifier + credential['user_id'] = user.pk + cache.set(key, credential, timeout=self._temporary_expires_in) + return verifier diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/nonce.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/nonce.py new file mode 100644 index 00000000..0bd70e31 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/nonce.py @@ -0,0 +1,15 @@ +from django.core.cache import cache + + +def exists_nonce_in_cache(nonce, request, timeout): + key_prefix = 'nonce:' + timestamp = request.timestamp + client_id = request.client_id + token = request.token + key = f'{key_prefix}{nonce}-{timestamp}-{client_id}' + if token: + key = f'{key}-{token}' + + rv = bool(cache.get(key)) + cache.set(key, 1, timeout=timeout) + return rv diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth1/resource_protector.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/resource_protector.py new file mode 100644 index 00000000..77f3d81f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth1/resource_protector.py @@ -0,0 +1,64 @@ +import functools +from authlib.oauth1.errors import OAuth1Error +from authlib.oauth1 import ResourceProtector as _ResourceProtector +from django.http import JsonResponse +from django.conf import settings +from .nonce import exists_nonce_in_cache + + +class ResourceProtector(_ResourceProtector): + def __init__(self, client_model, token_model): + self.client_model = client_model + self.token_model = token_model + + config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {}) + methods = config.get('signature_methods', []) + if methods and isinstance(methods, (list, tuple)): + self.SUPPORTED_SIGNATURE_METHODS = methods + + self._nonce_expires_in = config.get('nonce_expires_in', 86400) + + def get_client_by_id(self, client_id): + try: + return self.client_model.objects.get(client_id=client_id) + except self.client_model.DoesNotExist: + return None + + def get_token_credential(self, request): + try: + return self.token_model.objects.get( + client_id=request.client_id, + oauth_token=request.token + ) + except self.token_model.DoesNotExist: + return None + + def exists_nonce(self, nonce, request): + return exists_nonce_in_cache(nonce, request, self._nonce_expires_in) + + def acquire_credential(self, request): + if request.method in ['POST', 'PUT']: + body = request.POST.dict() + else: + body = None + + url = request.build_absolute_uri() + req = self.validate_request(request.method, url, body, request.headers) + return req.credential + + def __call__(self, realm=None): + def wrapper(f): + @functools.wraps(f) + def decorated(request, *args, **kwargs): + try: + credential = self.acquire_credential(request) + request.oauth1_credential = credential + except OAuth1Error as error: + body = dict(error.get_body()) + resp = JsonResponse(body, status=error.status_code) + resp['Cache-Control'] = 'no-store' + resp['Pragma'] = 'no-cache' + return resp + return f(request, *args, **kwargs) + return decorated + return wrapper diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__init__.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__init__.py new file mode 100644 index 00000000..05c1fdfe --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__init__.py @@ -0,0 +1,10 @@ +# flake8: noqa + +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector, BearerTokenValidator +from .endpoints import RevocationEndpoint +from .signals import ( + client_authenticated, + token_authenticated, + token_revoked +) diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..e97ba00c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..8db7e12c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/endpoints.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/endpoints.cpython-311.pyc new file mode 100644 index 00000000..73abf664 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/endpoints.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/requests.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/requests.cpython-311.pyc new file mode 100644 index 00000000..7da7a989 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/requests.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..6ca39f62 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/signals.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/signals.cpython-311.pyc new file mode 100644 index 00000000..eb39d91c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/__pycache__/signals.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/authorization_server.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/authorization_server.py new file mode 100644 index 00000000..08a27595 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/authorization_server.py @@ -0,0 +1,118 @@ +from django.http import HttpResponse +from django.utils.module_loading import import_string +from django.conf import settings +from authlib.oauth2 import ( + AuthorizationServer as _AuthorizationServer, +) +from authlib.oauth2.rfc6750 import BearerTokenGenerator +from authlib.common.security import generate_token as _generate_token +from authlib.common.encoding import json_dumps +from .requests import DjangoOAuth2Request, DjangoJsonRequest +from .signals import client_authenticated, token_revoked + + +class AuthorizationServer(_AuthorizationServer): + """Django implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + Initialize it with client model and token model:: + + from authlib.integrations.django_oauth2 import AuthorizationServer + from your_project.models import OAuth2Client, OAuth2Token + + server = AuthorizationServer(OAuth2Client, OAuth2Token) + """ + + def __init__(self, client_model, token_model): + self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {}) + self.client_model = client_model + self.token_model = token_model + scopes_supported = self.config.get('scopes_supported') + super().__init__(scopes_supported=scopes_supported) + # add default token generator + self.register_token_generator('default', self.create_bearer_token_generator()) + + def query_client(self, client_id): + """Default method for ``AuthorizationServer.query_client``. Developers MAY + rewrite this function to meet their own needs. + """ + try: + return self.client_model.objects.get(client_id=client_id) + except self.client_model.DoesNotExist: + return None + + def save_token(self, token, request): + """Default method for ``AuthorizationServer.save_token``. Developers MAY + rewrite this function to meet their own needs. + """ + client = request.client + if request.user: + user_id = request.user.pk + else: + user_id = client.user_id + item = self.token_model( + client_id=client.client_id, + user_id=user_id, + **token + ) + item.save() + return item + + def create_oauth2_request(self, request): + return DjangoOAuth2Request(request) + + def create_json_request(self, request): + return DjangoJsonRequest(request) + + def handle_response(self, status_code, payload, headers): + if isinstance(payload, dict): + payload = json_dumps(payload) + resp = HttpResponse(payload, status=status_code) + for k, v in headers: + resp[k] = v + return resp + + def send_signal(self, name, *args, **kwargs): + if name == 'after_authenticate_client': + client_authenticated.send(sender=self.__class__, *args, **kwargs) + elif name == 'after_revoke_token': + token_revoked.send(sender=self.__class__, *args, **kwargs) + + def create_bearer_token_generator(self): + """Default method to create BearerToken generator.""" + conf = self.config.get('access_token_generator', True) + access_token_generator = create_token_generator(conf, 42) + + conf = self.config.get('refresh_token_generator', False) + refresh_token_generator = create_token_generator(conf, 48) + + conf = self.config.get('token_expires_in') + expires_generator = create_token_expires_in_generator(conf) + + return BearerTokenGenerator( + access_token_generator=access_token_generator, + refresh_token_generator=refresh_token_generator, + expires_generator=expires_generator, + ) + + +def create_token_generator(token_generator_conf, length=42): + if callable(token_generator_conf): + return token_generator_conf + + if isinstance(token_generator_conf, str): + return import_string(token_generator_conf) + elif token_generator_conf is True: + def token_generator(*args, **kwargs): + return _generate_token(length) + return token_generator + + +def create_token_expires_in_generator(expires_in_conf=None): + data = {} + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) + if expires_in_conf: + data.update(expires_in_conf) + + def expires_in(client, grant_type): + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) + + return expires_in diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/endpoints.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/endpoints.py new file mode 100644 index 00000000..686675d5 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/endpoints.py @@ -0,0 +1,56 @@ +from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint + + +class RevocationEndpoint(_RevocationEndpoint): + """The revocation endpoint for OAuth authorization servers allows clients + to notify the authorization server that a previously obtained refresh or + access token is no longer needed. + + Register it into authorization server, and create token endpoint response + for token revocation:: + + from django.views.decorators.http import require_http_methods + + # see register into authorization server instance + server.register_endpoint(RevocationEndpoint) + + @require_http_methods(["POST"]) + def revoke_token(request): + return server.create_endpoint_response( + RevocationEndpoint.ENDPOINT_NAME, + request + ) + """ + + def query_token(self, token, token_type_hint): + """Query requested token from database.""" + token_model = self.server.token_model + if token_type_hint == 'access_token': + rv = _query_access_token(token_model, token) + elif token_type_hint == 'refresh_token': + rv = _query_refresh_token(token_model, token) + else: + rv = _query_access_token(token_model, token) + if not rv: + rv = _query_refresh_token(token_model, token) + + return rv + + def revoke_token(self, token, request): + """Mark the give token as revoked.""" + token.revoked = True + token.save() + + +def _query_access_token(token_model, token): + try: + return token_model.objects.get(access_token=token) + except token_model.DoesNotExist: + return None + + +def _query_refresh_token(token_model, token): + try: + return token_model.objects.get(refresh_token=token) + except token_model.DoesNotExist: + return None diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/requests.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/requests.py new file mode 100644 index 00000000..e9f2d95a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/requests.py @@ -0,0 +1,35 @@ +from django.http import HttpRequest +from django.utils.functional import cached_property +from authlib.common.encoding import json_loads +from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest + + +class DjangoOAuth2Request(OAuth2Request): + def __init__(self, request: HttpRequest): + super().__init__(request.method, request.build_absolute_uri(), None, request.headers) + self._request = request + + @property + def args(self): + return self._request.GET + + @property + def form(self): + return self._request.POST + + @cached_property + def data(self): + data = {} + data.update(self._request.GET.dict()) + data.update(self._request.POST.dict()) + return data + + +class DjangoJsonRequest(JsonRequest): + def __init__(self, request: HttpRequest): + super().__init__(request.method, request.build_absolute_uri(), None, request.headers) + self._request = request + + @cached_property + def data(self): + return json_loads(self._request.body) diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/resource_protector.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/resource_protector.py new file mode 100644 index 00000000..b89257ba --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/resource_protector.py @@ -0,0 +1,75 @@ +import functools +from django.http import JsonResponse +from authlib.oauth2 import ( + OAuth2Error, + ResourceProtector as _ResourceProtector, +) +from authlib.oauth2.rfc6749 import ( + MissingAuthorizationError, +) +from authlib.oauth2.rfc6750 import ( + BearerTokenValidator as _BearerTokenValidator +) +from .requests import DjangoJsonRequest +from .signals import token_authenticated + + +class ResourceProtector(_ResourceProtector): + def acquire_token(self, request, scopes=None, **kwargs): + """A method to acquire current valid token with the given scope. + + :param request: Django HTTP request instance + :param scopes: a list of scope values + :return: token object + """ + req = DjangoJsonRequest(request) + # backward compatibility + kwargs['scopes'] = scopes + for claim in kwargs: + if isinstance(kwargs[claim], str): + kwargs[claim] = [kwargs[claim]] + token = self.validate_request(request=req, **kwargs) + token_authenticated.send(sender=self.__class__, token=token) + return token + + def __call__(self, scopes=None, optional=False, **kwargs): + claims = kwargs + # backward compatibility + claims['scopes'] = scopes + def wrapper(f): + @functools.wraps(f) + def decorated(request, *args, **kwargs): + try: + token = self.acquire_token(request, **claims) + request.oauth_token = token + except MissingAuthorizationError as error: + if optional: + request.oauth_token = None + return f(request, *args, **kwargs) + return return_error_response(error) + except OAuth2Error as error: + return return_error_response(error) + return f(request, *args, **kwargs) + return decorated + return wrapper + + +class BearerTokenValidator(_BearerTokenValidator): + def __init__(self, token_model, realm=None, **extra_attributes): + self.token_model = token_model + super().__init__(realm, **extra_attributes) + + def authenticate_token(self, token_string): + try: + return self.token_model.objects.get(access_token=token_string) + except self.token_model.DoesNotExist: + return None + + +def return_error_response(error): + body = dict(error.get_body()) + resp = JsonResponse(body, status=error.status_code) + headers = error.get_headers() + for k, v in headers: + resp[k] = v + return resp diff --git a/.venv/Lib/site-packages/authlib/integrations/django_oauth2/signals.py b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/signals.py new file mode 100644 index 00000000..0e9c2659 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/django_oauth2/signals.py @@ -0,0 +1,11 @@ +from django.dispatch import Signal + + +#: signal when client is authenticated +client_authenticated = Signal() + +#: signal when token is revoked +token_revoked = Signal() + +#: signal when token is authenticated +token_authenticated = Signal() diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/flask_client/__init__.py new file mode 100644 index 00000000..ecdca2df --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_client/__init__.py @@ -0,0 +1,51 @@ +from werkzeug.local import LocalProxy +from .integration import FlaskIntegration, token_update +from .apps import FlaskOAuth1App, FlaskOAuth2App +from ..base_client import BaseOAuth, OAuthError + + +class OAuth(BaseOAuth): + oauth1_client_cls = FlaskOAuth1App + oauth2_client_cls = FlaskOAuth2App + framework_integration_cls = FlaskIntegration + + def __init__(self, app=None, cache=None, fetch_token=None, update_token=None): + super().__init__( + cache=cache, fetch_token=fetch_token, update_token=update_token) + self.app = app + if app: + self.init_app(app) + + def init_app(self, app, cache=None, fetch_token=None, update_token=None): + """Initialize lazy for Flask app. This is usually used for Flask application + factory pattern. + """ + self.app = app + if cache is not None: + self.cache = cache + + if fetch_token: + self.fetch_token = fetch_token + if update_token: + self.update_token = update_token + + app.extensions = getattr(app, 'extensions', {}) + app.extensions['authlib.integrations.flask_client'] = self + + def create_client(self, name): + if not self.app: + raise RuntimeError('OAuth is not init with Flask app.') + return super().create_client(name) + + def register(self, name, overwrite=False, **kwargs): + self._registry[name] = (overwrite, kwargs) + if self.app: + return self.create_client(name) + return LocalProxy(lambda: self.create_client(name)) + + +__all__ = [ + 'OAuth', 'FlaskIntegration', + 'FlaskOAuth1App', 'FlaskOAuth2App', + 'token_update', 'OAuthError', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..336618e8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/apps.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/apps.cpython-311.pyc new file mode 100644 index 00000000..65d25cec Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/apps.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/integration.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/integration.cpython-311.pyc new file mode 100644 index 00000000..2f95ddd9 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_client/__pycache__/integration.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/apps.py b/.venv/Lib/site-packages/authlib/integrations/flask_client/apps.py new file mode 100644 index 00000000..7567f4b3 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_client/apps.py @@ -0,0 +1,107 @@ +from flask import g, redirect, request, session +from ..requests_client import OAuth1Session, OAuth2Session +from ..base_client import ( + BaseApp, OAuthError, + OAuth1Mixin, OAuth2Mixin, OpenIDMixin, +) + + +class FlaskAppMixin: + @property + def token(self): + attr = f'_oauth_token_{self.name}' + token = g.get(attr) + if token: + return token + if self._fetch_token: + token = self._fetch_token() + self.token = token + return token + + @token.setter + def token(self, token): + attr = f'_oauth_token_{self.name}' + setattr(g, attr, token) + + def _get_requested_token(self, *args, **kwargs): + return self.token + + def save_authorize_data(self, **kwargs): + state = kwargs.pop('state', None) + if state: + self.framework.set_state_data(session, state, kwargs) + else: + raise RuntimeError('Missing state value') + + def authorize_redirect(self, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + rv = self.create_authorization_url(redirect_uri, **kwargs) + self.save_authorize_data(redirect_uri=redirect_uri, **rv) + return redirect(rv['url']) + + +class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp): + client_cls = OAuth1Session + + def authorize_access_token(self, **kwargs): + """Fetch access token in one step. + + :return: A token dict. + """ + params = request.args.to_dict(flat=True) + state = params.get('oauth_token') + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = self.framework.get_state_data(session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params['request_token'] = data['request_token'] + params.update(kwargs) + self.framework.clear_state_data(session, state) + token = self.fetch_access_token(**params) + self.token = token + return token + + +class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp): + client_cls = OAuth2Session + + def authorize_access_token(self, **kwargs): + """Fetch access token in one step. + + :return: A token dict. + """ + if request.method == 'GET': + error = request.args.get('error') + if error: + description = request.args.get('error_description') + raise OAuthError(error=error, description=description) + + params = { + 'code': request.args['code'], + 'state': request.args.get('state'), + } + else: + params = { + 'code': request.form['code'], + 'state': request.form.get('state'), + } + + claims_options = kwargs.pop('claims_options', None) + state_data = self.framework.get_state_data(session, params.get('state')) + self.framework.clear_state_data(session, params.get('state')) + params = self._format_state_params(state_data, params) + token = self.fetch_access_token(**params, **kwargs) + self.token = token + + if 'id_token' in token and 'nonce' in state_data: + userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options) + token['userinfo'] = userinfo + return token diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_client/integration.py b/.venv/Lib/site-packages/authlib/integrations/flask_client/integration.py new file mode 100644 index 00000000..f4ea57e3 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_client/integration.py @@ -0,0 +1,28 @@ +from flask import current_app +from flask.signals import Namespace +from ..base_client import FrameworkIntegration + +_signal = Namespace() +#: signal when token is updated +token_update = _signal.signal('token_update') + + +class FlaskIntegration(FrameworkIntegration): + def update_token(self, token, refresh_token=None, access_token=None): + token_update.send( + current_app, + name=self.name, + token=token, + refresh_token=refresh_token, + access_token=access_token, + ) + + @staticmethod + def load_config(oauth, name, params): + rv = {} + for k in params: + conf_key = f'{name}_{k}'.upper() + v = oauth.app.config.get(conf_key, None) + if v is not None: + rv[k] = v + return rv diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__init__.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__init__.py new file mode 100644 index 00000000..780b0594 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__init__.py @@ -0,0 +1,9 @@ +# flake8: noqa + +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector, current_credential +from .cache import ( + register_nonce_hooks, + register_temporary_credential_hooks, + create_exists_nonce_func, +) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..f01f1f62 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..4a1f0c3d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/cache.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/cache.cpython-311.pyc new file mode 100644 index 00000000..ef86ad5d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/cache.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..13ea15ad Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/authorization_server.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/authorization_server.py new file mode 100644 index 00000000..3a2a5600 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/authorization_server.py @@ -0,0 +1,182 @@ +import logging +from werkzeug.utils import import_string +from flask import Response +from flask import request as flask_req +from authlib.oauth1 import ( + OAuth1Request, + AuthorizationServer as _AuthorizationServer, +) +from authlib.common.security import generate_token +from authlib.common.urls import url_encode + +log = logging.getLogger(__name__) + + +class AuthorizationServer(_AuthorizationServer): + """Flask implementation of :class:`authlib.rfc5849.AuthorizationServer`. + Initialize it with Flask app instance, client model class and cache:: + + server = AuthorizationServer(app=app, query_client=query_client) + # or initialize lazily + server = AuthorizationServer() + server.init_app(app, query_client=query_client) + + :param app: A Flask app instance + :param query_client: A function to get client by client_id. The client + model class MUST implement the methods described by + :class:`~authlib.oauth1.rfc5849.ClientMixin`. + :param token_generator: A function to generate token + """ + + def __init__(self, app=None, query_client=None, token_generator=None): + self.app = app + self.query_client = query_client + self.token_generator = token_generator + + self._hooks = { + 'exists_nonce': None, + 'create_temporary_credential': None, + 'get_temporary_credential': None, + 'delete_temporary_credential': None, + 'create_authorization_verifier': None, + 'create_token_credential': None, + } + if app is not None: + self.init_app(app) + + def init_app(self, app, query_client=None, token_generator=None): + if query_client is not None: + self.query_client = query_client + if token_generator is not None: + self.token_generator = token_generator + + if self.token_generator is None: + self.token_generator = self.create_token_generator(app) + + methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS') + if methods and isinstance(methods, (list, tuple)): + self.SUPPORTED_SIGNATURE_METHODS = methods + + self.app = app + + def register_hook(self, name, func): + if name not in self._hooks: + raise ValueError('Invalid "name" of hook') + self._hooks[name] = func + + def create_token_generator(self, app): + token_generator = app.config.get('OAUTH1_TOKEN_GENERATOR') + + if isinstance(token_generator, str): + token_generator = import_string(token_generator) + else: + length = app.config.get('OAUTH1_TOKEN_LENGTH', 42) + + def token_generator(): + return generate_token(length) + + secret_generator = app.config.get('OAUTH1_TOKEN_SECRET_GENERATOR') + if isinstance(secret_generator, str): + secret_generator = import_string(secret_generator) + else: + length = app.config.get('OAUTH1_TOKEN_SECRET_LENGTH', 48) + + def secret_generator(): + return generate_token(length) + + def create_token(): + return { + 'oauth_token': token_generator(), + 'oauth_token_secret': secret_generator() + } + return create_token + + def get_client_by_id(self, client_id): + return self.query_client(client_id) + + def exists_nonce(self, nonce, request): + func = self._hooks['exists_nonce'] + if callable(func): + timestamp = request.timestamp + client_id = request.client_id + token = request.token + return func(nonce, timestamp, client_id, token) + + raise RuntimeError('"exists_nonce" hook is required.') + + def create_temporary_credential(self, request): + func = self._hooks['create_temporary_credential'] + if callable(func): + token = self.token_generator() + return func(token, request.client_id, request.redirect_uri) + raise RuntimeError( + '"create_temporary_credential" hook is required.' + ) + + def get_temporary_credential(self, request): + func = self._hooks['get_temporary_credential'] + if callable(func): + return func(request.token) + + raise RuntimeError( + '"get_temporary_credential" hook is required.' + ) + + def delete_temporary_credential(self, request): + func = self._hooks['delete_temporary_credential'] + if callable(func): + return func(request.token) + + raise RuntimeError( + '"delete_temporary_credential" hook is required.' + ) + + def create_authorization_verifier(self, request): + func = self._hooks['create_authorization_verifier'] + if callable(func): + verifier = generate_token(36) + func(request.credential, request.user, verifier) + return verifier + + raise RuntimeError( + '"create_authorization_verifier" hook is required.' + ) + + def create_token_credential(self, request): + func = self._hooks['create_token_credential'] + if callable(func): + temporary_credential = request.credential + token = self.token_generator() + return func(token, temporary_credential) + + raise RuntimeError( + '"create_token_credential" hook is required.' + ) + + def check_authorization_request(self): + req = self.create_oauth1_request(None) + self.validate_authorization_request(req) + return req + + def create_authorization_response(self, request=None, grant_user=None): + return super()\ + .create_authorization_response(request, grant_user) + + def create_token_response(self, request=None): + return super().create_token_response(request) + + def create_oauth1_request(self, request): + if request is None: + request = flask_req + if request.method in ('POST', 'PUT'): + body = request.form.to_dict(flat=True) + else: + body = None + return OAuth1Request(request.method, request.url, body, request.headers) + + def handle_response(self, status_code, payload, headers): + return Response( + url_encode(payload), + status=status_code, + headers=headers + ) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/cache.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/cache.py new file mode 100644 index 00000000..fdfc9a5a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/cache.py @@ -0,0 +1,80 @@ +from authlib.oauth1 import TemporaryCredential + + +def register_temporary_credential_hooks( + authorization_server, cache, key_prefix='temporary_credential:'): + """Register temporary credential related hooks to authorization server. + + :param authorization_server: AuthorizationServer instance + :param cache: Cache instance + :param key_prefix: key prefix for temporary credential + """ + + def create_temporary_credential(token, client_id, redirect_uri): + key = key_prefix + token['oauth_token'] + token['client_id'] = client_id + if redirect_uri: + token['oauth_callback'] = redirect_uri + + cache.set(key, token, timeout=86400) # cache for one day + return TemporaryCredential(token) + + def get_temporary_credential(oauth_token): + if not oauth_token: + return None + key = key_prefix + oauth_token + value = cache.get(key) + if value: + return TemporaryCredential(value) + + def delete_temporary_credential(oauth_token): + if oauth_token: + key = key_prefix + oauth_token + cache.delete(key) + + def create_authorization_verifier(credential, grant_user, verifier): + key = key_prefix + credential.get_oauth_token() + credential['oauth_verifier'] = verifier + credential['user_id'] = grant_user.get_user_id() + cache.set(key, credential, timeout=86400) + return credential + + authorization_server.register_hook( + 'create_temporary_credential', create_temporary_credential) + authorization_server.register_hook( + 'get_temporary_credential', get_temporary_credential) + authorization_server.register_hook( + 'delete_temporary_credential', delete_temporary_credential) + authorization_server.register_hook( + 'create_authorization_verifier', create_authorization_verifier) + + +def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400): + """Create an ``exists_nonce`` function that can be used in hooks and + resource protector. + + :param cache: Cache instance + :param key_prefix: key prefix for temporary credential + :param expires: Expire time for nonce + """ + def exists_nonce(nonce, timestamp, client_id, oauth_token): + key = f'{key_prefix}{nonce}-{timestamp}-{client_id}' + if oauth_token: + key = f'{key}-{oauth_token}' + rv = cache.has(key) + cache.set(key, 1, timeout=expires) + return rv + return exists_nonce + + +def register_nonce_hooks( + authorization_server, cache, key_prefix='nonce:', expires=86400): + """Register nonce related hooks to authorization server. + + :param authorization_server: AuthorizationServer instance + :param cache: Cache instance + :param key_prefix: key prefix for temporary credential + :param expires: Expire time for nonce + """ + exists_nonce = create_exists_nonce_func(cache, key_prefix, expires) + authorization_server.register_hook('exists_nonce', exists_nonce) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/resource_protector.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/resource_protector.py new file mode 100644 index 00000000..c941eb42 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth1/resource_protector.py @@ -0,0 +1,113 @@ +import functools +from flask import g, json, Response +from flask import request as _req +from werkzeug.local import LocalProxy +from authlib.consts import default_json_headers +from authlib.oauth1 import ResourceProtector as _ResourceProtector +from authlib.oauth1.errors import OAuth1Error + + +class ResourceProtector(_ResourceProtector): + """A protecting method for resource servers. Initialize a resource + protector with the these method: + + 1. query_client + 2. query_token, + 3. exists_nonce + + Usually, a ``query_client`` method would look like (if using SQLAlchemy):: + + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + + A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``:: + + def query_token(client_id, oauth_token): + return Token.query.filter_by(client_id=client_id, oauth_token=oauth_token).first() + + And for ``exists_nonce``, if using cache, we have a built-in hook to create this method:: + + from authlib.integrations.flask_oauth1 import create_exists_nonce_func + + exists_nonce = create_exists_nonce_func(cache) + + Then initialize the resource protector with those methods:: + + require_oauth = ResourceProtector( + app, query_client=query_client, + query_token=query_token, exists_nonce=exists_nonce, + ) + """ + def __init__(self, app=None, query_client=None, + query_token=None, exists_nonce=None): + self.query_client = query_client + self.query_token = query_token + self._exists_nonce = exists_nonce + + self.app = app + if app: + self.init_app(app) + + def init_app(self, app, query_client=None, query_token=None, + exists_nonce=None): + if query_client is not None: + self.query_client = query_client + if query_token is not None: + self.query_token = query_token + if exists_nonce is not None: + self._exists_nonce = exists_nonce + + methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS') + if methods and isinstance(methods, (list, tuple)): + self.SUPPORTED_SIGNATURE_METHODS = methods + + self.app = app + + def get_client_by_id(self, client_id): + return self.query_client(client_id) + + def get_token_credential(self, request): + return self.query_token(request.client_id, request.token) + + def exists_nonce(self, nonce, request): + if not self._exists_nonce: + raise RuntimeError('"exists_nonce" function is required.') + + timestamp = request.timestamp + client_id = request.client_id + token = request.token + return self._exists_nonce(nonce, timestamp, client_id, token) + + def acquire_credential(self): + req = self.validate_request( + _req.method, + _req.url, + _req.form.to_dict(flat=True), + _req.headers + ) + g.authlib_server_oauth1_credential = req.credential + return req.credential + + def __call__(self, scope=None): + def wrapper(f): + @functools.wraps(f) + def decorated(*args, **kwargs): + try: + self.acquire_credential() + except OAuth1Error as error: + body = dict(error.get_body()) + return Response( + json.dumps(body), + status=error.status_code, + headers=default_json_headers, + ) + return f(*args, **kwargs) + return decorated + return wrapper + + +def _get_current_credential(): + return g.get('authlib_server_oauth1_credential') + + +current_credential = LocalProxy(_get_current_credential) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__init__.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__init__.py new file mode 100644 index 00000000..170a7190 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__init__.py @@ -0,0 +1,12 @@ +# flake8: noqa + +from .authorization_server import AuthorizationServer +from .resource_protector import ( + ResourceProtector, + current_token, +) +from .signals import ( + client_authenticated, + token_authenticated, + token_revoked, +) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..783bd270 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..cec81608 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..9a7cf734 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/requests.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/requests.cpython-311.pyc new file mode 100644 index 00000000..bc0a973f Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/requests.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..5cfdcd0a Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/signals.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/signals.cpython-311.pyc new file mode 100644 index 00000000..ac7a5f7e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/__pycache__/signals.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/authorization_server.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/authorization_server.py new file mode 100644 index 00000000..14510b27 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/authorization_server.py @@ -0,0 +1,159 @@ +from werkzeug.utils import import_string +from flask import Response, json +from flask import request as flask_req +from authlib.oauth2 import ( + AuthorizationServer as _AuthorizationServer, +) +from authlib.oauth2.rfc6750 import BearerTokenGenerator +from authlib.common.security import generate_token +from .requests import FlaskOAuth2Request, FlaskJsonRequest +from .signals import client_authenticated, token_revoked + + +class AuthorizationServer(_AuthorizationServer): + """Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + Initialize it with ``query_client``, ``save_token`` methods and Flask + app instance:: + + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + + def save_token(token, request): + if request.user: + user_id = request.user.id + else: + user_id = None + client = request.client + tok = Token( + client_id=client.client_id, + user_id=user.id, + **token + ) + db.session.add(tok) + db.session.commit() + + server = AuthorizationServer(app, query_client, save_token) + # or initialize lazily + server = AuthorizationServer() + server.init_app(app, query_client, save_token) + """ + + def __init__(self, app=None, query_client=None, save_token=None): + super().__init__() + self._query_client = query_client + self._save_token = save_token + self._error_uris = None + if app is not None: + self.init_app(app) + + def init_app(self, app, query_client=None, save_token=None): + """Initialize later with Flask app instance.""" + if query_client is not None: + self._query_client = query_client + if save_token is not None: + self._save_token = save_token + + self.register_token_generator('default', self.create_bearer_token_generator(app.config)) + self.scopes_supported = app.config.get('OAUTH2_SCOPES_SUPPORTED') + self._error_uris = app.config.get('OAUTH2_ERROR_URIS') + + def query_client(self, client_id): + return self._query_client(client_id) + + def save_token(self, token, request): + return self._save_token(token, request) + + def get_error_uri(self, request, error): + if self._error_uris: + uris = dict(self._error_uris) + return uris.get(error.error) + + def create_oauth2_request(self, request): + return FlaskOAuth2Request(flask_req) + + def create_json_request(self, request): + return FlaskJsonRequest(flask_req) + + def handle_response(self, status_code, payload, headers): + if isinstance(payload, dict): + payload = json.dumps(payload) + return Response(payload, status=status_code, headers=headers) + + def send_signal(self, name, *args, **kwargs): + if name == 'after_authenticate_client': + client_authenticated.send(self, *args, **kwargs) + elif name == 'after_revoke_token': + token_revoked.send(self, *args, **kwargs) + + def create_bearer_token_generator(self, config): + """Create a generator function for generating ``token`` value. This + method will create a Bearer Token generator with + :class:`authlib.oauth2.rfc6750.BearerToken`. + + Configurable settings: + + 1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True. + 2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False. + 3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None. + + By default, it will not generate ``refresh_token``, which can be turn on by + configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``. + + Here are some examples of the token generator:: + + OAUTH2_ACCESS_TOKEN_GENERATOR = 'your_project.generators.gen_token' + + # and in module `your_project.generators`, you can define: + + def gen_token(client, grant_type, user, scope): + # generate token according to these parameters + token = create_random_token() + return f'{client.id}-{user.id}-{token}' + + Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``:: + + OAUTH2_TOKEN_EXPIRES_IN = { + 'authorization_code': 864000, + 'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600, + } + """ + conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) + access_token_generator = create_token_generator(conf, 42) + + conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) + refresh_token_generator = create_token_generator(conf, 48) + + expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') + expires_generator = create_token_expires_in_generator(expires_conf) + return BearerTokenGenerator( + access_token_generator, + refresh_token_generator, + expires_generator + ) + + +def create_token_expires_in_generator(expires_in_conf=None): + if isinstance(expires_in_conf, str): + return import_string(expires_in_conf) + + data = {} + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) + if isinstance(expires_in_conf, dict): + data.update(expires_in_conf) + + def expires_in(client, grant_type): + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) + + return expires_in + + +def create_token_generator(token_generator_conf, length=42): + if callable(token_generator_conf): + return token_generator_conf + + if isinstance(token_generator_conf, str): + return import_string(token_generator_conf) + elif token_generator_conf is True: + def token_generator(*args, **kwargs): + return generate_token(length) + return token_generator diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/errors.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/errors.py new file mode 100644 index 00000000..fb2f3a1f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/errors.py @@ -0,0 +1,39 @@ +import importlib + +import werkzeug +from werkzeug.exceptions import HTTPException + +_version = importlib.metadata.version('werkzeug').split('.')[0] + +if _version in ('0', '1'): + class _HTTPException(HTTPException): + def __init__(self, code, body, headers, response=None): + super().__init__(None, response) + self.code = code + + self.body = body + self.headers = headers + + def get_body(self, environ=None): + return self.body + + def get_headers(self, environ=None): + return self.headers +else: + class _HTTPException(HTTPException): + def __init__(self, code, body, headers, response=None): + super().__init__(None, response) + self.code = code + + self.body = body + self.headers = headers + + def get_body(self, environ=None, scope=None): + return self.body + + def get_headers(self, environ=None, scope=None): + return self.headers + + +def raise_http_exception(status, body, headers): + raise _HTTPException(status, body, headers) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/requests.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/requests.py new file mode 100644 index 00000000..0c2ab561 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/requests.py @@ -0,0 +1,30 @@ +from flask.wrappers import Request +from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest + + +class FlaskOAuth2Request(OAuth2Request): + def __init__(self, request: Request): + super().__init__(request.method, request.url, None, request.headers) + self._request = request + + @property + def args(self): + return self._request.args + + @property + def form(self): + return self._request.form + + @property + def data(self): + return self._request.values + + +class FlaskJsonRequest(JsonRequest): + def __init__(self, request: Request): + super().__init__(request.method, request.url, None, request.headers) + self._request = request + + @property + def data(self): + return self._request.get_json() diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/resource_protector.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/resource_protector.py new file mode 100644 index 00000000..be2b3fa2 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/resource_protector.py @@ -0,0 +1,114 @@ +import functools +from contextlib import contextmanager +from flask import g, json +from flask import request as _req +from werkzeug.local import LocalProxy +from authlib.oauth2 import ( + OAuth2Error, + ResourceProtector as _ResourceProtector +) +from authlib.oauth2.rfc6749 import ( + MissingAuthorizationError, +) +from .requests import FlaskJsonRequest +from .signals import token_authenticated +from .errors import raise_http_exception + + +class ResourceProtector(_ResourceProtector): + """A protecting method for resource servers. Creating a ``require_oauth`` + decorator easily with ResourceProtector:: + + from authlib.integrations.flask_oauth2 import ResourceProtector + + require_oauth = ResourceProtector() + + # add bearer token validator + from authlib.oauth2.rfc6750 import BearerTokenValidator + from project.models import Token + + class MyBearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + return Token.query.filter_by(access_token=token_string).first() + + require_oauth.register_token_validator(MyBearerTokenValidator()) + + # protect resource with require_oauth + + @app.route('/user') + @require_oauth(['profile']) + def user_profile(): + user = User.get(current_token.user_id) + return jsonify(user.to_dict()) + + """ + def raise_error_response(self, error): + """Raise HTTPException for OAuth2Error. Developers can re-implement + this method to customize the error response. + + :param error: OAuth2Error + :raise: HTTPException + """ + status = error.status_code + body = json.dumps(dict(error.get_body())) + headers = error.get_headers() + raise_http_exception(status, body, headers) + + def acquire_token(self, scopes=None, **kwargs): + """A method to acquire current valid token with the given scope. + + :param scopes: a list of scope values + :return: token object + """ + request = FlaskJsonRequest(_req) + # backward compatibility + kwargs['scopes'] = scopes + for claim in kwargs: + if isinstance(kwargs[claim], str): + kwargs[claim] = [kwargs[claim]] + token = self.validate_request(request=request, **kwargs) + token_authenticated.send(self, token=token) + g.authlib_server_oauth2_token = token + return token + + @contextmanager + def acquire(self, scopes=None): + """The with statement of ``require_oauth``. Instead of using a + decorator, you can use a with statement instead:: + + @app.route('/api/user') + def user_api(): + with require_oauth.acquire('profile') as token: + user = User.get(token.user_id) + return jsonify(user.to_dict()) + """ + try: + yield self.acquire_token(scopes) + except OAuth2Error as error: + self.raise_error_response(error) + + def __call__(self, scopes=None, optional=False, **kwargs): + claims = kwargs + # backward compatibility + claims['scopes'] = scopes + def wrapper(f): + @functools.wraps(f) + def decorated(*args, **kwargs): + try: + self.acquire_token(**claims) + except MissingAuthorizationError as error: + if optional: + return f(*args, **kwargs) + self.raise_error_response(error) + except OAuth2Error as error: + self.raise_error_response(error) + return f(*args, **kwargs) + return decorated + return wrapper + + +def _get_current_token(): + return g.get('authlib_server_oauth2_token') + + +current_token = LocalProxy(_get_current_token) diff --git a/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/signals.py b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/signals.py new file mode 100644 index 00000000..c61e0119 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/flask_oauth2/signals.py @@ -0,0 +1,12 @@ +from flask.signals import Namespace + +_signal = Namespace() + +#: signal when client is authenticated +client_authenticated = _signal.signal('client_authenticated') + +#: signal when token is revoked +token_revoked = _signal.signal('token_revoked') + +#: signal when token is authenticated +token_authenticated = _signal.signal('token_authenticated') diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__init__.py new file mode 100644 index 00000000..3b5437cc --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__init__.py @@ -0,0 +1,25 @@ +from authlib.oauth1 import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_RSA_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_TYPE_HEADER, + SIGNATURE_TYPE_QUERY, + SIGNATURE_TYPE_BODY, +) +from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client +from .oauth2_client import ( + OAuth2Auth, OAuth2Client, OAuth2ClientAuth, + AsyncOAuth2Client, +) +from .assertion_client import AssertionClient, AsyncAssertionClient +from ..base_client import OAuthError + + +__all__ = [ + 'OAuthError', + 'OAuth1Auth', 'AsyncOAuth1Client', + 'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT', + 'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY', + 'OAuth2Auth', 'OAuth2ClientAuth', 'OAuth2Client', 'AsyncOAuth2Client', + 'AssertionClient', 'AsyncAssertionClient', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..56e199f8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/assertion_client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/assertion_client.cpython-311.pyc new file mode 100644 index 00000000..22e61d27 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/assertion_client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth1_client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth1_client.cpython-311.pyc new file mode 100644 index 00000000..c761e190 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth1_client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth2_client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth2_client.cpython-311.pyc new file mode 100644 index 00000000..c471f4a6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/oauth2_client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..942e499e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/httpx_client/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/assertion_client.py b/.venv/Lib/site-packages/authlib/integrations/httpx_client/assertion_client.py new file mode 100644 index 00000000..83dc58b2 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/httpx_client/assertion_client.py @@ -0,0 +1,81 @@ +import httpx +from httpx import Response, USE_CLIENT_DEFAULT +from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient +from authlib.oauth2.rfc7523 import JWTBearerGrant +from .utils import extract_client_kwargs +from .oauth2_client import OAuth2Auth +from ..base_client import OAuthError + +__all__ = ['AsyncAssertionClient'] + + +class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient): + token_auth_class = OAuth2Auth + oauth_error_class = OAuthError + JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE + ASSERTION_METHODS = { + JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, + } + DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, + claims=None, token_placement='header', scope=None, **kwargs): + + client_kwargs = extract_client_kwargs(kwargs) + httpx.AsyncClient.__init__(self, **client_kwargs) + + _AssertionClient.__init__( + self, session=None, + token_endpoint=token_endpoint, issuer=issuer, subject=subject, + audience=audience, grant_type=grant_type, claims=claims, + token_placement=token_placement, scope=scope, **kwargs + ) + + async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs) -> Response: + """Send request with auto refresh token feature.""" + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token or self.token.is_expired(): + await self.refresh_token() + + auth = self.token_auth + return await super().request( + method, url, auth=auth, **kwargs) + + async def _refresh_token(self, data): + resp = await self.request( + 'POST', self.token_endpoint, data=data, withhold_token=True) + + return self.parse_response_token(resp) + + +class AssertionClient(_AssertionClient, httpx.Client): + token_auth_class = OAuth2Auth + oauth_error_class = OAuthError + JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE + ASSERTION_METHODS = { + JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, + } + DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, + claims=None, token_placement='header', scope=None, **kwargs): + + client_kwargs = extract_client_kwargs(kwargs) + httpx.Client.__init__(self, **client_kwargs) + + _AssertionClient.__init__( + self, session=self, + token_endpoint=token_endpoint, issuer=issuer, subject=subject, + audience=audience, grant_type=grant_type, claims=claims, + token_placement=token_placement, scope=scope, **kwargs + ) + + def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): + """Send request with auto refresh token feature.""" + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token or self.token.is_expired(): + self.refresh_token() + + auth = self.token_auth + return super().request( + method, url, auth=auth, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth1_client.py b/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth1_client.py new file mode 100644 index 00000000..ce031c97 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth1_client.py @@ -0,0 +1,102 @@ +import typing +import httpx +from httpx import Auth, Request, Response +from authlib.oauth1 import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_TYPE_HEADER, +) +from authlib.common.encoding import to_unicode +from authlib.oauth1 import ClientAuth +from authlib.oauth1.client import OAuth1Client as _OAuth1Client +from .utils import build_request, extract_client_kwargs +from ..base_client import OAuthError + + +class OAuth1Auth(Auth, ClientAuth): + """Signs the httpx request using OAuth 1 (RFC5849)""" + requires_request_body = True + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + url, headers, body = self.prepare( + request.method, str(request.url), request.headers, request.content) + headers['Content-Length'] = str(len(body)) + yield build_request(url=url, headers=headers, body=body, initial_request=request) + + +class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient): + auth_class = OAuth1Auth + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, **kwargs): + + _client_kwargs = extract_client_kwargs(kwargs) + httpx.AsyncClient.__init__(self, **_client_kwargs) + + _OAuth1Client.__init__( + self, None, + client_id=client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, + signature_method=signature_method, signature_type=signature_type, + force_include_body=force_include_body, **kwargs) + + async def fetch_access_token(self, url, verifier=None, **kwargs): + """Method for fetching an access token from the token endpoint. + + This is the final step in the OAuth 1 workflow. An access token is + obtained using all previously obtained credentials, including the + verifier from the authorization step. + + :param url: Access Token endpoint. + :param verifier: A verifier string to prove authorization was granted. + :param kwargs: Extra parameters to include for fetching access token. + :return: A token dict. + """ + if verifier: + self.auth.verifier = verifier + if not self.auth.verifier: + self.handle_error('missing_verifier', 'Missing "verifier" value') + token = await self._fetch_token(url, **kwargs) + self.auth.verifier = None + return token + + async def _fetch_token(self, url, **kwargs): + resp = await self.post(url, **kwargs) + text = await resp.aread() + token = self.parse_response_token(resp.status_code, to_unicode(text)) + self.token = token + return token + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) + + +class OAuth1Client(_OAuth1Client, httpx.Client): + auth_class = OAuth1Auth + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, **kwargs): + + _client_kwargs = extract_client_kwargs(kwargs) + httpx.Client.__init__(self, **_client_kwargs) + + _OAuth1Client.__init__( + self, self, + client_id=client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, + signature_method=signature_method, signature_type=signature_type, + force_include_body=force_include_body, **kwargs) + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth2_client.py b/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth2_client.py new file mode 100644 index 00000000..d4ee0f58 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/httpx_client/oauth2_client.py @@ -0,0 +1,220 @@ +import typing +from contextlib import asynccontextmanager + +import httpx +from httpx import Auth, Request, Response, USE_CLIENT_DEFAULT +from anyio import Lock # Import after httpx so import errors refer to httpx +from authlib.common.urls import url_decode +from authlib.oauth2.client import OAuth2Client as _OAuth2Client +from authlib.oauth2.auth import ClientAuth, TokenAuth +from .utils import HTTPX_CLIENT_KWARGS, build_request +from ..base_client import ( + OAuthError, + InvalidTokenError, + MissingTokenError, + UnsupportedTokenTypeError, +) + +__all__ = [ + 'OAuth2Auth', 'OAuth2ClientAuth', + 'AsyncOAuth2Client', 'OAuth2Client', +] + + +class OAuth2Auth(Auth, TokenAuth): + """Sign requests for OAuth 2.0, currently only bearer token is supported.""" + requires_request_body = True + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + try: + url, headers, body = self.prepare( + str(request.url), request.headers, request.content) + headers['Content-Length'] = str(len(body)) + yield build_request(url=url, headers=headers, body=body, initial_request=request) + except KeyError as error: + description = f'Unsupported token_type: {str(error)}' + raise UnsupportedTokenTypeError(description=description) + + +class OAuth2ClientAuth(Auth, ClientAuth): + requires_request_body = True + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + url, headers, body = self.prepare( + request.method, str(request.url), request.headers, request.content) + headers['Content-Length'] = str(len(body)) + yield build_request(url=url, headers=headers, body=body, initial_request=request) + + +class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient): + SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS + + client_auth_class = OAuth2ClientAuth + token_auth_class = OAuth2Auth + oauth_error_class = OAuthError + + def __init__(self, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, redirect_uri=None, + token=None, token_placement='header', + update_token=None, **kwargs): + + # extract httpx.Client kwargs + client_kwargs = self._extract_session_request_params(kwargs) + httpx.AsyncClient.__init__(self, **client_kwargs) + + # We use a Lock to synchronize coroutines to prevent + # multiple concurrent attempts to refresh the same token + self._token_refresh_lock = Lock() + + _OAuth2Client.__init__( + self, session=None, + client_id=client_id, client_secret=client_secret, + token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, redirect_uri=redirect_uri, + token=token, token_placement=token_placement, + update_token=update_token, **kwargs + ) + + async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + await self.ensure_active_token(self.token) + + auth = self.token_auth + + return await super().request( + method, url, auth=auth, **kwargs) + + @asynccontextmanager + async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + await self.ensure_active_token(self.token) + + auth = self.token_auth + + async with super().stream( + method, url, auth=auth, **kwargs) as resp: + yield resp + + async def ensure_active_token(self, token): + async with self._token_refresh_lock: + if self.token.is_expired(): + refresh_token = token.get('refresh_token') + url = self.metadata.get('token_endpoint') + if refresh_token and url: + await self.refresh_token(url, refresh_token=refresh_token) + elif self.metadata.get('grant_type') == 'client_credentials': + access_token = token['access_token'] + new_token = await self.fetch_token(url, grant_type='client_credentials') + if self.update_token: + await self.update_token(new_token, access_token=access_token) + else: + raise InvalidTokenError() + + async def _fetch_token(self, url, body='', headers=None, auth=USE_CLIENT_DEFAULT, + method='POST', **kwargs): + if method.upper() == 'POST': + resp = await self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + else: + if '?' in url: + url = '&'.join([url, body]) + else: + url = '?'.join([url, body]) + resp = await self.get(url, headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook['access_token_response']: + resp = hook(resp) + + return self.parse_response_token(resp) + + async def _refresh_token(self, url, refresh_token=None, body='', + headers=None, auth=USE_CLIENT_DEFAULT, **kwargs): + resp = await self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + + for hook in self.compliance_hook['refresh_token_response']: + resp = hook(resp) + + token = self.parse_response_token(resp) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if self.update_token: + await self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def _http_post(self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs): + return self.post( + url, data=dict(url_decode(body)), + headers=headers, auth=auth, **kwargs) + + +class OAuth2Client(_OAuth2Client, httpx.Client): + SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS + + client_auth_class = OAuth2ClientAuth + token_auth_class = OAuth2Auth + oauth_error_class = OAuthError + + def __init__(self, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, redirect_uri=None, + token=None, token_placement='header', + update_token=None, **kwargs): + + # extract httpx.Client kwargs + client_kwargs = self._extract_session_request_params(kwargs) + httpx.Client.__init__(self, **client_kwargs) + + _OAuth2Client.__init__( + self, session=self, + client_id=client_id, client_secret=client_secret, + token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, redirect_uri=redirect_uri, + token=token, token_placement=token_placement, + update_token=update_token, **kwargs + ) + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) + + def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + if not self.ensure_active_token(self.token): + raise InvalidTokenError() + + auth = self.token_auth + + return super().request( + method, url, auth=auth, **kwargs) + + def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): + if not withhold_token and auth is USE_CLIENT_DEFAULT: + if not self.token: + raise MissingTokenError() + + if not self.ensure_active_token(self.token): + raise InvalidTokenError() + + auth = self.token_auth + + return super().stream( + method, url, auth=auth, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/integrations/httpx_client/utils.py b/.venv/Lib/site-packages/authlib/integrations/httpx_client/utils.py new file mode 100644 index 00000000..8f19f37b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/httpx_client/utils.py @@ -0,0 +1,30 @@ +from httpx import Request + +HTTPX_CLIENT_KWARGS = [ + 'headers', 'cookies', 'verify', 'cert', 'http1', 'http2', + 'proxies', 'timeout', 'follow_redirects', 'limits', 'max_redirects', + 'event_hooks', 'base_url', 'transport', 'app', 'trust_env', +] + + +def extract_client_kwargs(kwargs): + client_kwargs = {} + for k in HTTPX_CLIENT_KWARGS: + if k in kwargs: + client_kwargs[k] = kwargs.pop(k) + return client_kwargs + + +def build_request(url, headers, body, initial_request: Request) -> Request: + """Make sure that all the data from initial request is passed to the updated object""" + updated_request = Request( + method=initial_request.method, + url=url, + headers=headers, + content=body + ) + + if hasattr(initial_request, 'extensions'): + updated_request.extensions = initial_request.extensions + + return updated_request diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/requests_client/__init__.py new file mode 100644 index 00000000..fcbdec32 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/requests_client/__init__.py @@ -0,0 +1,22 @@ +from .oauth1_session import OAuth1Session, OAuth1Auth +from .oauth2_session import OAuth2Session, OAuth2Auth +from .assertion_session import AssertionSession +from ..base_client import OAuthError +from authlib.oauth1 import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_RSA_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_TYPE_HEADER, + SIGNATURE_TYPE_QUERY, + SIGNATURE_TYPE_BODY, +) + + +__all__ = [ + 'OAuthError', + 'OAuth1Session', 'OAuth1Auth', + 'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT', + 'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY', + 'OAuth2Session', 'OAuth2Auth', + 'AssertionSession', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..7d93605f Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/assertion_session.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/assertion_session.cpython-311.pyc new file mode 100644 index 00000000..caaf0174 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/assertion_session.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth1_session.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth1_session.cpython-311.pyc new file mode 100644 index 00000000..9812dfc6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth1_session.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth2_session.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth2_session.cpython-311.pyc new file mode 100644 index 00000000..29d94c8d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/oauth2_session.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..16bd926e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/requests_client/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/assertion_session.py b/.venv/Lib/site-packages/authlib/integrations/requests_client/assertion_session.py new file mode 100644 index 00000000..d07c0016 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/requests_client/assertion_session.py @@ -0,0 +1,46 @@ +from requests import Session +from authlib.oauth2.rfc7521 import AssertionClient +from authlib.oauth2.rfc7523 import JWTBearerGrant +from .oauth2_session import OAuth2Auth +from .utils import update_session_configure + + +class AssertionAuth(OAuth2Auth): + def ensure_active_token(self): + if not self.token or self.token.is_expired() and self.client: + return self.client.refresh_token() + + +class AssertionSession(AssertionClient, Session): + """Constructs a new Assertion Framework for OAuth 2.0 Authorization Grants + per RFC7521_. + + .. _RFC7521: https://tools.ietf.org/html/rfc7521 + """ + token_auth_class = AssertionAuth + JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE + ASSERTION_METHODS = { + JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, + } + DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, + claims=None, token_placement='header', scope=None, default_timeout=None, **kwargs): + Session.__init__(self) + self.default_timeout = default_timeout + update_session_configure(self, kwargs) + AssertionClient.__init__( + self, session=self, + token_endpoint=token_endpoint, issuer=issuer, subject=subject, + audience=audience, grant_type=grant_type, claims=claims, + token_placement=token_placement, scope=scope, **kwargs + ) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + """Send request with auto refresh token feature.""" + if self.default_timeout: + kwargs.setdefault('timeout', self.default_timeout) + if not withhold_token and auth is None: + auth = self.token_auth + return super().request( + method, url, auth=auth, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth1_session.py b/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth1_session.py new file mode 100644 index 00000000..8c49fa98 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth1_session.py @@ -0,0 +1,59 @@ +from requests import Session +from requests.auth import AuthBase +from authlib.oauth1 import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_TYPE_HEADER, +) +from authlib.common.encoding import to_native +from authlib.oauth1 import ClientAuth +from authlib.oauth1.client import OAuth1Client +from ..base_client import OAuthError +from .utils import update_session_configure + + +class OAuth1Auth(AuthBase, ClientAuth): + """Signs the request using OAuth 1 (RFC5849)""" + + def __call__(self, req): + url, headers, body = self.prepare( + req.method, req.url, req.headers, req.body) + + req.url = to_native(url) + req.prepare_headers(headers) + if body: + req.body = body + return req + + +class OAuth1Session(OAuth1Client, Session): + auth_class = OAuth1Auth + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, **kwargs): + Session.__init__(self) + update_session_configure(self, kwargs) + OAuth1Client.__init__( + self, session=self, + client_id=client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, + signature_method=signature_method, signature_type=signature_type, + force_include_body=force_include_body, **kwargs) + + def rebuild_auth(self, prepared_request, response): + """When being redirected we should always strip Authorization + header, since nonce may not be reused as per OAuth spec. + """ + if 'Authorization' in prepared_request.headers: + # If we get redirected to a new host, we should strip out + # any authentication headers. + prepared_request.headers.pop('Authorization', True) + prepared_request.prepare_auth(self.auth) + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth2_session.py b/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth2_session.py new file mode 100644 index 00000000..9e2426a2 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/requests_client/oauth2_session.py @@ -0,0 +1,110 @@ +from requests import Session +from requests.auth import AuthBase +from authlib.oauth2.client import OAuth2Client +from authlib.oauth2.auth import ClientAuth, TokenAuth +from ..base_client import ( + OAuthError, + InvalidTokenError, + MissingTokenError, + UnsupportedTokenTypeError, +) +from .utils import update_session_configure + +__all__ = ['OAuth2Session', 'OAuth2Auth'] + + +class OAuth2Auth(AuthBase, TokenAuth): + """Sign requests for OAuth 2.0, currently only bearer token is supported.""" + + def ensure_active_token(self): + if self.client and not self.client.ensure_active_token(self.token): + raise InvalidTokenError() + + def __call__(self, req): + self.ensure_active_token() + try: + req.url, req.headers, req.body = self.prepare( + req.url, req.headers, req.body) + except KeyError as error: + description = f'Unsupported token_type: {str(error)}' + raise UnsupportedTokenTypeError(description=description) + return req + + +class OAuth2ClientAuth(AuthBase, ClientAuth): + """Attaches OAuth Client Authentication to the given Request object. + """ + def __call__(self, req): + req.url, req.headers, req.body = self.prepare( + req.method, req.url, req.headers, req.body + ) + return req + + +class OAuth2Session(OAuth2Client, Session): + """Construct a new OAuth 2 client requests session. + + :param client_id: Client ID, which you get from client registration. + :param client_secret: Client Secret, which you get from registration. + :param authorization_endpoint: URL of the authorization server's + authorization endpoint. + :param token_endpoint: URL of the authorization server's token endpoint. + :param token_endpoint_auth_method: client authentication method for + token endpoint. + :param revocation_endpoint: URL of the authorization server's OAuth 2.0 + revocation endpoint. + :param revocation_endpoint_auth_method: client authentication method for + revocation endpoint. + :param scope: Scope that you needed to access user resources. + :param state: Shared secret to prevent CSRF attack. + :param redirect_uri: Redirect URI you registered as callback. + :param token: A dict of token attributes such as ``access_token``, + ``token_type`` and ``expires_at``. + :param token_placement: The place to put token in HTTP request. Available + values: "header", "body", "uri". + :param update_token: A function for you to update token. It accept a + :class:`OAuth2Token` as parameter. + :param default_timeout: If settled, every requests will have a default timeout. + """ + client_auth_class = OAuth2ClientAuth + token_auth_class = OAuth2Auth + oauth_error_class = OAuthError + SESSION_REQUEST_PARAMS = ( + 'allow_redirects', 'timeout', 'cookies', 'files', + 'proxies', 'hooks', 'stream', 'verify', 'cert', 'json' + ) + + def __init__(self, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, state=None, redirect_uri=None, + token=None, token_placement='header', + update_token=None, default_timeout=None, **kwargs): + Session.__init__(self) + self.default_timeout = default_timeout + update_session_configure(self, kwargs) + + OAuth2Client.__init__( + self, session=self, + client_id=client_id, client_secret=client_secret, + token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, state=state, redirect_uri=redirect_uri, + token=token, token_placement=token_placement, + update_token=update_token, **kwargs + ) + + def fetch_access_token(self, url=None, **kwargs): + """Alias for fetch_token.""" + return self.fetch_token(url, **kwargs) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + """Send request with auto refresh token feature (if available).""" + if self.default_timeout: + kwargs.setdefault('timeout', self.default_timeout) + if not withhold_token and auth is None: + if not self.token: + raise MissingTokenError() + auth = self.token_auth + return super().request( + method, url, auth=auth, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/integrations/requests_client/utils.py b/.venv/Lib/site-packages/authlib/integrations/requests_client/utils.py new file mode 100644 index 00000000..53a07db3 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/requests_client/utils.py @@ -0,0 +1,10 @@ +REQUESTS_SESSION_KWARGS = [ + 'proxies', 'hooks', 'stream', 'verify', 'cert', + 'max_redirects', 'trust_env', +] + + +def update_session_configure(session, kwargs): + for k in REQUESTS_SESSION_KWARGS: + if k in kwargs: + setattr(session, k, kwargs.pop(k)) diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__init__.py b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__init__.py new file mode 100644 index 00000000..1964aa1a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__init__.py @@ -0,0 +1,17 @@ +from .client_mixin import OAuth2ClientMixin +from .tokens_mixins import OAuth2AuthorizationCodeMixin, OAuth2TokenMixin +from .functions import ( + create_query_client_func, + create_save_token_func, + create_query_token_func, + create_revocation_endpoint, + create_bearer_token_validator, +) + + +__all__ = [ + 'OAuth2ClientMixin', 'OAuth2AuthorizationCodeMixin', 'OAuth2TokenMixin', + 'create_query_client_func', 'create_save_token_func', + 'create_query_token_func', 'create_revocation_endpoint', + 'create_bearer_token_validator', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..190f0192 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/client_mixin.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/client_mixin.cpython-311.pyc new file mode 100644 index 00000000..17bc311c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/client_mixin.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/functions.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/functions.cpython-311.pyc new file mode 100644 index 00000000..0c720456 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/functions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/tokens_mixins.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/tokens_mixins.cpython-311.pyc new file mode 100644 index 00000000..56e04af9 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/__pycache__/tokens_mixins.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/client_mixin.py b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/client_mixin.py new file mode 100644 index 00000000..28505cda --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/client_mixin.py @@ -0,0 +1,138 @@ +import secrets + +from sqlalchemy import Column, String, Text, Integer +from authlib.common.encoding import json_loads, json_dumps +from authlib.oauth2.rfc6749 import ClientMixin +from authlib.oauth2.rfc6749 import scope_to_list, list_to_scope + + +class OAuth2ClientMixin(ClientMixin): + client_id = Column(String(48), index=True) + client_secret = Column(String(120)) + client_id_issued_at = Column(Integer, nullable=False, default=0) + client_secret_expires_at = Column(Integer, nullable=False, default=0) + _client_metadata = Column('client_metadata', Text) + + @property + def client_info(self): + """Implementation for Client Info in OAuth 2.0 Dynamic Client + Registration Protocol via `Section 3.2.1`_. + + .. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1 + """ + return dict( + client_id=self.client_id, + client_secret=self.client_secret, + client_id_issued_at=self.client_id_issued_at, + client_secret_expires_at=self.client_secret_expires_at, + ) + + @property + def client_metadata(self): + if 'client_metadata' in self.__dict__: + return self.__dict__['client_metadata'] + if self._client_metadata: + data = json_loads(self._client_metadata) + self.__dict__['client_metadata'] = data + return data + return {} + + def set_client_metadata(self, value): + self._client_metadata = json_dumps(value) + if 'client_metadata' in self.__dict__: + del self.__dict__['client_metadata'] + + @property + def redirect_uris(self): + return self.client_metadata.get('redirect_uris', []) + + @property + def token_endpoint_auth_method(self): + return self.client_metadata.get( + 'token_endpoint_auth_method', + 'client_secret_basic' + ) + + @property + def grant_types(self): + return self.client_metadata.get('grant_types', []) + + @property + def response_types(self): + return self.client_metadata.get('response_types', []) + + @property + def client_name(self): + return self.client_metadata.get('client_name') + + @property + def client_uri(self): + return self.client_metadata.get('client_uri') + + @property + def logo_uri(self): + return self.client_metadata.get('logo_uri') + + @property + def scope(self): + return self.client_metadata.get('scope', '') + + @property + def contacts(self): + return self.client_metadata.get('contacts', []) + + @property + def tos_uri(self): + return self.client_metadata.get('tos_uri') + + @property + def policy_uri(self): + return self.client_metadata.get('policy_uri') + + @property + def jwks_uri(self): + return self.client_metadata.get('jwks_uri') + + @property + def jwks(self): + return self.client_metadata.get('jwks', []) + + @property + def software_id(self): + return self.client_metadata.get('software_id') + + @property + def software_version(self): + return self.client_metadata.get('software_version') + + def get_client_id(self): + return self.client_id + + def get_default_redirect_uri(self): + if self.redirect_uris: + return self.redirect_uris[0] + + def get_allowed_scope(self, scope): + if not scope: + return '' + allowed = set(self.scope.split()) + scopes = scope_to_list(scope) + return list_to_scope([s for s in scopes if s in allowed]) + + def check_redirect_uri(self, redirect_uri): + return redirect_uri in self.redirect_uris + + def check_client_secret(self, client_secret): + return secrets.compare_digest(self.client_secret, client_secret) + + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + # TODO + return True + + def check_response_type(self, response_type): + return response_type in self.response_types + + def check_grant_type(self, grant_type): + return grant_type in self.grant_types diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/functions.py b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/functions.py new file mode 100644 index 00000000..74f10712 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/functions.py @@ -0,0 +1,101 @@ +import time + + +def create_query_client_func(session, client_model): + """Create an ``query_client`` function that can be used in authorization + server. + + :param session: SQLAlchemy session + :param client_model: Client model class + """ + def query_client(client_id): + q = session.query(client_model) + return q.filter_by(client_id=client_id).first() + return query_client + + +def create_save_token_func(session, token_model): + """Create an ``save_token`` function that can be used in authorization + server. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + def save_token(token, request): + if request.user: + user_id = request.user.get_user_id() + else: + user_id = None + client = request.client + item = token_model( + client_id=client.client_id, + user_id=user_id, + **token + ) + session.add(item) + session.commit() + return save_token + + +def create_query_token_func(session, token_model): + """Create an ``query_token`` function for revocation, introspection + token endpoints. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + def query_token(token, token_type_hint): + q = session.query(token_model) + if token_type_hint == 'access_token': + return q.filter_by(access_token=token).first() + elif token_type_hint == 'refresh_token': + return q.filter_by(refresh_token=token).first() + # without token_type_hint + item = q.filter_by(access_token=token).first() + if item: + return item + return q.filter_by(refresh_token=token).first() + return query_token + + +def create_revocation_endpoint(session, token_model): + """Create a revocation endpoint class with SQLAlchemy session + and token model. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + from authlib.oauth2.rfc7009 import RevocationEndpoint + query_token = create_query_token_func(session, token_model) + + class _RevocationEndpoint(RevocationEndpoint): + def query_token(self, token, token_type_hint): + return query_token(token, token_type_hint) + + def revoke_token(self, token, request): + now = int(time.time()) + hint = request.form.get('token_type_hint') + token.access_token_revoked_at = now + if hint != 'access_token': + token.refresh_token_revoked_at = now + session.add(token) + session.commit() + + return _RevocationEndpoint + + +def create_bearer_token_validator(session, token_model): + """Create an bearer token validator class with SQLAlchemy session + and token model. + + :param session: SQLAlchemy session + :param token_model: Token model class + """ + from authlib.oauth2.rfc6750 import BearerTokenValidator + + class _BearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + q = session.query(token_model) + return q.filter_by(access_token=token_string).first() + + return _BearerTokenValidator diff --git a/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/tokens_mixins.py b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/tokens_mixins.py new file mode 100644 index 00000000..28cee892 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/sqla_oauth2/tokens_mixins.py @@ -0,0 +1,70 @@ +import time +from sqlalchemy import Column, String, Text, Integer +from authlib.oauth2.rfc6749 import ( + TokenMixin, + AuthorizationCodeMixin, +) + + +class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin): + code = Column(String(120), unique=True, nullable=False) + client_id = Column(String(48)) + redirect_uri = Column(Text, default='') + response_type = Column(Text, default='') + scope = Column(Text, default='') + nonce = Column(Text) + auth_time = Column( + Integer, nullable=False, + default=lambda: int(time.time()) + ) + + code_challenge = Column(Text) + code_challenge_method = Column(String(48)) + + def is_expired(self): + return self.auth_time + 300 < time.time() + + def get_redirect_uri(self): + return self.redirect_uri + + def get_scope(self): + return self.scope + + def get_auth_time(self): + return self.auth_time + + def get_nonce(self): + return self.nonce + + +class OAuth2TokenMixin(TokenMixin): + client_id = Column(String(48)) + token_type = Column(String(40)) + access_token = Column(String(255), unique=True, nullable=False) + refresh_token = Column(String(255), index=True) + scope = Column(Text, default='') + issued_at = Column( + Integer, nullable=False, default=lambda: int(time.time()) + ) + access_token_revoked_at = Column(Integer, nullable=False, default=0) + refresh_token_revoked_at = Column(Integer, nullable=False, default=0) + expires_in = Column(Integer, nullable=False, default=0) + + def check_client(self, client): + return self.client_id == client.get_client_id() + + def get_scope(self): + return self.scope + + def get_expires_in(self): + return self.expires_in + + def is_revoked(self): + return self.access_token_revoked_at or self.refresh_token_revoked_at + + def is_expired(self): + if not self.expires_in: + return False + + expires_at = self.issued_at + self.expires_in + return expires_at < time.time() diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/__init__.py b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__init__.py new file mode 100644 index 00000000..7546c547 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__init__.py @@ -0,0 +1,22 @@ +# flake8: noqa + +from ..base_client import BaseOAuth, OAuthError +from .integration import StarletteIntegration +from .apps import StarletteOAuth1App, StarletteOAuth2App + + +class OAuth(BaseOAuth): + oauth1_client_cls = StarletteOAuth1App + oauth2_client_cls = StarletteOAuth2App + framework_integration_cls = StarletteIntegration + + def __init__(self, config=None, cache=None, fetch_token=None, update_token=None): + super().__init__( + cache=cache, fetch_token=fetch_token, update_token=update_token) + self.config = config + + +__all__ = [ + 'OAuth', 'OAuthError', + 'StarletteIntegration', 'StarletteOAuth1App', 'StarletteOAuth2App', +] diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..9f445b74 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/apps.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/apps.cpython-311.pyc new file mode 100644 index 00000000..24e4407f Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/apps.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/integration.cpython-311.pyc b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/integration.cpython-311.pyc new file mode 100644 index 00000000..65cdf540 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/integrations/starlette_client/__pycache__/integration.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/apps.py b/.venv/Lib/site-packages/authlib/integrations/starlette_client/apps.py new file mode 100644 index 00000000..114cbaff --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/starlette_client/apps.py @@ -0,0 +1,86 @@ +from starlette.datastructures import URL +from starlette.responses import RedirectResponse +from ..base_client import OAuthError +from ..base_client import BaseApp +from ..base_client.async_app import AsyncOAuth1Mixin, AsyncOAuth2Mixin +from ..base_client.async_openid import AsyncOpenIDMixin +from ..httpx_client import AsyncOAuth1Client, AsyncOAuth2Client + + +class StarletteAppMixin: + async def save_authorize_data(self, request, **kwargs): + state = kwargs.pop('state', None) + if state: + if self.framework.cache: + session = None + else: + session = request.session + await self.framework.set_state_data(session, state, kwargs) + else: + raise RuntimeError('Missing state value') + + async def authorize_redirect(self, request, redirect_uri=None, **kwargs): + """Create a HTTP Redirect for Authorization Endpoint. + + :param request: HTTP request instance from Starlette view. + :param redirect_uri: Callback or redirect URI for authorization. + :param kwargs: Extra parameters to include. + :return: A HTTP redirect response. + """ + + # Handle Starlette >= 0.26.0 where redirect_uri may now be a URL and not a string + if redirect_uri and isinstance(redirect_uri, URL): + redirect_uri = str(redirect_uri) + rv = await self.create_authorization_url(redirect_uri, **kwargs) + await self.save_authorize_data(request, redirect_uri=redirect_uri, **rv) + return RedirectResponse(rv['url'], status_code=302) + + +class StarletteOAuth1App(StarletteAppMixin, AsyncOAuth1Mixin, BaseApp): + client_cls = AsyncOAuth1Client + + async def authorize_access_token(self, request, **kwargs): + params = dict(request.query_params) + state = params.get('oauth_token') + if not state: + raise OAuthError(description='Missing "oauth_token" parameter') + + data = await self.framework.get_state_data(request.session, state) + if not data: + raise OAuthError(description='Missing "request_token" in temporary data') + + params['request_token'] = data['request_token'] + params.update(kwargs) + await self.framework.clear_state_data(request.session, state) + return await self.fetch_access_token(**params) + + +class StarletteOAuth2App(StarletteAppMixin, AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): + client_cls = AsyncOAuth2Client + + async def authorize_access_token(self, request, **kwargs): + error = request.query_params.get('error') + if error: + description = request.query_params.get('error_description') + raise OAuthError(error=error, description=description) + + params = { + 'code': request.query_params.get('code'), + 'state': request.query_params.get('state'), + } + + if self.framework.cache: + session = None + else: + session = request.session + + claims_options = kwargs.pop('claims_options', None) + state_data = await self.framework.get_state_data(session, params.get('state')) + await self.framework.clear_state_data(session, params.get('state')) + params = self._format_state_params(state_data, params) + token = await self.fetch_access_token(**params, **kwargs) + + if 'id_token' in token and 'nonce' in state_data: + userinfo = await self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options) + token['userinfo'] = userinfo + return token diff --git a/.venv/Lib/site-packages/authlib/integrations/starlette_client/integration.py b/.venv/Lib/site-packages/authlib/integrations/starlette_client/integration.py new file mode 100644 index 00000000..04ffd786 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/integrations/starlette_client/integration.py @@ -0,0 +1,66 @@ +import json +import time +from typing import ( + Any, + Dict, + Hashable, + Optional, +) + +from ..base_client import FrameworkIntegration + + +class StarletteIntegration(FrameworkIntegration): + async def _get_cache_data(self, key: Hashable): + value = await self.cache.get(key) + if not value: + return None + try: + return json.loads(value) + except (TypeError, ValueError): + return None + + async def get_state_data(self, session: Optional[Dict[str, Any]], state: str) -> Dict[str, Any]: + key = f'_state_{self.name}_{state}' + if self.cache: + value = await self._get_cache_data(key) + elif session is not None: + value = session.get(key) + else: + value = None + + if value: + return value.get('data') + return None + + async def set_state_data(self, session: Optional[Dict[str, Any]], state: str, data: Any): + key = f'_state_{self.name}_{state}' + if self.cache: + await self.cache.set(key, json.dumps({'data': data}), self.expires_in) + elif session is not None: + now = time.time() + session[key] = {'data': data, 'exp': now + self.expires_in} + + async def clear_state_data(self, session: Optional[Dict[str, Any]], state: str): + key = f'_state_{self.name}_{state}' + if self.cache: + await self.cache.delete(key) + elif session is not None: + session.pop(key, None) + self._clear_session_state(session) + + def update_token(self, token, refresh_token=None, access_token=None): + pass + + @staticmethod + def load_config(oauth, name, params): + if not oauth.config: + return {} + + rv = {} + for k in params: + conf_key = f'{name}_{k}'.upper() + v = oauth.config.get(conf_key, default=None) + if v is not None: + rv[k] = v + return rv diff --git a/.venv/Lib/site-packages/authlib/jose/__init__.py b/.venv/Lib/site-packages/authlib/jose/__init__.py new file mode 100644 index 00000000..2d6638a0 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/__init__.py @@ -0,0 +1,60 @@ +""" + authlib.jose + ~~~~~~~~~~~~ + + JOSE implementation in Authlib. Tracking the status of JOSE specs at + https://tools.ietf.org/wg/jose/ +""" +from .rfc7515 import ( + JsonWebSignature, JWSAlgorithm, JWSHeader, JWSObject, +) +from .rfc7516 import ( + JsonWebEncryption, JWEAlgorithm, JWEEncAlgorithm, JWEZipAlgorithm, +) +from .rfc7517 import Key, KeySet, JsonWebKey +from .rfc7518 import ( + register_jws_rfc7518, + register_jwe_rfc7518, + ECDHESAlgorithm, + OctKey, + RSAKey, + ECKey, +) +from .rfc7519 import JsonWebToken, BaseClaims, JWTClaims +from .rfc8037 import OKPKey, register_jws_rfc8037 + +from .errors import JoseError + +# register algorithms +register_jws_rfc7518(JsonWebSignature) +register_jws_rfc8037(JsonWebSignature) + +register_jwe_rfc7518(JsonWebEncryption) + +# attach algorithms +ECDHESAlgorithm.ALLOWED_KEY_CLS = (ECKey, OKPKey) + +# register supported keys +JsonWebKey.JWK_KEY_CLS = { + OctKey.kty: OctKey, + RSAKey.kty: RSAKey, + ECKey.kty: ECKey, + OKPKey.kty: OKPKey, +} + +jwt = JsonWebToken(list(JsonWebSignature.ALGORITHMS_REGISTRY.keys())) + + +__all__ = [ + 'JoseError', + + 'JsonWebSignature', 'JWSAlgorithm', 'JWSHeader', 'JWSObject', + 'JsonWebEncryption', 'JWEAlgorithm', 'JWEEncAlgorithm', 'JWEZipAlgorithm', + + 'JsonWebKey', 'Key', 'KeySet', + + 'OctKey', 'RSAKey', 'ECKey', 'OKPKey', + + 'JsonWebToken', 'BaseClaims', 'JWTClaims', + 'jwt', +] diff --git a/.venv/Lib/site-packages/authlib/jose/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..4d6b0f4e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..3912694c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/__pycache__/jwk.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/__pycache__/jwk.cpython-311.pyc new file mode 100644 index 00000000..969174a0 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/__pycache__/jwk.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..767c03b2 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/__init__.py b/.venv/Lib/site-packages/authlib/jose/drafts/__init__.py new file mode 100644 index 00000000..3044585e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/drafts/__init__.py @@ -0,0 +1,17 @@ +from ._jwe_algorithms import JWE_DRAFT_ALG_ALGORITHMS +from ._jwe_enc_cryptography import C20PEncAlgorithm +try: + from ._jwe_enc_cryptodome import XC20PEncAlgorithm +except ImportError: + XC20PEncAlgorithm = None + + +def register_jwe_draft(cls): + for alg in JWE_DRAFT_ALG_ALGORITHMS: + cls.register_algorithm(alg) + + cls.register_algorithm(C20PEncAlgorithm(256)) # C20P + if XC20PEncAlgorithm is not None: + cls.register_algorithm(XC20PEncAlgorithm(256)) # XC20P + +__all__ = ['register_jwe_draft'] diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..7a97f0f1 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_algorithms.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_algorithms.cpython-311.pyc new file mode 100644 index 00000000..ffd86bf1 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_algorithms.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptodome.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptodome.cpython-311.pyc new file mode 100644 index 00000000..62f58e3b Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptodome.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptography.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptography.cpython-311.pyc new file mode 100644 index 00000000..7401bf0a Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/drafts/__pycache__/_jwe_enc_cryptography.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_algorithms.py b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_algorithms.py new file mode 100644 index 00000000..c01b7e7d --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_algorithms.py @@ -0,0 +1,188 @@ +import struct +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash + +from authlib.jose.errors import InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError +from authlib.jose.rfc7516 import JWEAlgorithmWithTagAwareKeyAgreement +from authlib.jose.rfc7518 import AESAlgorithm, CBCHS2EncAlgorithm, ECKey, u32be_len_input +from authlib.jose.rfc8037 import OKPKey + + +class ECDH1PUAlgorithm(JWEAlgorithmWithTagAwareKeyAgreement): + EXTRA_HEADERS = ['epk', 'apu', 'apv', 'skid'] + ALLOWED_KEY_CLS = (ECKey, OKPKey) + + # https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04 + def __init__(self, key_size=None): + if key_size is None: + self.name = 'ECDH-1PU' + self.description = 'ECDH-1PU in the Direct Key Agreement mode' + else: + self.name = f'ECDH-1PU+A{key_size}KW' + self.description = ( + 'ECDH-1PU using Concat KDF and CEK wrapped ' + 'with A{}KW').format(key_size) + self.key_size = key_size + self.aeskw = AESAlgorithm(key_size) + + def prepare_key(self, raw_data): + if isinstance(raw_data, self.ALLOWED_KEY_CLS): + return raw_data + return ECKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + preset = {'epk': epk, 'header': h} + if self.key_size is not None: + cek = enc_alg.generate_cek() + preset['cek'] = cek + return preset + + def compute_shared_key(self, shared_key_e, shared_key_s): + return shared_key_e + shared_key_s + + def compute_fixed_info(self, headers, bit_size, tag): + if tag is None: + cctag = b'' + else: + cctag = u32be_len_input(tag) + + # AlgorithmID + if self.key_size is None: + alg_id = u32be_len_input(headers['enc']) + else: + alg_id = u32be_len_input(headers['alg']) + + # PartyUInfo + apu_info = u32be_len_input(headers.get('apu'), True) + + # PartyVInfo + apv_info = u32be_len_input(headers.get('apv'), True) + + # SuppPubInfo + pub_info = struct.pack('>I', bit_size) + cctag + + return alg_id + apu_info + apv_info + pub_info + + def compute_derived_key(self, shared_key, fixed_info, bit_size): + ckdf = ConcatKDFHash( + algorithm=hashes.SHA256(), + length=bit_size // 8, + otherinfo=fixed_info, + backend=default_backend() + ) + return ckdf.derive(shared_key) + + def deliver_at_sender(self, sender_static_key, sender_ephemeral_key, recipient_pubkey, headers, bit_size, tag): + shared_key_s = sender_static_key.exchange_shared_key(recipient_pubkey) + shared_key_e = sender_ephemeral_key.exchange_shared_key(recipient_pubkey) + shared_key = self.compute_shared_key(shared_key_e, shared_key_s) + + fixed_info = self.compute_fixed_info(headers, bit_size, tag) + + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def deliver_at_recipient(self, recipient_key, sender_static_pubkey, sender_ephemeral_pubkey, headers, bit_size, tag): + shared_key_s = recipient_key.exchange_shared_key(sender_static_pubkey) + shared_key_e = recipient_key.exchange_shared_key(sender_ephemeral_pubkey) + shared_key = self.compute_shared_key(shared_key_e, shared_key_s) + + fixed_info = self.compute_fixed_info(headers, bit_size, tag) + + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def _generate_ephemeral_key(self, key): + return key.generate_key(key['crv'], is_private=True) + + def _prepare_headers(self, epk): + # REQUIRED_JSON_FIELDS contains only public fields + pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS} + pub_epk['kty'] = epk.kty + return {'epk': pub_epk} + + def generate_keys_and_prepare_headers(self, enc_alg, key, sender_key, preset=None): + if not isinstance(enc_alg, CBCHS2EncAlgorithm): + raise InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError() + + if preset and 'epk' in preset: + epk = preset['epk'] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + if preset and 'cek' in preset: + cek = preset['cek'] + else: + cek = enc_alg.generate_cek() + + return {'epk': epk, 'cek': cek, 'header': h} + + def _agree_upon_key_at_sender(self, enc_alg, headers, key, sender_key, epk, tag=None): + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + public_key = key.get_op_key('wrapKey') + + return self.deliver_at_sender(sender_key, epk, public_key, headers, bit_size, tag) + + def _wrap_cek(self, cek, dk): + kek = self.aeskw.prepare_key(dk) + return self.aeskw.wrap_cek(cek, kek) + + def agree_upon_key_and_wrap_cek(self, enc_alg, headers, key, sender_key, epk, cek, tag): + dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk, tag) + return self._wrap_cek(cek, dk) + + def wrap(self, enc_alg, headers, key, sender_key, preset=None): + # In this class this method is used in direct key agreement mode only + if self.key_size is not None: + raise RuntimeError('Invalid algorithm state detected') + + if preset and 'epk' in preset: + epk = preset['epk'] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk) + + return {'ek': b'', 'cek': dk, 'header': h} + + def unwrap(self, enc_alg, ek, headers, key, sender_key, tag=None): + if 'epk' not in headers: + raise ValueError('Missing "epk" in headers') + + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + sender_pubkey = sender_key.get_op_key('wrapKey') + epk = key.import_key(headers['epk']) + epk_pubkey = epk.get_op_key('wrapKey') + dk = self.deliver_at_recipient(key, sender_pubkey, epk_pubkey, headers, bit_size, tag) + + if self.key_size is None: + return dk + + kek = self.aeskw.prepare_key(dk) + return self.aeskw.unwrap(enc_alg, ek, headers, kek) + + +JWE_DRAFT_ALG_ALGORITHMS = [ + ECDH1PUAlgorithm(None), # ECDH-1PU + ECDH1PUAlgorithm(128), # ECDH-1PU+A128KW + ECDH1PUAlgorithm(192), # ECDH-1PU+A192KW + ECDH1PUAlgorithm(256), # ECDH-1PU+A256KW +] + + +def register_jwe_alg_draft(cls): + for alg in JWE_DRAFT_ALG_ALGORITHMS: + cls.register_algorithm(alg) diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptodome.py b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptodome.py new file mode 100644 index 00000000..cb6fceaf --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptodome.py @@ -0,0 +1,52 @@ +""" + authlib.jose.draft + ~~~~~~~~~~~~~~~~~~~~ + + Content Encryption per `Section 4`_. + + .. _`Section 4`: https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4 +""" +from authlib.jose.rfc7516 import JWEEncAlgorithm +from Cryptodome.Cipher import ChaCha20_Poly1305 as Cryptodome_ChaCha20_Poly1305 + + +class XC20PEncAlgorithm(JWEEncAlgorithm): + # Use of an IV of size 192 bits is REQUIRED with this algorithm. + # https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4.1 + IV_SIZE = 192 + + def __init__(self, key_size): + self.name = 'XC20P' + self.description = 'XChaCha20-Poly1305' + self.key_size = key_size + self.CEK_SIZE = key_size + + def encrypt(self, msg, aad, iv, key): + """Content Encryption with AEAD_XCHACHA20_POLY1305 + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, tag) + """ + self.check_iv(iv) + chacha = Cryptodome_ChaCha20_Poly1305.new(key=key, nonce=iv) + chacha.update(aad) + ciphertext, tag = chacha.encrypt_and_digest(msg) + return ciphertext, tag + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Content Decryption with AEAD_XCHACHA20_POLY1305 + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + self.check_iv(iv) + chacha = Cryptodome_ChaCha20_Poly1305.new(key=key, nonce=iv) + chacha.update(aad) + return chacha.decrypt_and_verify(ciphertext, tag) diff --git a/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptography.py b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptography.py new file mode 100644 index 00000000..1b0c852b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/drafts/_jwe_enc_cryptography.py @@ -0,0 +1,50 @@ +""" + authlib.jose.draft + ~~~~~~~~~~~~~~~~~~~~ + + Content Encryption per `Section 4`_. + + .. _`Section 4`: https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4 +""" +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from authlib.jose.rfc7516 import JWEEncAlgorithm + + +class C20PEncAlgorithm(JWEEncAlgorithm): + # Use of an IV of size 96 bits is REQUIRED with this algorithm. + # https://datatracker.ietf.org/doc/html/draft-amringer-jose-chacha-02#section-4.1 + IV_SIZE = 96 + + def __init__(self, key_size): + self.name = 'C20P' + self.description = 'ChaCha20-Poly1305' + self.key_size = key_size + self.CEK_SIZE = key_size + + def encrypt(self, msg, aad, iv, key): + """Content Encryption with AEAD_CHACHA20_POLY1305 + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, tag) + """ + self.check_iv(iv) + chacha = ChaCha20Poly1305(key) + ciphertext = chacha.encrypt(iv, msg, aad) + return ciphertext[:-16], ciphertext[-16:] + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Content Decryption with AEAD_CHACHA20_POLY1305 + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + self.check_iv(iv) + chacha = ChaCha20Poly1305(key) + return chacha.decrypt(iv, ciphertext + tag, aad) diff --git a/.venv/Lib/site-packages/authlib/jose/errors.py b/.venv/Lib/site-packages/authlib/jose/errors.py new file mode 100644 index 00000000..fb02eb4e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/errors.py @@ -0,0 +1,113 @@ +from authlib.common.errors import AuthlibBaseError + + +class JoseError(AuthlibBaseError): + pass + + +class DecodeError(JoseError): + error = 'decode_error' + + +class MissingAlgorithmError(JoseError): + error = 'missing_algorithm' + + +class UnsupportedAlgorithmError(JoseError): + error = 'unsupported_algorithm' + + +class BadSignatureError(JoseError): + error = 'bad_signature' + + def __init__(self, result): + super().__init__() + self.result = result + + +class InvalidHeaderParameterNameError(JoseError): + error = 'invalid_header_parameter_name' + + def __init__(self, name): + description = f'Invalid Header Parameter Name: {name}' + super().__init__( + description=description) + + +class InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError(JoseError): + error = 'invalid_encryption_algorithm_for_ECDH_1PU_with_key_wrapping' + + def __init__(self): + description = 'In key agreement with key wrapping mode ECDH-1PU algorithm ' \ + 'only supports AES_CBC_HMAC_SHA2 family encryption algorithms' + super().__init__( + description=description) + + +class InvalidAlgorithmForMultipleRecipientsMode(JoseError): + error = 'invalid_algorithm_for_multiple_recipients_mode' + + def __init__(self, alg): + description = f'{alg} algorithm cannot be used in multiple recipients mode' + super().__init__( + description=description) + + +class KeyMismatchError(JoseError): + error = 'key_mismatch_error' + description = 'Key does not match to any recipient' + + +class MissingEncryptionAlgorithmError(JoseError): + error = 'missing_encryption_algorithm' + description = 'Missing "enc" in header' + + +class UnsupportedEncryptionAlgorithmError(JoseError): + error = 'unsupported_encryption_algorithm' + description = 'Unsupported "enc" value in header' + + +class UnsupportedCompressionAlgorithmError(JoseError): + error = 'unsupported_compression_algorithm' + description = 'Unsupported "zip" value in header' + + +class InvalidUseError(JoseError): + error = 'invalid_use' + description = 'Key "use" is not valid for your usage' + + +class InvalidClaimError(JoseError): + error = 'invalid_claim' + + def __init__(self, claim): + self.claim_name = claim + description = f'Invalid claim "{claim}"' + super().__init__(description=description) + + +class MissingClaimError(JoseError): + error = 'missing_claim' + + def __init__(self, claim): + description = f'Missing "{claim}" claim' + super().__init__(description=description) + + +class InsecureClaimError(JoseError): + error = 'insecure_claim' + + def __init__(self, claim): + description = f'Insecure claim "{claim}"' + super().__init__(description=description) + + +class ExpiredTokenError(JoseError): + error = 'expired_token' + description = 'The token is expired' + + +class InvalidTokenError(JoseError): + error = 'invalid_token' + description = 'The token is not valid yet' diff --git a/.venv/Lib/site-packages/authlib/jose/jwk.py b/.venv/Lib/site-packages/authlib/jose/jwk.py new file mode 100644 index 00000000..bc3b6eb5 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/jwk.py @@ -0,0 +1,19 @@ +from authlib.deprecate import deprecate +from .rfc7517 import JsonWebKey + + +def loads(obj, kid=None): + deprecate('Please use ``JsonWebKey`` directly.') + key_set = JsonWebKey.import_key_set(obj) + if key_set: + return key_set.find_by_kid(kid) + return JsonWebKey.import_key(obj) + + +def dumps(key, kty=None, **params): + deprecate('Please use ``JsonWebKey`` directly.') + if kty: + params['kty'] = kty + + key = JsonWebKey.import_key(key, params) + return dict(key) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc7515/__init__.py new file mode 100644 index 00000000..5f8e0f5f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7515/__init__.py @@ -0,0 +1,18 @@ +""" + authlib.jose.rfc7515 + ~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + JSON Web Signature (JWS). + + https://tools.ietf.org/html/rfc7515 +""" + +from .jws import JsonWebSignature +from .models import JWSAlgorithm, JWSHeader, JWSObject + + +__all__ = [ + 'JsonWebSignature', + 'JWSAlgorithm', 'JWSHeader', 'JWSObject' +] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..8e9b69e3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/jws.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/jws.cpython-311.pyc new file mode 100644 index 00000000..38d55ec3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/jws.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..54cd816e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7515/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/jws.py b/.venv/Lib/site-packages/authlib/jose/rfc7515/jws.py new file mode 100644 index 00000000..cf19c4ba --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7515/jws.py @@ -0,0 +1,304 @@ +from authlib.common.encoding import ( + to_bytes, + to_unicode, + urlsafe_b64encode, + json_b64encode, +) +from authlib.jose.util import ( + extract_header, + extract_segment, ensure_dict, +) +from authlib.jose.errors import ( + DecodeError, + MissingAlgorithmError, + UnsupportedAlgorithmError, + BadSignatureError, + InvalidHeaderParameterNameError, +) +from .models import JWSHeader, JWSObject + + +class JsonWebSignature: + + #: Registered Header Parameter Names defined by Section 4.1 + REGISTERED_HEADER_PARAMETER_NAMES = frozenset([ + 'alg', 'jku', 'jwk', 'kid', + 'x5u', 'x5c', 'x5t', 'x5t#S256', + 'typ', 'cty', 'crit' + ]) + + #: Defined available JWS algorithms in the registry + ALGORITHMS_REGISTRY = {} + + def __init__(self, algorithms=None, private_headers=None): + self._private_headers = private_headers + self._algorithms = algorithms + + @classmethod + def register_algorithm(cls, algorithm): + if not algorithm or algorithm.algorithm_type != 'JWS': + raise ValueError( + f'Invalid algorithm for JWS, {algorithm!r}') + cls.ALGORITHMS_REGISTRY[algorithm.name] = algorithm + + def serialize_compact(self, protected, payload, key): + """Generate a JWS Compact Serialization. The JWS Compact Serialization + represents digitally signed or MACed content as a compact, URL-safe + string, per `Section 7.1`_. + + .. code-block:: text + + BASE64URL(UTF8(JWS Protected Header)) || '.' || + BASE64URL(JWS Payload) || '.' || + BASE64URL(JWS Signature) + + :param protected: A dict of protected header + :param payload: A bytes/string of payload + :param key: Private key used to generate signature + :return: byte + """ + jws_header = JWSHeader(protected, None) + self._validate_private_headers(protected) + algorithm, key = self._prepare_algorithm_key(protected, payload, key) + + protected_segment = json_b64encode(jws_header.protected) + payload_segment = urlsafe_b64encode(to_bytes(payload)) + + # calculate signature + signing_input = b'.'.join([protected_segment, payload_segment]) + signature = urlsafe_b64encode(algorithm.sign(signing_input, key)) + return b'.'.join([protected_segment, payload_segment, signature]) + + def deserialize_compact(self, s, key, decode=None): + """Exact JWS Compact Serialization, and validate with the given key. + If key is not provided, the returned dict will contain the signature, + and signing input values. Via `Section 7.1`_. + + :param s: text of JWS Compact Serialization + :param key: key used to verify the signature + :param decode: a function to decode payload data + :return: JWSObject + :raise: BadSignatureError + + .. _`Section 7.1`: https://tools.ietf.org/html/rfc7515#section-7.1 + """ + try: + s = to_bytes(s) + signing_input, signature_segment = s.rsplit(b'.', 1) + protected_segment, payload_segment = signing_input.split(b'.', 1) + except ValueError: + raise DecodeError('Not enough segments') + + protected = _extract_header(protected_segment) + jws_header = JWSHeader(protected, None) + + payload = _extract_payload(payload_segment) + if decode: + payload = decode(payload) + + signature = _extract_signature(signature_segment) + rv = JWSObject(jws_header, payload, 'compact') + algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) + if algorithm.verify(signing_input, signature, key): + return rv + raise BadSignatureError(rv) + + def serialize_json(self, header_obj, payload, key): + """Generate a JWS JSON Serialization. The JWS JSON Serialization + represents digitally signed or MACed content as a JSON object, + per `Section 7.2`_. + + :param header_obj: A dict/list of header + :param payload: A string/dict of payload + :param key: Private key used to generate signature + :return: JWSObject + + Example ``header_obj`` of JWS JSON Serialization:: + + { + "protected: {"alg": "HS256"}, + "header": {"kid": "jose"} + } + + Pass a dict to generate flattened JSON Serialization, pass a list of + header dict to generate standard JSON Serialization. + """ + payload_segment = json_b64encode(payload) + + def _sign(jws_header): + self._validate_private_headers(jws_header) + _alg, _key = self._prepare_algorithm_key(jws_header, payload, key) + + protected_segment = json_b64encode(jws_header.protected) + signing_input = b'.'.join([protected_segment, payload_segment]) + signature = urlsafe_b64encode(_alg.sign(signing_input, _key)) + + rv = { + 'protected': to_unicode(protected_segment), + 'signature': to_unicode(signature) + } + if jws_header.header is not None: + rv['header'] = jws_header.header + return rv + + if isinstance(header_obj, dict): + data = _sign(JWSHeader.from_dict(header_obj)) + data['payload'] = to_unicode(payload_segment) + return data + + signatures = [_sign(JWSHeader.from_dict(h)) for h in header_obj] + return { + 'payload': to_unicode(payload_segment), + 'signatures': signatures + } + + def deserialize_json(self, obj, key, decode=None): + """Exact JWS JSON Serialization, and validate with the given key. + If key is not provided, it will return a dict without signature + verification. Header will still be validated. Via `Section 7.2`_. + + :param obj: text of JWS JSON Serialization + :param key: key used to verify the signature + :param decode: a function to decode payload data + :return: JWSObject + :raise: BadSignatureError + + .. _`Section 7.2`: https://tools.ietf.org/html/rfc7515#section-7.2 + """ + obj = ensure_dict(obj, 'JWS') + + payload_segment = obj.get('payload') + if payload_segment is None: + raise DecodeError('Missing "payload" value') + + payload_segment = to_bytes(payload_segment) + payload = _extract_payload(payload_segment) + if decode: + payload = decode(payload) + + if 'signatures' not in obj: + # flattened JSON JWS + jws_header, valid = self._validate_json_jws( + payload_segment, payload, obj, key) + + rv = JWSObject(jws_header, payload, 'flat') + if valid: + return rv + raise BadSignatureError(rv) + + headers = [] + is_valid = True + for header_obj in obj['signatures']: + jws_header, valid = self._validate_json_jws( + payload_segment, payload, header_obj, key) + headers.append(jws_header) + if not valid: + is_valid = False + + rv = JWSObject(headers, payload, 'json') + if is_valid: + return rv + raise BadSignatureError(rv) + + def serialize(self, header, payload, key): + """Generate a JWS Serialization. It will automatically generate a + Compact or JSON Serialization depending on the given header. If a + header is in a JSON header format, it will call + :meth:`serialize_json`, otherwise it will call + :meth:`serialize_compact`. + + :param header: A dict/list of header + :param payload: A string/dict of payload + :param key: Private key used to generate signature + :return: byte/dict + """ + if isinstance(header, (list, tuple)): + return self.serialize_json(header, payload, key) + if 'protected' in header: + return self.serialize_json(header, payload, key) + return self.serialize_compact(header, payload, key) + + def deserialize(self, s, key, decode=None): + """Deserialize JWS Serialization, both compact and JSON format. + It will automatically deserialize depending on the given JWS. + + :param s: text of JWS Compact/JSON Serialization + :param key: key used to verify the signature + :param decode: a function to decode payload data + :return: dict + :raise: BadSignatureError + + If key is not provided, it will still deserialize the serialization + without verification. + """ + if isinstance(s, dict): + return self.deserialize_json(s, key, decode) + + s = to_bytes(s) + if s.startswith(b'{') and s.endswith(b'}'): + return self.deserialize_json(s, key, decode) + return self.deserialize_compact(s, key, decode) + + def _prepare_algorithm_key(self, header, payload, key): + if 'alg' not in header: + raise MissingAlgorithmError() + + alg = header['alg'] + if self._algorithms is not None and alg not in self._algorithms: + raise UnsupportedAlgorithmError() + if alg not in self.ALGORITHMS_REGISTRY: + raise UnsupportedAlgorithmError() + + algorithm = self.ALGORITHMS_REGISTRY[alg] + if callable(key): + key = key(header, payload) + elif key is None and 'jwk' in header: + key = header['jwk'] + key = algorithm.prepare_key(key) + return algorithm, key + + def _validate_private_headers(self, header): + # only validate private headers when developers set + # private headers explicitly + if self._private_headers is not None: + names = self.REGISTERED_HEADER_PARAMETER_NAMES.copy() + names = names.union(self._private_headers) + + for k in header: + if k not in names: + raise InvalidHeaderParameterNameError(k) + + def _validate_json_jws(self, payload_segment, payload, header_obj, key): + protected_segment = header_obj.get('protected') + if not protected_segment: + raise DecodeError('Missing "protected" value') + + signature_segment = header_obj.get('signature') + if not signature_segment: + raise DecodeError('Missing "signature" value') + + protected_segment = to_bytes(protected_segment) + protected = _extract_header(protected_segment) + header = header_obj.get('header') + if header and not isinstance(header, dict): + raise DecodeError('Invalid "header" value') + + jws_header = JWSHeader(protected, header) + algorithm, key = self._prepare_algorithm_key(jws_header, payload, key) + signing_input = b'.'.join([protected_segment, payload_segment]) + signature = _extract_signature(to_bytes(signature_segment)) + if algorithm.verify(signing_input, signature, key): + return jws_header, True + return jws_header, False + + +def _extract_header(header_segment): + return extract_header(header_segment, DecodeError) + + +def _extract_signature(signature_segment): + return extract_segment(signature_segment, DecodeError, 'signature') + + +def _extract_payload(payload_segment): + return extract_segment(payload_segment, DecodeError, 'payload') diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7515/models.py b/.venv/Lib/site-packages/authlib/jose/rfc7515/models.py new file mode 100644 index 00000000..5da3c7e0 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7515/models.py @@ -0,0 +1,81 @@ +class JWSAlgorithm: + """Interface for JWS algorithm. JWA specification (RFC7518) SHOULD + implement the algorithms for JWS with this base implementation. + """ + name = None + description = None + algorithm_type = 'JWS' + algorithm_location = 'alg' + + def prepare_key(self, raw_data): + """Prepare key for signing and verifying signature.""" + raise NotImplementedError() + + def sign(self, msg, key): + """Sign the text msg with a private/sign key. + + :param msg: message bytes to be signed + :param key: private key to sign the message + :return: bytes + """ + raise NotImplementedError + + def verify(self, msg, sig, key): + """Verify the signature of text msg with a public/verify key. + + :param msg: message bytes to be signed + :param sig: result signature to be compared + :param key: public key to verify the signature + :return: boolean + """ + raise NotImplementedError + + +class JWSHeader(dict): + """Header object for JWS. It combine the protected header and unprotected + header together. JWSHeader itself is a dict of the combined dict. e.g. + + >>> protected = {'alg': 'HS256'} + >>> header = {'kid': 'a'} + >>> jws_header = JWSHeader(protected, header) + >>> print(jws_header) + {'alg': 'HS256', 'kid': 'a'} + >>> jws_header.protected == protected + >>> jws_header.header == header + + :param protected: dict of protected header + :param header: dict of unprotected header + """ + def __init__(self, protected, header): + obj = {} + if protected: + obj.update(protected) + if header: + obj.update(header) + super().__init__(obj) + self.protected = protected + self.header = header + + @classmethod + def from_dict(cls, obj): + if isinstance(obj, cls): + return obj + return cls(obj.get('protected'), obj.get('header')) + + +class JWSObject(dict): + """A dict instance to represent a JWS object.""" + def __init__(self, header, payload, type='compact'): + super().__init__( + header=header, + payload=payload, + ) + self.header = header + self.payload = payload + self.type = type + + @property + def headers(self): + """Alias of ``header`` for JSON typed JWS.""" + if self.type == 'json': + return self['header'] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc7516/__init__.py new file mode 100644 index 00000000..4a024335 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7516/__init__.py @@ -0,0 +1,18 @@ +""" + authlib.jose.rfc7516 + ~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + JSON Web Encryption (JWE). + + https://tools.ietf.org/html/rfc7516 +""" + +from .jwe import JsonWebEncryption +from .models import JWEAlgorithm, JWEAlgorithmWithTagAwareKeyAgreement, JWEEncAlgorithm, JWEZipAlgorithm + + +__all__ = [ + 'JsonWebEncryption', + 'JWEAlgorithm', 'JWEAlgorithmWithTagAwareKeyAgreement', 'JWEEncAlgorithm', 'JWEZipAlgorithm' +] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..4bbdf9f4 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/jwe.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/jwe.cpython-311.pyc new file mode 100644 index 00000000..8538f6b1 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/jwe.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..07ad29aa Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7516/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/jwe.py b/.venv/Lib/site-packages/authlib/jose/rfc7516/jwe.py new file mode 100644 index 00000000..084bccad --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7516/jwe.py @@ -0,0 +1,722 @@ +from collections import OrderedDict +from copy import deepcopy + +from authlib.common.encoding import ( + to_bytes, urlsafe_b64encode, json_b64encode, to_unicode +) +from authlib.jose.rfc7516.models import JWEAlgorithmWithTagAwareKeyAgreement, JWESharedHeader, JWEHeader +from authlib.jose.util import ( + extract_header, + extract_segment, ensure_dict, +) +from authlib.jose.errors import ( + DecodeError, + MissingAlgorithmError, + UnsupportedAlgorithmError, + MissingEncryptionAlgorithmError, + UnsupportedEncryptionAlgorithmError, + UnsupportedCompressionAlgorithmError, + InvalidHeaderParameterNameError, InvalidAlgorithmForMultipleRecipientsMode, KeyMismatchError, +) + + +class JsonWebEncryption: + #: Registered Header Parameter Names defined by Section 4.1 + REGISTERED_HEADER_PARAMETER_NAMES = frozenset([ + 'alg', 'enc', 'zip', + 'jku', 'jwk', 'kid', + 'x5u', 'x5c', 'x5t', 'x5t#S256', + 'typ', 'cty', 'crit' + ]) + + ALG_REGISTRY = {} + ENC_REGISTRY = {} + ZIP_REGISTRY = {} + + def __init__(self, algorithms=None, private_headers=None): + self._algorithms = algorithms + self._private_headers = private_headers + + @classmethod + def register_algorithm(cls, algorithm): + """Register an algorithm for ``alg`` or ``enc`` or ``zip`` of JWE.""" + if not algorithm or algorithm.algorithm_type != 'JWE': + raise ValueError( + f'Invalid algorithm for JWE, {algorithm!r}') + + if algorithm.algorithm_location == 'alg': + cls.ALG_REGISTRY[algorithm.name] = algorithm + elif algorithm.algorithm_location == 'enc': + cls.ENC_REGISTRY[algorithm.name] = algorithm + elif algorithm.algorithm_location == 'zip': + cls.ZIP_REGISTRY[algorithm.name] = algorithm + + def serialize_compact(self, protected, payload, key, sender_key=None): + """Generate a JWE Compact Serialization. + + The JWE Compact Serialization represents encrypted content as a compact, + URL-safe string. This string is:: + + BASE64URL(UTF8(JWE Protected Header)) || '.' || + BASE64URL(JWE Encrypted Key) || '.' || + BASE64URL(JWE Initialization Vector) || '.' || + BASE64URL(JWE Ciphertext) || '.' || + BASE64URL(JWE Authentication Tag) + + Only one recipient is supported by the JWE Compact Serialization and + it provides no syntax to represent JWE Shared Unprotected Header, JWE + Per-Recipient Unprotected Header, or JWE AAD values. + + :param protected: A dict of protected header + :param payload: Payload (bytes or a value convertible to bytes) + :param key: Public key used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE compact serialization as bytes + """ + + # step 1: Prepare algorithms & key + alg = self.get_header_alg(protected) + enc = self.get_header_enc(protected) + zip_alg = self.get_header_zip(protected) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(protected, alg) + + key = prepare_key(alg, protected, key) + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + # self._post_validate_header(protected, algorithm) + + # step 2: Generate a random Content Encryption Key (CEK) + # use enc_alg.generate_cek() in scope of upcoming .wrap or .generate_keys_and_prepare_headers call + + # step 3: Encrypt the CEK with the recipient's public key + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) and alg.key_size is not None: + # For a JWE algorithm with tag-aware key agreement in case key agreement with key wrapping mode is used: + # Defer key agreement with key wrapping until authentication tag is computed + prep = alg.generate_keys_and_prepare_headers(enc, key, sender_key) + epk = prep['epk'] + cek = prep['cek'] + protected.update(prep['header']) + else: + # In any other case: + # Keep the normal steps order defined by RFC 7516 + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + wrapped = alg.wrap(enc, protected, key, sender_key) + else: + wrapped = alg.wrap(enc, protected, key) + cek = wrapped['cek'] + ek = wrapped['ek'] + if 'header' in wrapped: + protected.update(wrapped['header']) + + # step 4: Generate a random JWE Initialization Vector + iv = enc.generate_iv() + + # step 5: Let the Additional Authenticated Data encryption parameter + # be ASCII(BASE64URL(UTF8(JWE Protected Header))) + protected_segment = json_b64encode(protected) + aad = to_bytes(protected_segment, 'ascii') + + # step 6: compress message if required + if zip_alg: + msg = zip_alg.compress(to_bytes(payload)) + else: + msg = to_bytes(payload) + + # step 7: perform encryption + ciphertext, tag = enc.encrypt(msg, aad, iv, cek) + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) and alg.key_size is not None: + # For a JWE algorithm with tag-aware key agreement in case key agreement with key wrapping mode is used: + # Perform key agreement with key wrapping deferred at step 3 + wrapped = alg.agree_upon_key_and_wrap_cek(enc, protected, key, sender_key, epk, cek, tag) + ek = wrapped['ek'] + + # step 8: build resulting message + return b'.'.join([ + protected_segment, + urlsafe_b64encode(ek), + urlsafe_b64encode(iv), + urlsafe_b64encode(ciphertext), + urlsafe_b64encode(tag) + ]) + + def serialize_json(self, header_obj, payload, keys, sender_key=None): + """Generate a JWE JSON Serialization (in fully general syntax). + + The JWE JSON Serialization represents encrypted content as a JSON + object. This representation is neither optimized for compactness nor + URL safe. + + The following members are defined for use in top-level JSON objects + used for the fully general JWE JSON Serialization syntax: + + protected + The "protected" member MUST be present and contain the value + BASE64URL(UTF8(JWE Protected Header)) when the JWE Protected + Header value is non-empty; otherwise, it MUST be absent. These + Header Parameter values are integrity protected. + + unprotected + The "unprotected" member MUST be present and contain the value JWE + Shared Unprotected Header when the JWE Shared Unprotected Header + value is non-empty; otherwise, it MUST be absent. This value is + represented as an unencoded JSON object, rather than as a string. + These Header Parameter values are not integrity protected. + + iv + The "iv" member MUST be present and contain the value + BASE64URL(JWE Initialization Vector) when the JWE Initialization + Vector value is non-empty; otherwise, it MUST be absent. + + aad + The "aad" member MUST be present and contain the value + BASE64URL(JWE AAD)) when the JWE AAD value is non-empty; + otherwise, it MUST be absent. A JWE AAD value can be included to + supply a base64url-encoded value to be integrity protected but not + encrypted. + + ciphertext + The "ciphertext" member MUST be present and contain the value + BASE64URL(JWE Ciphertext). + + tag + The "tag" member MUST be present and contain the value + BASE64URL(JWE Authentication Tag) when the JWE Authentication Tag + value is non-empty; otherwise, it MUST be absent. + + recipients + The "recipients" member value MUST be an array of JSON objects. + Each object contains information specific to a single recipient. + This member MUST be present with exactly one array element per + recipient, even if some or all of the array element values are the + empty JSON object "{}" (which can happen when all Header Parameter + values are shared between all recipients and when no encrypted key + is used, such as when doing Direct Encryption). + + The following members are defined for use in the JSON objects that + are elements of the "recipients" array: + + header + The "header" member MUST be present and contain the value JWE Per- + Recipient Unprotected Header when the JWE Per-Recipient + Unprotected Header value is non-empty; otherwise, it MUST be + absent. This value is represented as an unencoded JSON object, + rather than as a string. These Header Parameter values are not + integrity protected. + + encrypted_key + The "encrypted_key" member MUST be present and contain the value + BASE64URL(JWE Encrypted Key) when the JWE Encrypted Key value is + non-empty; otherwise, it MUST be absent. + + This implementation assumes that "alg" and "enc" header fields are + contained in the protected or shared unprotected header. + + :param header_obj: A dict of headers (in addition optionally contains JWE AAD) + :param payload: Payload (bytes or a value convertible to bytes) + :param keys: Public keys (or a single public key) used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE JSON serialization (in fully general syntax) as dict + + Example of `header_obj`:: + + { + "protected": { + "alg": "ECDH-1PU+A128KW", + "enc": "A256CBC-HS512", + "apu": "QWxpY2U", + "apv": "Qm9iIGFuZCBDaGFybGll" + }, + "unprotected": { + "jku": "https://alice.example.com/keys.jwks" + }, + "recipients": [ + { + "header": { + "kid": "bob-key-2" + } + }, + { + "header": { + "kid": "2021-05-06" + } + } + ], + "aad": b'Authenticate me too.' + } + """ + if not isinstance(keys, list): # single key + keys = [keys] + + if not keys: + raise ValueError("No keys have been provided") + + header_obj = deepcopy(header_obj) + + shared_header = JWESharedHeader.from_dict(header_obj) + + recipients = header_obj.get('recipients') + if recipients is None: + recipients = [{} for _ in keys] + for i in range(len(recipients)): + if recipients[i] is None: + recipients[i] = {} + if 'header' not in recipients[i]: + recipients[i]['header'] = {} + + jwe_aad = header_obj.get('aad') + + if len(keys) != len(recipients): + raise ValueError("Count of recipient keys {} does not equal to count of recipients {}" + .format(len(keys), len(recipients))) + + # step 1: Prepare algorithms & key + alg = self.get_header_alg(shared_header) + enc = self.get_header_enc(shared_header) + zip_alg = self.get_header_zip(shared_header) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(shared_header, alg) + for recipient in recipients: + self._validate_private_headers(recipient['header'], alg) + + for i in range(len(keys)): + keys[i] = prepare_key(alg, recipients[i]['header'], keys[i]) + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + # self._post_validate_header(protected, algorithm) + + # step 2: Generate a random Content Encryption Key (CEK) + # use enc_alg.generate_cek() in scope of upcoming .wrap or .generate_keys_and_prepare_headers call + + # step 3: Encrypt the CEK with the recipient's public key + preset = alg.generate_preset(enc, keys[0]) + if 'cek' in preset: + cek = preset['cek'] + else: + cek = None + if len(keys) > 1 and cek is None: + raise InvalidAlgorithmForMultipleRecipientsMode(alg.name) + if 'header' in preset: + shared_header.update_protected(preset['header']) + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) and alg.key_size is not None: + # For a JWE algorithm with tag-aware key agreement in case key agreement with key wrapping mode is used: + # Defer key agreement with key wrapping until authentication tag is computed + epks = [] + for i in range(len(keys)): + prep = alg.generate_keys_and_prepare_headers(enc, keys[i], sender_key, preset) + if cek is None: + cek = prep['cek'] + epks.append(prep['epk']) + recipients[i]['header'].update(prep['header']) + else: + # In any other case: + # Keep the normal steps order defined by RFC 7516 + for i in range(len(keys)): + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + wrapped = alg.wrap(enc, shared_header, keys[i], sender_key, preset) + else: + wrapped = alg.wrap(enc, shared_header, keys[i], preset) + if cek is None: + cek = wrapped['cek'] + recipients[i]['encrypted_key'] = wrapped['ek'] + if 'header' in wrapped: + recipients[i]['header'].update(wrapped['header']) + + # step 4: Generate a random JWE Initialization Vector + iv = enc.generate_iv() + + # step 5: Compute the Encoded Protected Header value + # BASE64URL(UTF8(JWE Protected Header)). If the JWE Protected Header + # is not present, let this value be the empty string. + # Let the Additional Authenticated Data encryption parameter be + # ASCII(Encoded Protected Header). However, if a JWE AAD value is + # present, instead let the Additional Authenticated Data encryption + # parameter be ASCII(Encoded Protected Header || '.' || BASE64URL(JWE AAD)). + aad = json_b64encode(shared_header.protected) if shared_header.protected else b'' + if jwe_aad is not None: + aad += b'.' + urlsafe_b64encode(jwe_aad) + aad = to_bytes(aad, 'ascii') + + # step 6: compress message if required + if zip_alg: + msg = zip_alg.compress(to_bytes(payload)) + else: + msg = to_bytes(payload) + + # step 7: perform encryption + ciphertext, tag = enc.encrypt(msg, aad, iv, cek) + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement) and alg.key_size is not None: + # For a JWE algorithm with tag-aware key agreement in case key agreement with key wrapping mode is used: + # Perform key agreement with key wrapping deferred at step 3 + for i in range(len(keys)): + wrapped = alg.agree_upon_key_and_wrap_cek(enc, shared_header, keys[i], sender_key, epks[i], cek, tag) + recipients[i]['encrypted_key'] = wrapped['ek'] + + # step 8: build resulting message + obj = OrderedDict() + + if shared_header.protected: + obj['protected'] = to_unicode(json_b64encode(shared_header.protected)) + + if shared_header.unprotected: + obj['unprotected'] = shared_header.unprotected + + for recipient in recipients: + if not recipient['header']: + del recipient['header'] + recipient['encrypted_key'] = to_unicode(urlsafe_b64encode(recipient['encrypted_key'])) + for member in set(recipient.keys()): + if member not in {'header', 'encrypted_key'}: + del recipient[member] + obj['recipients'] = recipients + + if jwe_aad is not None: + obj['aad'] = to_unicode(urlsafe_b64encode(jwe_aad)) + + obj['iv'] = to_unicode(urlsafe_b64encode(iv)) + + obj['ciphertext'] = to_unicode(urlsafe_b64encode(ciphertext)) + + obj['tag'] = to_unicode(urlsafe_b64encode(tag)) + + return obj + + def serialize(self, header, payload, key, sender_key=None): + """Generate a JWE Serialization. + + It will automatically generate a compact or JSON serialization depending + on `header` argument. If `header` is a dict with "protected", + "unprotected" and/or "recipients" keys, it will call `serialize_json`, + otherwise it will call `serialize_compact`. + + :param header: A dict of header(s) + :param payload: Payload (bytes or a value convertible to bytes) + :param key: Public key(s) used to encrypt payload + :param sender_key: Sender's private key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: JWE compact serialization as bytes or + JWE JSON serialization as dict + """ + if 'protected' in header or 'unprotected' in header or 'recipients' in header: + return self.serialize_json(header, payload, key, sender_key) + + return self.serialize_compact(header, payload, key, sender_key) + + def deserialize_compact(self, s, key, decode=None, sender_key=None): + """Extract JWE Compact Serialization. + + :param s: JWE Compact Serialization as bytes + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys where `header` value is + a dict containing protected header fields + """ + try: + s = to_bytes(s) + protected_s, ek_s, iv_s, ciphertext_s, tag_s = s.rsplit(b'.') + except ValueError: + raise DecodeError('Not enough segments') + + protected = extract_header(protected_s, DecodeError) + ek = extract_segment(ek_s, DecodeError, 'encryption key') + iv = extract_segment(iv_s, DecodeError, 'initialization vector') + ciphertext = extract_segment(ciphertext_s, DecodeError, 'ciphertext') + tag = extract_segment(tag_s, DecodeError, 'authentication tag') + + alg = self.get_header_alg(protected) + enc = self.get_header_enc(protected) + zip_alg = self.get_header_zip(protected) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(protected, alg) + + if isinstance(key, tuple) and len(key) == 2: + # Ignore separately provided kid, extract essentially key only + key = key[1] + + key = prepare_key(alg, protected, key) + + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + # For a JWE algorithm with tag-aware key agreement: + if alg.key_size is not None: + # In case key agreement with key wrapping mode is used: + # Provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key, sender_key, tag) + else: + # Otherwise, don't provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key, sender_key) + else: + # For any other JWE algorithm: + # Don't provide authentication tag to .unwrap method + cek = alg.unwrap(enc, ek, protected, key) + + aad = to_bytes(protected_s, 'ascii') + msg = enc.decrypt(ciphertext, aad, iv, tag, cek) + + if zip_alg: + payload = zip_alg.decompress(to_bytes(msg)) + else: + payload = msg + + if decode: + payload = decode(payload) + return {'header': protected, 'payload': payload} + + def deserialize_json(self, obj, key, decode=None, sender_key=None): + """Extract JWE JSON Serialization. + + :param obj: JWE JSON Serialization as dict or str + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys where `header` value is + a dict containing `protected`, `unprotected`, `recipients` and/or + `aad` keys + """ + obj = ensure_dict(obj, 'JWE') + obj = deepcopy(obj) + + if 'protected' in obj: + protected = extract_header(to_bytes(obj['protected']), DecodeError) + else: + protected = None + + unprotected = obj.get('unprotected') + + recipients = obj['recipients'] + for recipient in recipients: + if 'header' not in recipient: + recipient['header'] = {} + recipient['encrypted_key'] = extract_segment( + to_bytes(recipient['encrypted_key']), DecodeError, 'encrypted key') + + if 'aad' in obj: + jwe_aad = extract_segment(to_bytes(obj['aad']), DecodeError, 'JWE AAD') + else: + jwe_aad = None + + iv = extract_segment(to_bytes(obj['iv']), DecodeError, 'initialization vector') + + ciphertext = extract_segment(to_bytes(obj['ciphertext']), DecodeError, 'ciphertext') + + tag = extract_segment(to_bytes(obj['tag']), DecodeError, 'authentication tag') + + shared_header = JWESharedHeader(protected, unprotected) + + alg = self.get_header_alg(shared_header) + enc = self.get_header_enc(shared_header) + zip_alg = self.get_header_zip(shared_header) + + self._validate_sender_key(sender_key, alg) + self._validate_private_headers(shared_header, alg) + for recipient in recipients: + self._validate_private_headers(recipient['header'], alg) + + kid = None + if isinstance(key, tuple) and len(key) == 2: + # Extract separately provided kid and essentially key + kid = key[0] + key = key[1] + + key = alg.prepare_key(key) + + if kid is None: + # If kid has not been provided separately, try to get it from key itself + kid = key.kid + + if sender_key is not None: + sender_key = alg.prepare_key(sender_key) + + def _unwrap_with_sender_key_and_tag(ek, header): + return alg.unwrap(enc, ek, header, key, sender_key, tag) + + def _unwrap_with_sender_key_and_without_tag(ek, header): + return alg.unwrap(enc, ek, header, key, sender_key) + + def _unwrap_without_sender_key_and_tag(ek, header): + return alg.unwrap(enc, ek, header, key) + + def _unwrap_for_matching_recipient(unwrap_func): + if kid is not None: + for recipient in recipients: + if recipient['header'].get('kid') == kid: + header = JWEHeader(protected, unprotected, recipient['header']) + return unwrap_func(recipient['encrypted_key'], header) + + # Since no explicit match has been found, iterate over all the recipients + error = None + for recipient in recipients: + header = JWEHeader(protected, unprotected, recipient['header']) + try: + return unwrap_func(recipient['encrypted_key'], header) + except Exception as e: + error = e + else: + if error is None: + raise KeyMismatchError() + else: + raise error + + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + # For a JWE algorithm with tag-aware key agreement: + if alg.key_size is not None: + # In case key agreement with key wrapping mode is used: + # Provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient(_unwrap_with_sender_key_and_tag) + else: + # Otherwise, don't provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient(_unwrap_with_sender_key_and_without_tag) + else: + # For any other JWE algorithm: + # Don't provide authentication tag to .unwrap method + cek = _unwrap_for_matching_recipient(_unwrap_without_sender_key_and_tag) + + aad = to_bytes(obj.get('protected', '')) + if 'aad' in obj: + aad += b'.' + to_bytes(obj['aad']) + aad = to_bytes(aad, 'ascii') + + msg = enc.decrypt(ciphertext, aad, iv, tag, cek) + + if zip_alg: + payload = zip_alg.decompress(to_bytes(msg)) + else: + payload = msg + + if decode: + payload = decode(payload) + + for recipient in recipients: + if not recipient['header']: + del recipient['header'] + for member in set(recipient.keys()): + if member != 'header': + del recipient[member] + + header = {} + if protected: + header['protected'] = protected + if unprotected: + header['unprotected'] = unprotected + header['recipients'] = recipients + if jwe_aad is not None: + header['aad'] = jwe_aad + + return { + 'header': header, + 'payload': payload + } + + def deserialize(self, obj, key, decode=None, sender_key=None): + """Extract a JWE Serialization. + + It supports both compact and JSON serialization. + + :param obj: JWE compact serialization as bytes or + JWE JSON serialization as dict or str + :param key: Private key used to decrypt payload + (optionally can be a tuple of kid and essentially key) + :param decode: Function to decode payload data + :param sender_key: Sender's public key in case + JWEAlgorithmWithTagAwareKeyAgreement is used + :return: dict with `header` and `payload` keys + """ + if isinstance(obj, dict): + return self.deserialize_json(obj, key, decode, sender_key) + + obj = to_bytes(obj) + if obj.startswith(b'{') and obj.endswith(b'}'): + return self.deserialize_json(obj, key, decode, sender_key) + + return self.deserialize_compact(obj, key, decode, sender_key) + + @staticmethod + def parse_json(obj): + """Parse JWE JSON Serialization. + + :param obj: JWE JSON Serialization as str or dict + :return: Parsed JWE JSON Serialization as dict if `obj` is an str, + or `obj` as is if `obj` is already a dict + """ + return ensure_dict(obj, 'JWE') + + def get_header_alg(self, header): + if 'alg' not in header: + raise MissingAlgorithmError() + + alg = header['alg'] + if self._algorithms is not None and alg not in self._algorithms: + raise UnsupportedAlgorithmError() + if alg not in self.ALG_REGISTRY: + raise UnsupportedAlgorithmError() + return self.ALG_REGISTRY[alg] + + def get_header_enc(self, header): + if 'enc' not in header: + raise MissingEncryptionAlgorithmError() + enc = header['enc'] + if self._algorithms is not None and enc not in self._algorithms: + raise UnsupportedEncryptionAlgorithmError() + if enc not in self.ENC_REGISTRY: + raise UnsupportedEncryptionAlgorithmError() + return self.ENC_REGISTRY[enc] + + def get_header_zip(self, header): + if 'zip' in header: + z = header['zip'] + if self._algorithms is not None and z not in self._algorithms: + raise UnsupportedCompressionAlgorithmError() + if z not in self.ZIP_REGISTRY: + raise UnsupportedCompressionAlgorithmError() + return self.ZIP_REGISTRY[z] + + def _validate_sender_key(self, sender_key, alg): + if isinstance(alg, JWEAlgorithmWithTagAwareKeyAgreement): + if sender_key is None: + raise ValueError("{} algorithm requires sender_key but passed sender_key value is None" + .format(alg.name)) + else: + if sender_key is not None: + raise ValueError("{} algorithm does not use sender_key but passed sender_key value is not None" + .format(alg.name)) + + def _validate_private_headers(self, header, alg): + # only validate private headers when developers set + # private headers explicitly + if self._private_headers is None: + return + + names = self.REGISTERED_HEADER_PARAMETER_NAMES.copy() + names = names.union(self._private_headers) + + if alg.EXTRA_HEADERS: + names = names.union(alg.EXTRA_HEADERS) + + for k in header: + if k not in names: + raise InvalidHeaderParameterNameError(k) + + +def prepare_key(alg, header, key): + if callable(key): + key = key(header, None) + elif key is None and 'jwk' in header: + key = header['jwk'] + return alg.prepare_key(key) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7516/models.py b/.venv/Lib/site-packages/authlib/jose/rfc7516/models.py new file mode 100644 index 00000000..279563cf --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7516/models.py @@ -0,0 +1,148 @@ +import os +from abc import ABCMeta + + +class JWEAlgorithmBase(metaclass=ABCMeta): + """Base interface for all JWE algorithms. + """ + EXTRA_HEADERS = None + + name = None + description = None + algorithm_type = 'JWE' + algorithm_location = 'alg' + + def prepare_key(self, raw_data): + raise NotImplementedError + + def generate_preset(self, enc_alg, key): + raise NotImplementedError + + +class JWEAlgorithm(JWEAlgorithmBase, metaclass=ABCMeta): + """Interface for JWE algorithm conforming to RFC7518. + JWA specification (RFC7518) SHOULD implement the algorithms for JWE with this base implementation. + """ + def wrap(self, enc_alg, headers, key, preset=None): + raise NotImplementedError + + def unwrap(self, enc_alg, ek, headers, key): + raise NotImplementedError + + +class JWEAlgorithmWithTagAwareKeyAgreement(JWEAlgorithmBase, metaclass=ABCMeta): + """Interface for JWE algorithm with tag-aware key agreement (in key agreement with key wrapping mode). + ECDH-1PU is an example of such an algorithm. + """ + def generate_keys_and_prepare_headers(self, enc_alg, key, sender_key, preset=None): + raise NotImplementedError + + def agree_upon_key_and_wrap_cek(self, enc_alg, headers, key, sender_key, epk, cek, tag): + raise NotImplementedError + + def wrap(self, enc_alg, headers, key, sender_key, preset=None): + raise NotImplementedError + + def unwrap(self, enc_alg, ek, headers, key, sender_key, tag=None): + raise NotImplementedError + + +class JWEEncAlgorithm: + name = None + description = None + algorithm_type = 'JWE' + algorithm_location = 'enc' + + IV_SIZE = None + CEK_SIZE = None + + def generate_cek(self): + return os.urandom(self.CEK_SIZE // 8) + + def generate_iv(self): + return os.urandom(self.IV_SIZE // 8) + + def check_iv(self, iv): + if len(iv) * 8 != self.IV_SIZE: + raise ValueError('Invalid "iv" size') + + def encrypt(self, msg, aad, iv, key): + """Encrypt the given "msg" text. + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, tag) + """ + raise NotImplementedError + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Decrypt the given cipher text. + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + raise NotImplementedError + + +class JWEZipAlgorithm: + name = None + description = None + algorithm_type = 'JWE' + algorithm_location = 'zip' + + def compress(self, s): + raise NotImplementedError + + def decompress(self, s): + raise NotImplementedError + + +class JWESharedHeader(dict): + """Shared header object for JWE. + + Combines protected header and shared unprotected header together. + """ + def __init__(self, protected, unprotected): + obj = {} + if protected: + obj.update(protected) + if unprotected: + obj.update(unprotected) + super().__init__(obj) + self.protected = protected if protected else {} + self.unprotected = unprotected if unprotected else {} + + def update_protected(self, addition): + self.update(addition) + self.protected.update(addition) + + @classmethod + def from_dict(cls, obj): + if isinstance(obj, cls): + return obj + return cls(obj.get('protected'), obj.get('unprotected')) + + +class JWEHeader(dict): + """Header object for JWE. + + Combines protected header, shared unprotected header and specific recipient's unprotected header together. + """ + def __init__(self, protected, unprotected, header): + obj = {} + if protected: + obj.update(protected) + if unprotected: + obj.update(unprotected) + if header: + obj.update(header) + super().__init__(obj) + self.protected = protected if protected else {} + self.unprotected = unprotected if unprotected else {} + self.header = header if header else {} diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/__init__.py new file mode 100644 index 00000000..d3fbbb2d --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/__init__.py @@ -0,0 +1,17 @@ +""" + authlib.jose.rfc7517 + ~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + JSON Web Key (JWK). + + https://tools.ietf.org/html/rfc7517 +""" +from ._cryptography_key import load_pem_key +from .base_key import Key +from .asymmetric_key import AsymmetricKey +from .key_set import KeySet +from .jwk import JsonWebKey + + +__all__ = ['Key', 'AsymmetricKey', 'KeySet', 'JsonWebKey', 'load_pem_key'] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..9221268d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/_cryptography_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/_cryptography_key.cpython-311.pyc new file mode 100644 index 00000000..d9c4f902 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/_cryptography_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/asymmetric_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/asymmetric_key.cpython-311.pyc new file mode 100644 index 00000000..98071278 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/asymmetric_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/base_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/base_key.cpython-311.pyc new file mode 100644 index 00000000..5c7aa0d7 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/base_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/jwk.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/jwk.cpython-311.pyc new file mode 100644 index 00000000..10fa2e33 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/jwk.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/key_set.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/key_set.cpython-311.pyc new file mode 100644 index 00000000..d72a4374 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7517/__pycache__/key_set.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/_cryptography_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/_cryptography_key.py new file mode 100644 index 00000000..f7194a37 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/_cryptography_key.py @@ -0,0 +1,34 @@ +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, load_pem_public_key, load_ssh_public_key, +) +from cryptography.hazmat.backends import default_backend +from authlib.common.encoding import to_bytes + + +def load_pem_key(raw, ssh_type=None, key_type=None, password=None): + raw = to_bytes(raw) + + if ssh_type and raw.startswith(ssh_type): + return load_ssh_public_key(raw, backend=default_backend()) + + if key_type == 'public': + return load_pem_public_key(raw, backend=default_backend()) + + if key_type == 'private' or password is not None: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b'PUBLIC' in raw: + return load_pem_public_key(raw, backend=default_backend()) + + if b'PRIVATE' in raw: + return load_pem_private_key(raw, password=password, backend=default_backend()) + + if b'CERTIFICATE' in raw: + cert = load_pem_x509_certificate(raw, default_backend()) + return cert.public_key() + + try: + return load_pem_private_key(raw, password=password, backend=default_backend()) + except ValueError: + return load_pem_public_key(raw, backend=default_backend()) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/asymmetric_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/asymmetric_key.py new file mode 100644 index 00000000..35b1937c --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/asymmetric_key.py @@ -0,0 +1,191 @@ +from authlib.common.encoding import to_bytes +from cryptography.hazmat.primitives.serialization import ( + Encoding, PrivateFormat, PublicFormat, + BestAvailableEncryption, NoEncryption, +) +from ._cryptography_key import load_pem_key +from .base_key import Key + + +class AsymmetricKey(Key): + """This is the base class for a JSON Web Key.""" + PUBLIC_KEY_FIELDS = [] + PRIVATE_KEY_FIELDS = [] + PRIVATE_KEY_CLS = bytes + PUBLIC_KEY_CLS = bytes + SSH_PUBLIC_PREFIX = b'' + + def __init__(self, private_key=None, public_key=None, options=None): + super().__init__(options) + self.private_key = private_key + self.public_key = public_key + + @property + def public_only(self): + if self.private_key: + return False + if 'd' in self.tokens: + return False + return True + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if operation in self.PUBLIC_KEY_OPS: + return self.get_public_key() + return self.get_private_key() + + def get_public_key(self): + if self.public_key: + return self.public_key + + private_key = self.get_private_key() + if private_key: + return private_key.public_key() + + return self.public_key + + def get_private_key(self): + if self.private_key: + return self.private_key + + if self.tokens: + self.load_raw_key() + return self.private_key + + def load_raw_key(self): + if 'd' in self.tokens: + self.private_key = self.load_private_key() + else: + self.public_key = self.load_public_key() + + def load_dict_key(self): + if self.private_key: + self._dict_data.update(self.dumps_private_key()) + else: + self._dict_data.update(self.dumps_public_key()) + + def dumps_private_key(self): + raise NotImplementedError() + + def dumps_public_key(self): + raise NotImplementedError() + + def load_private_key(self): + raise NotImplementedError() + + def load_public_key(self): + raise NotImplementedError() + + def as_dict(self, is_private=False, **params): + """Represent this key as a dict of the JSON Web Key.""" + tokens = self.tokens + if is_private and 'd' not in tokens: + raise ValueError('This is a public key') + + kid = tokens.get('kid') + if 'd' in tokens and not is_private: + # filter out private fields + tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS} + tokens['kty'] = self.kty + if kid: + tokens['kid'] = kid + + if not kid: + tokens['kid'] = self.thumbprint() + + tokens.update(params) + return tokens + + def as_key(self, is_private=False): + """Represent this key as raw key.""" + if is_private: + return self.get_private_key() + return self.get_public_key() + + def as_bytes(self, encoding=None, is_private=False, password=None): + """Export key into PEM/DER format bytes. + + :param encoding: "PEM" or "DER" + :param is_private: export private key or public key + :param password: encrypt private key with password + :return: bytes + """ + + if encoding is None or encoding == 'PEM': + encoding = Encoding.PEM + elif encoding == 'DER': + encoding = Encoding.DER + else: + raise ValueError(f'Invalid encoding: {encoding!r}') + + raw_key = self.as_key(is_private) + if is_private: + if not raw_key: + raise ValueError('This is a public key') + if password is None: + encryption_algorithm = NoEncryption() + else: + encryption_algorithm = BestAvailableEncryption(to_bytes(password)) + return raw_key.private_bytes( + encoding=encoding, + format=PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) + return raw_key.public_bytes( + encoding=encoding, + format=PublicFormat.SubjectPublicKeyInfo, + ) + + def as_pem(self, is_private=False, password=None): + return self.as_bytes(is_private=is_private, password=password) + + def as_der(self, is_private=False, password=None): + return self.as_bytes(encoding='DER', is_private=is_private, password=password) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + return key + + @classmethod + def import_key(cls, raw, options=None): + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + + if isinstance(raw, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw, options=options) + elif isinstance(raw, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw, options=options) + elif isinstance(raw, dict): + key = cls.import_dict_key(raw, options) + else: + if options is not None: + password = options.pop('password', None) + else: + password = None + raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password) + if isinstance(raw_key, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw_key, options=options) + elif isinstance(raw_key, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw_key, options=options) + else: + raise ValueError('Invalid data for importing key') + return key + + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance(key, cls.PRIVATE_KEY_CLS) + + @classmethod + def generate_key(cls, crv_or_size, options=None, is_private=False): + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/base_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/base_key.py new file mode 100644 index 00000000..1afe8d48 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/base_key.py @@ -0,0 +1,118 @@ +import hashlib +from collections import OrderedDict +from authlib.common.encoding import ( + json_dumps, + to_bytes, + to_unicode, + urlsafe_b64encode, +) +from ..errors import InvalidUseError + + +class Key: + """This is the base class for a JSON Web Key.""" + kty = '_' + + ALLOWED_PARAMS = [ + 'use', 'key_ops', 'alg', 'kid', + 'x5u', 'x5c', 'x5t', 'x5t#S256' + ] + + PRIVATE_KEY_OPS = [ + 'sign', 'decrypt', 'unwrapKey', + ] + PUBLIC_KEY_OPS = [ + 'verify', 'encrypt', 'wrapKey', + ] + + REQUIRED_JSON_FIELDS = [] + + def __init__(self, options=None): + self.options = options or {} + self._dict_data = {} + + @property + def tokens(self): + if not self._dict_data: + self.load_dict_key() + + rv = dict(self._dict_data) + rv['kty'] = self.kty + for k in self.ALLOWED_PARAMS: + if k not in rv and k in self.options: + rv[k] = self.options[k] + return rv + + @property + def kid(self): + return self.tokens.get('kid') + + def keys(self): + return self.tokens.keys() + + def __getitem__(self, item): + return self.tokens[item] + + @property + def public_only(self): + raise NotImplementedError() + + def load_raw_key(self): + raise NotImplementedError() + + def load_dict_key(self): + raise NotImplementedError() + + def check_key_op(self, operation): + """Check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :raise: ValueError + """ + key_ops = self.tokens.get('key_ops') + if key_ops is not None and operation not in key_ops: + raise ValueError(f'Unsupported key_op "{operation}"') + + if operation in self.PRIVATE_KEY_OPS and self.public_only: + raise ValueError(f'Invalid key_op "{operation}" for public key') + + use = self.tokens.get('use') + if use: + if operation in ['sign', 'verify']: + if use != 'sig': + raise InvalidUseError() + elif operation in ['decrypt', 'encrypt', 'wrapKey', 'unwrapKey']: + if use != 'enc': + raise InvalidUseError() + + def as_dict(self, is_private=False, **params): + raise NotImplementedError() + + def as_json(self, is_private=False, **params): + """Represent this key as a JSON string.""" + obj = self.as_dict(is_private, **params) + return json_dumps(obj) + + def thumbprint(self): + """Implementation of RFC7638 JSON Web Key (JWK) Thumbprint.""" + fields = list(self.REQUIRED_JSON_FIELDS) + fields.append('kty') + fields.sort() + data = OrderedDict() + + for k in fields: + data[k] = self.tokens[k] + + json_data = json_dumps(data) + digest_data = hashlib.sha256(to_bytes(json_data)).digest() + return to_unicode(urlsafe_b64encode(digest_data)) + + @classmethod + def check_required_fields(cls, data): + for k in cls.REQUIRED_JSON_FIELDS: + if k not in data: + raise ValueError(f'Missing required field: "{k}"') + + @classmethod + def validate_raw_key(cls, key): + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/jwk.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/jwk.py new file mode 100644 index 00000000..b1578c49 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/jwk.py @@ -0,0 +1,64 @@ +from authlib.common.encoding import json_loads +from .key_set import KeySet +from ._cryptography_key import load_pem_key + + +class JsonWebKey: + JWK_KEY_CLS = {} + + @classmethod + def generate_key(cls, kty, crv_or_size, options=None, is_private=False): + """Generate a Key with the given key type, curve name or bit size. + + :param kty: string of ``oct``, ``RSA``, ``EC``, ``OKP`` + :param crv_or_size: curve name or bit size + :param options: a dict of other options for Key + :param is_private: create a private key or public key + :return: Key instance + """ + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.generate_key(crv_or_size, options, is_private) + + @classmethod + def import_key(cls, raw, options=None): + """Import a Key from bytes, string, PEM or dict. + + :return: Key instance + """ + kty = None + if options is not None: + kty = options.get('kty') + + if kty is None and isinstance(raw, dict): + kty = raw.get('kty') + + if kty is None: + raw_key = load_pem_key(raw) + for _kty in cls.JWK_KEY_CLS: + key_cls = cls.JWK_KEY_CLS[_kty] + if key_cls.validate_raw_key(raw_key): + return key_cls.import_key(raw_key, options) + + key_cls = cls.JWK_KEY_CLS[kty] + return key_cls.import_key(raw, options) + + @classmethod + def import_key_set(cls, raw): + """Import KeySet from string, dict or a list of keys. + + :return: KeySet instance + """ + raw = _transform_raw_key(raw) + if isinstance(raw, dict) and 'keys' in raw: + keys = raw.get('keys') + return KeySet([cls.import_key(k) for k in keys]) + raise ValueError('Invalid key set format') + + +def _transform_raw_key(raw): + if isinstance(raw, str) and \ + raw.startswith('{') and raw.endswith('}'): + return json_loads(raw) + elif isinstance(raw, (tuple, list)): + return {'keys': raw} + return raw diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7517/key_set.py b/.venv/Lib/site-packages/authlib/jose/rfc7517/key_set.py new file mode 100644 index 00000000..3416ce9b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7517/key_set.py @@ -0,0 +1,29 @@ +from authlib.common.encoding import json_dumps + + +class KeySet: + """This class represents a JSON Web Key Set.""" + + def __init__(self, keys): + self.keys = keys + + def as_dict(self, is_private=False, **params): + """Represent this key as a dict of the JSON Web Key Set.""" + return {'keys': [k.as_dict(is_private, **params) for k in self.keys]} + + def as_json(self, is_private=False, **params): + """Represent this key set as a JSON string.""" + obj = self.as_dict(is_private, **params) + return json_dumps(obj) + + def find_by_kid(self, kid): + """Find the key matches the given kid value. + + :param kid: A string of kid + :return: Key instance + :raise: ValueError + """ + for k in self.keys: + if k.kid == kid: + return k + raise ValueError('Invalid JSON Web Key Set') diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/__init__.py new file mode 100644 index 00000000..360f6c68 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/__init__.py @@ -0,0 +1,35 @@ +from .oct_key import OctKey +from .rsa_key import RSAKey +from .ec_key import ECKey +from .jws_algs import JWS_ALGORITHMS +from .jwe_algs import JWE_ALG_ALGORITHMS, AESAlgorithm, ECDHESAlgorithm, u32be_len_input +from .jwe_encs import JWE_ENC_ALGORITHMS, CBCHS2EncAlgorithm +from .jwe_zips import DeflateZipAlgorithm + + +def register_jws_rfc7518(cls): + for algorithm in JWS_ALGORITHMS: + cls.register_algorithm(algorithm) + + +def register_jwe_rfc7518(cls): + for algorithm in JWE_ALG_ALGORITHMS: + cls.register_algorithm(algorithm) + + for algorithm in JWE_ENC_ALGORITHMS: + cls.register_algorithm(algorithm) + + cls.register_algorithm(DeflateZipAlgorithm()) + + +__all__ = [ + 'register_jws_rfc7518', + 'register_jwe_rfc7518', + 'OctKey', + 'RSAKey', + 'ECKey', + 'u32be_len_input', + 'AESAlgorithm', + 'ECDHESAlgorithm', + 'CBCHS2EncAlgorithm', +] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6eef9470 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/ec_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/ec_key.cpython-311.pyc new file mode 100644 index 00000000..88feed75 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/ec_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_algs.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_algs.cpython-311.pyc new file mode 100644 index 00000000..331f7e30 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_algs.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_encs.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_encs.cpython-311.pyc new file mode 100644 index 00000000..3b4422a6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_encs.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_zips.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_zips.cpython-311.pyc new file mode 100644 index 00000000..a9cfa0ff Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jwe_zips.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jws_algs.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jws_algs.cpython-311.pyc new file mode 100644 index 00000000..345f1649 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/jws_algs.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/oct_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/oct_key.cpython-311.pyc new file mode 100644 index 00000000..344d0113 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/oct_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/rsa_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/rsa_key.cpython-311.pyc new file mode 100644 index 00000000..9cdbf6ca Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/rsa_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..5376d4e3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7518/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/ec_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/ec_key.py new file mode 100644 index 00000000..05f0c044 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/ec_key.py @@ -0,0 +1,101 @@ +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, + EllipticCurvePrivateNumbers, EllipticCurvePublicNumbers, + SECP256R1, SECP384R1, SECP521R1, SECP256K1, +) +from cryptography.hazmat.backends import default_backend +from authlib.common.encoding import base64_to_int, int_to_base64 +from ..rfc7517 import AsymmetricKey + + +class ECKey(AsymmetricKey): + """Key class of the ``EC`` key type.""" + + kty = 'EC' + DSS_CURVES = { + 'P-256': SECP256R1, + 'P-384': SECP384R1, + 'P-521': SECP521R1, + # https://tools.ietf.org/html/rfc8812#section-3.1 + 'secp256k1': SECP256K1, + } + CURVES_DSS = { + SECP256R1.name: 'P-256', + SECP384R1.name: 'P-384', + SECP521R1.name: 'P-521', + SECP256K1.name: 'secp256k1', + } + REQUIRED_JSON_FIELDS = ['crv', 'x', 'y'] + + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ['crv', 'd', 'x', 'y'] + + PUBLIC_KEY_CLS = EllipticCurvePublicKey + PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization + SSH_PUBLIC_PREFIX = b'ecdsa-sha2-' + + def exchange_shared_key(self, pubkey): + # # used in ECDHESAlgorithm + private_key = self.get_private_key() + if private_key: + return private_key.exchange(ec.ECDH(), pubkey) + raise ValueError('Invalid key for exchanging shared key') + + @property + def curve_key_size(self): + raw_key = self.get_private_key() + if not raw_key: + raw_key = self.public_key + return raw_key.curve.key_size + + def load_private_key(self): + curve = self.DSS_CURVES[self._dict_data['crv']]() + public_numbers = EllipticCurvePublicNumbers( + base64_to_int(self._dict_data['x']), + base64_to_int(self._dict_data['y']), + curve, + ) + private_numbers = EllipticCurvePrivateNumbers( + base64_to_int(self.tokens['d']), + public_numbers + ) + return private_numbers.private_key(default_backend()) + + def load_public_key(self): + curve = self.DSS_CURVES[self._dict_data['crv']]() + public_numbers = EllipticCurvePublicNumbers( + base64_to_int(self._dict_data['x']), + base64_to_int(self._dict_data['y']), + curve, + ) + return public_numbers.public_key(default_backend()) + + def dumps_private_key(self): + numbers = self.private_key.private_numbers() + return { + 'crv': self.CURVES_DSS[self.private_key.curve.name], + 'x': int_to_base64(numbers.public_numbers.x), + 'y': int_to_base64(numbers.public_numbers.y), + 'd': int_to_base64(numbers.private_value), + } + + def dumps_public_key(self): + numbers = self.public_key.public_numbers() + return { + 'crv': self.CURVES_DSS[numbers.curve.name], + 'x': int_to_base64(numbers.x), + 'y': int_to_base64(numbers.y) + } + + @classmethod + def generate_key(cls, crv='P-256', options=None, is_private=False) -> 'ECKey': + if crv not in cls.DSS_CURVES: + raise ValueError(f'Invalid crv value: "{crv}"') + raw_key = ec.generate_private_key( + curve=cls.DSS_CURVES[crv](), + backend=default_backend(), + ) + if not is_private: + raw_key = raw_key.public_key() + return cls.import_key(raw_key, options=options) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_algs.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_algs.py new file mode 100644 index 00000000..b57654a9 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_algs.py @@ -0,0 +1,349 @@ +import os +import struct +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.keywrap import ( + aes_key_wrap, + aes_key_unwrap +) +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import GCM +from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash +from authlib.common.encoding import ( + to_bytes, to_native, + urlsafe_b64decode, + urlsafe_b64encode +) +from authlib.jose.rfc7516 import JWEAlgorithm +from .rsa_key import RSAKey +from .ec_key import ECKey +from .oct_key import OctKey + + +class DirectAlgorithm(JWEAlgorithm): + name = 'dir' + description = 'Direct use of a shared symmetric key' + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + return {} + + def wrap(self, enc_alg, headers, key, preset=None): + cek = key.get_op_key('encrypt') + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return {'ek': b'', 'cek': cek} + + def unwrap(self, enc_alg, ek, headers, key): + cek = key.get_op_key('decrypt') + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return cek + + +class RSAAlgorithm(JWEAlgorithm): + #: A key of size 2048 bits or larger MUST be used with these algorithms + #: RSA1_5, RSA-OAEP, RSA-OAEP-256 + key_size = 2048 + + def __init__(self, name, description, pad_fn): + self.name = name + self.description = description + self.padding = pad_fn + + def prepare_key(self, raw_data): + return RSAKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {'cek': cek} + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and 'cek' in preset: + cek = preset['cek'] + else: + cek = enc_alg.generate_cek() + + op_key = key.get_op_key('wrapKey') + if op_key.key_size < self.key_size: + raise ValueError('A key of size 2048 bits or larger MUST be used') + ek = op_key.encrypt(cek, self.padding) + return {'ek': ek, 'cek': cek} + + def unwrap(self, enc_alg, ek, headers, key): + # it will raise ValueError if failed + op_key = key.get_op_key('unwrapKey') + cek = op_key.decrypt(ek, self.padding) + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return cek + + +class AESAlgorithm(JWEAlgorithm): + def __init__(self, key_size): + self.name = f'A{key_size}KW' + self.description = f'AES Key Wrap using {key_size}-bit key' + self.key_size = key_size + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {'cek': cek} + + def _check_key(self, key): + if len(key) * 8 != self.key_size: + raise ValueError( + f'A key of size {self.key_size} bits is required.') + + def wrap_cek(self, cek, key): + op_key = key.get_op_key('wrapKey') + self._check_key(op_key) + ek = aes_key_wrap(op_key, cek, default_backend()) + return {'ek': ek, 'cek': cek} + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and 'cek' in preset: + cek = preset['cek'] + else: + cek = enc_alg.generate_cek() + return self.wrap_cek(cek, key) + + def unwrap(self, enc_alg, ek, headers, key): + op_key = key.get_op_key('unwrapKey') + self._check_key(op_key) + cek = aes_key_unwrap(op_key, ek, default_backend()) + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return cek + + +class AESGCMAlgorithm(JWEAlgorithm): + EXTRA_HEADERS = frozenset(['iv', 'tag']) + + def __init__(self, key_size): + self.name = f'A{key_size}GCMKW' + self.description = f'Key wrapping with AES GCM using {key_size}-bit key' + self.key_size = key_size + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + cek = enc_alg.generate_cek() + return {'cek': cek} + + def _check_key(self, key): + if len(key) * 8 != self.key_size: + raise ValueError( + f'A key of size {self.key_size} bits is required.') + + def wrap(self, enc_alg, headers, key, preset=None): + if preset and 'cek' in preset: + cek = preset['cek'] + else: + cek = enc_alg.generate_cek() + + op_key = key.get_op_key('wrapKey') + self._check_key(op_key) + + #: https://tools.ietf.org/html/rfc7518#section-4.7.1.1 + #: The "iv" (initialization vector) Header Parameter value is the + #: base64url-encoded representation of the 96-bit IV value + iv_size = 96 + iv = os.urandom(iv_size // 8) + + cipher = Cipher(AES(op_key), GCM(iv), backend=default_backend()) + enc = cipher.encryptor() + ek = enc.update(cek) + enc.finalize() + + h = { + 'iv': to_native(urlsafe_b64encode(iv)), + 'tag': to_native(urlsafe_b64encode(enc.tag)) + } + return {'ek': ek, 'cek': cek, 'header': h} + + def unwrap(self, enc_alg, ek, headers, key): + op_key = key.get_op_key('unwrapKey') + self._check_key(op_key) + + iv = headers.get('iv') + if not iv: + raise ValueError('Missing "iv" in headers') + + tag = headers.get('tag') + if not tag: + raise ValueError('Missing "tag" in headers') + + iv = urlsafe_b64decode(to_bytes(iv)) + tag = urlsafe_b64decode(to_bytes(tag)) + + cipher = Cipher(AES(op_key), GCM(iv, tag), backend=default_backend()) + d = cipher.decryptor() + cek = d.update(ek) + d.finalize() + if len(cek) * 8 != enc_alg.CEK_SIZE: + raise ValueError('Invalid "cek" length') + return cek + + +class ECDHESAlgorithm(JWEAlgorithm): + EXTRA_HEADERS = ['epk', 'apu', 'apv'] + ALLOWED_KEY_CLS = ECKey + + # https://tools.ietf.org/html/rfc7518#section-4.6 + def __init__(self, key_size=None): + if key_size is None: + self.name = 'ECDH-ES' + self.description = 'ECDH-ES in the Direct Key Agreement mode' + else: + self.name = f'ECDH-ES+A{key_size}KW' + self.description = ( + 'ECDH-ES using Concat KDF and CEK wrapped ' + 'with A{}KW').format(key_size) + self.key_size = key_size + self.aeskw = AESAlgorithm(key_size) + + def prepare_key(self, raw_data): + if isinstance(raw_data, self.ALLOWED_KEY_CLS): + return raw_data + return ECKey.import_key(raw_data) + + def generate_preset(self, enc_alg, key): + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + preset = {'epk': epk, 'header': h} + if self.key_size is not None: + cek = enc_alg.generate_cek() + preset['cek'] = cek + return preset + + def compute_fixed_info(self, headers, bit_size): + # AlgorithmID + if self.key_size is None: + alg_id = u32be_len_input(headers['enc']) + else: + alg_id = u32be_len_input(headers['alg']) + + # PartyUInfo + apu_info = u32be_len_input(headers.get('apu'), True) + + # PartyVInfo + apv_info = u32be_len_input(headers.get('apv'), True) + + # SuppPubInfo + pub_info = struct.pack('>I', bit_size) + + return alg_id + apu_info + apv_info + pub_info + + def compute_derived_key(self, shared_key, fixed_info, bit_size): + ckdf = ConcatKDFHash( + algorithm=hashes.SHA256(), + length=bit_size // 8, + otherinfo=fixed_info, + backend=default_backend() + ) + return ckdf.derive(shared_key) + + def deliver(self, key, pubkey, headers, bit_size): + shared_key = key.exchange_shared_key(pubkey) + fixed_info = self.compute_fixed_info(headers, bit_size) + return self.compute_derived_key(shared_key, fixed_info, bit_size) + + def _generate_ephemeral_key(self, key): + return key.generate_key(key['crv'], is_private=True) + + def _prepare_headers(self, epk): + # REQUIRED_JSON_FIELDS contains only public fields + pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS} + pub_epk['kty'] = epk.kty + return {'epk': pub_epk} + + def wrap(self, enc_alg, headers, key, preset=None): + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + if preset and 'epk' in preset: + epk = preset['epk'] + h = {} + else: + epk = self._generate_ephemeral_key(key) + h = self._prepare_headers(epk) + + public_key = key.get_op_key('wrapKey') + dk = self.deliver(epk, public_key, headers, bit_size) + + if self.key_size is None: + return {'ek': b'', 'cek': dk, 'header': h} + + if preset and 'cek' in preset: + preset_for_kw = {'cek': preset['cek']} + else: + preset_for_kw = None + + kek = self.aeskw.prepare_key(dk) + rv = self.aeskw.wrap(enc_alg, headers, kek, preset_for_kw) + rv['header'] = h + return rv + + def unwrap(self, enc_alg, ek, headers, key): + if 'epk' not in headers: + raise ValueError('Missing "epk" in headers') + + if self.key_size is None: + bit_size = enc_alg.CEK_SIZE + else: + bit_size = self.key_size + + epk = key.import_key(headers['epk']) + public_key = epk.get_op_key('wrapKey') + dk = self.deliver(key, public_key, headers, bit_size) + + if self.key_size is None: + return dk + + kek = self.aeskw.prepare_key(dk) + return self.aeskw.unwrap(enc_alg, ek, headers, kek) + + +def u32be_len_input(s, base64=False): + if not s: + return b'\x00\x00\x00\x00' + if base64: + s = urlsafe_b64decode(to_bytes(s)) + else: + s = to_bytes(s) + return struct.pack('>I', len(s)) + s + + +JWE_ALG_ALGORITHMS = [ + DirectAlgorithm(), # dir + RSAAlgorithm('RSA1_5', 'RSAES-PKCS1-v1_5', padding.PKCS1v15()), + RSAAlgorithm( + 'RSA-OAEP', 'RSAES OAEP using default parameters', + padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)), + RSAAlgorithm( + 'RSA-OAEP-256', 'RSAES OAEP using SHA-256 and MGF1 with SHA-256', + padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)), + + AESAlgorithm(128), # A128KW + AESAlgorithm(192), # A192KW + AESAlgorithm(256), # A256KW + AESGCMAlgorithm(128), # A128GCMKW + AESGCMAlgorithm(192), # A192GCMKW + AESGCMAlgorithm(256), # A256GCMKW + ECDHESAlgorithm(None), # ECDH-ES + ECDHESAlgorithm(128), # ECDH-ES+A128KW + ECDHESAlgorithm(192), # ECDH-ES+A192KW + ECDHESAlgorithm(256), # ECDH-ES+A256KW +] + +# 'PBES2-HS256+A128KW': '', +# 'PBES2-HS384+A192KW': '', +# 'PBES2-HS512+A256KW': '', diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_encs.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_encs.py new file mode 100644 index 00000000..f951d101 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_encs.py @@ -0,0 +1,144 @@ +""" + authlib.jose.rfc7518 + ~~~~~~~~~~~~~~~~~~~~ + + Cryptographic Algorithms for Cryptographic Algorithms for Content + Encryption per `Section 5`_. + + .. _`Section 5`: https://tools.ietf.org/html/rfc7518#section-5 +""" +import hmac +import hashlib +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import GCM, CBC +from cryptography.hazmat.primitives.padding import PKCS7 +from cryptography.exceptions import InvalidTag +from ..rfc7516 import JWEEncAlgorithm +from .util import encode_int + + +class CBCHS2EncAlgorithm(JWEEncAlgorithm): + # The IV used is a 128-bit value generated randomly or + # pseudo-randomly for use in the cipher. + IV_SIZE = 128 + + def __init__(self, key_size, hash_type): + self.name = f'A{key_size}CBC-HS{hash_type}' + tpl = 'AES_{}_CBC_HMAC_SHA_{} authenticated encryption algorithm' + self.description = tpl.format(key_size, hash_type) + + # bit length + self.key_size = key_size + # byte length + self.key_len = key_size // 8 + + self.CEK_SIZE = key_size * 2 + self.hash_alg = getattr(hashlib, f'sha{hash_type}') + + def _hmac(self, ciphertext, aad, iv, key): + al = encode_int(len(aad) * 8, 64) + msg = aad + iv + ciphertext + al + d = hmac.new(key, msg, self.hash_alg).digest() + return d[:self.key_len] + + def encrypt(self, msg, aad, iv, key): + """Key Encryption with AES_CBC_HMAC_SHA2. + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, iv, tag) + """ + self.check_iv(iv) + hkey = key[:self.key_len] + ekey = key[self.key_len:] + + pad = PKCS7(AES.block_size).padder() + padded_data = pad.update(msg) + pad.finalize() + + cipher = Cipher(AES(ekey), CBC(iv), backend=default_backend()) + enc = cipher.encryptor() + ciphertext = enc.update(padded_data) + enc.finalize() + tag = self._hmac(ciphertext, aad, iv, hkey) + return ciphertext, tag + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Key Decryption with AES AES_CBC_HMAC_SHA2. + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + self.check_iv(iv) + hkey = key[:self.key_len] + dkey = key[self.key_len:] + + _tag = self._hmac(ciphertext, aad, iv, hkey) + if not hmac.compare_digest(_tag, tag): + raise InvalidTag() + + cipher = Cipher(AES(dkey), CBC(iv), backend=default_backend()) + d = cipher.decryptor() + data = d.update(ciphertext) + d.finalize() + unpad = PKCS7(AES.block_size).unpadder() + return unpad.update(data) + unpad.finalize() + + +class GCMEncAlgorithm(JWEEncAlgorithm): + # Use of an IV of size 96 bits is REQUIRED with this algorithm. + # https://tools.ietf.org/html/rfc7518#section-5.3 + IV_SIZE = 96 + + def __init__(self, key_size): + self.name = f'A{key_size}GCM' + self.description = f'AES GCM using {key_size}-bit key' + self.key_size = key_size + self.CEK_SIZE = key_size + + def encrypt(self, msg, aad, iv, key): + """Key Encryption with AES GCM + + :param msg: text to be encrypt in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param key: encrypted key in bytes + :return: (ciphertext, iv, tag) + """ + self.check_iv(iv) + cipher = Cipher(AES(key), GCM(iv), backend=default_backend()) + enc = cipher.encryptor() + enc.authenticate_additional_data(aad) + ciphertext = enc.update(msg) + enc.finalize() + return ciphertext, enc.tag + + def decrypt(self, ciphertext, aad, iv, tag, key): + """Key Decryption with AES GCM + + :param ciphertext: ciphertext in bytes + :param aad: additional authenticated data in bytes + :param iv: initialization vector in bytes + :param tag: authentication tag in bytes + :param key: encrypted key in bytes + :return: message + """ + self.check_iv(iv) + cipher = Cipher(AES(key), GCM(iv, tag), backend=default_backend()) + d = cipher.decryptor() + d.authenticate_additional_data(aad) + return d.update(ciphertext) + d.finalize() + + +JWE_ENC_ALGORITHMS = [ + CBCHS2EncAlgorithm(128, 256), # A128CBC-HS256 + CBCHS2EncAlgorithm(192, 384), # A192CBC-HS384 + CBCHS2EncAlgorithm(256, 512), # A256CBC-HS512 + GCMEncAlgorithm(128), # A128GCM + GCMEncAlgorithm(192), # A192GCM + GCMEncAlgorithm(256), # A256GCM +] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_zips.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_zips.py new file mode 100644 index 00000000..23968610 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/jwe_zips.py @@ -0,0 +1,21 @@ +import zlib +from ..rfc7516 import JWEZipAlgorithm, JsonWebEncryption + + +class DeflateZipAlgorithm(JWEZipAlgorithm): + name = 'DEF' + description = 'DEFLATE' + + def compress(self, s): + """Compress bytes data with DEFLATE algorithm.""" + data = zlib.compress(s) + # drop gzip headers and tail + return data[2:-4] + + def decompress(self, s): + """Decompress DEFLATE bytes data.""" + return zlib.decompress(s, -zlib.MAX_WBITS) + + +def register_jwe_rfc7518(): + JsonWebEncryption.register_algorithm(DeflateZipAlgorithm()) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/jws_algs.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/jws_algs.py new file mode 100644 index 00000000..2c028403 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/jws_algs.py @@ -0,0 +1,215 @@ +""" + authlib.jose.rfc7518 + ~~~~~~~~~~~~~~~~~~~~ + + "alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_. + + .. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3 +""" + +import hmac +import hashlib +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric.utils import ( + decode_dss_signature, encode_dss_signature +) +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.exceptions import InvalidSignature +from ..rfc7515 import JWSAlgorithm +from .oct_key import OctKey +from .rsa_key import RSAKey +from .ec_key import ECKey +from .util import encode_int, decode_int + + +class NoneAlgorithm(JWSAlgorithm): + name = 'none' + description = 'No digital signature or MAC performed' + + def prepare_key(self, raw_data): + return None + + def sign(self, msg, key): + return b'' + + def verify(self, msg, sig, key): + return False + + +class HMACAlgorithm(JWSAlgorithm): + """HMAC using SHA algorithms for JWS. Available algorithms: + + - HS256: HMAC using SHA-256 + - HS384: HMAC using SHA-384 + - HS512: HMAC using SHA-512 + """ + SHA256 = hashlib.sha256 + SHA384 = hashlib.sha384 + SHA512 = hashlib.sha512 + + def __init__(self, sha_type): + self.name = f'HS{sha_type}' + self.description = f'HMAC using SHA-{sha_type}' + self.hash_alg = getattr(self, f'SHA{sha_type}') + + def prepare_key(self, raw_data): + return OctKey.import_key(raw_data) + + def sign(self, msg, key): + # it is faster than the one in cryptography + op_key = key.get_op_key('sign') + return hmac.new(op_key, msg, self.hash_alg).digest() + + def verify(self, msg, sig, key): + op_key = key.get_op_key('verify') + v_sig = hmac.new(op_key, msg, self.hash_alg).digest() + return hmac.compare_digest(sig, v_sig) + + +class RSAAlgorithm(JWSAlgorithm): + """RSA using SHA algorithms for JWS. Available algorithms: + + - RS256: RSASSA-PKCS1-v1_5 using SHA-256 + - RS384: RSASSA-PKCS1-v1_5 using SHA-384 + - RS512: RSASSA-PKCS1-v1_5 using SHA-512 + """ + SHA256 = hashes.SHA256 + SHA384 = hashes.SHA384 + SHA512 = hashes.SHA512 + + def __init__(self, sha_type): + self.name = f'RS{sha_type}' + self.description = f'RSASSA-PKCS1-v1_5 using SHA-{sha_type}' + self.hash_alg = getattr(self, f'SHA{sha_type}') + self.padding = padding.PKCS1v15() + + def prepare_key(self, raw_data): + return RSAKey.import_key(raw_data) + + def sign(self, msg, key): + op_key = key.get_op_key('sign') + return op_key.sign(msg, self.padding, self.hash_alg()) + + def verify(self, msg, sig, key): + op_key = key.get_op_key('verify') + try: + op_key.verify(sig, msg, self.padding, self.hash_alg()) + return True + except InvalidSignature: + return False + + +class ECAlgorithm(JWSAlgorithm): + """ECDSA using SHA algorithms for JWS. Available algorithms: + + - ES256: ECDSA using P-256 and SHA-256 + - ES384: ECDSA using P-384 and SHA-384 + - ES512: ECDSA using P-521 and SHA-512 + """ + SHA256 = hashes.SHA256 + SHA384 = hashes.SHA384 + SHA512 = hashes.SHA512 + + def __init__(self, name, curve, sha_type): + self.name = name + self.curve = curve + self.description = f'ECDSA using {self.curve} and SHA-{sha_type}' + self.hash_alg = getattr(self, f'SHA{sha_type}') + + def prepare_key(self, raw_data): + key = ECKey.import_key(raw_data) + if key['crv'] != self.curve: + raise ValueError(f'Key for "{self.name}" not supported, only "{self.curve}" allowed') + return key + + def sign(self, msg, key): + op_key = key.get_op_key('sign') + der_sig = op_key.sign(msg, ECDSA(self.hash_alg())) + r, s = decode_dss_signature(der_sig) + size = key.curve_key_size + return encode_int(r, size) + encode_int(s, size) + + def verify(self, msg, sig, key): + key_size = key.curve_key_size + length = (key_size + 7) // 8 + + if len(sig) != 2 * length: + return False + + r = decode_int(sig[:length]) + s = decode_int(sig[length:]) + der_sig = encode_dss_signature(r, s) + + try: + op_key = key.get_op_key('verify') + op_key.verify(der_sig, msg, ECDSA(self.hash_alg())) + return True + except InvalidSignature: + return False + + +class RSAPSSAlgorithm(JWSAlgorithm): + """RSASSA-PSS using SHA algorithms for JWS. Available algorithms: + + - PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256 + - PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384 + - PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512 + """ + SHA256 = hashes.SHA256 + SHA384 = hashes.SHA384 + SHA512 = hashes.SHA512 + + def __init__(self, sha_type): + self.name = f'PS{sha_type}' + tpl = 'RSASSA-PSS using SHA-{} and MGF1 with SHA-{}' + self.description = tpl.format(sha_type, sha_type) + self.hash_alg = getattr(self, f'SHA{sha_type}') + + def prepare_key(self, raw_data): + return RSAKey.import_key(raw_data) + + def sign(self, msg, key): + op_key = key.get_op_key('sign') + return op_key.sign( + msg, + padding.PSS( + mgf=padding.MGF1(self.hash_alg()), + salt_length=self.hash_alg.digest_size + ), + self.hash_alg() + ) + + def verify(self, msg, sig, key): + op_key = key.get_op_key('verify') + try: + op_key.verify( + sig, + msg, + padding.PSS( + mgf=padding.MGF1(self.hash_alg()), + salt_length=self.hash_alg.digest_size + ), + self.hash_alg() + ) + return True + except InvalidSignature: + return False + + +JWS_ALGORITHMS = [ + NoneAlgorithm(), # none + HMACAlgorithm(256), # HS256 + HMACAlgorithm(384), # HS384 + HMACAlgorithm(512), # HS512 + RSAAlgorithm(256), # RS256 + RSAAlgorithm(384), # RS384 + RSAAlgorithm(512), # RS512 + ECAlgorithm('ES256', 'P-256', 256), + ECAlgorithm('ES384', 'P-384', 384), + ECAlgorithm('ES512', 'P-521', 512), + ECAlgorithm('ES256K', 'secp256k1', 256), # defined in RFC8812 + RSAPSSAlgorithm(256), # PS256 + RSAPSSAlgorithm(384), # PS384 + RSAPSSAlgorithm(512), # PS512 +] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/oct_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/oct_key.py new file mode 100644 index 00000000..1db321a7 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/oct_key.py @@ -0,0 +1,80 @@ +from authlib.common.encoding import ( + to_bytes, to_unicode, + urlsafe_b64encode, urlsafe_b64decode, +) +from authlib.common.security import generate_token +from ..rfc7517 import Key + + +class OctKey(Key): + """Key class of the ``oct`` key type.""" + + kty = 'oct' + REQUIRED_JSON_FIELDS = ['k'] + + def __init__(self, raw_key=None, options=None): + super().__init__(options) + self.raw_key = raw_key + + @property + def public_only(self): + return False + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if not self.raw_key: + self.load_raw_key() + return self.raw_key + + def load_raw_key(self): + self.raw_key = urlsafe_b64decode(to_bytes(self.tokens['k'])) + + def load_dict_key(self): + k = to_unicode(urlsafe_b64encode(self.raw_key)) + self._dict_data = {'kty': self.kty, 'k': k} + + def as_dict(self, is_private=False, **params): + tokens = self.tokens + if 'kid' not in tokens: + tokens['kid'] = self.thumbprint() + + tokens.update(params) + return tokens + + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, bytes) + + @classmethod + def import_key(cls, raw, options=None): + """Import a key from bytes, string, or dict data.""" + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + + if isinstance(raw, dict): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + else: + raw_key = to_bytes(raw) + key = cls(raw_key=raw_key, options=options) + return key + + @classmethod + def generate_key(cls, key_size=256, options=None, is_private=True): + """Generate a ``OctKey`` with the given bit size.""" + if not is_private: + raise ValueError('oct key can not be generated as public') + + if key_size % 8 != 0: + raise ValueError('Invalid bit size for oct key') + + return cls.import_key(generate_token(key_size // 8), options) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/rsa_key.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/rsa_key.py new file mode 100644 index 00000000..53bd9958 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/rsa_key.py @@ -0,0 +1,123 @@ +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPublicKey, RSAPrivateKeyWithSerialization, + RSAPrivateNumbers, RSAPublicNumbers, + rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp +) +from cryptography.hazmat.backends import default_backend +from authlib.common.encoding import base64_to_int, int_to_base64 +from ..rfc7517 import AsymmetricKey + + +class RSAKey(AsymmetricKey): + """Key class of the ``RSA`` key type.""" + + kty = 'RSA' + PUBLIC_KEY_CLS = RSAPublicKey + PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization + + PUBLIC_KEY_FIELDS = ['e', 'n'] + PRIVATE_KEY_FIELDS = ['d', 'dp', 'dq', 'e', 'n', 'p', 'q', 'qi'] + REQUIRED_JSON_FIELDS = ['e', 'n'] + SSH_PUBLIC_PREFIX = b'ssh-rsa' + + def dumps_private_key(self): + numbers = self.private_key.private_numbers() + return { + 'n': int_to_base64(numbers.public_numbers.n), + 'e': int_to_base64(numbers.public_numbers.e), + 'd': int_to_base64(numbers.d), + 'p': int_to_base64(numbers.p), + 'q': int_to_base64(numbers.q), + 'dp': int_to_base64(numbers.dmp1), + 'dq': int_to_base64(numbers.dmq1), + 'qi': int_to_base64(numbers.iqmp) + } + + def dumps_public_key(self): + numbers = self.public_key.public_numbers() + return { + 'n': int_to_base64(numbers.n), + 'e': int_to_base64(numbers.e) + } + + def load_private_key(self): + obj = self._dict_data + + if 'oth' in obj: # pragma: no cover + # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 + raise ValueError('"oth" is not supported yet') + + public_numbers = RSAPublicNumbers( + base64_to_int(obj['e']), base64_to_int(obj['n'])) + + if has_all_prime_factors(obj): + numbers = RSAPrivateNumbers( + d=base64_to_int(obj['d']), + p=base64_to_int(obj['p']), + q=base64_to_int(obj['q']), + dmp1=base64_to_int(obj['dp']), + dmq1=base64_to_int(obj['dq']), + iqmp=base64_to_int(obj['qi']), + public_numbers=public_numbers) + else: + d = base64_to_int(obj['d']) + p, q = rsa_recover_prime_factors( + public_numbers.n, d, public_numbers.e) + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers) + + return numbers.private_key(default_backend()) + + def load_public_key(self): + numbers = RSAPublicNumbers( + base64_to_int(self._dict_data['e']), + base64_to_int(self._dict_data['n']) + ) + return numbers.public_key(default_backend()) + + @classmethod + def generate_key(cls, key_size=2048, options=None, is_private=False) -> 'RSAKey': + if key_size < 512: + raise ValueError('key_size must not be less than 512') + if key_size % 8 != 0: + raise ValueError('Invalid key_size for RSAKey') + raw_key = rsa.generate_private_key( + public_exponent=65537, + key_size=key_size, + backend=default_backend(), + ) + if not is_private: + raw_key = raw_key.public_key() + return cls.import_key(raw_key, options=options) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + if 'd' in raw and not has_all_prime_factors(raw): + # reload dict key + key.load_raw_key() + key.load_dict_key() + return key + + +def has_all_prime_factors(obj): + props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [prop in obj for prop in props] + if all(props_found): + return True + + if any(props_found): + raise ValueError( + 'RSA key must include all parameters ' + 'if any are present besides d') + + return False diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7518/util.py b/.venv/Lib/site-packages/authlib/jose/rfc7518/util.py new file mode 100644 index 00000000..d2d13ec1 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7518/util.py @@ -0,0 +1,12 @@ +import binascii + + +def encode_int(num, bits): + length = ((bits + 7) // 8) * 2 + padded_hex = '%0*x' % (length, num) + big_endian = binascii.a2b_hex(padded_hex.encode('ascii')) + return big_endian + + +def decode_int(b): + return int(binascii.b2a_hex(b), 16) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc7519/__init__.py new file mode 100644 index 00000000..5eea5b7f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7519/__init__.py @@ -0,0 +1,15 @@ +""" + authlib.jose.rfc7519 + ~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + JSON Web Token (JWT). + + https://tools.ietf.org/html/rfc7519 +""" + +from .jwt import JsonWebToken +from .claims import BaseClaims, JWTClaims + + +__all__ = ['JsonWebToken', 'BaseClaims', 'JWTClaims'] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..1aff8b86 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/claims.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/claims.cpython-311.pyc new file mode 100644 index 00000000..26539704 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/claims.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/jwt.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/jwt.cpython-311.pyc new file mode 100644 index 00000000..e7779351 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc7519/__pycache__/jwt.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/claims.py b/.venv/Lib/site-packages/authlib/jose/rfc7519/claims.py new file mode 100644 index 00000000..6a9877bc --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7519/claims.py @@ -0,0 +1,227 @@ +import time +from authlib.jose.errors import ( + MissingClaimError, + InvalidClaimError, + ExpiredTokenError, + InvalidTokenError, +) + + +class BaseClaims(dict): + """Payload claims for JWT, which contains a validate interface. + + :param payload: the payload dict of JWT + :param header: the header dict of JWT + :param options: validate options + :param params: other params + + An example on ``options`` parameter, the format is inspired by + `OpenID Connect Claims`_:: + + { + "iss": { + "essential": True, + "values": ["https://example.com", "https://example.org"] + }, + "sub": { + "essential": True + "value": "248289761001" + }, + "jti": { + "validate": validate_jti + } + } + + .. _`OpenID Connect Claims`: + http://openid.net/specs/openid-connect-core-1_0.html#IndividualClaimsRequests + """ + REGISTERED_CLAIMS = [] + + def __init__(self, payload, header, options=None, params=None): + super().__init__(payload) + self.header = header + self.options = options or {} + self.params = params or {} + + def __getattr__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError as error: + if key in self.REGISTERED_CLAIMS: + return self.get(key) + raise error + + def _validate_essential_claims(self): + for k in self.options: + if self.options[k].get('essential'): + if k not in self: + raise MissingClaimError(k) + elif not self.get(k): + raise InvalidClaimError(k) + + def _validate_claim_value(self, claim_name): + option = self.options.get(claim_name) + if not option: + return + + value = self.get(claim_name) + option_value = option.get('value') + if option_value and value != option_value: + raise InvalidClaimError(claim_name) + + option_values = option.get('values') + if option_values and value not in option_values: + raise InvalidClaimError(claim_name) + + validate = option.get('validate') + if validate and not validate(self, value): + raise InvalidClaimError(claim_name) + + def get_registered_claims(self): + rv = {} + for k in self.REGISTERED_CLAIMS: + if k in self: + rv[k] = self[k] + return rv + + +class JWTClaims(BaseClaims): + REGISTERED_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'nbf', 'iat', 'jti'] + + def validate(self, now=None, leeway=0): + """Validate everything in claims payload.""" + self._validate_essential_claims() + + if now is None: + now = int(time.time()) + + self.validate_iss() + self.validate_sub() + self.validate_aud() + self.validate_exp(now, leeway) + self.validate_nbf(now, leeway) + self.validate_iat(now, leeway) + self.validate_jti() + + # Validate custom claims + for key in self.options.keys(): + if key not in self.REGISTERED_CLAIMS: + self._validate_claim_value(key) + + def validate_iss(self): + """The "iss" (issuer) claim identifies the principal that issued the + JWT. The processing of this claim is generally application specific. + The "iss" value is a case-sensitive string containing a StringOrURI + value. Use of this claim is OPTIONAL. + """ + self._validate_claim_value('iss') + + def validate_sub(self): + """The "sub" (subject) claim identifies the principal that is the + subject of the JWT. The claims in a JWT are normally statements + about the subject. The subject value MUST either be scoped to be + locally unique in the context of the issuer or be globally unique. + The processing of this claim is generally application specific. The + "sub" value is a case-sensitive string containing a StringOrURI + value. Use of this claim is OPTIONAL. + """ + self._validate_claim_value('sub') + + def validate_aud(self): + """The "aud" (audience) claim identifies the recipients that the JWT is + intended for. Each principal intended to process the JWT MUST + identify itself with a value in the audience claim. If the principal + processing the claim does not identify itself with a value in the + "aud" claim when this claim is present, then the JWT MUST be + rejected. In the general case, the "aud" value is an array of case- + sensitive strings, each containing a StringOrURI value. In the + special case when the JWT has one audience, the "aud" value MAY be a + single case-sensitive string containing a StringOrURI value. The + interpretation of audience values is generally application specific. + Use of this claim is OPTIONAL. + """ + aud_option = self.options.get('aud') + aud = self.get('aud') + if not aud_option or not aud: + return + + aud_values = aud_option.get('values') + if not aud_values: + aud_value = aud_option.get('value') + if aud_value: + aud_values = [aud_value] + + if not aud_values: + return + + if isinstance(self['aud'], list): + aud_list = self['aud'] + else: + aud_list = [self['aud']] + + if not any([v in aud_list for v in aud_values]): + raise InvalidClaimError('aud') + + def validate_exp(self, now, leeway): + """The "exp" (expiration time) claim identifies the expiration time on + or after which the JWT MUST NOT be accepted for processing. The + processing of the "exp" claim requires that the current date/time + MUST be before the expiration date/time listed in the "exp" claim. + Implementers MAY provide for some small leeway, usually no more than + a few minutes, to account for clock skew. Its value MUST be a number + containing a NumericDate value. Use of this claim is OPTIONAL. + """ + if 'exp' in self: + exp = self['exp'] + if not _validate_numeric_time(exp): + raise InvalidClaimError('exp') + if exp < (now - leeway): + raise ExpiredTokenError() + + def validate_nbf(self, now, leeway): + """The "nbf" (not before) claim identifies the time before which the JWT + MUST NOT be accepted for processing. The processing of the "nbf" + claim requires that the current date/time MUST be after or equal to + the not-before date/time listed in the "nbf" claim. Implementers MAY + provide for some small leeway, usually no more than a few minutes, to + account for clock skew. Its value MUST be a number containing a + NumericDate value. Use of this claim is OPTIONAL. + """ + if 'nbf' in self: + nbf = self['nbf'] + if not _validate_numeric_time(nbf): + raise InvalidClaimError('nbf') + if nbf > (now + leeway): + raise InvalidTokenError() + + def validate_iat(self, now, leeway): + """The "iat" (issued at) claim identifies the time at which the JWT was + issued. This claim can be used to determine the age of the JWT. + Implementers MAY provide for some small leeway, usually no more + than a few minutes, to account for clock skew. Its value MUST be a + number containing a NumericDate value. Use of this claim is OPTIONAL. + """ + if 'iat' in self: + iat = self['iat'] + if not _validate_numeric_time(iat): + raise InvalidClaimError('iat') + if iat > (now + leeway): + raise InvalidTokenError( + description='The token is not valid as it was issued in the future' + ) + + def validate_jti(self): + """The "jti" (JWT ID) claim provides a unique identifier for the JWT. + The identifier value MUST be assigned in a manner that ensures that + there is a negligible probability that the same value will be + accidentally assigned to a different data object; if the application + uses multiple issuers, collisions MUST be prevented among values + produced by different issuers as well. The "jti" claim can be used + to prevent the JWT from being replayed. The "jti" value is a case- + sensitive string. Use of this claim is OPTIONAL. + """ + self._validate_claim_value('jti') + + +def _validate_numeric_time(s): + return isinstance(s, (int, float)) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc7519/jwt.py b/.venv/Lib/site-packages/authlib/jose/rfc7519/jwt.py new file mode 100644 index 00000000..ba27998b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc7519/jwt.py @@ -0,0 +1,183 @@ +import re +import random +import datetime +import calendar +from authlib.common.encoding import ( + to_bytes, to_unicode, + json_loads, json_dumps, +) +from .claims import JWTClaims +from ..errors import DecodeError, InsecureClaimError +from ..rfc7515 import JsonWebSignature +from ..rfc7516 import JsonWebEncryption +from ..rfc7517 import KeySet, Key + + +class JsonWebToken: + SENSITIVE_NAMES = ('password', 'token', 'secret', 'secret_key') + # Thanks to sentry SensitiveDataFilter + SENSITIVE_VALUES = re.compile(r'|'.join([ + # http://www.richardsramblings.com/regex/credit-card-numbers/ + r'\b(?:3[47]\d|(?:4\d|5[1-5]|65)\d{2}|6011)\d{12}\b', + # various private keys + r'-----BEGIN[A-Z ]+PRIVATE KEY-----.+-----END[A-Z ]+PRIVATE KEY-----', + # social security numbers (US) + r'^\b(?!(000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}\b', + ]), re.DOTALL) + + def __init__(self, algorithms, private_headers=None): + self._jws = JsonWebSignature(algorithms, private_headers=private_headers) + self._jwe = JsonWebEncryption(algorithms, private_headers=private_headers) + + def check_sensitive_data(self, payload): + """Check if payload contains sensitive information.""" + for k in payload: + # check claims key name + if k in self.SENSITIVE_NAMES: + raise InsecureClaimError(k) + + # check claims values + v = payload[k] + if isinstance(v, str) and self.SENSITIVE_VALUES.search(v): + raise InsecureClaimError(k) + + def encode(self, header, payload, key, check=True): + """Encode a JWT with the given header, payload and key. + + :param header: A dict of JWS header + :param payload: A dict to be encoded + :param key: key used to sign the signature + :param check: check if sensitive data in payload + :return: bytes + """ + header.setdefault('typ', 'JWT') + + for k in ['exp', 'iat', 'nbf']: + # convert datetime into timestamp + claim = payload.get(k) + if isinstance(claim, datetime.datetime): + payload[k] = calendar.timegm(claim.utctimetuple()) + + if check: + self.check_sensitive_data(payload) + + key = find_encode_key(key, header) + text = to_bytes(json_dumps(payload)) + if 'enc' in header: + return self._jwe.serialize_compact(header, text, key) + else: + return self._jws.serialize_compact(header, text, key) + + def decode(self, s, key, claims_cls=None, + claims_options=None, claims_params=None): + """Decode the JWT with the given key. This is similar with + :meth:`verify`, except that it will raise BadSignatureError when + signature doesn't match. + + :param s: text of JWT + :param key: key used to verify the signature + :param claims_cls: class to be used for JWT claims + :param claims_options: `options` parameters for claims_cls + :param claims_params: `params` parameters for claims_cls + :return: claims_cls instance + :raise: BadSignatureError + """ + if claims_cls is None: + claims_cls = JWTClaims + + if callable(key): + load_key = key + else: + load_key = create_load_key(prepare_raw_key(key)) + + s = to_bytes(s) + dot_count = s.count(b'.') + if dot_count == 2: + data = self._jws.deserialize_compact(s, load_key, decode_payload) + elif dot_count == 4: + data = self._jwe.deserialize_compact(s, load_key, decode_payload) + else: + raise DecodeError('Invalid input segments length') + return claims_cls( + data['payload'], data['header'], + options=claims_options, + params=claims_params, + ) + + +def decode_payload(bytes_payload): + try: + payload = json_loads(to_unicode(bytes_payload)) + except ValueError: + raise DecodeError('Invalid payload value') + if not isinstance(payload, dict): + raise DecodeError('Invalid payload type') + return payload + + +def prepare_raw_key(raw): + if isinstance(raw, KeySet): + return raw + + if isinstance(raw, str) and \ + raw.startswith('{') and raw.endswith('}'): + raw = json_loads(raw) + elif isinstance(raw, (tuple, list)): + raw = {'keys': raw} + return raw + + +def find_encode_key(key, header): + if isinstance(key, KeySet): + kid = header.get('kid') + if kid: + return key.find_by_kid(kid) + + rv = random.choice(key.keys) + # use side effect to add kid value into header + header['kid'] = rv.kid + return rv + + if isinstance(key, dict) and 'keys' in key: + keys = key['keys'] + kid = header.get('kid') + for k in keys: + if k.get('kid') == kid: + return k + + if not kid: + rv = random.choice(keys) + header['kid'] = rv['kid'] + return rv + raise ValueError('Invalid JSON Web Key Set') + + # append kid into header + if isinstance(key, dict) and 'kid' in key: + header['kid'] = key['kid'] + elif isinstance(key, Key) and key.kid: + header['kid'] = key.kid + return key + + +def create_load_key(key): + def load_key(header, payload): + if isinstance(key, KeySet): + return key.find_by_kid(header.get('kid')) + + if isinstance(key, dict) and 'keys' in key: + keys = key['keys'] + kid = header.get('kid') + + if kid is not None: + # look for the requested key + for k in keys: + if k.get('kid') == kid: + return k + else: + # use the only key + if len(keys) == 1: + return keys[0] + raise ValueError('Invalid JSON Web Key Set') + return key + + return load_key diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/__init__.py b/.venv/Lib/site-packages/authlib/jose/rfc8037/__init__.py new file mode 100644 index 00000000..fd0f3fe4 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc8037/__init__.py @@ -0,0 +1,5 @@ +from .okp_key import OKPKey +from .jws_eddsa import register_jws_rfc8037 + + +__all__ = ['register_jws_rfc8037', 'OKPKey'] diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..bd21d8d3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/jws_eddsa.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/jws_eddsa.cpython-311.pyc new file mode 100644 index 00000000..24d23ab8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/jws_eddsa.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/okp_key.cpython-311.pyc b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/okp_key.cpython-311.pyc new file mode 100644 index 00000000..2e2e2265 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/jose/rfc8037/__pycache__/okp_key.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/jws_eddsa.py b/.venv/Lib/site-packages/authlib/jose/rfc8037/jws_eddsa.py new file mode 100644 index 00000000..872da8e3 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc8037/jws_eddsa.py @@ -0,0 +1,27 @@ +from cryptography.exceptions import InvalidSignature +from ..rfc7515 import JWSAlgorithm +from .okp_key import OKPKey + + +class EdDSAAlgorithm(JWSAlgorithm): + name = 'EdDSA' + description = 'Edwards-curve Digital Signature Algorithm for JWS' + + def prepare_key(self, raw_data): + return OKPKey.import_key(raw_data) + + def sign(self, msg, key): + op_key = key.get_op_key('sign') + return op_key.sign(msg) + + def verify(self, msg, sig, key): + op_key = key.get_op_key('verify') + try: + op_key.verify(sig, msg) + return True + except InvalidSignature: + return False + + +def register_jws_rfc8037(cls): + cls.register_algorithm(EdDSAAlgorithm()) diff --git a/.venv/Lib/site-packages/authlib/jose/rfc8037/okp_key.py b/.venv/Lib/site-packages/authlib/jose/rfc8037/okp_key.py new file mode 100644 index 00000000..40f74689 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/rfc8037/okp_key.py @@ -0,0 +1,103 @@ +from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PublicKey, Ed25519PrivateKey +) +from cryptography.hazmat.primitives.asymmetric.ed448 import ( + Ed448PublicKey, Ed448PrivateKey +) +from cryptography.hazmat.primitives.asymmetric.x25519 import ( + X25519PublicKey, X25519PrivateKey +) +from cryptography.hazmat.primitives.asymmetric.x448 import ( + X448PublicKey, X448PrivateKey +) +from cryptography.hazmat.primitives.serialization import ( + Encoding, PublicFormat, PrivateFormat, NoEncryption +) +from authlib.common.encoding import ( + to_unicode, to_bytes, + urlsafe_b64decode, urlsafe_b64encode, +) +from ..rfc7517 import AsymmetricKey + + +PUBLIC_KEYS_MAP = { + 'Ed25519': Ed25519PublicKey, + 'Ed448': Ed448PublicKey, + 'X25519': X25519PublicKey, + 'X448': X448PublicKey, +} +PRIVATE_KEYS_MAP = { + 'Ed25519': Ed25519PrivateKey, + 'Ed448': Ed448PrivateKey, + 'X25519': X25519PrivateKey, + 'X448': X448PrivateKey, +} + + +class OKPKey(AsymmetricKey): + """Key class of the ``OKP`` key type.""" + + kty = 'OKP' + REQUIRED_JSON_FIELDS = ['crv', 'x'] + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ['crv', 'd'] + PUBLIC_KEY_CLS = tuple(PUBLIC_KEYS_MAP.values()) + PRIVATE_KEY_CLS = tuple(PRIVATE_KEYS_MAP.values()) + SSH_PUBLIC_PREFIX = b'ssh-ed25519' + + def exchange_shared_key(self, pubkey): + # used in ECDHESAlgorithm + private_key = self.get_private_key() + if private_key and isinstance(private_key, (X25519PrivateKey, X448PrivateKey)): + return private_key.exchange(pubkey) + raise ValueError('Invalid key for exchanging shared key') + + @staticmethod + def get_key_curve(key): + if isinstance(key, (Ed25519PublicKey, Ed25519PrivateKey)): + return 'Ed25519' + elif isinstance(key, (Ed448PublicKey, Ed448PrivateKey)): + return 'Ed448' + elif isinstance(key, (X25519PublicKey, X25519PrivateKey)): + return 'X25519' + elif isinstance(key, (X448PublicKey, X448PrivateKey)): + return 'X448' + + def load_private_key(self): + crv_key = PRIVATE_KEYS_MAP[self._dict_data['crv']] + d_bytes = urlsafe_b64decode(to_bytes(self._dict_data['d'])) + return crv_key.from_private_bytes(d_bytes) + + def load_public_key(self): + crv_key = PUBLIC_KEYS_MAP[self._dict_data['crv']] + x_bytes = urlsafe_b64decode(to_bytes(self._dict_data['x'])) + return crv_key.from_public_bytes(x_bytes) + + def dumps_private_key(self): + obj = self.dumps_public_key(self.private_key.public_key()) + d_bytes = self.private_key.private_bytes( + Encoding.Raw, + PrivateFormat.Raw, + NoEncryption() + ) + obj['d'] = to_unicode(urlsafe_b64encode(d_bytes)) + return obj + + def dumps_public_key(self, public_key=None): + if public_key is None: + public_key = self.public_key + x_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw) + return { + 'crv': self.get_key_curve(public_key), + 'x': to_unicode(urlsafe_b64encode(x_bytes)), + } + + @classmethod + def generate_key(cls, crv='Ed25519', options=None, is_private=False) -> 'OKPKey': + if crv not in PRIVATE_KEYS_MAP: + raise ValueError(f'Invalid crv value: "{crv}"') + private_key_cls = PRIVATE_KEYS_MAP[crv] + raw_key = private_key_cls.generate() + if not is_private: + raw_key = raw_key.public_key() + return cls.import_key(raw_key, options=options) diff --git a/.venv/Lib/site-packages/authlib/jose/util.py b/.venv/Lib/site-packages/authlib/jose/util.py new file mode 100644 index 00000000..5b0c759f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/jose/util.py @@ -0,0 +1,37 @@ +import binascii +from authlib.common.encoding import urlsafe_b64decode, json_loads, to_unicode +from authlib.jose.errors import DecodeError + + +def extract_header(header_segment, error_cls): + header_data = extract_segment(header_segment, error_cls, 'header') + + try: + header = json_loads(header_data.decode('utf-8')) + except ValueError as e: + raise error_cls(f'Invalid header string: {e}') + + if not isinstance(header, dict): + raise error_cls('Header must be a json object') + return header + + +def extract_segment(segment, error_cls, name='payload'): + try: + return urlsafe_b64decode(segment) + except (TypeError, binascii.Error): + msg = f'Invalid {name} padding' + raise error_cls(msg) + + +def ensure_dict(s, structure_name): + if not isinstance(s, dict): + try: + s = json_loads(to_unicode(s)) + except (ValueError, TypeError): + raise DecodeError(f'Invalid {structure_name}') + + if not isinstance(s, dict): + raise DecodeError(f'Invalid {structure_name}') + + return s diff --git a/.venv/Lib/site-packages/authlib/oauth1/__init__.py b/.venv/Lib/site-packages/authlib/oauth1/__init__.py new file mode 100644 index 00000000..c9a73ddf --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/__init__.py @@ -0,0 +1,34 @@ +from .rfc5849 import ( + OAuth1Request, + ClientAuth, + SIGNATURE_HMAC_SHA1, + SIGNATURE_RSA_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_TYPE_HEADER, + SIGNATURE_TYPE_QUERY, + SIGNATURE_TYPE_BODY, + ClientMixin, + TemporaryCredentialMixin, + TokenCredentialMixin, + TemporaryCredential, + AuthorizationServer, + ResourceProtector, +) + +__all__ = [ + 'OAuth1Request', + 'ClientAuth', + 'SIGNATURE_HMAC_SHA1', + 'SIGNATURE_RSA_SHA1', + 'SIGNATURE_PLAINTEXT', + 'SIGNATURE_TYPE_HEADER', + 'SIGNATURE_TYPE_QUERY', + 'SIGNATURE_TYPE_BODY', + + 'ClientMixin', + 'TemporaryCredentialMixin', + 'TokenCredentialMixin', + 'TemporaryCredential', + 'AuthorizationServer', + 'ResourceProtector', +] diff --git a/.venv/Lib/site-packages/authlib/oauth1/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..f80b60cc Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..6db5599e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..a69bf043 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/client.py b/.venv/Lib/site-packages/authlib/oauth1/client.py new file mode 100644 index 00000000..1f74f321 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/client.py @@ -0,0 +1,172 @@ +from authlib.common.urls import ( + url_decode, + add_params_to_uri, + urlparse, +) +from authlib.common.encoding import json_loads +from .rfc5849 import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_TYPE_HEADER, + ClientAuth, +) + + +class OAuth1Client: + auth_class = ClientAuth + + def __init__(self, session, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, realm=None, **kwargs): + if not client_id: + raise ValueError('Missing "client_id"') + + self.session = session + self.auth = self.auth_class( + client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, + signature_method=signature_method, + signature_type=signature_type, + rsa_key=rsa_key, + verifier=verifier, + realm=realm, + force_include_body=force_include_body + ) + self._kwargs = kwargs + + @property + def redirect_uri(self): + return self.auth.redirect_uri + + @redirect_uri.setter + def redirect_uri(self, uri): + self.auth.redirect_uri = uri + + @property + def token(self): + return dict( + oauth_token=self.auth.token, + oauth_token_secret=self.auth.token_secret, + oauth_verifier=self.auth.verifier + ) + + @token.setter + def token(self, token): + """This token setter is designed for an easy integration for + OAuthClient. Make sure both OAuth1Session and OAuth2Session + have token setters. + """ + if token is None: + self.auth.token = None + self.auth.token_secret = None + self.auth.verifier = None + elif 'oauth_token' in token: + self.auth.token = token['oauth_token'] + if 'oauth_token_secret' in token: + self.auth.token_secret = token['oauth_token_secret'] + if 'oauth_verifier' in token: + self.auth.verifier = token['oauth_verifier'] + else: + message = f'oauth_token is missing: {token!r}' + self.handle_error('missing_token', message) + + def create_authorization_url(self, url, request_token=None, **kwargs): + """Create an authorization URL by appending request_token and optional + kwargs to url. + + This is the second step in the OAuth 1 workflow. The user should be + redirected to this authorization URL, grant access to you, and then + be redirected back to you. The redirection back can either be specified + during client registration or by supplying a callback URI per request. + + :param url: The authorization endpoint URL. + :param request_token: The previously obtained request token. + :param kwargs: Optional parameters to append to the URL. + :returns: The authorization URL with new parameters embedded. + """ + kwargs['oauth_token'] = request_token or self.auth.token + if self.auth.redirect_uri: + kwargs['oauth_callback'] = self.auth.redirect_uri + return add_params_to_uri(url, kwargs.items()) + + def fetch_request_token(self, url, **kwargs): + """Method for fetching an access token from the token endpoint. + + This is the first step in the OAuth 1 workflow. A request token is + obtained by making a signed post request to url. The token is then + parsed from the application/x-www-form-urlencoded response and ready + to be used to construct an authorization url. + + :param url: Request Token endpoint. + :param kwargs: Extra parameters to include for fetching token. + :return: A Request Token dict. + """ + return self._fetch_token(url, **kwargs) + + def fetch_access_token(self, url, verifier=None, **kwargs): + """Method for fetching an access token from the token endpoint. + + This is the final step in the OAuth 1 workflow. An access token is + obtained using all previously obtained credentials, including the + verifier from the authorization step. + + :param url: Access Token endpoint. + :param verifier: A verifier string to prove authorization was granted. + :param kwargs: Extra parameters to include for fetching access token. + :return: A token dict. + """ + if verifier: + self.auth.verifier = verifier + if not self.auth.verifier: + self.handle_error('missing_verifier', 'Missing "verifier" value') + return self._fetch_token(url, **kwargs) + + def parse_authorization_response(self, url): + """Extract parameters from the post authorization redirect + response URL. + + :param url: The full URL that resulted from the user being redirected + back from the OAuth provider to you, the client. + :returns: A dict of parameters extracted from the URL. + """ + token = dict(url_decode(urlparse.urlparse(url).query)) + self.token = token + return token + + def _fetch_token(self, url, **kwargs): + resp = self.session.post(url, auth=self.auth, **kwargs) + token = self.parse_response_token(resp.status_code, resp.text) + self.token = token + self.auth.verifier = None + return token + + def parse_response_token(self, status_code, text): + if status_code >= 400: + message = ( + "Token request failed with code {}, " + "response was '{}'." + ).format(status_code, text) + self.handle_error('fetch_token_denied', message) + + try: + text = text.strip() + if text.startswith('{'): + token = json_loads(text) + else: + token = dict(url_decode(text)) + except (TypeError, ValueError) as e: + error = ( + "Unable to decode token from token response. " + "This is commonly caused by an unsuccessful request where" + " a non urlencoded error message is returned. " + "The decoding error was {}" + ).format(e) + raise ValueError(error) + return token + + @staticmethod + def handle_error(error_type, error_description): + raise ValueError(f'{error_type}: {error_description}') diff --git a/.venv/Lib/site-packages/authlib/oauth1/errors.py b/.venv/Lib/site-packages/authlib/oauth1/errors.py new file mode 100644 index 00000000..e7770da5 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/errors.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .rfc5849.errors import * diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__init__.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__init__.py new file mode 100644 index 00000000..1f029fbb --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__init__.py @@ -0,0 +1,45 @@ +""" + authlib.oauth1.rfc5849 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of The OAuth 1.0 Protocol. + + https://tools.ietf.org/html/rfc5849 +""" + +from .wrapper import OAuth1Request +from .client_auth import ClientAuth +from .signature import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_RSA_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_TYPE_HEADER, + SIGNATURE_TYPE_QUERY, + SIGNATURE_TYPE_BODY, +) +from .models import ( + ClientMixin, + TemporaryCredentialMixin, + TokenCredentialMixin, + TemporaryCredential, +) +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector + +__all__ = [ + 'OAuth1Request', + 'ClientAuth', + 'SIGNATURE_HMAC_SHA1', + 'SIGNATURE_RSA_SHA1', + 'SIGNATURE_PLAINTEXT', + 'SIGNATURE_TYPE_HEADER', + 'SIGNATURE_TYPE_QUERY', + 'SIGNATURE_TYPE_BODY', + + 'ClientMixin', + 'TemporaryCredentialMixin', + 'TokenCredentialMixin', + 'TemporaryCredential', + 'AuthorizationServer', + 'ResourceProtector', +] diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..62494829 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..6c4e9142 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/base_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/base_server.cpython-311.pyc new file mode 100644 index 00000000..dc6f70bb Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/base_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/client_auth.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/client_auth.cpython-311.pyc new file mode 100644 index 00000000..db9694a3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/client_auth.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..5453fa3b Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..1a9b2889 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/parameters.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/parameters.cpython-311.pyc new file mode 100644 index 00000000..24f172f6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/parameters.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..1c01cfd5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/rsa.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/rsa.cpython-311.pyc new file mode 100644 index 00000000..9ac0a18f Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/rsa.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/signature.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/signature.cpython-311.pyc new file mode 100644 index 00000000..6ae075c6 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/signature.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..5fb509f4 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/wrapper.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/wrapper.cpython-311.pyc new file mode 100644 index 00000000..df2afbb7 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/__pycache__/wrapper.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/authorization_server.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/authorization_server.py new file mode 100644 index 00000000..54cf7bab --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/authorization_server.py @@ -0,0 +1,357 @@ +from authlib.common.urls import is_valid_url, add_params_to_uri +from .base_server import BaseServer +from .errors import ( + OAuth1Error, + InvalidRequestError, + MissingRequiredParameterError, + InvalidClientError, + InvalidTokenError, + AccessDeniedError, + MethodNotAllowedError, +) + + +class AuthorizationServer(BaseServer): + TOKEN_RESPONSE_HEADER = [ + ('Content-Type', 'application/x-www-form-urlencoded'), + ('Cache-Control', 'no-store'), + ('Pragma', 'no-cache'), + ] + + TEMPORARY_CREDENTIALS_METHOD = 'POST' + + def _get_client(self, request): + client = self.get_client_by_id(request.client_id) + request.client = client + return client + + def create_oauth1_request(self, request): + raise NotImplementedError() + + def handle_response(self, status_code, payload, headers): + raise NotImplementedError() + + def handle_error_response(self, error): + return self.handle_response( + error.status_code, + error.get_body(), + error.get_headers() + ) + + def validate_temporary_credentials_request(self, request): + """Validate HTTP request for temporary credentials.""" + + # The client obtains a set of temporary credentials from the server by + # making an authenticated (Section 3) HTTP "POST" request to the + # Temporary Credential Request endpoint (unless the server advertises + # another HTTP request method for the client to use). + if request.method.upper() != self.TEMPORARY_CREDENTIALS_METHOD: + raise MethodNotAllowedError() + + # REQUIRED parameter + if not request.client_id: + raise MissingRequiredParameterError('oauth_consumer_key') + + # REQUIRED parameter + oauth_callback = request.redirect_uri + if not request.redirect_uri: + raise MissingRequiredParameterError('oauth_callback') + + # An absolute URI or + # other means (the parameter value MUST be set to "oob" + if oauth_callback != 'oob' and not is_valid_url(oauth_callback): + raise InvalidRequestError('Invalid "oauth_callback" value') + + client = self._get_client(request) + if not client: + raise InvalidClientError() + + self.validate_timestamp_and_nonce(request) + self.validate_oauth_signature(request) + return request + + def create_temporary_credentials_response(self, request=None): + """Validate temporary credentials token request and create response + for temporary credentials token. Assume the endpoint of temporary + credentials request is ``https://photos.example.net/initiate``: + + .. code-block:: http + + POST /initiate HTTP/1.1 + Host: photos.example.net + Authorization: OAuth realm="Photos", + oauth_consumer_key="dpf43f3p2l4k3l03", + oauth_signature_method="HMAC-SHA1", + oauth_timestamp="137131200", + oauth_nonce="wIjqoS", + oauth_callback="http%3A%2F%2Fprinter.example.com%2Fready", + oauth_signature="74KNZJeDHnMBp0EMJ9ZHt%2FXKycU%3D" + + The server validates the request and replies with a set of temporary + credentials in the body of the HTTP response: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/x-www-form-urlencoded + + oauth_token=hh5s93j4hdidpola&oauth_token_secret=hdhd0244k9j7ao03& + oauth_callback_confirmed=true + + :param request: OAuth1Request instance. + :returns: (status_code, body, headers) + """ + try: + request = self.create_oauth1_request(request) + self.validate_temporary_credentials_request(request) + except OAuth1Error as error: + return self.handle_error_response(error) + + credential = self.create_temporary_credential(request) + payload = [ + ('oauth_token', credential.get_oauth_token()), + ('oauth_token_secret', credential.get_oauth_token_secret()), + ('oauth_callback_confirmed', True) + ] + return self.handle_response(200, payload, self.TOKEN_RESPONSE_HEADER) + + def validate_authorization_request(self, request): + """Validate the request for resource owner authorization.""" + if not request.token: + raise MissingRequiredParameterError('oauth_token') + + credential = self.get_temporary_credential(request) + if not credential: + raise InvalidTokenError() + + # assign credential for later use + request.credential = credential + return request + + def create_authorization_response(self, request, grant_user=None): + """Validate authorization request and create authorization response. + Assume the endpoint for authorization request is + ``https://photos.example.net/authorize``, the client redirects Jane's + user-agent to the server's Resource Owner Authorization endpoint to + obtain Jane's approval for accessing her private photos:: + + https://photos.example.net/authorize?oauth_token=hh5s93j4hdidpola + + The server requests Jane to sign in using her username and password + and if successful, asks her to approve granting 'printer.example.com' + access to her private photos. Jane approves the request and her + user-agent is redirected to the callback URI provided by the client + in the previous request (line breaks are for display purposes only):: + + http://printer.example.com/ready? + oauth_token=hh5s93j4hdidpola&oauth_verifier=hfdp7dh39dks9884 + + :param request: OAuth1Request instance. + :param grant_user: if granted, pass the grant user, otherwise None. + :returns: (status_code, body, headers) + """ + request = self.create_oauth1_request(request) + # authorize endpoint should try catch this error + self.validate_authorization_request(request) + + temporary_credentials = request.credential + redirect_uri = temporary_credentials.get_redirect_uri() + if not redirect_uri or redirect_uri == 'oob': + client_id = temporary_credentials.get_client_id() + client = self.get_client_by_id(client_id) + redirect_uri = client.get_default_redirect_uri() + + if grant_user is None: + error = AccessDeniedError() + location = add_params_to_uri(redirect_uri, error.get_body()) + return self.handle_response(302, '', [('Location', location)]) + + request.user = grant_user + verifier = self.create_authorization_verifier(request) + + params = [ + ('oauth_token', request.token), + ('oauth_verifier', verifier) + ] + location = add_params_to_uri(redirect_uri, params) + return self.handle_response(302, '', [('Location', location)]) + + def validate_token_request(self, request): + """Validate request for issuing token.""" + + if not request.client_id: + raise MissingRequiredParameterError('oauth_consumer_key') + + client = self._get_client(request) + if not client: + raise InvalidClientError() + + if not request.token: + raise MissingRequiredParameterError('oauth_token') + + token = self.get_temporary_credential(request) + if not token: + raise InvalidTokenError() + + verifier = request.oauth_params.get('oauth_verifier') + if not verifier: + raise MissingRequiredParameterError('oauth_verifier') + + if not token.check_verifier(verifier): + raise InvalidRequestError('Invalid "oauth_verifier"') + + request.credential = token + self.validate_timestamp_and_nonce(request) + self.validate_oauth_signature(request) + return request + + def create_token_response(self, request): + """Validate token request and create token response. Assuming the + endpoint of token request is ``https://photos.example.net/token``, + the callback request informs the client that Jane completed the + authorization process. The client then requests a set of token + credentials using its temporary credentials (over a secure Transport + Layer Security (TLS) channel): + + .. code-block:: http + + POST /token HTTP/1.1 + Host: photos.example.net + Authorization: OAuth realm="Photos", + oauth_consumer_key="dpf43f3p2l4k3l03", + oauth_token="hh5s93j4hdidpola", + oauth_signature_method="HMAC-SHA1", + oauth_timestamp="137131201", + oauth_nonce="walatlh", + oauth_verifier="hfdp7dh39dks9884", + oauth_signature="gKgrFCywp7rO0OXSjdot%2FIHF7IU%3D" + + The server validates the request and replies with a set of token + credentials in the body of the HTTP response: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/x-www-form-urlencoded + + oauth_token=nnch734d00sl2jdk&oauth_token_secret=pfkkdhi9sl3r4s00 + + :param request: OAuth1Request instance. + :returns: (status_code, body, headers) + """ + try: + request = self.create_oauth1_request(request) + except OAuth1Error as error: + return self.handle_error_response(error) + + try: + self.validate_token_request(request) + except OAuth1Error as error: + self.delete_temporary_credential(request) + return self.handle_error_response(error) + + credential = self.create_token_credential(request) + payload = [ + ('oauth_token', credential.get_oauth_token()), + ('oauth_token_secret', credential.get_oauth_token_secret()), + ] + self.delete_temporary_credential(request) + return self.handle_response(200, payload, self.TOKEN_RESPONSE_HEADER) + + def create_temporary_credential(self, request): + """Generate and save a temporary credential into database or cache. + A temporary credential is used for exchanging token credential. This + method should be re-implemented:: + + def create_temporary_credential(self, request): + oauth_token = generate_token(36) + oauth_token_secret = generate_token(48) + temporary_credential = TemporaryCredential( + oauth_token=oauth_token, + oauth_token_secret=oauth_token_secret, + client_id=request.client_id, + redirect_uri=request.redirect_uri, + ) + # if the credential has a save method + temporary_credential.save() + return temporary_credential + + :param request: OAuth1Request instance + :return: TemporaryCredential instance + """ + raise NotImplementedError() + + def get_temporary_credential(self, request): + """Get the temporary credential from database or cache. A temporary + credential should share the same methods as described in models of + ``TemporaryCredentialMixin``:: + + def get_temporary_credential(self, request): + key = 'a-key-prefix:{}'.format(request.token) + data = cache.get(key) + # TemporaryCredential shares methods from TemporaryCredentialMixin + return TemporaryCredential(data) + + :param request: OAuth1Request instance + :return: TemporaryCredential instance + """ + raise NotImplementedError() + + def delete_temporary_credential(self, request): + """Delete temporary credential from database or cache. For instance, + if temporary credential is saved in cache:: + + def delete_temporary_credential(self, request): + key = 'a-key-prefix:{}'.format(request.token) + cache.delete(key) + + :param request: OAuth1Request instance + """ + raise NotImplementedError() + + def create_authorization_verifier(self, request): + """Create and bind ``oauth_verifier`` to temporary credential. It + could be re-implemented in this way:: + + def create_authorization_verifier(self, request): + verifier = generate_token(36) + + temporary_credential = request.credential + user_id = request.user.id + + temporary_credential.user_id = user_id + temporary_credential.oauth_verifier = verifier + # if the credential has a save method + temporary_credential.save() + + # remember to return the verifier + return verifier + + :param request: OAuth1Request instance + :return: A string of ``oauth_verifier`` + """ + raise NotImplementedError() + + def create_token_credential(self, request): + """Create and save token credential into database. This method would + be re-implemented like this:: + + def create_token_credential(self, request): + oauth_token = generate_token(36) + oauth_token_secret = generate_token(48) + temporary_credential = request.credential + + token_credential = TokenCredential( + oauth_token=oauth_token, + oauth_token_secret=oauth_token_secret, + client_id=temporary_credential.get_client_id(), + user_id=temporary_credential.get_user_id() + ) + # if the credential has a save method + token_credential.save() + return token_credential + + :param request: OAuth1Request instance + :return: TokenCredential instance + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/base_server.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/base_server.py new file mode 100644 index 00000000..5d29deb9 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/base_server.py @@ -0,0 +1,119 @@ +import time +from .signature import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_RSA_SHA1, +) +from .signature import ( + verify_hmac_sha1, + verify_plaintext, + verify_rsa_sha1, +) +from .errors import ( + InvalidRequestError, + MissingRequiredParameterError, + UnsupportedSignatureMethodError, + InvalidNonceError, + InvalidSignatureError, +) + + +class BaseServer: + SIGNATURE_METHODS = { + SIGNATURE_HMAC_SHA1: verify_hmac_sha1, + SIGNATURE_RSA_SHA1: verify_rsa_sha1, + SIGNATURE_PLAINTEXT: verify_plaintext, + } + SUPPORTED_SIGNATURE_METHODS = [SIGNATURE_HMAC_SHA1] + EXPIRY_TIME = 300 + + @classmethod + def register_signature_method(cls, name, verify): + """Extend signature method verification. + + :param name: A string to represent signature method. + :param verify: A function to verify signature. + + The ``verify`` method accept ``OAuth1Request`` as parameter:: + + def verify_custom_method(request): + # verify this request, return True or False + return True + + Server.register_signature_method('custom-name', verify_custom_method) + """ + cls.SIGNATURE_METHODS[name] = verify + + def validate_timestamp_and_nonce(self, request): + """Validate ``oauth_timestamp`` and ``oauth_nonce`` in HTTP request. + + :param request: OAuth1Request instance + """ + timestamp = request.oauth_params.get('oauth_timestamp') + nonce = request.oauth_params.get('oauth_nonce') + + if request.signature_method == SIGNATURE_PLAINTEXT: + # The parameters MAY be omitted when using the "PLAINTEXT" + # signature method + if not timestamp and not nonce: + return + + if not timestamp: + raise MissingRequiredParameterError('oauth_timestamp') + + try: + # The timestamp value MUST be a positive integer + timestamp = int(timestamp) + if timestamp < 0: + raise InvalidRequestError('Invalid "oauth_timestamp" value') + + if self.EXPIRY_TIME and time.time() - timestamp > self.EXPIRY_TIME: + raise InvalidRequestError('Invalid "oauth_timestamp" value') + except (ValueError, TypeError): + raise InvalidRequestError('Invalid "oauth_timestamp" value') + + if not nonce: + raise MissingRequiredParameterError('oauth_nonce') + + if self.exists_nonce(nonce, request): + raise InvalidNonceError() + + def validate_oauth_signature(self, request): + """Validate ``oauth_signature`` from HTTP request. + + :param request: OAuth1Request instance + """ + method = request.signature_method + if not method: + raise MissingRequiredParameterError('oauth_signature_method') + + if method not in self.SUPPORTED_SIGNATURE_METHODS: + raise UnsupportedSignatureMethodError() + + if not request.signature: + raise MissingRequiredParameterError('oauth_signature') + + verify = self.SIGNATURE_METHODS.get(method) + if not verify: + raise UnsupportedSignatureMethodError() + + if not verify(request): + raise InvalidSignatureError() + + def get_client_by_id(self, client_id): + """Get client instance with the given ``client_id``. + + :param client_id: A string of client_id + :return: Client instance + """ + raise NotImplementedError() + + def exists_nonce(self, nonce, request): + """The nonce value MUST be unique across all requests with the same + timestamp, client credentials, and token combinations. + + :param nonce: A string value of ``oauth_nonce`` + :param request: OAuth1Request instance + :return: Boolean + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/client_auth.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/client_auth.py new file mode 100644 index 00000000..2c59b594 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/client_auth.py @@ -0,0 +1,187 @@ +import time +import base64 +import hashlib +from authlib.common.security import generate_token +from authlib.common.urls import extract_params +from authlib.common.encoding import to_native +from .wrapper import OAuth1Request +from .signature import ( + SIGNATURE_HMAC_SHA1, + SIGNATURE_PLAINTEXT, + SIGNATURE_RSA_SHA1, + SIGNATURE_TYPE_HEADER, + SIGNATURE_TYPE_BODY, + SIGNATURE_TYPE_QUERY, +) +from .signature import ( + sign_hmac_sha1, + sign_rsa_sha1, + sign_plaintext +) +from .parameters import ( + prepare_form_encoded_body, + prepare_headers, + prepare_request_uri_query, +) + + +CONTENT_TYPE_FORM_URLENCODED = 'application/x-www-form-urlencoded' +CONTENT_TYPE_MULTI_PART = 'multipart/form-data' + + +class ClientAuth: + SIGNATURE_METHODS = { + SIGNATURE_HMAC_SHA1: sign_hmac_sha1, + SIGNATURE_RSA_SHA1: sign_rsa_sha1, + SIGNATURE_PLAINTEXT: sign_plaintext, + } + + @classmethod + def register_signature_method(cls, name, sign): + """Extend client signature methods. + + :param name: A string to represent signature method. + :param sign: A function to generate signature. + + The ``sign`` method accept 2 parameters:: + + def custom_sign_method(client, request): + # client is the instance of Client. + return 'your-signed-string' + + Client.register_signature_method('custom-name', custom_sign_method) + """ + cls.SIGNATURE_METHODS[name] = sign + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + realm=None, force_include_body=False): + self.client_id = client_id + self.client_secret = client_secret + self.token = token + self.token_secret = token_secret + self.redirect_uri = redirect_uri + self.signature_method = signature_method + self.signature_type = signature_type + self.rsa_key = rsa_key + self.verifier = verifier + self.realm = realm + self.force_include_body = force_include_body + + def get_oauth_signature(self, method, uri, headers, body): + """Get an OAuth signature to be used in signing a request + + To satisfy `section 3.4.1.2`_ item 2, if the request argument's + headers dict attribute contains a Host item, its value will + replace any netloc part of the request argument's uri attribute + value. + + .. _`section 3.4.1.2`: https://tools.ietf.org/html/rfc5849#section-3.4.1.2 + """ + sign = self.SIGNATURE_METHODS.get(self.signature_method) + if not sign: + raise ValueError('Invalid signature method.') + + request = OAuth1Request(method, uri, body=body, headers=headers) + return sign(self, request) + + def get_oauth_params(self, nonce, timestamp): + oauth_params = [ + ('oauth_nonce', nonce), + ('oauth_timestamp', timestamp), + ('oauth_version', '1.0'), + ('oauth_signature_method', self.signature_method), + ('oauth_consumer_key', self.client_id), + ] + if self.token: + oauth_params.append(('oauth_token', self.token)) + if self.redirect_uri: + oauth_params.append(('oauth_callback', self.redirect_uri)) + if self.verifier: + oauth_params.append(('oauth_verifier', self.verifier)) + return oauth_params + + def _render(self, uri, headers, body, oauth_params): + if self.signature_type == SIGNATURE_TYPE_HEADER: + headers = prepare_headers(oauth_params, headers, realm=self.realm) + elif self.signature_type == SIGNATURE_TYPE_BODY: + if CONTENT_TYPE_FORM_URLENCODED in headers.get('Content-Type', ''): + decoded_body = extract_params(body) or [] + body = prepare_form_encoded_body(oauth_params, decoded_body) + headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED + elif self.signature_type == SIGNATURE_TYPE_QUERY: + uri = prepare_request_uri_query(oauth_params, uri) + else: + raise ValueError('Unknown signature type specified.') + return uri, headers, body + + def sign(self, method, uri, headers, body): + """Sign the HTTP request, add OAuth parameters and signature. + + :param method: HTTP method of the request. + :param uri: URI of the HTTP request. + :param body: Body payload of the HTTP request. + :param headers: Headers of the HTTP request. + :return: uri, headers, body + """ + nonce = generate_nonce() + timestamp = generate_timestamp() + if body is None: + body = b'' + + # transform int to str + timestamp = str(timestamp) + + if headers is None: + headers = {} + + oauth_params = self.get_oauth_params(nonce, timestamp) + + # https://datatracker.ietf.org/doc/html/draft-eaton-oauth-bodyhash-00.html + # include oauth_body_hash + if body and headers.get('Content-Type') != CONTENT_TYPE_FORM_URLENCODED: + oauth_body_hash = base64.b64encode(hashlib.sha1(body).digest()) + oauth_params.append(('oauth_body_hash', oauth_body_hash.decode('utf-8'))) + + uri, headers, body = self._render(uri, headers, body, oauth_params) + + sig = self.get_oauth_signature(method, uri, headers, body) + oauth_params.append(('oauth_signature', sig)) + + uri, headers, body = self._render(uri, headers, body, oauth_params) + return uri, headers, body + + def prepare(self, method, uri, headers, body): + """Add OAuth parameters to the request. + + Parameters may be included from the body if the content-type is + urlencoded, if no content type is set, a guess is made. + """ + content_type = to_native(headers.get('Content-Type', '')) + if self.signature_type == SIGNATURE_TYPE_BODY: + content_type = CONTENT_TYPE_FORM_URLENCODED + elif not content_type and extract_params(body): + content_type = CONTENT_TYPE_FORM_URLENCODED + + if CONTENT_TYPE_FORM_URLENCODED in content_type: + headers['Content-Type'] = CONTENT_TYPE_FORM_URLENCODED + uri, headers, body = self.sign(method, uri, headers, body) + elif self.force_include_body: + # To allow custom clients to work on non form encoded bodies. + uri, headers, body = self.sign(method, uri, headers, body) + else: + # Omit body data in the signing of non form-encoded requests + uri, headers, _ = self.sign(method, uri, headers, b'') + body = b'' + return uri, headers, body + + +def generate_nonce(): + return generate_token() + + +def generate_timestamp(): + return str(int(time.time())) diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/errors.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/errors.py new file mode 100644 index 00000000..93396fce --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/errors.py @@ -0,0 +1,89 @@ +""" + authlib.oauth1.rfc5849.errors + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + RFC5849 has no definition on errors. This module is designed by + Authlib based on OAuth 1.0a `Section 10`_ with some changes. + + .. _`Section 10`: https://oauth.net/core/1.0a/#rfc.section.10 +""" +from authlib.common.errors import AuthlibHTTPError +from authlib.common.security import is_secure_transport + + +class OAuth1Error(AuthlibHTTPError): + def __init__(self, description=None, uri=None, status_code=None): + super().__init__(None, description, uri, status_code) + + def get_headers(self): + """Get a list of headers.""" + return [ + ('Content-Type', 'application/x-www-form-urlencoded'), + ('Cache-Control', 'no-store'), + ('Pragma', 'no-cache') + ] + + +class InsecureTransportError(OAuth1Error): + error = 'insecure_transport' + description = 'OAuth 2 MUST utilize https.' + + @classmethod + def check(cls, uri): + if not is_secure_transport(uri): + raise cls() + + +class InvalidRequestError(OAuth1Error): + error = 'invalid_request' + + +class UnsupportedParameterError(OAuth1Error): + error = 'unsupported_parameter' + + +class UnsupportedSignatureMethodError(OAuth1Error): + error = 'unsupported_signature_method' + + +class MissingRequiredParameterError(OAuth1Error): + error = 'missing_required_parameter' + + def __init__(self, key): + description = f'missing "{key}" in parameters' + super().__init__(description=description) + + +class DuplicatedOAuthProtocolParameterError(OAuth1Error): + error = 'duplicated_oauth_protocol_parameter' + + +class InvalidClientError(OAuth1Error): + error = 'invalid_client' + status_code = 401 + + +class InvalidTokenError(OAuth1Error): + error = 'invalid_token' + description = 'Invalid or expired "oauth_token" in parameters' + status_code = 401 + + +class InvalidSignatureError(OAuth1Error): + error = 'invalid_signature' + status_code = 401 + + +class InvalidNonceError(OAuth1Error): + error = 'invalid_nonce' + status_code = 401 + + +class AccessDeniedError(OAuth1Error): + error = 'access_denied' + description = 'The resource owner or authorization server denied the request' + + +class MethodNotAllowedError(OAuth1Error): + error = 'method_not_allowed' + status_code = 405 diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/models.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/models.py new file mode 100644 index 00000000..c9f3ea61 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/models.py @@ -0,0 +1,108 @@ +class ClientMixin: + def get_default_redirect_uri(self): + """A method to get client default redirect_uri. For instance, the + database table for client has a column called ``default_redirect_uri``:: + + def get_default_redirect_uri(self): + return self.default_redirect_uri + + :return: A URL string + """ + raise NotImplementedError() + + def get_client_secret(self): + """A method to return the client_secret of this client. For instance, + the database table has a column called ``client_secret``:: + + def get_client_secret(self): + return self.client_secret + """ + raise NotImplementedError() + + def get_rsa_public_key(self): + """A method to get the RSA public key for RSA-SHA1 signature method. + For instance, the value is saved on column ``rsa_public_key``:: + + def get_rsa_public_key(self): + return self.rsa_public_key + """ + raise NotImplementedError() + + +class TokenCredentialMixin: + def get_oauth_token(self): + """A method to get the value of ``oauth_token``. For instance, the + database table has a column called ``oauth_token``:: + + def get_oauth_token(self): + return self.oauth_token + + :return: A string + """ + raise NotImplementedError() + + def get_oauth_token_secret(self): + """A method to get the value of ``oauth_token_secret``. For instance, + the database table has a column called ``oauth_token_secret``:: + + def get_oauth_token_secret(self): + return self.oauth_token_secret + + :return: A string + """ + raise NotImplementedError() + + +class TemporaryCredentialMixin(TokenCredentialMixin): + def get_client_id(self): + """A method to get the client_id associated with this credential. + For instance, the table in the database has a column ``client_id``:: + + def get_client_id(self): + return self.client_id + """ + raise NotImplementedError() + + def get_redirect_uri(self): + """A method to get temporary credential's ``oauth_callback``. + For instance, the database table for temporary credential has a + column called ``oauth_callback``:: + + def get_redirect_uri(self): + return self.oauth_callback + + :return: A URL string + """ + raise NotImplementedError() + + def check_verifier(self, verifier): + """A method to check if the given verifier matches this temporary + credential. For instance that this temporary credential has recorded + the value in database as column ``oauth_verifier``:: + + def check_verifier(self, verifier): + return self.oauth_verifier == verifier + + :return: Boolean + """ + raise NotImplementedError() + + +class TemporaryCredential(dict, TemporaryCredentialMixin): + def get_client_id(self): + return self.get('client_id') + + def get_user_id(self): + return self.get('user_id') + + def get_redirect_uri(self): + return self.get('oauth_callback') + + def check_verifier(self, verifier): + return self.get('oauth_verifier') == verifier + + def get_oauth_token(self): + return self.get('oauth_token') + + def get_oauth_token_secret(self): + return self.get('oauth_token_secret') diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/parameters.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/parameters.py new file mode 100644 index 00000000..0e64e5c6 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/parameters.py @@ -0,0 +1,101 @@ +""" + authlib.spec.rfc5849.parameters + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + This module contains methods related to `section 3.5`_ of the OAuth 1.0a spec. + + .. _`section 3.5`: https://tools.ietf.org/html/rfc5849#section-3.5 +""" +from authlib.common.urls import urlparse, url_encode, extract_params +from .util import escape + + +def prepare_headers(oauth_params, headers=None, realm=None): + """**Prepare the Authorization header.** + Per `section 3.5.1`_ of the spec. + + Protocol parameters can be transmitted using the HTTP "Authorization" + header field as defined by `RFC2617`_ with the auth-scheme name set to + "OAuth" (case insensitive). + + For example:: + + Authorization: OAuth realm="Photos", + oauth_consumer_key="dpf43f3p2l4k3l03", + oauth_signature_method="HMAC-SHA1", + oauth_timestamp="137131200", + oauth_nonce="wIjqoS", + oauth_callback="http%3A%2F%2Fprinter.example.com%2Fready", + oauth_signature="74KNZJeDHnMBp0EMJ9ZHt%2FXKycU%3D", + oauth_version="1.0" + + .. _`section 3.5.1`: https://tools.ietf.org/html/rfc5849#section-3.5.1 + .. _`RFC2617`: https://tools.ietf.org/html/rfc2617 + """ + headers = headers or {} + + # step 1, 2, 3 in Section 3.5.1 + header_parameters = ', '.join([ + f'{escape(k)}="{escape(v)}"' for k, v in oauth_params + if k.startswith('oauth_') + ]) + + # 4. The OPTIONAL "realm" parameter MAY be added and interpreted per + # `RFC2617 section 1.2`_. + # + # .. _`RFC2617 section 1.2`: https://tools.ietf.org/html/rfc2617#section-1.2 + if realm: + # NOTE: realm should *not* be escaped + header_parameters = f'realm="{realm}", ' + header_parameters + + # the auth-scheme name set to "OAuth" (case insensitive). + headers['Authorization'] = f'OAuth {header_parameters}' + return headers + + +def _append_params(oauth_params, params): + """Append OAuth params to an existing set of parameters. + + Both params and oauth_params is must be lists of 2-tuples. + + Per `section 3.5.2`_ and `3.5.3`_ of the spec. + + .. _`section 3.5.2`: https://tools.ietf.org/html/rfc5849#section-3.5.2 + .. _`3.5.3`: https://tools.ietf.org/html/rfc5849#section-3.5.3 + + """ + merged = list(params) + merged.extend(oauth_params) + # The request URI / entity-body MAY include other request-specific + # parameters, in which case, the protocol parameters SHOULD be appended + # following the request-specific parameters, properly separated by an "&" + # character (ASCII code 38) + merged.sort(key=lambda i: i[0].startswith('oauth_')) + return merged + + +def prepare_form_encoded_body(oauth_params, body): + """Prepare the Form-Encoded Body. + + Per `section 3.5.2`_ of the spec. + + .. _`section 3.5.2`: https://tools.ietf.org/html/rfc5849#section-3.5.2 + + """ + # append OAuth params to the existing body + return url_encode(_append_params(oauth_params, body)) + + +def prepare_request_uri_query(oauth_params, uri): + """Prepare the Request URI Query. + + Per `section 3.5.3`_ of the spec. + + .. _`section 3.5.3`: https://tools.ietf.org/html/rfc5849#section-3.5.3 + + """ + # append OAuth params to the existing set of query components + sch, net, path, par, query, fra = urlparse.urlparse(uri) + query = url_encode( + _append_params(oauth_params, extract_params(query) or [])) + return urlparse.urlunparse((sch, net, path, par, query, fra)) diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/resource_protector.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/resource_protector.py new file mode 100644 index 00000000..2b5d7819 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/resource_protector.py @@ -0,0 +1,41 @@ +from .base_server import BaseServer +from .wrapper import OAuth1Request +from .errors import ( + MissingRequiredParameterError, + InvalidClientError, + InvalidTokenError, +) + + +class ResourceProtector(BaseServer): + def validate_request(self, method, uri, body, headers): + request = OAuth1Request(method, uri, body, headers) + + if not request.client_id: + raise MissingRequiredParameterError('oauth_consumer_key') + + client = self.get_client_by_id(request.client_id) + if not client: + raise InvalidClientError() + request.client = client + + if not request.token: + raise MissingRequiredParameterError('oauth_token') + + token = self.get_token_credential(request) + if not token: + raise InvalidTokenError() + + request.credential = token + self.validate_timestamp_and_nonce(request) + self.validate_oauth_signature(request) + return request + + def get_token_credential(self, request): + """Fetch the token credential from data store like a database, + framework should implement this function. + + :param request: OAuth1Request instance + :return: Token model instance + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/rsa.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/rsa.py new file mode 100644 index 00000000..3785b0f7 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/rsa.py @@ -0,0 +1,29 @@ +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, load_pem_public_key +) +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.exceptions import InvalidSignature +from authlib.common.encoding import to_bytes + + +def sign_sha1(msg, rsa_private_key): + key = load_pem_private_key( + to_bytes(rsa_private_key), + password=None, + backend=default_backend() + ) + return key.sign(msg, padding.PKCS1v15(), hashes.SHA1()) + + +def verify_sha1(sig, msg, rsa_public_key): + key = load_pem_public_key( + to_bytes(rsa_public_key), + backend=default_backend() + ) + try: + key.verify(sig, msg, padding.PKCS1v15(), hashes.SHA1()) + return True + except InvalidSignature: + return False diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/signature.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/signature.py new file mode 100644 index 00000000..bfb87fee --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/signature.py @@ -0,0 +1,386 @@ +""" + authlib.oauth1.rfc5849.signature + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of `section 3.4`_ of the spec. + + .. _`section 3.4`: https://tools.ietf.org/html/rfc5849#section-3.4 +""" +import binascii +import hashlib +import hmac +from authlib.common.urls import urlparse +from authlib.common.encoding import to_unicode, to_bytes +from .util import escape, unescape + +SIGNATURE_HMAC_SHA1 = "HMAC-SHA1" +SIGNATURE_RSA_SHA1 = "RSA-SHA1" +SIGNATURE_PLAINTEXT = "PLAINTEXT" + +SIGNATURE_TYPE_HEADER = 'HEADER' +SIGNATURE_TYPE_QUERY = 'QUERY' +SIGNATURE_TYPE_BODY = 'BODY' + + +def construct_base_string(method, uri, params, host=None): + """Generate signature base string from request, per `Section 3.4.1`_. + + For example, the HTTP request:: + + POST /request?b5=%3D%253D&a3=a&c%40=&a2=r%20b HTTP/1.1 + Host: example.com + Content-Type: application/x-www-form-urlencoded + Authorization: OAuth realm="Example", + oauth_consumer_key="9djdj82h48djs9d2", + oauth_token="kkk9d7dh3k39sjv7", + oauth_signature_method="HMAC-SHA1", + oauth_timestamp="137131201", + oauth_nonce="7d8f3e4a", + oauth_signature="bYT5CMsGcbgUdFHObYMEfcx6bsw%3D" + + c2&a3=2+q + + is represented by the following signature base string (line breaks + are for display purposes only):: + + POST&http%3A%2F%2Fexample.com%2Frequest&a2%3Dr%2520b%26a3%3D2%2520q + %26a3%3Da%26b5%3D%253D%25253D%26c%2540%3D%26c2%3D%26oauth_consumer_ + key%3D9djdj82h48djs9d2%26oauth_nonce%3D7d8f3e4a%26oauth_signature_m + ethod%3DHMAC-SHA1%26oauth_timestamp%3D137131201%26oauth_token%3Dkkk + 9d7dh3k39sjv7 + + .. _`Section 3.4.1`: https://tools.ietf.org/html/rfc5849#section-3.4.1 + """ + + # Create base string URI per Section 3.4.1.2 + base_string_uri = normalize_base_string_uri(uri, host) + + # Cleanup parameter sources per Section 3.4.1.3.1 + unescaped_params = [] + for k, v in params: + # The "oauth_signature" parameter MUST be excluded from the signature + if k in ('oauth_signature', 'realm'): + continue + + # ensure oauth params are unescaped + if k.startswith('oauth_'): + v = unescape(v) + unescaped_params.append((k, v)) + + # Normalize parameters per Section 3.4.1.3.2 + normalized_params = normalize_parameters(unescaped_params) + + # construct base string + return '&'.join([ + escape(method.upper()), + escape(base_string_uri), + escape(normalized_params), + ]) + + +def normalize_base_string_uri(uri, host=None): + """Normalize Base String URI per `Section 3.4.1.2`_. + + For example, the HTTP request:: + + GET /r%20v/X?id=123 HTTP/1.1 + Host: EXAMPLE.COM:80 + + is represented by the base string URI: "http://example.com/r%20v/X". + + In another example, the HTTPS request:: + + GET /?q=1 HTTP/1.1 + Host: www.example.net:8080 + + is represented by the base string URI: "https://www.example.net:8080/". + + .. _`Section 3.4.1.2`: https://tools.ietf.org/html/rfc5849#section-3.4.1.2 + + The host argument overrides the netloc part of the uri argument. + """ + uri = to_unicode(uri) + scheme, netloc, path, params, query, fragment = urlparse.urlparse(uri) + + # The scheme, authority, and path of the request resource URI `RFC3986` + # are included by constructing an "http" or "https" URI representing + # the request resource (without the query or fragment) as follows: + # + # .. _`RFC3986`: https://tools.ietf.org/html/rfc3986 + + if not scheme or not netloc: + raise ValueError('uri must include a scheme and netloc') + + # Per `RFC 2616 section 5.1.2`_: + # + # Note that the absolute path cannot be empty; if none is present in + # the original URI, it MUST be given as "/" (the server root). + # + # .. _`RFC 2616 section 5.1.2`: https://tools.ietf.org/html/rfc2616#section-5.1.2 + if not path: + path = '/' + + # 1. The scheme and host MUST be in lowercase. + scheme = scheme.lower() + netloc = netloc.lower() + + # 2. The host and port values MUST match the content of the HTTP + # request "Host" header field. + if host is not None: + netloc = host.lower() + + # 3. The port MUST be included if it is not the default port for the + # scheme, and MUST be excluded if it is the default. Specifically, + # the port MUST be excluded when making an HTTP request `RFC2616`_ + # to port 80 or when making an HTTPS request `RFC2818`_ to port 443. + # All other non-default port numbers MUST be included. + # + # .. _`RFC2616`: https://tools.ietf.org/html/rfc2616 + # .. _`RFC2818`: https://tools.ietf.org/html/rfc2818 + default_ports = ( + ('http', '80'), + ('https', '443'), + ) + if ':' in netloc: + host, port = netloc.split(':', 1) + if (scheme, port) in default_ports: + netloc = host + + return urlparse.urlunparse((scheme, netloc, path, params, '', '')) + + +def normalize_parameters(params): + """Normalize parameters per `Section 3.4.1.3.2`_. + + For example, the list of parameters from the previous section would + be normalized as follows: + + Encoded:: + + +------------------------+------------------+ + | Name | Value | + +------------------------+------------------+ + | b5 | %3D%253D | + | a3 | a | + | c%40 | | + | a2 | r%20b | + | oauth_consumer_key | 9djdj82h48djs9d2 | + | oauth_token | kkk9d7dh3k39sjv7 | + | oauth_signature_method | HMAC-SHA1 | + | oauth_timestamp | 137131201 | + | oauth_nonce | 7d8f3e4a | + | c2 | | + | a3 | 2%20q | + +------------------------+------------------+ + + Sorted:: + + +------------------------+------------------+ + | Name | Value | + +------------------------+------------------+ + | a2 | r%20b | + | a3 | 2%20q | + | a3 | a | + | b5 | %3D%253D | + | c%40 | | + | c2 | | + | oauth_consumer_key | 9djdj82h48djs9d2 | + | oauth_nonce | 7d8f3e4a | + | oauth_signature_method | HMAC-SHA1 | + | oauth_timestamp | 137131201 | + | oauth_token | kkk9d7dh3k39sjv7 | + +------------------------+------------------+ + + Concatenated Pairs:: + + +-------------------------------------+ + | Name=Value | + +-------------------------------------+ + | a2=r%20b | + | a3=2%20q | + | a3=a | + | b5=%3D%253D | + | c%40= | + | c2= | + | oauth_consumer_key=9djdj82h48djs9d2 | + | oauth_nonce=7d8f3e4a | + | oauth_signature_method=HMAC-SHA1 | + | oauth_timestamp=137131201 | + | oauth_token=kkk9d7dh3k39sjv7 | + +-------------------------------------+ + + and concatenated together into a single string (line breaks are for + display purposes only):: + + a2=r%20b&a3=2%20q&a3=a&b5=%3D%253D&c%40=&c2=&oauth_consumer_key=9dj + dj82h48djs9d2&oauth_nonce=7d8f3e4a&oauth_signature_method=HMAC-SHA1 + &oauth_timestamp=137131201&oauth_token=kkk9d7dh3k39sjv7 + + .. _`Section 3.4.1.3.2`: https://tools.ietf.org/html/rfc5849#section-3.4.1.3.2 + """ + + # 1. First, the name and value of each parameter are encoded + # (`Section 3.6`_). + # + # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 + key_values = [(escape(k), escape(v)) for k, v in params] + + # 2. The parameters are sorted by name, using ascending byte value + # ordering. If two or more parameters share the same name, they + # are sorted by their value. + key_values.sort() + + # 3. The name of each parameter is concatenated to its corresponding + # value using an "=" character (ASCII code 61) as a separator, even + # if the value is empty. + parameter_parts = [f'{k}={v}' for k, v in key_values] + + # 4. The sorted name/value pairs are concatenated together into a + # single string by using an "&" character (ASCII code 38) as + # separator. + return '&'.join(parameter_parts) + + +def generate_signature_base_string(request): + """Generate signature base string from request.""" + host = request.headers.get('Host', None) + return construct_base_string( + request.method, request.uri, request.params, host) + + +def hmac_sha1_signature(base_string, client_secret, token_secret): + """Generate signature via HMAC-SHA1 method, per `Section 3.4.2`_. + + The "HMAC-SHA1" signature method uses the HMAC-SHA1 signature + algorithm as defined in `RFC2104`_:: + + digest = HMAC-SHA1 (key, text) + + .. _`RFC2104`: https://tools.ietf.org/html/rfc2104 + .. _`Section 3.4.2`: https://tools.ietf.org/html/rfc5849#section-3.4.2 + """ + + # The HMAC-SHA1 function variables are used in following way: + + # text is set to the value of the signature base string from + # `Section 3.4.1.1`_. + # + # .. _`Section 3.4.1.1`: https://tools.ietf.org/html/rfc5849#section-3.4.1.1 + text = base_string + + # key is set to the concatenated values of: + # 1. The client shared-secret, after being encoded (`Section 3.6`_). + # + # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 + key = escape(client_secret or '') + + # 2. An "&" character (ASCII code 38), which MUST be included + # even when either secret is empty. + key += '&' + + # 3. The token shared-secret, after being encoded (`Section 3.6`_). + # + # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 + key += escape(token_secret or '') + + signature = hmac.new(to_bytes(key), to_bytes(text), hashlib.sha1) + + # digest is used to set the value of the "oauth_signature" protocol + # parameter, after the result octet string is base64-encoded + # per `RFC2045, Section 6.8`. + # + # .. _`RFC2045, Section 6.8`: https://tools.ietf.org/html/rfc2045#section-6.8 + sig = binascii.b2a_base64(signature.digest())[:-1] + return to_unicode(sig) + + +def rsa_sha1_signature(base_string, rsa_private_key): + """Generate signature via RSA-SHA1 method, per `Section 3.4.3`_. + + The "RSA-SHA1" signature method uses the RSASSA-PKCS1-v1_5 signature + algorithm as defined in `RFC3447, Section 8.2`_ (also known as + PKCS#1), using SHA-1 as the hash function for EMSA-PKCS1-v1_5. To + use this method, the client MUST have established client credentials + with the server that included its RSA public key (in a manner that is + beyond the scope of this specification). + + .. _`Section 3.4.3`: https://tools.ietf.org/html/rfc5849#section-3.4.3 + .. _`RFC3447, Section 8.2`: https://tools.ietf.org/html/rfc3447#section-8.2 + """ + from .rsa import sign_sha1 + base_string = to_bytes(base_string) + s = sign_sha1(to_bytes(base_string), rsa_private_key) + sig = binascii.b2a_base64(s)[:-1] + return to_unicode(sig) + + +def plaintext_signature(client_secret, token_secret): + """Generate signature via PLAINTEXT method, per `Section 3.4.4`_. + + The "PLAINTEXT" method does not employ a signature algorithm. It + MUST be used with a transport-layer mechanism such as TLS or SSL (or + sent over a secure channel with equivalent protections). It does not + utilize the signature base string or the "oauth_timestamp" and + "oauth_nonce" parameters. + + .. _`Section 3.4.4`: https://tools.ietf.org/html/rfc5849#section-3.4.4 + """ + + # The "oauth_signature" protocol parameter is set to the concatenated + # value of: + + # 1. The client shared-secret, after being encoded (`Section 3.6`_). + # + # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 + signature = escape(client_secret or '') + + # 2. An "&" character (ASCII code 38), which MUST be included even + # when either secret is empty. + signature += '&' + + # 3. The token shared-secret, after being encoded (`Section 3.6`_). + # + # .. _`Section 3.6`: https://tools.ietf.org/html/rfc5849#section-3.6 + signature += escape(token_secret or '') + + return signature + + +def sign_hmac_sha1(client, request): + """Sign a HMAC-SHA1 signature.""" + base_string = generate_signature_base_string(request) + return hmac_sha1_signature( + base_string, client.client_secret, client.token_secret) + + +def sign_rsa_sha1(client, request): + """Sign a RSASSA-PKCS #1 v1.5 base64 encoded signature.""" + base_string = generate_signature_base_string(request) + return rsa_sha1_signature(base_string, client.rsa_key) + + +def sign_plaintext(client, request): + """Sign a PLAINTEXT signature.""" + return plaintext_signature(client.client_secret, client.token_secret) + + +def verify_hmac_sha1(request): + """Verify a HMAC-SHA1 signature.""" + base_string = generate_signature_base_string(request) + sig = hmac_sha1_signature( + base_string, request.client_secret, request.token_secret) + return hmac.compare_digest(sig, request.signature) + + +def verify_rsa_sha1(request): + """Verify a RSASSA-PKCS #1 v1.5 base64 encoded signature.""" + from .rsa import verify_sha1 + base_string = generate_signature_base_string(request) + sig = binascii.a2b_base64(to_bytes(request.signature)) + return verify_sha1(sig, to_bytes(base_string), request.rsa_public_key) + + +def verify_plaintext(request): + """Verify a PLAINTEXT signature.""" + sig = plaintext_signature(request.client_secret, request.token_secret) + return hmac.compare_digest(sig, request.signature) diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/util.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/util.py new file mode 100644 index 00000000..9383e22e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/util.py @@ -0,0 +1,9 @@ +from authlib.common.urls import quote, unquote + + +def escape(s): + return quote(s, safe=b'~') + + +def unescape(s): + return unquote(s) diff --git a/.venv/Lib/site-packages/authlib/oauth1/rfc5849/wrapper.py b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/wrapper.py new file mode 100644 index 00000000..c03687ed --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth1/rfc5849/wrapper.py @@ -0,0 +1,129 @@ +from urllib.request import parse_keqv_list, parse_http_list +from authlib.common.urls import ( + urlparse, extract_params, url_decode, +) +from .signature import ( + SIGNATURE_TYPE_QUERY, + SIGNATURE_TYPE_BODY, + SIGNATURE_TYPE_HEADER +) +from .errors import ( + InsecureTransportError, + DuplicatedOAuthProtocolParameterError +) +from .util import unescape + + +class OAuth1Request: + def __init__(self, method, uri, body=None, headers=None): + InsecureTransportError.check(uri) + self.method = method + self.uri = uri + self.body = body + self.headers = headers or {} + + # states namespaces + self.client = None + self.credential = None + self.user = None + + self.query = urlparse.urlparse(uri).query + self.query_params = url_decode(self.query) + self.body_params = extract_params(body) or [] + + self.auth_params, self.realm = _parse_authorization_header(headers) + self.signature_type, self.oauth_params = _parse_oauth_params( + self.query_params, self.body_params, self.auth_params) + + params = [] + params.extend(self.query_params) + params.extend(self.body_params) + params.extend(self.auth_params) + self.params = params + + @property + def client_id(self): + return self.oauth_params.get('oauth_consumer_key') + + @property + def client_secret(self): + if self.client: + return self.client.get_client_secret() + + @property + def rsa_public_key(self): + if self.client: + return self.client.get_rsa_public_key() + + @property + def timestamp(self): + return self.oauth_params.get('oauth_timestamp') + + @property + def redirect_uri(self): + return self.oauth_params.get('oauth_callback') + + @property + def signature(self): + return self.oauth_params.get('oauth_signature') + + @property + def signature_method(self): + return self.oauth_params.get('oauth_signature_method') + + @property + def token(self): + return self.oauth_params.get('oauth_token') + + @property + def token_secret(self): + if self.credential: + return self.credential.get_oauth_token_secret() + + +def _filter_oauth(params): + for k, v in params: + if k.startswith('oauth_'): + yield (k, v) + + +def _parse_authorization_header(headers): + """Parse an OAuth authorization header into a list of 2-tuples""" + authorization_header = headers.get('Authorization') + if not authorization_header: + return [], None + + auth_scheme = 'oauth ' + if authorization_header.lower().startswith(auth_scheme): + items = parse_http_list(authorization_header[len(auth_scheme):]) + try: + items = parse_keqv_list(items).items() + auth_params = [(unescape(k), unescape(v)) for k, v in items] + realm = dict(auth_params).get('realm') + return auth_params, realm + except (IndexError, ValueError): + pass + raise ValueError('Malformed authorization header') + + +def _parse_oauth_params(query_params, body_params, auth_params): + oauth_params_set = [ + (SIGNATURE_TYPE_QUERY, list(_filter_oauth(query_params))), + (SIGNATURE_TYPE_BODY, list(_filter_oauth(body_params))), + (SIGNATURE_TYPE_HEADER, list(_filter_oauth(auth_params))) + ] + oauth_params_set = [params for params in oauth_params_set if params[1]] + if len(oauth_params_set) > 1: + found_types = [p[0] for p in oauth_params_set] + raise DuplicatedOAuthProtocolParameterError( + '"oauth_" params must come from only 1 signature type ' + 'but were found in {}'.format(','.join(found_types)) + ) + + if oauth_params_set: + signature_type = oauth_params_set[0][0] + oauth_params = dict(oauth_params_set[0][1]) + else: + signature_type = None + oauth_params = {} + return signature_type, oauth_params diff --git a/.venv/Lib/site-packages/authlib/oauth2/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/__init__.py new file mode 100644 index 00000000..05fdf30b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/__init__.py @@ -0,0 +1,16 @@ +from .base import OAuth2Error +from .auth import ClientAuth, TokenAuth +from .client import OAuth2Client +from .rfc6749 import ( + OAuth2Request, + JsonRequest, + AuthorizationServer, + ClientAuthentication, + ResourceProtector, +) + +__all__ = [ + 'OAuth2Error', 'ClientAuth', 'TokenAuth', 'OAuth2Client', + 'OAuth2Request', 'JsonRequest', 'AuthorizationServer', + 'ClientAuthentication', 'ResourceProtector', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b0f325d2 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/__pycache__/auth.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/auth.cpython-311.pyc new file mode 100644 index 00000000..94dd5738 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/auth.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..d3713309 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..ca39b199 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/auth.py b/.venv/Lib/site-packages/authlib/oauth2/auth.py new file mode 100644 index 00000000..e4ad1804 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/auth.py @@ -0,0 +1,106 @@ +import base64 +from urllib.parse import quote +from authlib.common.urls import add_params_to_qs, add_params_to_uri +from authlib.common.encoding import to_bytes, to_native +from .rfc6749 import OAuth2Token +from .rfc6750 import add_bearer_token + + +def encode_client_secret_basic(client, method, uri, headers, body): + text = f'{quote(client.client_id)}:{quote(client.client_secret)}' + auth = to_native(base64.b64encode(to_bytes(text, 'latin1'))) + headers['Authorization'] = f'Basic {auth}' + return uri, headers, body + + +def encode_client_secret_post(client, method, uri, headers, body): + body = add_params_to_qs(body or '', [ + ('client_id', client.client_id), + ('client_secret', client.client_secret or '') + ]) + if 'Content-Length' in headers: + headers['Content-Length'] = str(len(body)) + return uri, headers, body + + +def encode_none(client, method, uri, headers, body): + if method == 'GET': + uri = add_params_to_uri(uri, [('client_id', client.client_id)]) + return uri, headers, body + body = add_params_to_qs(body, [('client_id', client.client_id)]) + if 'Content-Length' in headers: + headers['Content-Length'] = str(len(body)) + return uri, headers, body + + +class ClientAuth: + """Attaches OAuth Client Information to HTTP requests. + + :param client_id: Client ID, which you get from client registration. + :param client_secret: Client Secret, which you get from registration. + :param auth_method: Client auth method for token endpoint. The supported + methods for now: + + * client_secret_basic (default) + * client_secret_post + * none + """ + DEFAULT_AUTH_METHODS = { + 'client_secret_basic': encode_client_secret_basic, + 'client_secret_post': encode_client_secret_post, + 'none': encode_none, + } + + def __init__(self, client_id, client_secret, auth_method=None): + if auth_method is None: + auth_method = 'client_secret_basic' + + self.client_id = client_id + self.client_secret = client_secret + + if auth_method in self.DEFAULT_AUTH_METHODS: + auth_method = self.DEFAULT_AUTH_METHODS[auth_method] + + self.auth_method = auth_method + + def prepare(self, method, uri, headers, body): + return self.auth_method(self, method, uri, headers, body) + + +class TokenAuth: + """Attach token information to HTTP requests. + + :param token: A dict or OAuth2Token instance of an OAuth 2.0 token + :param token_placement: The placement of the token, default is ``header``, + available choices: + + * header (default) + * body + * uri + """ + DEFAULT_TOKEN_TYPE = 'bearer' + SIGN_METHODS = { + 'bearer': add_bearer_token + } + + def __init__(self, token, token_placement='header', client=None): + self.token = OAuth2Token.from_dict(token) + self.token_placement = token_placement + self.client = client + self.hooks = set() + + def set_token(self, token): + self.token = OAuth2Token.from_dict(token) + + def prepare(self, uri, headers, body): + token_type = self.token.get('token_type', self.DEFAULT_TOKEN_TYPE) + sign = self.SIGN_METHODS[token_type.lower()] + uri, headers, body = sign( + self.token['access_token'], + uri, headers, body, + self.token_placement) + + for hook in self.hooks: + uri, headers, body = hook(uri, headers, body) + + return uri, headers, body diff --git a/.venv/Lib/site-packages/authlib/oauth2/base.py b/.venv/Lib/site-packages/authlib/oauth2/base.py new file mode 100644 index 00000000..9bcb15f8 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/base.py @@ -0,0 +1,26 @@ +from authlib.common.errors import AuthlibHTTPError +from authlib.common.urls import add_params_to_uri + + +class OAuth2Error(AuthlibHTTPError): + def __init__(self, description=None, uri=None, + status_code=None, state=None, + redirect_uri=None, redirect_fragment=False, error=None): + super().__init__(error, description, uri, status_code) + self.state = state + self.redirect_uri = redirect_uri + self.redirect_fragment = redirect_fragment + + def get_body(self): + """Get a list of body.""" + error = super().get_body() + if self.state: + error.append(('state', self.state)) + return error + + def __call__(self, uri=None): + if self.redirect_uri: + params = self.get_body() + loc = add_params_to_uri(self.redirect_uri, params, self.redirect_fragment) + return 302, '', [('Location', loc)] + return super().__call__(uri=uri) diff --git a/.venv/Lib/site-packages/authlib/oauth2/client.py b/.venv/Lib/site-packages/authlib/oauth2/client.py new file mode 100644 index 00000000..7adb0c8e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/client.py @@ -0,0 +1,438 @@ +from authlib.common.security import generate_token +from authlib.common.urls import url_decode +from .rfc6749.parameters import ( + prepare_grant_uri, + prepare_token_request, + parse_authorization_code_response, + parse_implicit_response, +) +from .rfc7009 import prepare_revoke_token_request +from .rfc7636 import create_s256_code_challenge +from .auth import TokenAuth, ClientAuth +from .base import OAuth2Error + +DEFAULT_HEADERS = { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' +} + + +class OAuth2Client: + """Construct a new OAuth 2 protocol client. + + :param session: Requests session object to communicate with + authorization server. + :param client_id: Client ID, which you get from client registration. + :param client_secret: Client Secret, which you get from registration. + :param token_endpoint_auth_method: client authentication method for + token endpoint. + :param revocation_endpoint_auth_method: client authentication method for + revocation endpoint. + :param scope: Scope that you needed to access user resources. + :param state: Shared secret to prevent CSRF attack. + :param redirect_uri: Redirect URI you registered as callback. + :param code_challenge_method: PKCE method name, only S256 is supported. + :param token: A dict of token attributes such as ``access_token``, + ``token_type`` and ``expires_at``. + :param token_placement: The place to put token in HTTP request. Available + values: "header", "body", "uri". + :param update_token: A function for you to update token. It accept a + :class:`OAuth2Token` as parameter. + """ + client_auth_class = ClientAuth + token_auth_class = TokenAuth + oauth_error_class = OAuth2Error + + EXTRA_AUTHORIZE_PARAMS = ( + 'response_mode', 'nonce', 'prompt', 'login_hint' + ) + SESSION_REQUEST_PARAMS = [] + + def __init__(self, session, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, state=None, redirect_uri=None, code_challenge_method=None, + token=None, token_placement='header', update_token=None, **metadata): + + self.session = session + self.client_id = client_id + self.client_secret = client_secret + self.state = state + + if token_endpoint_auth_method is None: + if client_secret: + token_endpoint_auth_method = 'client_secret_basic' + else: + token_endpoint_auth_method = 'none' + + self.token_endpoint_auth_method = token_endpoint_auth_method + + if revocation_endpoint_auth_method is None: + if client_secret: + revocation_endpoint_auth_method = 'client_secret_basic' + else: + revocation_endpoint_auth_method = 'none' + + self.revocation_endpoint_auth_method = revocation_endpoint_auth_method + + self.scope = scope + self.redirect_uri = redirect_uri + self.code_challenge_method = code_challenge_method + + self.token_auth = self.token_auth_class(token, token_placement, self) + self.update_token = update_token + + token_updater = metadata.pop('token_updater', None) + if token_updater: + raise ValueError('update token has been redesigned, checkout the documentation') + + self.metadata = metadata + + self.compliance_hook = { + 'access_token_response': set(), + 'refresh_token_request': set(), + 'refresh_token_response': set(), + 'revoke_token_request': set(), + 'introspect_token_request': set(), + } + self._auth_methods = {} + + def register_client_auth_method(self, auth): + """Extend client authenticate for token endpoint. + + :param auth: an instance to sign the request + """ + if isinstance(auth, tuple): + self._auth_methods[auth[0]] = auth[1] + else: + self._auth_methods[auth.name] = auth + + def client_auth(self, auth_method): + if isinstance(auth_method, str) and auth_method in self._auth_methods: + auth_method = self._auth_methods[auth_method] + return self.client_auth_class( + client_id=self.client_id, + client_secret=self.client_secret, + auth_method=auth_method, + ) + + @property + def token(self): + return self.token_auth.token + + @token.setter + def token(self, token): + self.token_auth.set_token(token) + + def create_authorization_url(self, url, state=None, code_verifier=None, **kwargs): + """Generate an authorization URL and state. + + :param url: Authorization endpoint url, must be HTTPS. + :param state: An optional state string for CSRF protection. If not + given it will be generated for you. + :param code_verifier: An optional code_verifier for code challenge. + :param kwargs: Extra parameters to include. + :return: authorization_url, state + """ + if state is None: + state = generate_token() + + response_type = self.metadata.get('response_type', 'code') + response_type = kwargs.pop('response_type', response_type) + if 'redirect_uri' not in kwargs: + kwargs['redirect_uri'] = self.redirect_uri + if 'scope' not in kwargs: + kwargs['scope'] = self.scope + + if code_verifier and response_type == 'code' and self.code_challenge_method == 'S256': + kwargs['code_challenge'] = create_s256_code_challenge(code_verifier) + kwargs['code_challenge_method'] = self.code_challenge_method + + for k in self.EXTRA_AUTHORIZE_PARAMS: + if k not in kwargs and k in self.metadata: + kwargs[k] = self.metadata[k] + + uri = prepare_grant_uri( + url, client_id=self.client_id, response_type=response_type, + state=state, **kwargs) + return uri, state + + def fetch_token(self, url=None, body='', method='POST', headers=None, + auth=None, grant_type=None, state=None, **kwargs): + """Generic method for fetching an access token from the token endpoint. + + :param url: Access Token endpoint URL, if not configured, + ``authorization_response`` is used to extract token from + its fragment (implicit way). + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param method: The HTTP method used to make the request. Defaults + to POST, but may also be GET. Other methods should + be added as needed. + :param headers: Dict to default request headers with. + :param auth: An auth tuple or method as accepted by requests. + :param grant_type: Use specified grant_type to fetch token + :return: A :class:`OAuth2Token` object (a dict too). + """ + state = state or self.state + # implicit grant_type + authorization_response = kwargs.pop('authorization_response', None) + if authorization_response and '#' in authorization_response: + return self.token_from_fragment(authorization_response, state) + + session_kwargs = self._extract_session_request_params(kwargs) + + if authorization_response and 'code=' in authorization_response: + grant_type = 'authorization_code' + params = parse_authorization_code_response( + authorization_response, + state=state, + ) + kwargs['code'] = params['code'] + + if grant_type is None: + grant_type = self.metadata.get('grant_type') + + if grant_type is None: + grant_type = _guess_grant_type(kwargs) + self.metadata['grant_type'] = grant_type + + body = self._prepare_token_endpoint_body(body, grant_type, **kwargs) + + if auth is None: + auth = self.client_auth(self.token_endpoint_auth_method) + + if headers is None: + headers = DEFAULT_HEADERS + + if url is None: + url = self.metadata.get('token_endpoint') + + return self._fetch_token( + url, body=body, auth=auth, method=method, + headers=headers, **session_kwargs + ) + + def token_from_fragment(self, authorization_response, state=None): + token = parse_implicit_response(authorization_response, state) + if 'error' in token: + raise self.oauth_error_class( + error=token['error'], + description=token.get('error_description') + ) + self.token = token + return token + + def refresh_token(self, url, refresh_token=None, body='', + auth=None, headers=None, **kwargs): + """Fetch a new access token using a refresh token. + + :param url: Refresh Token endpoint, must be HTTPS. + :param refresh_token: The refresh_token to use. + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: A :class:`OAuth2Token` object (a dict too). + """ + session_kwargs = self._extract_session_request_params(kwargs) + refresh_token = refresh_token or self.token.get('refresh_token') + if 'scope' not in kwargs and self.scope: + kwargs['scope'] = self.scope + body = prepare_token_request( + 'refresh_token', body, + refresh_token=refresh_token, **kwargs + ) + + if headers is None: + headers = DEFAULT_HEADERS.copy() + + for hook in self.compliance_hook['refresh_token_request']: + url, headers, body = hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.token_endpoint_auth_method) + + return self._refresh_token( + url, refresh_token=refresh_token, body=body, headers=headers, + auth=auth, **session_kwargs) + + def ensure_active_token(self, token): + if not token.is_expired(): + return True + refresh_token = token.get('refresh_token') + url = self.metadata.get('token_endpoint') + if refresh_token and url: + self.refresh_token(url, refresh_token=refresh_token) + return True + elif self.metadata.get('grant_type') == 'client_credentials': + access_token = token['access_token'] + new_token = self.fetch_token(url, grant_type='client_credentials') + if self.update_token: + self.update_token(new_token, access_token=access_token) + return True + + def revoke_token(self, url, token=None, token_type_hint=None, + body=None, auth=None, headers=None, **kwargs): + """Revoke token method defined via `RFC7009`_. + + :param url: Revoke Token endpoint, must be HTTPS. + :param token: The token to be revoked. + :param token_type_hint: The type of the token that to be revoked. + It can be "access_token" or "refresh_token". + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: Revocation Response + + .. _`RFC7009`: https://tools.ietf.org/html/rfc7009 + """ + return self._handle_token_hint( + 'revoke_token_request', url, + token=token, token_type_hint=token_type_hint, + body=body, auth=auth, headers=headers, **kwargs) + + def introspect_token(self, url, token=None, token_type_hint=None, + body=None, auth=None, headers=None, **kwargs): + """Implementation of OAuth 2.0 Token Introspection defined via `RFC7662`_. + + :param url: Introspection Endpoint, must be HTTPS. + :param token: The token to be introspected. + :param token_type_hint: The type of the token that to be revoked. + It can be "access_token" or "refresh_token". + :param body: Optional application/x-www-form-urlencoded body to add the + include in the token request. Prefer kwargs over body. + :param auth: An auth tuple or method as accepted by requests. + :param headers: Dict to default request headers with. + :return: Introspection Response + + .. _`RFC7662`: https://tools.ietf.org/html/rfc7662 + """ + return self._handle_token_hint( + 'introspect_token_request', url, + token=token, token_type_hint=token_type_hint, + body=body, auth=auth, headers=headers, **kwargs) + + def register_compliance_hook(self, hook_type, hook): + """Register a hook for request/response tweaking. + + Available hooks are: + + * access_token_response: invoked before token parsing. + * refresh_token_request: invoked before refreshing token. + * refresh_token_response: invoked before refresh token parsing. + * protected_request: invoked before making a request. + * revoke_token_request: invoked before revoking a token. + * introspect_token_request: invoked before introspecting a token. + """ + if hook_type == 'protected_request': + self.token_auth.hooks.add(hook) + return + + if hook_type not in self.compliance_hook: + raise ValueError('Hook type %s is not in %s.', + hook_type, self.compliance_hook) + self.compliance_hook[hook_type].add(hook) + + def parse_response_token(self, resp): + if resp.status_code >= 500: + resp.raise_for_status() + + token = resp.json() + if 'error' in token: + raise self.oauth_error_class( + error=token['error'], + description=token.get('error_description') + ) + self.token = token + return self.token + + def _fetch_token(self, url, body='', headers=None, auth=None, + method='POST', **kwargs): + + if method.upper() == 'POST': + resp = self.session.post( + url, data=dict(url_decode(body)), + headers=headers, auth=auth, **kwargs) + else: + if '?' in url: + url = '&'.join([url, body]) + else: + url = '?'.join([url, body]) + resp = self.session.request(method, url, headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook['access_token_response']: + resp = hook(resp) + + return self.parse_response_token(resp) + + def _refresh_token(self, url, refresh_token=None, body='', headers=None, + auth=None, **kwargs): + resp = self._http_post(url, body=body, auth=auth, headers=headers, **kwargs) + + for hook in self.compliance_hook['refresh_token_response']: + resp = hook(resp) + + token = self.parse_response_token(resp) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if callable(self.update_token): + self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def _handle_token_hint(self, hook, url, token=None, token_type_hint=None, + body=None, auth=None, headers=None, **kwargs): + if token is None and self.token: + token = self.token.get('refresh_token') or self.token.get('access_token') + + if body is None: + body = '' + + body, headers = prepare_revoke_token_request( + token, token_type_hint, body, headers) + + for hook in self.compliance_hook[hook]: + url, headers, body = hook(url, headers, body) + + if auth is None: + auth = self.client_auth(self.revocation_endpoint_auth_method) + + session_kwargs = self._extract_session_request_params(kwargs) + return self._http_post( + url, body, auth=auth, headers=headers, **session_kwargs) + + def _prepare_token_endpoint_body(self, body, grant_type, **kwargs): + if grant_type == 'authorization_code': + if 'redirect_uri' not in kwargs: + kwargs['redirect_uri'] = self.redirect_uri + return prepare_token_request(grant_type, body, **kwargs) + + if 'scope' not in kwargs and self.scope: + kwargs['scope'] = self.scope + return prepare_token_request(grant_type, body, **kwargs) + + def _extract_session_request_params(self, kwargs): + """Extract parameters for session object from the passing ``**kwargs``.""" + rv = {} + for k in self.SESSION_REQUEST_PARAMS: + if k in kwargs: + rv[k] = kwargs.pop(k) + return rv + + def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): + return self.session.post( + url, data=dict(url_decode(body)), + headers=headers, auth=auth, **kwargs) + + +def _guess_grant_type(kwargs): + if 'code' in kwargs: + grant_type = 'authorization_code' + elif 'username' in kwargs and 'password' in kwargs: + grant_type = 'password' + else: + grant_type = 'client_credentials' + return grant_type diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__init__.py new file mode 100644 index 00000000..e1748e3d --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__init__.py @@ -0,0 +1,83 @@ +""" + authlib.oauth2.rfc6749 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + The OAuth 2.0 Authorization Framework. + + https://tools.ietf.org/html/rfc6749 +""" + +from .requests import OAuth2Request, JsonRequest +from .wrappers import OAuth2Token +from .errors import ( + OAuth2Error, + AccessDeniedError, + MissingAuthorizationError, + InvalidGrantError, + InvalidClientError, + InvalidRequestError, + InvalidScopeError, + InsecureTransportError, + UnauthorizedClientError, + UnsupportedResponseTypeError, + UnsupportedGrantTypeError, + UnsupportedTokenTypeError, + # exceptions for clients + MissingCodeException, + MissingTokenException, + MissingTokenTypeException, + MismatchingStateException, +) +from .models import ClientMixin, AuthorizationCodeMixin, TokenMixin +from .authenticate_client import ClientAuthentication +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector, TokenValidator +from .token_endpoint import TokenEndpoint +from .grants import ( + BaseGrant, + AuthorizationEndpointMixin, + TokenEndpointMixin, + AuthorizationCodeGrant, + ImplicitGrant, + ResourceOwnerPasswordCredentialsGrant, + ClientCredentialsGrant, + RefreshTokenGrant, +) +from .util import scope_to_list, list_to_scope + +__all__ = [ + 'OAuth2Token', + 'OAuth2Request', 'JsonRequest', + 'OAuth2Error', + 'AccessDeniedError', + 'MissingAuthorizationError', + 'InvalidGrantError', + 'InvalidClientError', + 'InvalidRequestError', + 'InvalidScopeError', + 'InsecureTransportError', + 'UnauthorizedClientError', + 'UnsupportedResponseTypeError', + 'UnsupportedGrantTypeError', + 'UnsupportedTokenTypeError', + 'MissingCodeException', + 'MissingTokenException', + 'MissingTokenTypeException', + 'MismatchingStateException', + 'ClientMixin', 'AuthorizationCodeMixin', 'TokenMixin', + 'ClientAuthentication', + 'AuthorizationServer', + 'ResourceProtector', + 'TokenValidator', + 'TokenEndpoint', + 'BaseGrant', + 'AuthorizationEndpointMixin', + 'TokenEndpointMixin', + 'AuthorizationCodeGrant', + 'ImplicitGrant', + 'ResourceOwnerPasswordCredentialsGrant', + 'ClientCredentialsGrant', + 'RefreshTokenGrant', + 'scope_to_list', 'list_to_scope', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b36d37af Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authenticate_client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authenticate_client.cpython-311.pyc new file mode 100644 index 00000000..d114b4f3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authenticate_client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authorization_server.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authorization_server.cpython-311.pyc new file mode 100644 index 00000000..02fc4fcd Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/authorization_server.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..c914affc Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..b33cb7c2 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/parameters.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/parameters.cpython-311.pyc new file mode 100644 index 00000000..ac2657a5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/parameters.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/requests.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/requests.cpython-311.pyc new file mode 100644 index 00000000..bd583b95 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/requests.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/resource_protector.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/resource_protector.cpython-311.pyc new file mode 100644 index 00000000..771b2de3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/resource_protector.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/token_endpoint.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/token_endpoint.cpython-311.pyc new file mode 100644 index 00000000..04342a6e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/token_endpoint.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..8d6270ef Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/wrappers.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 00000000..cc450a6d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/__pycache__/wrappers.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authenticate_client.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authenticate_client.py new file mode 100644 index 00000000..adcfd25f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authenticate_client.py @@ -0,0 +1,103 @@ +""" + authlib.oauth2.rfc6749.authenticate_client + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Registry of client authentication methods, with 3 built-in methods: + + 1. client_secret_basic + 2. client_secret_post + 3. none + + The "client_secret_basic" method is used a lot in examples of `RFC6749`_, + but the concept of naming are introduced in `RFC7591`_. + + .. _`RFC6749`: https://tools.ietf.org/html/rfc6749 + .. _`RFC7591`: https://tools.ietf.org/html/rfc7591 +""" + +import logging +from .errors import InvalidClientError +from .util import extract_basic_authorization + +log = logging.getLogger(__name__) + +__all__ = ['ClientAuthentication'] + + +class ClientAuthentication: + def __init__(self, query_client): + self.query_client = query_client + self._methods = { + 'none': authenticate_none, + 'client_secret_basic': authenticate_client_secret_basic, + 'client_secret_post': authenticate_client_secret_post, + } + + def register(self, method, func): + self._methods[method] = func + + def authenticate(self, request, methods, endpoint): + for method in methods: + func = self._methods[method] + client = func(self.query_client, request) + if client and client.check_endpoint_auth_method(method, endpoint): + request.auth_method = method + return client + + if 'client_secret_basic' in methods: + raise InvalidClientError(state=request.state, status_code=401) + raise InvalidClientError(state=request.state) + + def __call__(self, request, methods, endpoint='token'): + return self.authenticate(request, methods, endpoint) + + +def authenticate_client_secret_basic(query_client, request): + """Authenticate client by ``client_secret_basic`` method. The client + uses HTTP Basic for authentication. + """ + client_id, client_secret = extract_basic_authorization(request.headers) + if client_id and client_secret: + client = _validate_client(query_client, client_id, request.state, 401) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_basic" success') + return client + log.debug(f'Authenticate {client_id} via "client_secret_basic" failed') + + +def authenticate_client_secret_post(query_client, request): + """Authenticate client by ``client_secret_post`` method. The client + uses POST parameters for authentication. + """ + data = request.form + client_id = data.get('client_id') + client_secret = data.get('client_secret') + if client_id and client_secret: + client = _validate_client(query_client, client_id, request.state) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_post" success') + return client + log.debug(f'Authenticate {client_id} via "client_secret_post" failed') + + +def authenticate_none(query_client, request): + """Authenticate public client by ``none`` method. The client + does not have a client secret. + """ + client_id = request.client_id + if client_id and not request.data.get('client_secret'): + client = _validate_client(query_client, client_id, request.state) + log.debug(f'Authenticate {client_id} via "none" success') + return client + log.debug(f'Authenticate {client_id} via "none" failed') + + +def _validate_client(query_client, client_id, state=None, status_code=400): + if client_id is None: + raise InvalidClientError(state=state, status_code=status_code) + + client = query_client(client_id) + if not client: + raise InvalidClientError(state=state, status_code=status_code) + + return client diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authorization_server.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authorization_server.py new file mode 100644 index 00000000..3190540e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/authorization_server.py @@ -0,0 +1,301 @@ +from authlib.common.errors import ContinueIteration +from .authenticate_client import ClientAuthentication +from .requests import OAuth2Request, JsonRequest +from .errors import ( + OAuth2Error, + InvalidScopeError, + UnsupportedResponseTypeError, + UnsupportedGrantTypeError, +) +from .util import scope_to_list + + +class AuthorizationServer: + """Authorization server that handles Authorization Endpoint and Token + Endpoint. + + :param scopes_supported: A list of supported scopes by this authorization server. + """ + def __init__(self, scopes_supported=None): + self.scopes_supported = scopes_supported + self._token_generators = {} + self._client_auth = None + self._authorization_grants = [] + self._token_grants = [] + self._endpoints = {} + + def query_client(self, client_id): + """Query OAuth client by client_id. The client model class MUST + implement the methods described by + :class:`~authlib.oauth2.rfc6749.ClientMixin`. + """ + raise NotImplementedError() + + def save_token(self, token, request): + """Define function to save the generated token into database.""" + raise NotImplementedError() + + def generate_token(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """Generate the token dict. + + :param grant_type: current requested grant_type. + :param client: the client that making the request. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + # generator for a specified grant type + func = self._token_generators.get(grant_type) + if not func: + # default generator for all grant types + func = self._token_generators.get('default') + if not func: + raise RuntimeError('No configured token generator') + + return func( + grant_type=grant_type, client=client, user=user, scope=scope, + expires_in=expires_in, include_refresh_token=include_refresh_token) + + def register_token_generator(self, grant_type, func): + """Register a function as token generator for the given ``grant_type``. + Developers MUST register a default token generator with a special + ``grant_type=default``:: + + def generate_bearer_token(grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + token = {'token_type': 'Bearer', 'access_token': ...} + if include_refresh_token: + token['refresh_token'] = ... + ... + return token + + authorization_server.register_token_generator('default', generate_bearer_token) + + If you register a generator for a certain grant type, that generator will only works + for the given grant type:: + + authorization_server.register_token_generator('client_credentials', generate_bearer_token) + + :param grant_type: string name of the grant type + :param func: a function to generate token + """ + self._token_generators[grant_type] = func + + def authenticate_client(self, request, methods, endpoint='token'): + """Authenticate client via HTTP request information with the given + methods, such as ``client_secret_basic``, ``client_secret_post``. + """ + if self._client_auth is None and self.query_client: + self._client_auth = ClientAuthentication(self.query_client) + return self._client_auth(request, methods, endpoint) + + def register_client_auth_method(self, method, func): + """Add more client auth method. The default methods are: + + * none: The client is a public client and does not have a client secret + * client_secret_post: The client uses the HTTP POST parameters + * client_secret_basic: The client uses HTTP Basic + + :param method: Name of the Auth method + :param func: Function to authenticate the client + + The auth method accept two parameters: ``query_client`` and ``request``, + an example for this method:: + + def authenticate_client_via_custom(query_client, request): + client_id = request.headers['X-Client-Id'] + client = query_client(client_id) + do_some_validation(client) + return client + + authorization_server.register_client_auth_method( + 'custom', authenticate_client_via_custom) + """ + if self._client_auth is None and self.query_client: + self._client_auth = ClientAuthentication(self.query_client) + + self._client_auth.register(method, func) + + def get_error_uri(self, request, error): + """Return a URI for the given error, framework may implement this method.""" + return None + + def send_signal(self, name, *args, **kwargs): + """Framework integration can re-implement this method to support + signal system. + """ + raise NotImplementedError() + + def create_oauth2_request(self, request) -> OAuth2Request: + """This method MUST be implemented in framework integrations. It is + used to create an OAuth2Request instance. + + :param request: the "request" instance in framework + :return: OAuth2Request instance + """ + raise NotImplementedError() + + def create_json_request(self, request) -> JsonRequest: + """This method MUST be implemented in framework integrations. It is + used to create an HttpRequest instance. + + :param request: the "request" instance in framework + :return: HttpRequest instance + """ + raise NotImplementedError() + + def handle_response(self, status, body, headers): + """Return HTTP response. Framework MUST implement this function.""" + raise NotImplementedError() + + def validate_requested_scope(self, scope, state=None): + """Validate if requested scope is supported by Authorization Server. + Developers CAN re-write this method to meet your needs. + """ + if scope and self.scopes_supported: + scopes = set(scope_to_list(scope)) + if not set(self.scopes_supported).issuperset(scopes): + raise InvalidScopeError(state=state) + + def register_grant(self, grant_cls, extensions=None): + """Register a grant class into the endpoint registry. Developers + can implement the grants in ``authlib.oauth2.rfc6749.grants`` and + register with this method:: + + class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): + def authenticate_user(self, credential): + # ... + + authorization_server.register_grant(AuthorizationCodeGrant) + + :param grant_cls: a grant class. + :param extensions: extensions for the grant class. + """ + if hasattr(grant_cls, 'check_authorization_endpoint'): + self._authorization_grants.append((grant_cls, extensions)) + if hasattr(grant_cls, 'check_token_endpoint'): + self._token_grants.append((grant_cls, extensions)) + + def register_endpoint(self, endpoint): + """Add extra endpoint to authorization server. e.g. + RevocationEndpoint:: + + authorization_server.register_endpoint(RevocationEndpoint) + + :param endpoint_cls: A endpoint class or instance. + """ + if isinstance(endpoint, type): + endpoint = endpoint(self) + else: + endpoint.server = self + + endpoints = self._endpoints.setdefault(endpoint.ENDPOINT_NAME, []) + endpoints.append(endpoint) + + def get_authorization_grant(self, request): + """Find the authorization grant for current request. + + :param request: OAuth2Request instance. + :return: grant instance + """ + for (grant_cls, extensions) in self._authorization_grants: + if grant_cls.check_authorization_endpoint(request): + return _create_grant(grant_cls, extensions, request, self) + raise UnsupportedResponseTypeError(request.response_type) + + def get_consent_grant(self, request=None, end_user=None): + """Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization. + """ + request = self.create_oauth2_request(request) + request.user = end_user + + grant = self.get_authorization_grant(request) + grant.validate_consent_request() + return grant + + def get_token_grant(self, request): + """Find the token grant for current request. + + :param request: OAuth2Request instance. + :return: grant instance + """ + for (grant_cls, extensions) in self._token_grants: + if grant_cls.check_token_endpoint(request): + return _create_grant(grant_cls, extensions, request, self) + raise UnsupportedGrantTypeError(request.grant_type) + + def create_endpoint_response(self, name, request=None): + """Validate endpoint request and create endpoint response. + + :param name: Endpoint name + :param request: HTTP request instance. + :return: Response + """ + if name not in self._endpoints: + raise RuntimeError(f'There is no "{name}" endpoint.') + + endpoints = self._endpoints[name] + for endpoint in endpoints: + request = endpoint.create_endpoint_request(request) + try: + return self.handle_response(*endpoint(request)) + except ContinueIteration: + continue + except OAuth2Error as error: + return self.handle_error_response(request, error) + + def create_authorization_response(self, request=None, grant_user=None): + """Validate authorization request and create authorization response. + + :param request: HTTP request instance. + :param grant_user: if granted, it is resource owner. If denied, + it is None. + :returns: Response + """ + if not isinstance(request, OAuth2Request): + request = self.create_oauth2_request(request) + + try: + grant = self.get_authorization_grant(request) + except UnsupportedResponseTypeError as error: + return self.handle_error_response(request, error) + + try: + redirect_uri = grant.validate_authorization_request() + args = grant.create_authorization_response(redirect_uri, grant_user) + return self.handle_response(*args) + except OAuth2Error as error: + return self.handle_error_response(request, error) + + def create_token_response(self, request=None): + """Validate token request and create token response. + + :param request: HTTP request instance + """ + request = self.create_oauth2_request(request) + try: + grant = self.get_token_grant(request) + except UnsupportedGrantTypeError as error: + return self.handle_error_response(request, error) + + try: + grant.validate_token_request() + args = grant.create_token_response() + return self.handle_response(*args) + except OAuth2Error as error: + return self.handle_error_response(request, error) + + def handle_error_response(self, request, error): + return self.handle_response(*error(self.get_error_uri(request, error))) + + +def _create_grant(grant_cls, extensions, request, server): + grant = grant_cls(request, server) + if extensions: + for ext in extensions: + ext(grant) + return grant diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/errors.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/errors.py new file mode 100644 index 00000000..63ffb47e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/errors.py @@ -0,0 +1,233 @@ +""" + authlib.oauth2.rfc6749.errors + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Implementation for OAuth 2 Error Response. A basic error has + parameters: + + error + REQUIRED. A single ASCII [USASCII] error code. + + error_description + OPTIONAL. Human-readable ASCII [USASCII] text providing + additional information, used to assist the client developer in + understanding the error that occurred. + + error_uri + OPTIONAL. A URI identifying a human-readable web page with + information about the error, used to provide the client + developer with additional information about the error. + Values for the "error_uri" parameter MUST conform to the + URI-reference syntax and thus MUST NOT include characters + outside the set %x21 / %x23-5B / %x5D-7E. + + state + REQUIRED if a "state" parameter was present in the client + authorization request. The exact value received from the + client. + + https://tools.ietf.org/html/rfc6749#section-5.2 + + :copyright: (c) 2017 by Hsiaoming Yang. +""" +from authlib.oauth2.base import OAuth2Error +from authlib.common.security import is_secure_transport + +__all__ = [ + 'OAuth2Error', + 'InsecureTransportError', 'InvalidRequestError', + 'InvalidClientError', 'UnauthorizedClientError', 'InvalidGrantError', + 'UnsupportedResponseTypeError', 'UnsupportedGrantTypeError', + 'InvalidScopeError', 'AccessDeniedError', + 'MissingAuthorizationError', 'UnsupportedTokenTypeError', + 'MissingCodeException', 'MissingTokenException', + 'MissingTokenTypeException', 'MismatchingStateException', +] + + +class InsecureTransportError(OAuth2Error): + error = 'insecure_transport' + description = 'OAuth 2 MUST utilize https.' + + @classmethod + def check(cls, uri): + """Check and raise InsecureTransportError with the given URI.""" + if not is_secure_transport(uri): + raise cls() + + +class InvalidRequestError(OAuth2Error): + """The request is missing a required parameter, includes an + unsupported parameter value (other than grant type), + repeats a parameter, includes multiple credentials, + utilizes more than one mechanism for authenticating the + client, or is otherwise malformed. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'invalid_request' + + +class InvalidClientError(OAuth2Error): + """Client authentication failed (e.g., unknown client, no + client authentication included, or unsupported + authentication method). The authorization server MAY + return an HTTP 401 (Unauthorized) status code to indicate + which HTTP authentication schemes are supported. If the + client attempted to authenticate via the "Authorization" + request header field, the authorization server MUST + respond with an HTTP 401 (Unauthorized) status code and + include the "WWW-Authenticate" response header field + matching the authentication scheme used by the client. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'invalid_client' + status_code = 400 + + def get_headers(self): + headers = super().get_headers() + if self.status_code == 401: + error_description = self.get_error_description() + # safe escape + error_description = error_description.replace('"', '|') + extras = [ + f'error="{self.error}"', + f'error_description="{error_description}"' + ] + headers.append( + ('WWW-Authenticate', 'Basic ' + ', '.join(extras)) + ) + return headers + + +class InvalidGrantError(OAuth2Error): + """The provided authorization grant (e.g., authorization + code, resource owner credentials) or refresh token is + invalid, expired, revoked, does not match the redirection + URI used in the authorization request, or was issued to + another client. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'invalid_grant' + + +class UnauthorizedClientError(OAuth2Error): + """ The authenticated client is not authorized to use this + authorization grant type. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'unauthorized_client' + + +class UnsupportedResponseTypeError(OAuth2Error): + """The authorization server does not support obtaining + an access token using this method.""" + error = 'unsupported_response_type' + + def __init__(self, response_type): + super().__init__() + self.response_type = response_type + + def get_error_description(self): + return f'response_type={self.response_type} is not supported' + + +class UnsupportedGrantTypeError(OAuth2Error): + """The authorization grant type is not supported by the + authorization server. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'unsupported_grant_type' + + def __init__(self, grant_type): + super().__init__() + self.grant_type = grant_type + + def get_error_description(self): + return f'grant_type={self.grant_type} is not supported' + + +class InvalidScopeError(OAuth2Error): + """The requested scope is invalid, unknown, malformed, or + exceeds the scope granted by the resource owner. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + error = 'invalid_scope' + description = 'The requested scope is invalid, unknown, or malformed.' + + +class AccessDeniedError(OAuth2Error): + """The resource owner or authorization server denied the request. + + Used in authorization endpoint for "code" and "implicit". Defined in + `Section 4.1.2.1`_. + + .. _`Section 4.1.2.1`: https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + """ + error = 'access_denied' + description = 'The resource owner or authorization server denied the request' + + +# -- below are extended errors -- # + + +class ForbiddenError(OAuth2Error): + status_code = 401 + + def __init__(self, auth_type=None, realm=None): + super().__init__() + self.auth_type = auth_type + self.realm = realm + + def get_headers(self): + headers = super().get_headers() + if not self.auth_type: + return headers + + extras = [] + if self.realm: + extras.append(f'realm="{self.realm}"') + extras.append(f'error="{self.error}"') + error_description = self.description + extras.append(f'error_description="{error_description}"') + headers.append( + ('WWW-Authenticate', f'{self.auth_type} ' + ', '.join(extras)) + ) + return headers + + +class MissingAuthorizationError(ForbiddenError): + error = 'missing_authorization' + description = 'Missing "Authorization" in headers.' + + +class UnsupportedTokenTypeError(ForbiddenError): + error = 'unsupported_token_type' + + +# -- exceptions for clients -- # + + +class MissingCodeException(OAuth2Error): + error = 'missing_code' + description = 'Missing "code" in response.' + + +class MissingTokenException(OAuth2Error): + error = 'missing_token' + description = 'Missing "access_token" in response.' + + +class MissingTokenTypeException(OAuth2Error): + error = 'missing_token_type' + description = 'Missing "token_type" in response.' + + +class MismatchingStateException(OAuth2Error): + error = 'mismatching_state' + description = 'CSRF Warning! State not equal in request and response.' diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__init__.py new file mode 100644 index 00000000..b1797565 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__init__.py @@ -0,0 +1,37 @@ +""" + authlib.oauth2.rfc6749.grants + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Implementation for `Section 4`_ of "Obtaining Authorization". + + To request an access token, the client obtains authorization from the + resource owner. The authorization is expressed in the form of an + authorization grant, which the client uses to request the access + token. OAuth defines four grant types: + + 1. authorization code + 2. implicit + 3. resource owner password credentials + 4. client credentials. + + It also provides an extension mechanism for defining additional grant + types. Authlib defines refresh_token as a grant type too. + + .. _`Section 4`: https://tools.ietf.org/html/rfc6749#section-4 +""" + +# flake8: noqa + +from .base import BaseGrant, AuthorizationEndpointMixin, TokenEndpointMixin +from .authorization_code import AuthorizationCodeGrant +from .implicit import ImplicitGrant +from .resource_owner_password_credentials import ResourceOwnerPasswordCredentialsGrant +from .client_credentials import ClientCredentialsGrant +from .refresh_token import RefreshTokenGrant + +__all__ = [ + 'BaseGrant', 'AuthorizationEndpointMixin', 'TokenEndpointMixin', + 'AuthorizationCodeGrant', 'ImplicitGrant', + 'ResourceOwnerPasswordCredentialsGrant', + 'ClientCredentialsGrant', 'RefreshTokenGrant', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..9e38f0e2 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/authorization_code.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/authorization_code.cpython-311.pyc new file mode 100644 index 00000000..de6235ef Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/authorization_code.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..929adc0e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/client_credentials.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/client_credentials.cpython-311.pyc new file mode 100644 index 00000000..8ecd110e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/client_credentials.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/implicit.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/implicit.cpython-311.pyc new file mode 100644 index 00000000..1698f59d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/implicit.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/refresh_token.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/refresh_token.cpython-311.pyc new file mode 100644 index 00000000..aa7d35e5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/refresh_token.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/resource_owner_password_credentials.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/resource_owner_password_credentials.cpython-311.pyc new file mode 100644 index 00000000..cf1b76d7 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/__pycache__/resource_owner_password_credentials.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/authorization_code.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/authorization_code.py new file mode 100644 index 00000000..76a51de1 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/authorization_code.py @@ -0,0 +1,378 @@ +import logging +from authlib.common.urls import add_params_to_uri +from authlib.common.security import generate_token +from .base import BaseGrant, AuthorizationEndpointMixin, TokenEndpointMixin +from ..errors import ( + OAuth2Error, + UnauthorizedClientError, + InvalidClientError, + InvalidGrantError, + InvalidRequestError, + AccessDeniedError, +) + +log = logging.getLogger(__name__) + + +class AuthorizationCodeGrant(BaseGrant, AuthorizationEndpointMixin, TokenEndpointMixin): + """The authorization code grant type is used to obtain both access + tokens and refresh tokens and is optimized for confidential clients. + Since this is a redirection-based flow, the client must be capable of + interacting with the resource owner's user-agent (typically a web + browser) and capable of receiving incoming requests (via redirection) + from the authorization server:: + + +----------+ + | Resource | + | Owner | + | | + +----------+ + ^ + | + (B) + +----|-----+ Client Identifier +---------------+ + | -+----(A)-- & Redirection URI ---->| | + | User- | | Authorization | + | Agent -+----(B)-- User authenticates --->| Server | + | | | | + | -+----(C)-- Authorization Code ---<| | + +-|----|---+ +---------------+ + | | ^ v + (A) (C) | | + | | | | + ^ v | | + +---------+ | | + | |>---(D)-- Authorization Code ---------' | + | Client | & Redirection URI | + | | | + | |<---(E)----- Access Token -------------------' + +---------+ (w/ Optional Refresh Token) + """ + #: Allowed client auth methods for token endpoint + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post'] + + #: Generated "code" length + AUTHORIZATION_CODE_LENGTH = 48 + + RESPONSE_TYPES = {'code'} + GRANT_TYPE = 'authorization_code' + + def validate_authorization_request(self): + """The client constructs the request URI by adding the following + parameters to the query component of the authorization endpoint URI + using the "application/x-www-form-urlencoded" format. + Per `Section 4.1.1`_. + + response_type + REQUIRED. Value MUST be set to "code". + + client_id + REQUIRED. The client identifier as described in Section 2.2. + + redirect_uri + OPTIONAL. As described in Section 3.1.2. + + scope + OPTIONAL. The scope of the access request as described by + Section 3.3. + + state + RECOMMENDED. An opaque value used by the client to maintain + state between the request and callback. The authorization + server includes this value when redirecting the user-agent back + to the client. The parameter SHOULD be used for preventing + cross-site request forgery as described in Section 10.12. + + The client directs the resource owner to the constructed URI using an + HTTP redirection response, or by other means available to it via the + user-agent. + + For example, the client directs the user-agent to make the following + HTTP request using TLS (with extra line breaks for display purposes + only): + + .. code-block:: http + + GET /authorize?response_type=code&client_id=s6BhdRkqt3&state=xyz + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb HTTP/1.1 + Host: server.example.com + + The authorization server validates the request to ensure that all + required parameters are present and valid. If the request is valid, + the authorization server authenticates the resource owner and obtains + an authorization decision (by asking the resource owner or by + establishing approval via other means). + + .. _`Section 4.1.1`: https://tools.ietf.org/html/rfc6749#section-4.1.1 + """ + return validate_code_authorization_request(self) + + def create_authorization_response(self, redirect_uri: str, grant_user): + """If the resource owner grants the access request, the authorization + server issues an authorization code and delivers it to the client by + adding the following parameters to the query component of the + redirection URI using the "application/x-www-form-urlencoded" format. + Per `Section 4.1.2`_. + + code + REQUIRED. The authorization code generated by the + authorization server. The authorization code MUST expire + shortly after it is issued to mitigate the risk of leaks. A + maximum authorization code lifetime of 10 minutes is + RECOMMENDED. The client MUST NOT use the authorization code + more than once. If an authorization code is used more than + once, the authorization server MUST deny the request and SHOULD + revoke (when possible) all tokens previously issued based on + that authorization code. The authorization code is bound to + the client identifier and redirection URI. + state + REQUIRED if the "state" parameter was present in the client + authorization request. The exact value received from the + client. + + For example, the authorization server redirects the user-agent by + sending the following HTTP response. + + .. code-block:: http + + HTTP/1.1 302 Found + Location: https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA + &state=xyz + + .. _`Section 4.1.2`: https://tools.ietf.org/html/rfc6749#section-4.1.2 + + :param redirect_uri: Redirect to the given URI for the authorization + :param grant_user: if resource owner granted the request, pass this + resource owner, otherwise pass None. + :returns: (status_code, body, headers) + """ + if not grant_user: + raise AccessDeniedError(state=self.request.state, redirect_uri=redirect_uri) + + self.request.user = grant_user + + code = self.generate_authorization_code() + self.save_authorization_code(code, self.request) + + params = [('code', code)] + if self.request.state: + params.append(('state', self.request.state)) + uri = add_params_to_uri(redirect_uri, params) + headers = [('Location', uri)] + return 302, '', headers + + def validate_token_request(self): + """The client makes a request to the token endpoint by sending the + following parameters using the "application/x-www-form-urlencoded" + format per `Section 4.1.3`_: + + grant_type + REQUIRED. Value MUST be set to "authorization_code". + + code + REQUIRED. The authorization code received from the + authorization server. + + redirect_uri + REQUIRED, if the "redirect_uri" parameter was included in the + authorization request as described in Section 4.1.1, and their + values MUST be identical. + + client_id + REQUIRED, if the client is not authenticating with the + authorization server as described in Section 3.2.1. + + If the client type is confidential or the client was issued client + credentials (or assigned other authentication requirements), the + client MUST authenticate with the authorization server as described + in Section 3.2.1. + + For example, the client makes the following HTTP request using TLS: + + .. code-block:: http + + POST /token HTTP/1.1 + Host: server.example.com + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + Content-Type: application/x-www-form-urlencoded + + grant_type=authorization_code&code=SplxlOBeZQQYbYS6WxSbIA + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb + + .. _`Section 4.1.3`: https://tools.ietf.org/html/rfc6749#section-4.1.3 + """ + # ignore validate for grant_type, since it is validated by + # check_token_endpoint + + # authenticate the client if client authentication is included + client = self.authenticate_token_endpoint_client() + + log.debug('Validate token request of %r', client) + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError( + f'The client is not authorized to use "grant_type={self.GRANT_TYPE}"') + + code = self.request.form.get('code') + if code is None: + raise InvalidRequestError('Missing "code" in request.') + + # ensure that the authorization code was issued to the authenticated + # confidential client, or if the client is public, ensure that the + # code was issued to "client_id" in the request + authorization_code = self.query_authorization_code(code, client) + if not authorization_code: + raise InvalidGrantError('Invalid "code" in request.') + + # validate redirect_uri parameter + log.debug('Validate token redirect_uri of %r', client) + redirect_uri = self.request.redirect_uri + original_redirect_uri = authorization_code.get_redirect_uri() + if original_redirect_uri and redirect_uri != original_redirect_uri: + raise InvalidGrantError('Invalid "redirect_uri" in request.') + + # save for create_token_response + self.request.client = client + self.request.authorization_code = authorization_code + self.execute_hook('after_validate_token_request') + + def create_token_response(self): + """If the access token request is valid and authorized, the + authorization server issues an access token and optional refresh + token as described in Section 5.1. If the request client + authentication failed or is invalid, the authorization server returns + an error response as described in Section 5.2. Per `Section 4.1.4`_. + + An example successful response: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"2YotnFZFEjr1zCsicMWpAA", + "token_type":"example", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter":"example_value" + } + + :returns: (status_code, body, headers) + + .. _`Section 4.1.4`: https://tools.ietf.org/html/rfc6749#section-4.1.4 + """ + client = self.request.client + authorization_code = self.request.authorization_code + + user = self.authenticate_user(authorization_code) + if not user: + raise InvalidGrantError('There is no "user" for this code.') + self.request.user = user + + scope = authorization_code.get_scope() + token = self.generate_token( + user=user, + scope=scope, + include_refresh_token=client.check_grant_type('refresh_token'), + ) + log.debug('Issue token %r to %r', token, client) + + self.save_token(token) + self.execute_hook('process_token', token=token) + self.delete_authorization_code(authorization_code) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def generate_authorization_code(self): + """"The method to generate "code" value for authorization code data. + Developers may rewrite this method, or customize the code length with:: + + class MyAuthorizationCodeGrant(AuthorizationCodeGrant): + AUTHORIZATION_CODE_LENGTH = 32 # default is 48 + """ + return generate_token(self.AUTHORIZATION_CODE_LENGTH) + + def save_authorization_code(self, code, request): + """Save authorization_code for later use. Developers MUST implement + it in subclass. Here is an example:: + + def save_authorization_code(self, code, request): + client = request.client + item = AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + user_id=request.user.id, + ) + item.save() + """ + raise NotImplementedError() + + def query_authorization_code(self, code, client): # pragma: no cover + """Get authorization_code from previously savings. Developers MUST + implement it in subclass:: + + def query_authorization_code(self, code, client): + return Authorization.get(code=code, client_id=client.client_id) + + :param code: a string represent the code. + :param client: client related to this code. + :return: authorization_code object + """ + raise NotImplementedError() + + def delete_authorization_code(self, authorization_code): + """Delete authorization code from database or cache. Developers MUST + implement it in subclass, e.g.:: + + def delete_authorization_code(self, authorization_code): + authorization_code.delete() + + :param authorization_code: the instance of authorization_code + """ + raise NotImplementedError() + + def authenticate_user(self, authorization_code): + """Authenticate the user related to this authorization_code. Developers + MUST implement this method in subclass, e.g.:: + + def authenticate_user(self, authorization_code): + return User.get(authorization_code.user_id) + + :param authorization_code: AuthorizationCode object + :return: user + """ + raise NotImplementedError() + + +def validate_code_authorization_request(grant): + request = grant.request + client_id = request.client_id + log.debug('Validate authorization request of %r', client_id) + + if client_id is None: + raise InvalidClientError(state=request.state) + + client = grant.server.query_client(client_id) + if not client: + raise InvalidClientError(state=request.state) + + redirect_uri = grant.validate_authorization_redirect_uri(request, client) + response_type = request.response_type + if not client.check_response_type(response_type): + raise UnauthorizedClientError( + f'The client is not authorized to use "response_type={response_type}"', + state=grant.request.state, + redirect_uri=redirect_uri, + ) + + try: + grant.request.client = client + grant.validate_requested_scope() + grant.execute_hook('after_validate_authorization_request') + except OAuth2Error as error: + error.redirect_uri = redirect_uri + raise error + return redirect_uri diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/base.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/base.py new file mode 100644 index 00000000..0d2bf453 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/base.py @@ -0,0 +1,148 @@ +from authlib.consts import default_json_headers +from ..requests import OAuth2Request +from ..errors import InvalidRequestError + + +class BaseGrant: + #: Allowed client auth methods for token endpoint + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic'] + + #: Designed for which "grant_type" + GRANT_TYPE = None + + # NOTE: there is no charset for application/json, since + # application/json should always in UTF-8. + # The example on RFC is incorrect. + # https://tools.ietf.org/html/rfc4627 + TOKEN_RESPONSE_HEADER = default_json_headers + + def __init__(self, request: OAuth2Request, server): + self.prompt = None + self.redirect_uri = None + self.request = request + self.server = server + self._hooks = { + 'after_validate_authorization_request': set(), + 'after_validate_consent_request': set(), + 'after_validate_token_request': set(), + 'process_token': set(), + } + + @property + def client(self): + return self.request.client + + def generate_token(self, user=None, scope=None, grant_type=None, + expires_in=None, include_refresh_token=True): + if grant_type is None: + grant_type = self.GRANT_TYPE + return self.server.generate_token( + client=self.request.client, + grant_type=grant_type, + user=user, + scope=scope, + expires_in=expires_in, + include_refresh_token=include_refresh_token, + ) + + def authenticate_token_endpoint_client(self): + """Authenticate client with the given methods for token endpoint. + + For example, the client makes the following HTTP request using TLS: + + .. code-block:: http + + POST /token HTTP/1.1 + Host: server.example.com + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + Content-Type: application/x-www-form-urlencoded + + grant_type=authorization_code&code=SplxlOBeZQQYbYS6WxSbIA + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb + + Default available methods are: "none", "client_secret_basic" and + "client_secret_post". + + :return: client + """ + client = self.server.authenticate_client( + self.request, self.TOKEN_ENDPOINT_AUTH_METHODS) + self.server.send_signal( + 'after_authenticate_client', + client=client, grant=self) + return client + + def save_token(self, token): + """A method to save token into database.""" + return self.server.save_token(token, self.request) + + def validate_requested_scope(self): + """Validate if requested scope is supported by Authorization Server.""" + scope = self.request.scope + state = self.request.state + return self.server.validate_requested_scope(scope, state) + + def register_hook(self, hook_type, hook): + if hook_type not in self._hooks: + raise ValueError('Hook type %s is not in %s.', + hook_type, self._hooks) + self._hooks[hook_type].add(hook) + + def execute_hook(self, hook_type, *args, **kwargs): + for hook in self._hooks[hook_type]: + hook(self, *args, **kwargs) + + +class TokenEndpointMixin: + #: Allowed HTTP methods of this token endpoint + TOKEN_ENDPOINT_HTTP_METHODS = ['POST'] + + #: Designed for which "grant_type" + GRANT_TYPE = None + + @classmethod + def check_token_endpoint(cls, request: OAuth2Request): + return request.grant_type == cls.GRANT_TYPE and \ + request.method in cls.TOKEN_ENDPOINT_HTTP_METHODS + + def validate_token_request(self): + raise NotImplementedError() + + def create_token_response(self): + raise NotImplementedError() + + +class AuthorizationEndpointMixin: + RESPONSE_TYPES = set() + ERROR_RESPONSE_FRAGMENT = False + + @classmethod + def check_authorization_endpoint(cls, request: OAuth2Request): + return request.response_type in cls.RESPONSE_TYPES + + @staticmethod + def validate_authorization_redirect_uri(request: OAuth2Request, client): + if request.redirect_uri: + if not client.check_redirect_uri(request.redirect_uri): + raise InvalidRequestError( + f'Redirect URI {request.redirect_uri} is not supported by client.', + state=request.state) + return request.redirect_uri + else: + redirect_uri = client.get_default_redirect_uri() + if not redirect_uri: + raise InvalidRequestError( + 'Missing "redirect_uri" in request.', + state=request.state) + return redirect_uri + + def validate_consent_request(self): + redirect_uri = self.validate_authorization_request() + self.execute_hook('after_validate_consent_request', redirect_uri) + self.redirect_uri = redirect_uri + + def validate_authorization_request(self): + raise NotImplementedError() + + def create_authorization_response(self, redirect_uri: str, grant_user): + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/client_credentials.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/client_credentials.py new file mode 100644 index 00000000..57249cba --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/client_credentials.py @@ -0,0 +1,102 @@ +import logging +from .base import BaseGrant, TokenEndpointMixin +from ..errors import UnauthorizedClientError + +log = logging.getLogger(__name__) + + +class ClientCredentialsGrant(BaseGrant, TokenEndpointMixin): + """The client can request an access token using only its client + credentials (or other supported means of authentication) when the + client is requesting access to the protected resources under its + control, or those of another resource owner that have been previously + arranged with the authorization server. + + The client credentials grant type MUST only be used by confidential + clients:: + + +---------+ +---------------+ + | | | | + | |>--(A)- Client Authentication --->| Authorization | + | Client | | Server | + | |<--(B)---- Access Token ---------<| | + | | | | + +---------+ +---------------+ + + https://tools.ietf.org/html/rfc6749#section-4.4 + """ + GRANT_TYPE = 'client_credentials' + + def validate_token_request(self): + """The client makes a request to the token endpoint by adding the + following parameters using the "application/x-www-form-urlencoded" + format per Appendix B with a character encoding of UTF-8 in the HTTP + request entity-body: + + grant_type + REQUIRED. Value MUST be set to "client_credentials". + + scope + OPTIONAL. The scope of the access request as described by + Section 3.3. + + The client MUST authenticate with the authorization server as + described in Section 3.2.1. + + For example, the client makes the following HTTP request using + transport-layer security (with extra line breaks for display purposes + only): + + .. code-block:: http + + POST /token HTTP/1.1 + Host: server.example.com + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + Content-Type: application/x-www-form-urlencoded + + grant_type=client_credentials + + The authorization server MUST authenticate the client. + """ + + # ignore validate for grant_type, since it is validated by + # check_token_endpoint + client = self.authenticate_token_endpoint_client() + log.debug('Validate token request of %r', client) + + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + self.request.client = client + self.validate_requested_scope() + + def create_token_response(self): + """If the access token request is valid and authorized, the + authorization server issues an access token as described in + Section 5.1. A refresh token SHOULD NOT be included. If the request + failed client authentication or is invalid, the authorization server + returns an error response as described in Section 5.2. + + An example successful response: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"2YotnFZFEjr1zCsicMWpAA", + "token_type":"example", + "expires_in":3600, + "example_parameter":"example_value" + } + + :returns: (status_code, body, headers) + """ + token = self.generate_token(scope=self.request.scope, include_refresh_token=False) + log.debug('Issue token %r to %r', token, self.client) + self.save_token(token) + self.execute_hook('process_token', self, token=token) + return 200, token, self.TOKEN_RESPONSE_HEADER diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/implicit.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/implicit.py new file mode 100644 index 00000000..75b12be4 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/implicit.py @@ -0,0 +1,229 @@ +import logging +from authlib.common.urls import add_params_to_uri +from .base import BaseGrant, AuthorizationEndpointMixin +from ..errors import ( + OAuth2Error, + UnauthorizedClientError, + AccessDeniedError, +) + +log = logging.getLogger(__name__) + + +class ImplicitGrant(BaseGrant, AuthorizationEndpointMixin): + """The implicit grant type is used to obtain access tokens (it does not + support the issuance of refresh tokens) and is optimized for public + clients known to operate a particular redirection URI. These clients + are typically implemented in a browser using a scripting language + such as JavaScript. + + Since this is a redirection-based flow, the client must be capable of + interacting with the resource owner's user-agent (typically a web + browser) and capable of receiving incoming requests (via redirection) + from the authorization server. + + Unlike the authorization code grant type, in which the client makes + separate requests for authorization and for an access token, the + client receives the access token as the result of the authorization + request. + + The implicit grant type does not include client authentication, and + relies on the presence of the resource owner and the registration of + the redirection URI. Because the access token is encoded into the + redirection URI, it may be exposed to the resource owner and other + applications residing on the same device:: + + +----------+ + | Resource | + | Owner | + | | + +----------+ + ^ + | + (B) + +----|-----+ Client Identifier +---------------+ + | -+----(A)-- & Redirection URI --->| | + | User- | | Authorization | + | Agent -|----(B)-- User authenticates -->| Server | + | | | | + | |<---(C)--- Redirection URI ----<| | + | | with Access Token +---------------+ + | | in Fragment + | | +---------------+ + | |----(D)--- Redirection URI ---->| Web-Hosted | + | | without Fragment | Client | + | | | Resource | + | (F) |<---(E)------- Script ---------<| | + | | +---------------+ + +-|--------+ + | | + (A) (G) Access Token + | | + ^ v + +---------+ + | | + | Client | + | | + +---------+ + """ + #: authorization_code grant type has authorization endpoint + AUTHORIZATION_ENDPOINT = True + #: Allowed client auth methods for token endpoint + TOKEN_ENDPOINT_AUTH_METHODS = ['none'] + + RESPONSE_TYPES = {'token'} + GRANT_TYPE = 'implicit' + ERROR_RESPONSE_FRAGMENT = True + + def validate_authorization_request(self): + """The client constructs the request URI by adding the following + parameters to the query component of the authorization endpoint URI + using the "application/x-www-form-urlencoded" format. + Per `Section 4.2.1`_. + + response_type + REQUIRED. Value MUST be set to "token". + + client_id + REQUIRED. The client identifier as described in Section 2.2. + + redirect_uri + OPTIONAL. As described in Section 3.1.2. + + scope + OPTIONAL. The scope of the access request as described by + Section 3.3. + + state + RECOMMENDED. An opaque value used by the client to maintain + state between the request and callback. The authorization + server includes this value when redirecting the user-agent back + to the client. The parameter SHOULD be used for preventing + cross-site request forgery as described in Section 10.12. + + The client directs the resource owner to the constructed URI using an + HTTP redirection response, or by other means available to it via the + user-agent. + + For example, the client directs the user-agent to make the following + HTTP request using TLS: + + .. code-block:: http + + GET /authorize?response_type=token&client_id=s6BhdRkqt3&state=xyz + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb HTTP/1.1 + Host: server.example.com + + .. _`Section 4.2.1`: https://tools.ietf.org/html/rfc6749#section-4.2.1 + """ + # ignore validate for response_type, since it is validated by + # check_authorization_endpoint + + # The implicit grant type is optimized for public clients + client = self.authenticate_token_endpoint_client() + log.debug('Validate authorization request of %r', client) + + redirect_uri = self.validate_authorization_redirect_uri( + self.request, client) + + response_type = self.request.response_type + if not client.check_response_type(response_type): + raise UnauthorizedClientError( + 'The client is not authorized to use ' + '"response_type={}"'.format(response_type), + state=self.request.state, + redirect_uri=redirect_uri, + redirect_fragment=True, + ) + + try: + self.request.client = client + self.validate_requested_scope() + self.execute_hook('after_validate_authorization_request') + except OAuth2Error as error: + error.redirect_uri = redirect_uri + error.redirect_fragment = True + raise error + return redirect_uri + + def create_authorization_response(self, redirect_uri, grant_user): + """If the resource owner grants the access request, the authorization + server issues an access token and delivers it to the client by adding + the following parameters to the fragment component of the redirection + URI using the "application/x-www-form-urlencoded" format. + Per `Section 4.2.2`_. + + access_token + REQUIRED. The access token issued by the authorization server. + + token_type + REQUIRED. The type of the token issued as described in + Section 7.1. Value is case insensitive. + + expires_in + RECOMMENDED. The lifetime in seconds of the access token. For + example, the value "3600" denotes that the access token will + expire in one hour from the time the response was generated. + If omitted, the authorization server SHOULD provide the + expiration time via other means or document the default value. + + scope + OPTIONAL, if identical to the scope requested by the client; + otherwise, REQUIRED. The scope of the access token as + described by Section 3.3. + + state + REQUIRED if the "state" parameter was present in the client + authorization request. The exact value received from the + client. + + The authorization server MUST NOT issue a refresh token. + + For example, the authorization server redirects the user-agent by + sending the following HTTP response: + + .. code-block:: http + + HTTP/1.1 302 Found + Location: http://example.com/cb#access_token=2YotnFZFEjr1zCsicMWpAA + &state=xyz&token_type=example&expires_in=3600 + + Developers should note that some user-agents do not support the + inclusion of a fragment component in the HTTP "Location" response + header field. Such clients will require using other methods for + redirecting the client than a 3xx redirection response -- for + example, returning an HTML page that includes a 'continue' button + with an action linked to the redirection URI. + + .. _`Section 4.2.2`: https://tools.ietf.org/html/rfc6749#section-4.2.2 + + :param redirect_uri: Redirect to the given URI for the authorization + :param grant_user: if resource owner granted the request, pass this + resource owner, otherwise pass None. + :returns: (status_code, body, headers) + """ + state = self.request.state + if grant_user: + self.request.user = grant_user + token = self.generate_token( + user=grant_user, + scope=self.request.scope, + include_refresh_token=False, + ) + log.debug('Grant token %r to %r', token, self.request.client) + + self.save_token(token) + self.execute_hook('process_token', token=token) + params = [(k, token[k]) for k in token] + if state: + params.append(('state', state)) + + uri = add_params_to_uri(redirect_uri, params, fragment=True) + headers = [('Location', uri)] + return 302, '', headers + else: + raise AccessDeniedError( + state=state, + redirect_uri=redirect_uri, + redirect_fragment=True + ) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/refresh_token.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/refresh_token.py new file mode 100644 index 00000000..4df5b70e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/refresh_token.py @@ -0,0 +1,179 @@ +""" + authlib.oauth2.rfc6749.grants.refresh_token + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + A special grant endpoint for refresh_token grant_type. Refreshing an + Access Token per `Section 6`_. + + .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 +""" + +import logging +from .base import BaseGrant, TokenEndpointMixin +from ..util import scope_to_list +from ..errors import ( + InvalidRequestError, + InvalidScopeError, + InvalidGrantError, + UnauthorizedClientError, +) +log = logging.getLogger(__name__) + + +class RefreshTokenGrant(BaseGrant, TokenEndpointMixin): + """A special grant endpoint for refresh_token grant_type. Refreshing an + Access Token per `Section 6`_. + + .. _`Section 6`: https://tools.ietf.org/html/rfc6749#section-6 + """ + GRANT_TYPE = 'refresh_token' + + #: The authorization server MAY issue a new refresh token + INCLUDE_NEW_REFRESH_TOKEN = False + + def _validate_request_client(self): + # require client authentication for confidential clients or for any + # client that was issued client credentials (or with other + # authentication requirements) + client = self.authenticate_token_endpoint_client() + log.debug('Validate token request of %r', client) + + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + return client + + def _validate_request_token(self, client): + refresh_token = self.request.form.get('refresh_token') + if refresh_token is None: + raise InvalidRequestError('Missing "refresh_token" in request.') + + token = self.authenticate_refresh_token(refresh_token) + if not token or not token.check_client(client): + raise InvalidGrantError() + return token + + def _validate_token_scope(self, token): + scope = self.request.scope + if not scope: + return + + original_scope = token.get_scope() + if not original_scope: + raise InvalidScopeError() + + original_scope = set(scope_to_list(original_scope)) + if not original_scope.issuperset(set(scope_to_list(scope))): + raise InvalidScopeError() + + def validate_token_request(self): + """If the authorization server issued a refresh token to the client, the + client makes a refresh request to the token endpoint by adding the + following parameters using the "application/x-www-form-urlencoded" + format per Appendix B with a character encoding of UTF-8 in the HTTP + request entity-body, per Section 6: + + grant_type + REQUIRED. Value MUST be set to "refresh_token". + + refresh_token + REQUIRED. The refresh token issued to the client. + + scope + OPTIONAL. The scope of the access request as described by + Section 3.3. The requested scope MUST NOT include any scope + not originally granted by the resource owner, and if omitted is + treated as equal to the scope originally granted by the + resource owner. + + + For example, the client makes the following HTTP request using + transport-layer security (with extra line breaks for display purposes + only): + + .. code-block:: http + + POST /token HTTP/1.1 + Host: server.example.com + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + Content-Type: application/x-www-form-urlencoded + + grant_type=refresh_token&refresh_token=tGzv3JOkF0XG5Qx2TlKWIA + """ + client = self._validate_request_client() + self.request.client = client + refresh_token = self._validate_request_token(client) + self._validate_token_scope(refresh_token) + self.request.refresh_token = refresh_token + + def create_token_response(self): + """If valid and authorized, the authorization server issues an access + token as described in Section 5.1. If the request failed + verification or is invalid, the authorization server returns an error + response as described in Section 5.2. + """ + refresh_token = self.request.refresh_token + user = self.authenticate_user(refresh_token) + if not user: + raise InvalidRequestError('There is no "user" for this token.') + + client = self.request.client + token = self.issue_token(user, refresh_token) + log.debug('Issue token %r to %r', token, client) + + self.request.user = user + self.save_token(token) + self.execute_hook('process_token', token=token) + self.revoke_old_credential(refresh_token) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def issue_token(self, user, refresh_token): + scope = self.request.scope + if not scope: + scope = refresh_token.get_scope() + + token = self.generate_token( + user=user, + scope=scope, + include_refresh_token=self.INCLUDE_NEW_REFRESH_TOKEN, + ) + return token + + def authenticate_refresh_token(self, refresh_token): + """Get token information with refresh_token string. Developers MUST + implement this method in subclass:: + + def authenticate_refresh_token(self, refresh_token): + token = Token.get(refresh_token=refresh_token) + if token and not token.refresh_token_revoked: + return token + + :param refresh_token: The refresh token issued to the client + :return: token + """ + raise NotImplementedError() + + def authenticate_user(self, refresh_token): + """Authenticate the user related to this credential. Developers MUST + implement this method in subclass:: + + def authenticate_user(self, credential): + return User.get(credential.user_id) + + :param refresh_token: Token object + :return: user + """ + raise NotImplementedError() + + def revoke_old_credential(self, refresh_token): + """The authorization server MAY revoke the old refresh token after + issuing a new refresh token to the client. Developers MUST implement + this method in subclass:: + + def revoke_old_credential(self, refresh_token): + credential.revoked = True + credential.save() + + :param refresh_token: Token object + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py new file mode 100644 index 00000000..41cabb62 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/grants/resource_owner_password_credentials.py @@ -0,0 +1,154 @@ +import logging +from .base import BaseGrant, TokenEndpointMixin +from ..errors import ( + UnauthorizedClientError, + InvalidRequestError, +) + +log = logging.getLogger(__name__) + + +class ResourceOwnerPasswordCredentialsGrant(BaseGrant, TokenEndpointMixin): + """The resource owner password credentials grant type is suitable in + cases where the resource owner has a trust relationship with the + client, such as the device operating system or a highly privileged + + application. The authorization server should take special care when + enabling this grant type and only allow it when other flows are not + viable. + + This grant type is suitable for clients capable of obtaining the + resource owner's credentials (username and password, typically using + an interactive form). It is also used to migrate existing clients + using direct authentication schemes such as HTTP Basic or Digest + authentication to OAuth by converting the stored credentials to an + access token:: + + +----------+ + | Resource | + | Owner | + | | + +----------+ + v + | Resource Owner + (A) Password Credentials + | + v + +---------+ +---------------+ + | |>--(B)---- Resource Owner ------->| | + | | Password Credentials | Authorization | + | Client | | Server | + | |<--(C)---- Access Token ---------<| | + | | (w/ Optional Refresh Token) | | + +---------+ +---------------+ + """ + GRANT_TYPE = 'password' + + def validate_token_request(self): + """The client makes a request to the token endpoint by adding the + following parameters using the "application/x-www-form-urlencoded" + format per Appendix B with a character encoding of UTF-8 in the HTTP + request entity-body: + + grant_type + REQUIRED. Value MUST be set to "password". + + username + REQUIRED. The resource owner username. + + password + REQUIRED. The resource owner password. + + scope + OPTIONAL. The scope of the access request as described by + Section 3.3. + + If the client type is confidential or the client was issued client + credentials (or assigned other authentication requirements), the + client MUST authenticate with the authorization server as described + in Section 3.2.1. + + For example, the client makes the following HTTP request using + transport-layer security (with extra line breaks for display purposes + only): + + .. code-block:: http + + POST /token HTTP/1.1 + Host: server.example.com + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + Content-Type: application/x-www-form-urlencoded + + grant_type=password&username=johndoe&password=A3ddj3w + """ + # ignore validate for grant_type, since it is validated by + # check_token_endpoint + client = self.authenticate_token_endpoint_client() + log.debug('Validate token request of %r', client) + + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + params = self.request.form + if 'username' not in params: + raise InvalidRequestError('Missing "username" in request.') + if 'password' not in params: + raise InvalidRequestError('Missing "password" in request.') + + log.debug('Authenticate user of %r', params['username']) + user = self.authenticate_user( + params['username'], + params['password'] + ) + if not user: + raise InvalidRequestError( + 'Invalid "username" or "password" in request.', + ) + self.request.client = client + self.request.user = user + self.validate_requested_scope() + + def create_token_response(self): + """If the access token request is valid and authorized, the + authorization server issues an access token and optional refresh + token as described in Section 5.1. If the request failed client + authentication or is invalid, the authorization server returns an + error response as described in Section 5.2. + + An example successful response: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"2YotnFZFEjr1zCsicMWpAA", + "token_type":"example", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter":"example_value" + } + + :returns: (status_code, body, headers) + """ + user = self.request.user + scope = self.request.scope + token = self.generate_token(user=user, scope=scope) + log.debug('Issue token %r to %r', token, self.client) + self.save_token(token) + self.execute_hook('process_token', token=token) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def authenticate_user(self, username, password): + """validate the resource owner password credentials using its + existing password validation algorithm:: + + def authenticate_user(self, username, password): + user = get_user_by_username(username) + if user.check_password(password): + return user + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/models.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/models.py new file mode 100644 index 00000000..fe4922bb --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/models.py @@ -0,0 +1,228 @@ +""" + authlib.oauth2.rfc6749.models + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + This module defines how to construct Client, AuthorizationCode and Token. +""" +from authlib.deprecate import deprecate + + +class ClientMixin: + """Implementation of OAuth 2 Client described in `Section 2`_ with + some methods to help validation. A client has at least these information: + + * client_id: A string represents client identifier. + * client_secret: A string represents client password. + * token_endpoint_auth_method: A way to authenticate client at token + endpoint. + + .. _`Section 2`: https://tools.ietf.org/html/rfc6749#section-2 + """ + + def get_client_id(self): + """A method to return client_id of the client. For instance, the value + in database is saved in a column called ``client_id``:: + + def get_client_id(self): + return self.client_id + + :return: string + """ + raise NotImplementedError() + + def get_default_redirect_uri(self): + """A method to get client default redirect_uri. For instance, the + database table for client has a column called ``default_redirect_uri``:: + + def get_default_redirect_uri(self): + return self.default_redirect_uri + + :return: A URL string + """ + raise NotImplementedError() + + def get_allowed_scope(self, scope): + """A method to return a list of requested scopes which are supported by + this client. For instance, there is a ``scope`` column:: + + def get_allowed_scope(self, scope): + if not scope: + return '' + allowed = set(scope_to_list(self.scope)) + return list_to_scope([s for s in scope.split() if s in allowed]) + + :param scope: the requested scope. + :return: string of scope + """ + raise NotImplementedError() + + def check_redirect_uri(self, redirect_uri): + """Validate redirect_uri parameter in Authorization Endpoints. For + instance, in the client table, there is an ``allowed_redirect_uris`` + column:: + + def check_redirect_uri(self, redirect_uri): + return redirect_uri in self.allowed_redirect_uris + + :param redirect_uri: A URL string for redirecting. + :return: bool + """ + raise NotImplementedError() + + def check_client_secret(self, client_secret): + """Check client_secret matching with the client. For instance, in + the client table, the column is called ``client_secret``:: + + import secrets + + def check_client_secret(self, client_secret): + return secrets.compare_digest(self.client_secret, client_secret) + + :param client_secret: A string of client secret + :return: bool + """ + raise NotImplementedError() + + def check_endpoint_auth_method(self, method, endpoint): + """Check if client support the given method for the given endpoint. + There is a ``token_endpoint_auth_method`` defined via `RFC7591`_. + Developers MAY re-implement this method with:: + + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + # if client table has ``token_endpoint_auth_method`` + return self.token_endpoint_auth_method == method + return True + + Method values defined by this specification are: + + * "none": The client is a public client as defined in OAuth 2.0, + and does not have a client secret. + + * "client_secret_post": The client uses the HTTP POST parameters + as defined in OAuth 2.0 + + * "client_secret_basic": The client uses HTTP Basic as defined in + OAuth 2.0 + + .. _`RFC7591`: https://tools.ietf.org/html/rfc7591 + """ + raise NotImplementedError() + + def check_token_endpoint_auth_method(self, method): + deprecate('Please implement ``check_endpoint_auth_method`` instead.') + return self.check_endpoint_auth_method(method, 'token') + + def check_response_type(self, response_type): + """Validate if the client can handle the given response_type. There + are two response types defined by RFC6749: code and token. For + instance, there is a ``allowed_response_types`` column in your client:: + + def check_response_type(self, response_type): + return response_type in self.response_types + + :param response_type: the requested response_type string. + :return: bool + """ + raise NotImplementedError() + + def check_grant_type(self, grant_type): + """Validate if the client can handle the given grant_type. There are + four grant types defined by RFC6749: + + * authorization_code + * implicit + * client_credentials + * password + + For instance, there is a ``allowed_grant_types`` column in your client:: + + def check_grant_type(self, grant_type): + return grant_type in self.grant_types + + :param grant_type: the requested grant_type string. + :return: bool + """ + raise NotImplementedError() + + +class AuthorizationCodeMixin: + def get_redirect_uri(self): + """A method to get authorization code's ``redirect_uri``. + For instance, the database table for authorization code has a + column called ``redirect_uri``:: + + def get_redirect_uri(self): + return self.redirect_uri + + :return: A URL string + """ + raise NotImplementedError() + + def get_scope(self): + """A method to get scope of the authorization code. For instance, + the column is called ``scope``:: + + def get_scope(self): + return self.scope + + :return: scope string + """ + raise NotImplementedError() + + +class TokenMixin: + def check_client(self, client): + """A method to check if this token is issued to the given client. + For instance, ``client_id`` is saved on token table:: + + def check_client(self, client): + return self.client_id == client.client_id + + :return: bool + """ + raise NotImplementedError() + + def get_scope(self): + """A method to get scope of the authorization code. For instance, + the column is called ``scope``:: + + def get_scope(self): + return self.scope + + :return: scope string + """ + raise NotImplementedError() + + def get_expires_in(self): + """A method to get the ``expires_in`` value of the token. e.g. + the column is called ``expires_in``:: + + def get_expires_in(self): + return self.expires_in + + :return: timestamp int + """ + raise NotImplementedError() + + def is_expired(self): + """A method to define if this token is expired. For instance, + there is a column ``expired_at`` in the table:: + + def is_expired(self): + return self.expired_at < now + + :return: boolean + """ + raise NotImplementedError() + + def is_revoked(self): + """A method to define if this token is revoked. For instance, + there is a boolean column ``revoked`` in the table:: + + def is_revoked(self): + return self.revoked + + :return: boolean + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/parameters.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/parameters.py new file mode 100644 index 00000000..8c3a5aa6 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/parameters.py @@ -0,0 +1,214 @@ +from authlib.common.urls import ( + urlparse, + add_params_to_uri, + add_params_to_qs, +) +from authlib.common.encoding import to_unicode +from .errors import ( + MissingCodeException, + MissingTokenException, + MissingTokenTypeException, + MismatchingStateException, +) +from .util import list_to_scope + + +def prepare_grant_uri(uri, client_id, response_type, redirect_uri=None, + scope=None, state=None, **kwargs): + """Prepare the authorization grant request URI. + + The client constructs the request URI by adding the following + parameters to the query component of the authorization endpoint URI + using the ``application/x-www-form-urlencoded`` format: + + :param uri: The authorize endpoint to fetch "code" or "token". + :param client_id: The client identifier as described in `Section 2.2`_. + :param response_type: To indicate which OAuth 2 grant/flow is required, + "code" and "token". + :param redirect_uri: The client provided URI to redirect back to after + authorization as described in `Section 3.1.2`_. + :param scope: The scope of the access request as described by + `Section 3.3`_. + :param state: An opaque value used by the client to maintain + state between the request and callback. The authorization + server includes this value when redirecting the user-agent + back to the client. The parameter SHOULD be used for + preventing cross-site request forgery as described in + `Section 10.12`_. + :param kwargs: Extra arguments to embed in the grant/authorization URL. + + An example of an authorization code grant authorization URL:: + + /authorize?response_type=code&client_id=s6BhdRkqt3&state=xyz + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb + + .. _`Section 2.2`: https://tools.ietf.org/html/rfc6749#section-2.2 + .. _`Section 3.1.2`: https://tools.ietf.org/html/rfc6749#section-3.1.2 + .. _`Section 3.3`: https://tools.ietf.org/html/rfc6749#section-3.3 + .. _`section 10.12`: https://tools.ietf.org/html/rfc6749#section-10.12 + """ + params = [ + ('response_type', response_type), + ('client_id', client_id) + ] + + if redirect_uri: + params.append(('redirect_uri', redirect_uri)) + if scope: + params.append(('scope', list_to_scope(scope))) + if state: + params.append(('state', state)) + + for k in kwargs: + if kwargs[k] is not None: + params.append((to_unicode(k), kwargs[k])) + + return add_params_to_uri(uri, params) + + +def prepare_token_request(grant_type, body='', redirect_uri=None, **kwargs): + """Prepare the access token request. Per `Section 4.1.3`_. + + The client makes a request to the token endpoint by adding the + following parameters using the ``application/x-www-form-urlencoded`` + format in the HTTP request entity-body: + + :param grant_type: To indicate grant type being used, i.e. "password", + "authorization_code" or "client_credentials". + :param body: Existing request body to embed parameters in. + :param redirect_uri: If the "redirect_uri" parameter was included in the + authorization request as described in + `Section 4.1.1`_, and their values MUST be identical. + :param kwargs: Extra arguments to embed in the request body. + + An example of an authorization code token request body:: + + grant_type=authorization_code&code=SplxlOBeZQQYbYS6WxSbIA + &redirect_uri=https%3A%2F%2Fclient%2Eexample%2Ecom%2Fcb + + .. _`Section 4.1.1`: https://tools.ietf.org/html/rfc6749#section-4.1.1 + .. _`Section 4.1.3`: https://tools.ietf.org/html/rfc6749#section-4.1.3 + """ + params = [('grant_type', grant_type)] + + if redirect_uri: + params.append(('redirect_uri', redirect_uri)) + + if 'scope' in kwargs: + kwargs['scope'] = list_to_scope(kwargs['scope']) + + if grant_type == 'authorization_code' and 'code' not in kwargs: + raise MissingCodeException() + + for k in kwargs: + if kwargs[k]: + params.append((to_unicode(k), kwargs[k])) + + return add_params_to_qs(body, params) + + +def parse_authorization_code_response(uri, state=None): + """Parse authorization grant response URI into a dict. + + If the resource owner grants the access request, the authorization + server issues an authorization code and delivers it to the client by + adding the following parameters to the query component of the + redirection URI using the ``application/x-www-form-urlencoded`` format: + + **code** + REQUIRED. The authorization code generated by the + authorization server. The authorization code MUST expire + shortly after it is issued to mitigate the risk of leaks. A + maximum authorization code lifetime of 10 minutes is + RECOMMENDED. The client MUST NOT use the authorization code + more than once. If an authorization code is used more than + once, the authorization server MUST deny the request and SHOULD + revoke (when possible) all tokens previously issued based on + that authorization code. The authorization code is bound to + the client identifier and redirection URI. + + **state** + REQUIRED if the "state" parameter was present in the client + authorization request. The exact value received from the + client. + + :param uri: The full redirect URL back to the client. + :param state: The state parameter from the authorization request. + + For example, the authorization server redirects the user-agent by + sending the following HTTP response: + + .. code-block:: http + + HTTP/1.1 302 Found + Location: https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA + &state=xyz + + """ + query = urlparse.urlparse(uri).query + params = dict(urlparse.parse_qsl(query)) + + if 'code' not in params: + raise MissingCodeException() + + params_state = params.get('state') + if state and params_state != state: + raise MismatchingStateException() + + return params + + +def parse_implicit_response(uri, state=None): + """Parse the implicit token response URI into a dict. + + If the resource owner grants the access request, the authorization + server issues an access token and delivers it to the client by adding + the following parameters to the fragment component of the redirection + URI using the ``application/x-www-form-urlencoded`` format: + + **access_token** + REQUIRED. The access token issued by the authorization server. + + **token_type** + REQUIRED. The type of the token issued as described in + Section 7.1. Value is case insensitive. + + **expires_in** + RECOMMENDED. The lifetime in seconds of the access token. For + example, the value "3600" denotes that the access token will + expire in one hour from the time the response was generated. + If omitted, the authorization server SHOULD provide the + expiration time via other means or document the default value. + + **scope** + OPTIONAL, if identical to the scope requested by the client, + otherwise REQUIRED. The scope of the access token as described + by Section 3.3. + + **state** + REQUIRED if the "state" parameter was present in the client + authorization request. The exact value received from the + client. + + Similar to the authorization code response, but with a full token provided + in the URL fragment: + + .. code-block:: http + + HTTP/1.1 302 Found + Location: http://example.com/cb#access_token=2YotnFZFEjr1zCsicMWpAA + &state=xyz&token_type=example&expires_in=3600 + """ + fragment = urlparse.urlparse(uri).fragment + params = dict(urlparse.parse_qsl(fragment, keep_blank_values=True)) + + if 'access_token' not in params: + raise MissingTokenException() + + if 'token_type' not in params: + raise MissingTokenTypeException() + + if state and params.get('state', None) != state: + raise MismatchingStateException() + + return params diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/requests.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/requests.py new file mode 100644 index 00000000..1c0e4859 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/requests.py @@ -0,0 +1,84 @@ +from authlib.common.encoding import json_loads +from authlib.common.urls import urlparse, url_decode +from .errors import InsecureTransportError + + +class OAuth2Request: + def __init__(self, method: str, uri: str, body=None, headers=None): + InsecureTransportError.check(uri) + #: HTTP method + self.method = method + self.uri = uri + self.body = body + #: HTTP headers + self.headers = headers or {} + + self.client = None + self.auth_method = None + self.user = None + self.authorization_code = None + self.refresh_token = None + self.credential = None + + @property + def args(self): + query = urlparse.urlparse(self.uri).query + return dict(url_decode(query)) + + @property + def form(self): + return self.body or {} + + @property + def data(self): + data = {} + data.update(self.args) + data.update(self.form) + return data + + @property + def client_id(self) -> str: + """The authorization server issues the registered client a client + identifier -- a unique string representing the registration + information provided by the client. The value is extracted from + request. + + :return: string + """ + return self.data.get('client_id') + + @property + def response_type(self) -> str: + rt = self.data.get('response_type') + if rt and ' ' in rt: + # sort multiple response types + return ' '.join(sorted(rt.split())) + return rt + + @property + def grant_type(self) -> str: + return self.form.get('grant_type') + + @property + def redirect_uri(self): + return self.data.get('redirect_uri') + + @property + def scope(self) -> str: + return self.data.get('scope') + + @property + def state(self): + return self.data.get('state') + + +class JsonRequest: + def __init__(self, method, uri, body=None, headers=None): + self.method = method + self.uri = uri + self.body = body + self.headers = headers or {} + + @property + def data(self): + return json_loads(self.body) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/resource_protector.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/resource_protector.py new file mode 100644 index 00000000..60a85d80 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/resource_protector.py @@ -0,0 +1,140 @@ +""" + authlib.oauth2.rfc6749.resource_protector + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Implementation of Accessing Protected Resources per `Section 7`_. + + .. _`Section 7`: https://tools.ietf.org/html/rfc6749#section-7 +""" +from .util import scope_to_list +from .errors import MissingAuthorizationError, UnsupportedTokenTypeError + + +class TokenValidator: + """Base token validator class. Subclass this validator to register + into ResourceProtector instance. + """ + TOKEN_TYPE = 'bearer' + + def __init__(self, realm=None, **extra_attributes): + self.realm = realm + self.extra_attributes = extra_attributes + + @staticmethod + def scope_insufficient(token_scopes, required_scopes): + if not required_scopes: + return False + + token_scopes = scope_to_list(token_scopes) + if not token_scopes: + return True + + token_scopes = set(token_scopes) + for scope in required_scopes: + resource_scopes = set(scope_to_list(scope)) + if token_scopes.issuperset(resource_scopes): + return False + + return True + + def authenticate_token(self, token_string): + """A method to query token from database with the given token string. + Developers MUST re-implement this method. For instance:: + + def authenticate_token(self, token_string): + return get_token_from_database(token_string) + + :param token_string: A string to represent the access_token. + :return: token + """ + raise NotImplementedError() + + def validate_request(self, request): + """A method to validate if the HTTP request is valid or not. Developers MUST + re-implement this method. For instance, your server requires a + "X-Device-Version" in the header:: + + def validate_request(self, request): + if 'X-Device-Version' not in request.headers: + raise InvalidRequestError() + + Usually, you don't have to detect if the request is valid or not. If you have + to, you MUST re-implement this method. + + :param request: instance of HttpRequest + :raise: InvalidRequestError + """ + + def validate_token(self, token, scopes, request): + """A method to validate if the authorized token is valid, if it has the + permission on the given scopes. Developers MUST re-implement this method. + e.g, check if token is expired, revoked:: + + def validate_token(self, token, scopes, request): + if not token: + raise InvalidTokenError() + if token.is_expired() or token.is_revoked(): + raise InvalidTokenError() + if not match_token_scopes(token, scopes): + raise InsufficientScopeError() + """ + raise NotImplementedError() + + +class ResourceProtector: + def __init__(self): + self._token_validators = {} + self._default_realm = None + self._default_auth_type = None + + def register_token_validator(self, validator: TokenValidator): + """Register a token validator for a given Authorization type. + Authlib has a built-in BearerTokenValidator per rfc6750. + """ + if not self._default_auth_type: + self._default_realm = validator.realm + self._default_auth_type = validator.TOKEN_TYPE + + if validator.TOKEN_TYPE not in self._token_validators: + self._token_validators[validator.TOKEN_TYPE] = validator + + def get_token_validator(self, token_type): + """Get token validator from registry for the given token type.""" + validator = self._token_validators.get(token_type.lower()) + if not validator: + raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) + return validator + + def parse_request_authorization(self, request): + """Parse the token and token validator from request Authorization header. + Here is an example of Authorization header:: + + Authorization: Bearer a-token-string + + This method will parse this header, if it can find the validator for + ``Bearer``, it will return the validator and ``a-token-string``. + + :return: validator, token_string + :raise: MissingAuthorizationError + :raise: UnsupportedTokenTypeError + """ + auth = request.headers.get('Authorization') + if not auth: + raise MissingAuthorizationError(self._default_auth_type, self._default_realm) + + # https://tools.ietf.org/html/rfc6749#section-7.1 + token_parts = auth.split(None, 1) + if len(token_parts) != 2: + raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) + + token_type, token_string = token_parts + validator = self.get_token_validator(token_type) + return validator, token_string + + def validate_request(self, scopes, request, **kwargs): + """Validate the request and return a token.""" + validator, token_string = self.parse_request_authorization(request) + validator.validate_request(request) + token = validator.authenticate_token(token_string) + validator.validate_token(token, scopes, request, **kwargs) + return token diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/token_endpoint.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/token_endpoint.py new file mode 100644 index 00000000..0ede557f --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/token_endpoint.py @@ -0,0 +1,32 @@ +class TokenEndpoint: + #: Endpoint name to be registered + ENDPOINT_NAME = None + #: Supported token types + SUPPORTED_TOKEN_TYPES = ('access_token', 'refresh_token') + #: Allowed client authenticate methods + CLIENT_AUTH_METHODS = ['client_secret_basic'] + + def __init__(self, server): + self.server = server + + def __call__(self, request): + # make it callable for authorization server + # ``create_endpoint_response`` + return self.create_endpoint_response(request) + + def create_endpoint_request(self, request): + return self.server.create_oauth2_request(request) + + def authenticate_endpoint_client(self, request): + """Authentication client for endpoint with ``CLIENT_AUTH_METHODS``. + """ + client = self.server.authenticate_client( + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME) + request.client = client + return client + + def authenticate_token(self, request, client): + raise NotImplementedError() + + def create_endpoint_response(self, request): + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/util.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/util.py new file mode 100644 index 00000000..a216fbf3 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/util.py @@ -0,0 +1,40 @@ +import base64 +import binascii +from authlib.common.encoding import to_unicode + + +def list_to_scope(scope): + """Convert a list of scopes to a space separated string.""" + if isinstance(scope, (set, tuple, list)): + return " ".join([to_unicode(s) for s in scope]) + if scope is None: + return scope + return to_unicode(scope) + + +def scope_to_list(scope): + """Convert a space separated string to a list of scopes.""" + if isinstance(scope, (tuple, list, set)): + return [to_unicode(s) for s in scope] + elif scope is None: + return None + return scope.strip().split() + + +def extract_basic_authorization(headers): + auth = headers.get('Authorization') + if not auth or ' ' not in auth: + return None, None + + auth_type, auth_token = auth.split(None, 1) + if auth_type.lower() != 'basic': + return None, None + + try: + query = to_unicode(base64.b64decode(auth_token)) + except (binascii.Error, TypeError): + return None, None + if ':' in query: + username, password = query.split(':', 1) + return username, password + return query, None diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6749/wrappers.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/wrappers.py new file mode 100644 index 00000000..2ecf8248 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6749/wrappers.py @@ -0,0 +1,23 @@ +import time + + +class OAuth2Token(dict): + def __init__(self, params): + if params.get('expires_at'): + params['expires_at'] = int(params['expires_at']) + elif params.get('expires_in'): + params['expires_at'] = int(time.time()) + \ + int(params['expires_in']) + super().__init__(params) + + def is_expired(self): + expires_at = self.get('expires_at') + if not expires_at: + return None + return expires_at < time.time() + + @classmethod + def from_dict(cls, token): + if isinstance(token, dict) and not isinstance(token, cls): + token = cls(token) + return token diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__init__.py new file mode 100644 index 00000000..ef3880ba --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__init__.py @@ -0,0 +1,26 @@ +""" + authlib.oauth2.rfc6750 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + The OAuth 2.0 Authorization Framework: Bearer Token Usage. + + https://tools.ietf.org/html/rfc6750 +""" + +from .errors import InvalidTokenError, InsufficientScopeError +from .parameters import add_bearer_token +from .token import BearerTokenGenerator +from .validator import BearerTokenValidator + +# TODO: add deprecation +BearerToken = BearerTokenGenerator + + +__all__ = [ + 'InvalidTokenError', 'InsufficientScopeError', + 'add_bearer_token', + 'BearerToken', + 'BearerTokenGenerator', + 'BearerTokenValidator', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..0dd8b5a0 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..d5678828 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/parameters.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/parameters.cpython-311.pyc new file mode 100644 index 00000000..7adc265d Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/parameters.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/token.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/token.cpython-311.pyc new file mode 100644 index 00000000..fd8ddb48 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/token.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/validator.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/validator.cpython-311.pyc new file mode 100644 index 00000000..230577d3 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/__pycache__/validator.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/errors.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/errors.py new file mode 100644 index 00000000..1be92a35 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/errors.py @@ -0,0 +1,80 @@ +""" + authlib.rfc6750.errors + ~~~~~~~~~~~~~~~~~~~~~~ + + OAuth Extensions Error Registration. When a request fails, + the resource server responds using the appropriate HTTP + status code and includes one of the following error codes + in the response. + + https://tools.ietf.org/html/rfc6750#section-6.2 + + :copyright: (c) 2017 by Hsiaoming Yang. +""" +from ..base import OAuth2Error + +__all__ = [ + 'InvalidTokenError', 'InsufficientScopeError' +] + + +class InvalidTokenError(OAuth2Error): + """The access token provided is expired, revoked, malformed, or + invalid for other reasons. The resource SHOULD respond with + the HTTP 401 (Unauthorized) status code. The client MAY + request a new access token and retry the protected resource + request. + + https://tools.ietf.org/html/rfc6750#section-3.1 + """ + error = 'invalid_token' + description = ( + 'The access token provided is expired, revoked, malformed, ' + 'or invalid for other reasons.' + ) + status_code = 401 + + def __init__(self, description=None, uri=None, status_code=None, + state=None, realm=None, **extra_attributes): + super().__init__( + description, uri, status_code, state) + self.realm = realm + self.extra_attributes = extra_attributes + + def get_headers(self): + """If the protected resource request does not include authentication + credentials or does not contain an access token that enables access + to the protected resource, the resource server MUST include the HTTP + "WWW-Authenticate" response header field; it MAY include it in + response to other conditions as well. + + https://tools.ietf.org/html/rfc6750#section-3 + """ + headers = super().get_headers() + + extras = [] + if self.realm: + extras.append(f'realm="{self.realm}"') + if self.extra_attributes: + extras.extend([f'{k}="{self.extra_attributes[k]}"' for k in self.extra_attributes]) + extras.append(f'error="{self.error}"') + error_description = self.get_error_description() + extras.append(f'error_description="{error_description}"') + headers.append( + ('WWW-Authenticate', 'Bearer ' + ', '.join(extras)) + ) + return headers + + +class InsufficientScopeError(OAuth2Error): + """The request requires higher privileges than provided by the + access token. The resource server SHOULD respond with the HTTP + 403 (Forbidden) status code and MAY include the "scope" + attribute with the scope necessary to access the protected + resource. + + https://tools.ietf.org/html/rfc6750#section-3.1 + """ + error = 'insufficient_scope' + description = 'The request requires higher privileges than provided by the access token.' + status_code = 403 diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/parameters.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/parameters.py new file mode 100644 index 00000000..8914a909 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/parameters.py @@ -0,0 +1,41 @@ +from authlib.common.urls import add_params_to_qs, add_params_to_uri + + +def add_to_uri(token, uri): + """Add a Bearer Token to the request URI. + Not recommended, use only if client can't use authorization header or body. + + http://www.example.com/path?access_token=h480djs93hd8 + """ + return add_params_to_uri(uri, [('access_token', token)]) + + +def add_to_headers(token, headers=None): + """Add a Bearer Token to the request URI. + Recommended method of passing bearer tokens. + + Authorization: Bearer h480djs93hd8 + """ + headers = headers or {} + headers['Authorization'] = f'Bearer {token}' + return headers + + +def add_to_body(token, body=None): + """Add a Bearer Token to the request body. + + access_token=h480djs93hd8 + """ + if body is None: + body = '' + return add_params_to_qs(body, [('access_token', token)]) + + +def add_bearer_token(token, uri, headers, body, placement='header'): + if placement in ('uri', 'url', 'query'): + uri = add_to_uri(token, uri) + elif placement in ('header', 'headers'): + headers = add_to_headers(token, headers) + elif placement == 'body': + body = add_to_body(token, body) + return uri, headers, body diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/token.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/token.py new file mode 100644 index 00000000..1ab4dc5b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/token.py @@ -0,0 +1,88 @@ +class BearerTokenGenerator: + """Bearer token generator which can create the payload for token response + by OAuth 2 server. A typical token response would be: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json;charset=UTF-8 + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"mF_9.B5f-4.1JqM", + "token_type":"Bearer", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" + } + """ + + #: default expires_in value + DEFAULT_EXPIRES_IN = 3600 + #: default expires_in value differentiate by grant_type + GRANT_TYPES_EXPIRES_IN = { + 'authorization_code': 864000, + 'implicit': 3600, + 'password': 864000, + 'client_credentials': 864000 + } + + def __init__(self, access_token_generator, + refresh_token_generator=None, + expires_generator=None): + self.access_token_generator = access_token_generator + self.refresh_token_generator = refresh_token_generator + self.expires_generator = expires_generator + + def _get_expires_in(self, client, grant_type): + if self.expires_generator is None: + expires_in = self.GRANT_TYPES_EXPIRES_IN.get( + grant_type, self.DEFAULT_EXPIRES_IN) + elif callable(self.expires_generator): + expires_in = self.expires_generator(client, grant_type) + elif isinstance(self.expires_generator, int): + expires_in = self.expires_generator + else: + expires_in = self.DEFAULT_EXPIRES_IN + return expires_in + + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + + def generate(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + scope = self.get_allowed_scope(client, scope) + access_token = self.access_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope) + if expires_in is None: + expires_in = self._get_expires_in(client, grant_type) + + token = { + 'token_type': 'Bearer', + 'access_token': access_token, + } + if expires_in: + token['expires_in'] = expires_in + if include_refresh_token and self.refresh_token_generator: + token['refresh_token'] = self.refresh_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope) + if scope: + token['scope'] = scope + return token + + def __call__(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + return self.generate(grant_type, client, user, scope, expires_in, include_refresh_token) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc6750/validator.py b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/validator.py new file mode 100644 index 00000000..d4790145 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc6750/validator.py @@ -0,0 +1,39 @@ +""" + authlib.oauth2.rfc6750.validator + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Validate Bearer Token for in request, scope and token. +""" + +from ..rfc6749 import TokenValidator +from .errors import ( + InvalidTokenError, + InsufficientScopeError +) + + +class BearerTokenValidator(TokenValidator): + TOKEN_TYPE = 'bearer' + + def authenticate_token(self, token_string): + """A method to query token from database with the given token string. + Developers MUST re-implement this method. For instance:: + + def authenticate_token(self, token_string): + return get_token_from_database(token_string) + + :param token_string: A string to represent the access_token. + :return: token + """ + raise NotImplementedError() + + def validate_token(self, token, scopes, request): + """Check if token is active and matches the requested scopes.""" + if not token: + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if token.is_expired(): + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if token.is_revoked(): + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if self.scope_insufficient(token.get_scope(), scopes): + raise InsufficientScopeError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__init__.py new file mode 100644 index 00000000..2b9c1202 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__init__.py @@ -0,0 +1,14 @@ +""" + authlib.oauth2.rfc7009 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + OAuth 2.0 Token Revocation. + + https://tools.ietf.org/html/rfc7009 +""" + +from .parameters import prepare_revoke_token_request +from .revocation import RevocationEndpoint + +__all__ = ['prepare_revoke_token_request', 'RevocationEndpoint'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b7fa56c5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/parameters.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/parameters.cpython-311.pyc new file mode 100644 index 00000000..9cf459b8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/parameters.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/revocation.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/revocation.cpython-311.pyc new file mode 100644 index 00000000..bb048821 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/__pycache__/revocation.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/parameters.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/parameters.py new file mode 100644 index 00000000..2a829a75 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/parameters.py @@ -0,0 +1,25 @@ +from authlib.common.urls import add_params_to_qs + + +def prepare_revoke_token_request(token, token_type_hint=None, + body=None, headers=None): + """Construct request body and headers for revocation endpoint. + + :param token: access_token or refresh_token string. + :param token_type_hint: Optional, `access_token` or `refresh_token`. + :param body: current request body. + :param headers: current request headers. + :return: tuple of (body, headers) + + https://tools.ietf.org/html/rfc7009#section-2.1 + """ + params = [('token', token)] + if token_type_hint: + params.append(('token_type_hint', token_type_hint)) + + body = add_params_to_qs(body or '', params) + if headers is None: + headers = {} + + headers['Content-Type'] = 'application/x-www-form-urlencoded' + return body, headers diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7009/revocation.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/revocation.py new file mode 100644 index 00000000..f0984789 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7009/revocation.py @@ -0,0 +1,108 @@ +from authlib.consts import default_json_headers +from ..rfc6749 import TokenEndpoint +from ..rfc6749 import ( + InvalidRequestError, + UnsupportedTokenTypeError, +) + + +class RevocationEndpoint(TokenEndpoint): + """Implementation of revocation endpoint which is described in + `RFC7009`_. + + .. _RFC7009: https://tools.ietf.org/html/rfc7009 + """ + #: Endpoint name to be registered + ENDPOINT_NAME = 'revocation' + + def authenticate_token(self, request, client): + """The client constructs the request by including the following + parameters using the "application/x-www-form-urlencoded" format in + the HTTP request entity-body: + + token + REQUIRED. The token that the client wants to get revoked. + + token_type_hint + OPTIONAL. A hint about the type of the token submitted for + revocation. + """ + self.check_params(request, client) + token = self.query_token(request.form['token'], request.form.get('token_type_hint')) + if token and token.check_client(client): + return token + + def check_params(self, request, client): + if 'token' not in request.form: + raise InvalidRequestError() + + hint = request.form.get('token_type_hint') + if hint and hint not in self.SUPPORTED_TOKEN_TYPES: + raise UnsupportedTokenTypeError() + + def create_endpoint_response(self, request): + """Validate revocation request and create the response for revocation. + For example, a client may request the revocation of a refresh token + with the following request:: + + POST /revoke HTTP/1.1 + Host: server.example.com + Content-Type: application/x-www-form-urlencoded + Authorization: Basic czZCaGRSa3F0MzpnWDFmQmF0M2JW + + token=45ghiukldjahdnhzdauz&token_type_hint=refresh_token + + :returns: (status_code, body, headers) + """ + # The authorization server first validates the client credentials + client = self.authenticate_endpoint_client(request) + + # then verifies whether the token was issued to the client making + # the revocation request + token = self.authenticate_token(request, client) + + # the authorization server invalidates the token + if token: + self.revoke_token(token, request) + self.server.send_signal( + 'after_revoke_token', + token=token, + client=client, + ) + return 200, {}, default_json_headers + + def query_token(self, token_string, token_type_hint): + """Get the token from database/storage by the given token string. + Developers should implement this method:: + + def query_token(self, token_string, token_type_hint): + if token_type_hint == 'access_token': + return Token.query_by_access_token(token_string) + if token_type_hint == 'refresh_token': + return Token.query_by_refresh_token(token_string) + return Token.query_by_access_token(token_string) or \ + Token.query_by_refresh_token(token_string) + """ + raise NotImplementedError() + + def revoke_token(self, token, request): + """Mark token as revoked. Since token MUST be unique, it would be + dangerous to delete it. Consider this situation: + + 1. Jane obtained a token XYZ + 2. Jane revoked (deleted) token XYZ + 3. Bob generated a new token XYZ + 4. Jane can use XYZ to access Bob's resource + + It would be secure to mark a token as revoked:: + + def revoke_token(self, token, request): + hint = request.form.get('token_type_hint') + if hint == 'access_token': + token.access_token_revoked = True + else: + token.access_token_revoked = True + token.refresh_token_revoked = True + token.save() + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__init__.py new file mode 100644 index 00000000..0dbe0b30 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__init__.py @@ -0,0 +1,3 @@ +from .client import AssertionClient + +__all__ = ['AssertionClient'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..98eeef93 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..3f8c34eb Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7521/client.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/client.py new file mode 100644 index 00000000..e7ce2c3c --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7521/client.py @@ -0,0 +1,91 @@ +from authlib.common.encoding import to_native +from authlib.oauth2.base import OAuth2Error + + +class AssertionClient: + """Constructs a new Assertion Framework for OAuth 2.0 Authorization Grants + per RFC7521_. + + .. _RFC7521: https://tools.ietf.org/html/rfc7521 + """ + DEFAULT_GRANT_TYPE = None + ASSERTION_METHODS = {} + token_auth_class = None + oauth_error_class = OAuth2Error + + def __init__(self, session, token_endpoint, issuer, subject, + audience=None, grant_type=None, claims=None, + token_placement='header', scope=None, **kwargs): + + self.session = session + + if audience is None: + audience = token_endpoint + + self.token_endpoint = token_endpoint + + if grant_type is None: + grant_type = self.DEFAULT_GRANT_TYPE + + self.grant_type = grant_type + + # https://tools.ietf.org/html/rfc7521#section-5.1 + self.issuer = issuer + self.subject = subject + self.audience = audience + self.claims = claims + self.scope = scope + if self.token_auth_class is not None: + self.token_auth = self.token_auth_class(None, token_placement, self) + self._kwargs = kwargs + + @property + def token(self): + return self.token_auth.token + + @token.setter + def token(self, token): + self.token_auth.set_token(token) + + def refresh_token(self): + """Using Assertions as Authorization Grants to refresh token as + described in `Section 4.1`_. + + .. _`Section 4.1`: https://tools.ietf.org/html/rfc7521#section-4.1 + """ + generate_assertion = self.ASSERTION_METHODS[self.grant_type] + assertion = generate_assertion( + issuer=self.issuer, + subject=self.subject, + audience=self.audience, + claims=self.claims, + **self._kwargs + ) + data = { + 'assertion': to_native(assertion), + 'grant_type': self.grant_type, + } + if self.scope: + data['scope'] = self.scope + + return self._refresh_token(data) + + def parse_response_token(self, resp): + if resp.status_code >= 500: + resp.raise_for_status() + + token = resp.json() + if 'error' in token: + raise self.oauth_error_class( + error=token['error'], + description=token.get('error_description') + ) + + self.token = token + return self.token + + def _refresh_token(self, data): + resp = self.session.request( + 'POST', self.token_endpoint, data=data, withhold_token=True) + + return self.parse_response_token(resp) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__init__.py new file mode 100644 index 00000000..ec9d3d32 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__init__.py @@ -0,0 +1,37 @@ +""" + authlib.oauth2.rfc7523 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + JSON Web Token (JWT) Profile for OAuth 2.0 Client + Authentication and Authorization Grants. + + https://tools.ietf.org/html/rfc7523 +""" + +from .jwt_bearer import JWTBearerGrant +from .client import ( + JWTBearerClientAssertion, +) +from .assertion import ( + client_secret_jwt_sign, + private_key_jwt_sign, +) +from .auth import ( + ClientSecretJWT, PrivateKeyJWT, +) +from .token import JWTBearerTokenGenerator +from .validator import JWTBearerToken, JWTBearerTokenValidator + +__all__ = [ + 'JWTBearerGrant', + 'JWTBearerClientAssertion', + 'client_secret_jwt_sign', + 'private_key_jwt_sign', + 'ClientSecretJWT', + 'PrivateKeyJWT', + + 'JWTBearerToken', + 'JWTBearerTokenGenerator', + 'JWTBearerTokenValidator', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..379a23fd Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/assertion.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/assertion.cpython-311.pyc new file mode 100644 index 00000000..d21c5058 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/assertion.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/auth.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/auth.cpython-311.pyc new file mode 100644 index 00000000..1bc9eb20 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/auth.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..4f34f250 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/jwt_bearer.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/jwt_bearer.cpython-311.pyc new file mode 100644 index 00000000..145e2412 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/jwt_bearer.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/token.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/token.cpython-311.pyc new file mode 100644 index 00000000..f335590e Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/token.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/validator.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/validator.cpython-311.pyc new file mode 100644 index 00000000..5746f5f8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/__pycache__/validator.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/assertion.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/assertion.py new file mode 100644 index 00000000..0bb9fe7b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/assertion.py @@ -0,0 +1,66 @@ +import time +from authlib.jose import jwt +from authlib.common.security import generate_token + + +def sign_jwt_bearer_assertion( + key, issuer, audience, subject=None, issued_at=None, + expires_at=None, claims=None, header=None, **kwargs): + + if header is None: + header = {} + alg = kwargs.pop('alg', None) + if alg: + header['alg'] = alg + if 'alg' not in header: + raise ValueError('Missing "alg" in header') + + payload = {'iss': issuer, 'aud': audience} + + # subject is not required in Google service + if subject: + payload['sub'] = subject + + if not issued_at: + issued_at = int(time.time()) + + expires_in = kwargs.pop('expires_in', 3600) + if not expires_at: + expires_at = issued_at + expires_in + + payload['iat'] = issued_at + payload['exp'] = expires_at + + if claims: + payload.update(claims) + + return jwt.encode(header, payload, key) + + +def client_secret_jwt_sign(client_secret, client_id, token_endpoint, alg='HS256', + claims=None, **kwargs): + return _sign(client_secret, client_id, token_endpoint, alg, claims, **kwargs) + + +def private_key_jwt_sign(private_key, client_id, token_endpoint, alg='RS256', + claims=None, **kwargs): + return _sign(private_key, client_id, token_endpoint, alg, claims, **kwargs) + + +def _sign(key, client_id, token_endpoint, alg, claims=None, **kwargs): + # REQUIRED. Issuer. This MUST contain the client_id of the OAuth Client. + issuer = client_id + # REQUIRED. Subject. This MUST contain the client_id of the OAuth Client. + subject = client_id + # The Audience SHOULD be the URL of the Authorization Server's Token Endpoint. + audience = token_endpoint + + # jti is required + if claims is None: + claims = {} + if 'jti' not in claims: + claims['jti'] = generate_token(36) + + return sign_jwt_bearer_assertion( + key=key, issuer=issuer, audience=audience, subject=subject, + claims=claims, alg=alg, **kwargs) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/auth.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/auth.py new file mode 100644 index 00000000..77644667 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/auth.py @@ -0,0 +1,94 @@ +from authlib.common.urls import add_params_to_qs +from .assertion import client_secret_jwt_sign, private_key_jwt_sign +from .client import ASSERTION_TYPE + + +class ClientSecretJWT: + """Authentication method for OAuth 2.0 Client. This authentication + method is called ``client_secret_jwt``, which is using ``client_id`` + and ``client_secret`` constructed with JWT to identify a client. + + Here is an example of use ``client_secret_jwt`` with Requests Session:: + + from authlib.integrations.requests_client import OAuth2Session + + token_endpoint = 'https://example.com/oauth/token' + session = OAuth2Session( + 'your-client-id', 'your-client-secret', + token_endpoint_auth_method='client_secret_jwt' + ) + session.register_client_auth_method(ClientSecretJWT(token_endpoint)) + session.fetch_token(token_endpoint) + + :param token_endpoint: A string URL of the token endpoint + :param claims: Extra JWT claims + :param headers: Extra JWT headers + :param alg: ``alg`` value, default is HS256 + """ + name = 'client_secret_jwt' + alg = 'HS256' + + def __init__(self, token_endpoint=None, claims=None, headers=None, alg=None): + self.token_endpoint = token_endpoint + self.claims = claims + self.headers = headers + if alg is not None: + self.alg = alg + + def sign(self, auth, token_endpoint): + return client_secret_jwt_sign( + auth.client_secret, + client_id=auth.client_id, + token_endpoint=token_endpoint, + claims=self.claims, + header=self.headers, + alg=self.alg, + ) + + def __call__(self, auth, method, uri, headers, body): + token_endpoint = self.token_endpoint + if not token_endpoint: + token_endpoint = uri + + client_assertion = self.sign(auth, token_endpoint) + body = add_params_to_qs(body or '', [ + ('client_assertion_type', ASSERTION_TYPE), + ('client_assertion', client_assertion) + ]) + return uri, headers, body + + +class PrivateKeyJWT(ClientSecretJWT): + """Authentication method for OAuth 2.0 Client. This authentication + method is called ``private_key_jwt``, which is using ``client_id`` + and ``private_key`` constructed with JWT to identify a client. + + Here is an example of use ``private_key_jwt`` with Requests Session:: + + from authlib.integrations.requests_client import OAuth2Session + + token_endpoint = 'https://example.com/oauth/token' + session = OAuth2Session( + 'your-client-id', 'your-client-private-key', + token_endpoint_auth_method='private_key_jwt' + ) + session.register_client_auth_method(PrivateKeyJWT(token_endpoint)) + session.fetch_token(token_endpoint) + + :param token_endpoint: A string URL of the token endpoint + :param claims: Extra JWT claims + :param headers: Extra JWT headers + :param alg: ``alg`` value, default is RS256 + """ + name = 'private_key_jwt' + alg = 'RS256' + + def sign(self, auth, token_endpoint): + return private_key_jwt_sign( + auth.client_secret, + client_id=auth.client_id, + token_endpoint=token_endpoint, + claims=self.claims, + header=self.headers, + alg=self.alg, + ) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/client.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/client.py new file mode 100644 index 00000000..2a6a1bfc --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/client.py @@ -0,0 +1,113 @@ +import logging +from authlib.jose import jwt +from authlib.jose.errors import JoseError +from ..rfc6749 import InvalidClientError + +ASSERTION_TYPE = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' +log = logging.getLogger(__name__) + + +class JWTBearerClientAssertion: + """Implementation of Using JWTs for Client Authentication, which is + defined by RFC7523. + """ + #: Value of ``client_assertion_type`` of JWTs + CLIENT_ASSERTION_TYPE = ASSERTION_TYPE + #: Name of the client authentication method + CLIENT_AUTH_METHOD = 'client_assertion_jwt' + + def __init__(self, token_url, validate_jti=True): + self.token_url = token_url + self._validate_jti = validate_jti + + def __call__(self, query_client, request): + data = request.form + assertion_type = data.get('client_assertion_type') + assertion = data.get('client_assertion') + if assertion_type == ASSERTION_TYPE and assertion: + resolve_key = self.create_resolve_key_func(query_client, request) + self.process_assertion_claims(assertion, resolve_key) + return self.authenticate_client(request.client) + log.debug('Authenticate via %r failed', self.CLIENT_AUTH_METHOD) + + def create_claims_options(self): + """Create a claims_options for verify JWT payload claims. Developers + MAY overwrite this method to create a more strict options.""" + # https://tools.ietf.org/html/rfc7523#section-3 + # The Audience SHOULD be the URL of the Authorization Server's Token Endpoint + options = { + 'iss': {'essential': True, 'validate': _validate_iss}, + 'sub': {'essential': True}, + 'aud': {'essential': True, 'value': self.token_url}, + 'exp': {'essential': True}, + } + if self._validate_jti: + options['jti'] = {'essential': True, 'validate': self.validate_jti} + return options + + def process_assertion_claims(self, assertion, resolve_key): + """Extract JWT payload claims from request "assertion", per + `Section 3.1`_. + + :param assertion: assertion string value in the request + :param resolve_key: function to resolve the sign key + :return: JWTClaims + :raise: InvalidClientError + + .. _`Section 3.1`: https://tools.ietf.org/html/rfc7523#section-3.1 + """ + try: + claims = jwt.decode( + assertion, resolve_key, + claims_options=self.create_claims_options() + ) + claims.validate() + except JoseError as e: + log.debug('Assertion Error: %r', e) + raise InvalidClientError() + return claims + + def authenticate_client(self, client): + if client.check_endpoint_auth_method(self.CLIENT_AUTH_METHOD, 'token'): + return client + raise InvalidClientError() + + def create_resolve_key_func(self, query_client, request): + def resolve_key(headers, payload): + # https://tools.ietf.org/html/rfc7523#section-3 + # For client authentication, the subject MUST be the + # "client_id" of the OAuth client + client_id = payload['sub'] + client = query_client(client_id) + if not client: + raise InvalidClientError() + request.client = client + return self.resolve_client_public_key(client, headers) + return resolve_key + + def validate_jti(self, claims, jti): + """Validate if the given ``jti`` value is used before. Developers + MUST implement this method:: + + def validate_jti(self, claims, jti): + key = 'jti:{}-{}'.format(claims['sub'], jti) + if redis.get(key): + return False + redis.set(key, 1, ex=3600) + return True + """ + raise NotImplementedError() + + def resolve_client_public_key(self, client, headers): + """Resolve the client public key for verifying the JWT signature. + A client may have many public keys, in this case, we can retrieve it + via ``kid`` value in headers. Developers MUST implement this method:: + + def resolve_client_public_key(self, client, headers): + return client.public_key + """ + raise NotImplementedError() + + +def _validate_iss(claims, iss): + return claims['sub'] == iss diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/jwt_bearer.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/jwt_bearer.py new file mode 100644 index 00000000..fb672a92 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/jwt_bearer.py @@ -0,0 +1,182 @@ +import logging +from authlib.jose import jwt, JoseError +from ..rfc6749 import BaseGrant, TokenEndpointMixin +from ..rfc6749 import ( + UnauthorizedClientError, + InvalidRequestError, + InvalidGrantError, + InvalidClientError, +) +from .assertion import sign_jwt_bearer_assertion + +log = logging.getLogger(__name__) +JWT_BEARER_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + + +class JWTBearerGrant(BaseGrant, TokenEndpointMixin): + GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + #: Options for verifying JWT payload claims. Developers MAY + #: overwrite this constant to create a more strict options. + CLAIMS_OPTIONS = { + 'iss': {'essential': True}, + 'aud': {'essential': True}, + 'exp': {'essential': True}, + } + + @staticmethod + def sign(key, issuer, audience, subject=None, + issued_at=None, expires_at=None, claims=None, **kwargs): + return sign_jwt_bearer_assertion( + key, issuer, audience, subject, issued_at, + expires_at, claims, **kwargs) + + def process_assertion_claims(self, assertion): + """Extract JWT payload claims from request "assertion", per + `Section 3.1`_. + + :param assertion: assertion string value in the request + :return: JWTClaims + :raise: InvalidGrantError + + .. _`Section 3.1`: https://tools.ietf.org/html/rfc7523#section-3.1 + """ + try: + claims = jwt.decode( + assertion, self.resolve_public_key, + claims_options=self.CLAIMS_OPTIONS) + claims.validate() + except JoseError as e: + log.debug('Assertion Error: %r', e) + raise InvalidGrantError(description=e.description) + return claims + + def resolve_public_key(self, headers, payload): + client = self.resolve_issuer_client(payload['iss']) + return self.resolve_client_key(client, headers, payload) + + def validate_token_request(self): + """The client makes a request to the token endpoint by sending the + following parameters using the "application/x-www-form-urlencoded" + format per `Section 2.1`_: + + grant_type + REQUIRED. Value MUST be set to + "urn:ietf:params:oauth:grant-type:jwt-bearer". + + assertion + REQUIRED. Value MUST contain a single JWT. + + scope + OPTIONAL. + + The following example demonstrates an access token request with a JWT + as an authorization grant: + + .. code-block:: http + + POST /token.oauth2 HTTP/1.1 + Host: as.example.com + Content-Type: application/x-www-form-urlencoded + + grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer + &assertion=eyJhbGciOiJFUzI1NiIsImtpZCI6IjE2In0. + eyJpc3Mi[...omitted for brevity...]. + J9l-ZhwP[...omitted for brevity...] + + .. _`Section 2.1`: https://tools.ietf.org/html/rfc7523#section-2.1 + """ + assertion = self.request.form.get('assertion') + if not assertion: + raise InvalidRequestError('Missing "assertion" in request') + + claims = self.process_assertion_claims(assertion) + client = self.resolve_issuer_client(claims['iss']) + log.debug('Validate token request of %s', client) + + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + self.request.client = client + self.validate_requested_scope() + + subject = claims.get('sub') + if subject: + user = self.authenticate_user(subject) + if not user: + raise InvalidGrantError(description='Invalid "sub" value in assertion') + + log.debug('Check client(%s) permission to User(%s)', client, user) + if not self.has_granted_permission(client, user): + raise InvalidClientError( + description='Client has no permission to access user data') + self.request.user = user + + def create_token_response(self): + """If valid and authorized, the authorization server issues an access + token. + """ + token = self.generate_token( + scope=self.request.scope, + user=self.request.user, + include_refresh_token=False, + ) + log.debug('Issue token %r to %r', token, self.request.client) + self.save_token(token) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def resolve_issuer_client(self, issuer): + """Fetch client via "iss" in assertion claims. Developers MUST + implement this method in subclass, e.g.:: + + def resolve_issuer_client(self, issuer): + return Client.query_by_iss(issuer) + + :param issuer: "iss" value in assertion + :return: Client instance + """ + raise NotImplementedError() + + def resolve_client_key(self, client, headers, payload): + """Resolve client key to decode assertion data. Developers MUST + implement this method in subclass. For instance, there is a + "jwks" column on client table, e.g.:: + + def resolve_client_key(self, client, headers, payload): + # from authlib.jose import JsonWebKey + + key_set = JsonWebKey.import_key_set(client.jwks) + return key_set.find_by_kid(headers['kid']) + + :param client: instance of OAuth client model + :param headers: headers part of the JWT + :param payload: payload part of the JWT + :return: ``authlib.jose.Key`` instance + """ + raise NotImplementedError() + + def authenticate_user(self, subject): + """Authenticate user with the given assertion claims. Developers MUST + implement it in subclass, e.g.:: + + def authenticate_user(self, subject): + return User.get_by_sub(subject) + + :param subject: "sub" value in claims + :return: User instance + """ + raise NotImplementedError() + + def has_granted_permission(self, client, user): + """Check if the client has permission to access the given user's resource. + Developers MUST implement it in subclass, e.g.:: + + def has_granted_permission(self, client, user): + permission = ClientUserGrant.query(client=client, user=user) + return permission.granted + + :param client: instance of OAuth client model + :param user: instance of User model + :return: bool + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/token.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/token.py new file mode 100644 index 00000000..27fab5f4 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/token.py @@ -0,0 +1,93 @@ +import time +from authlib.common.encoding import to_native +from authlib.jose import jwt + + +class JWTBearerTokenGenerator: + """A JSON Web Token formatted bearer token generator for jwt-bearer grant type. + This token generator can be registered into authorization server:: + + authorization_server.register_token_generator( + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + JWTBearerTokenGenerator(private_rsa_key), + ) + + In this way, we can generate the token into JWT format. And we don't have to + save this token into database, since it will be short time valid. Consider to + rewrite ``JWTBearerGrant.save_token``:: + + class MyJWTBearerGrant(JWTBearerGrant): + def save_token(self, token): + pass + + :param secret_key: private RSA key in bytes, JWK or JWK Set. + :param issuer: a string or URI of the issuer + :param alg: ``alg`` to use in JWT + """ + DEFAULT_EXPIRES_IN = 3600 + + def __init__(self, secret_key, issuer=None, alg='RS256'): + self.secret_key = secret_key + self.issuer = issuer + self.alg = alg + + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + + @staticmethod + def get_sub_value(user): + """Return user's ID as ``sub`` value in token payload. For instance:: + + @staticmethod + def get_sub_value(user): + return str(user.id) + """ + return user.get_user_id() + + def get_token_data(self, grant_type, client, expires_in, user=None, scope=None): + scope = self.get_allowed_scope(client, scope) + issued_at = int(time.time()) + data = { + 'scope': scope, + 'grant_type': grant_type, + 'iat': issued_at, + 'exp': issued_at + expires_in, + 'client_id': client.get_client_id(), + } + if self.issuer: + data['iss'] = self.issuer + if user: + data['sub'] = self.get_sub_value(user) + return data + + def generate(self, grant_type, client, user=None, scope=None, expires_in=None): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + if not expires_in: + expires_in = self.DEFAULT_EXPIRES_IN + + token_data = self.get_token_data(grant_type, client, expires_in, user, scope) + access_token = jwt.encode({'alg': self.alg}, token_data, key=self.secret_key, check=False) + token = { + 'token_type': 'Bearer', + 'access_token': to_native(access_token), + 'expires_in': expires_in + } + if scope: + token['scope'] = scope + return token + + def __call__(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + # there is absolutely no refresh token in JWT format + return self.generate(grant_type, client, user, scope, expires_in) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7523/validator.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/validator.py new file mode 100644 index 00000000..f2423b8a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7523/validator.py @@ -0,0 +1,54 @@ +import time +import logging +from authlib.jose import jwt, JoseError, JWTClaims +from ..rfc6749 import TokenMixin +from ..rfc6750 import BearerTokenValidator + +logger = logging.getLogger(__name__) + + +class JWTBearerToken(TokenMixin, JWTClaims): + def check_client(self, client): + return self['client_id'] == client.get_client_id() + + def get_scope(self): + return self.get('scope') + + def get_expires_in(self): + return self['exp'] - self['iat'] + + def is_expired(self): + return self['exp'] < time.time() + + def is_revoked(self): + return False + + +class JWTBearerTokenValidator(BearerTokenValidator): + TOKEN_TYPE = 'bearer' + token_cls = JWTBearerToken + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + super().__init__(realm, **extra_attributes) + self.public_key = public_key + claims_options = { + 'exp': {'essential': True}, + 'client_id': {'essential': True}, + 'grant_type': {'essential': True}, + } + if issuer: + claims_options['iss'] = {'essential': True, 'value': issuer} + self.claims_options = claims_options + + def authenticate_token(self, token_string): + try: + claims = jwt.decode( + token_string, self.public_key, + claims_options=self.claims_options, + claims_cls=self.token_cls, + ) + claims.validate() + return claims + except JoseError as error: + logger.debug('Authenticate token failed. %r', error) + return None diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__init__.py new file mode 100644 index 00000000..8ebb0709 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__init__.py @@ -0,0 +1,25 @@ +""" + authlib.oauth2.rfc7591 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + OAuth 2.0 Dynamic Client Registration Protocol. + + https://tools.ietf.org/html/rfc7591 +""" + + +from .claims import ClientMetadataClaims +from .endpoint import ClientRegistrationEndpoint +from .errors import ( + InvalidRedirectURIError, + InvalidClientMetadataError, + InvalidSoftwareStatementError, + UnapprovedSoftwareStatementError, +) + +__all__ = [ + 'ClientMetadataClaims', 'ClientRegistrationEndpoint', + 'InvalidRedirectURIError', 'InvalidClientMetadataError', + 'InvalidSoftwareStatementError', 'UnapprovedSoftwareStatementError', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..2203051a Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/claims.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/claims.cpython-311.pyc new file mode 100644 index 00000000..1d9c3b18 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/claims.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/endpoint.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/endpoint.cpython-311.pyc new file mode 100644 index 00000000..2800807a Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/endpoint.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..c0732f09 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/claims.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/claims.py new file mode 100644 index 00000000..b6157b52 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/claims.py @@ -0,0 +1,218 @@ +from authlib.jose import BaseClaims, JsonWebKey +from authlib.jose.errors import InvalidClaimError +from authlib.common.urls import is_valid_url + + +class ClientMetadataClaims(BaseClaims): + # https://tools.ietf.org/html/rfc7591#section-2 + REGISTERED_CLAIMS = [ + 'redirect_uris', + 'token_endpoint_auth_method', + 'grant_types', + 'response_types', + 'client_name', + 'client_uri', + 'logo_uri', + 'scope', + 'contacts', + 'tos_uri', + 'policy_uri', + 'jwks_uri', + 'jwks', + 'software_id', + 'software_version', + ] + + def validate(self): + self._validate_essential_claims() + self.validate_redirect_uris() + self.validate_token_endpoint_auth_method() + self.validate_grant_types() + self.validate_response_types() + self.validate_client_name() + self.validate_client_uri() + self.validate_logo_uri() + self.validate_scope() + self.validate_contacts() + self.validate_tos_uri() + self.validate_policy_uri() + self.validate_jwks_uri() + self.validate_jwks() + self.validate_software_id() + self.validate_software_version() + + def validate_redirect_uris(self): + """Array of redirection URI strings for use in redirect-based flows + such as the authorization code and implicit flows. As required by + Section 2 of OAuth 2.0 [RFC6749], clients using flows with + redirection MUST register their redirection URI values. + Authorization servers that support dynamic registration for + redirect-based flows MUST implement support for this metadata + value. + """ + uris = self.get('redirect_uris') + if uris: + for uri in uris: + self._validate_uri('redirect_uris', uri) + + def validate_token_endpoint_auth_method(self): + """String indicator of the requested authentication method for the + token endpoint. + """ + # If unspecified or omitted, the default is "client_secret_basic" + if 'token_endpoint_auth_method' not in self: + self['token_endpoint_auth_method'] = 'client_secret_basic' + self._validate_claim_value('token_endpoint_auth_method') + + def validate_grant_types(self): + """Array of OAuth 2.0 grant type strings that the client can use at + the token endpoint. + """ + self._validate_claim_value('grant_types') + + def validate_response_types(self): + """Array of the OAuth 2.0 response type strings that the client can + use at the authorization endpoint. + """ + self._validate_claim_value('response_types') + + def validate_client_name(self): + """Human-readable string name of the client to be presented to the + end-user during authorization. If omitted, the authorization + server MAY display the raw "client_id" value to the end-user + instead. It is RECOMMENDED that clients always send this field. + The value of this field MAY be internationalized, as described in + Section 2.2. + """ + + def validate_client_uri(self): + """URL string of a web page providing information about the client. + If present, the server SHOULD display this URL to the end-user in + a clickable fashion. It is RECOMMENDED that clients always send + this field. The value of this field MUST point to a valid web + page. The value of this field MAY be internationalized, as + described in Section 2.2. + """ + self._validate_uri('client_uri') + + def validate_logo_uri(self): + """URL string that references a logo for the client. If present, the + server SHOULD display this image to the end-user during approval. + The value of this field MUST point to a valid image file. The + value of this field MAY be internationalized, as described in + Section 2.2. + """ + self._validate_uri('logo_uri') + + def validate_scope(self): + """String containing a space-separated list of scope values (as + described in Section 3.3 of OAuth 2.0 [RFC6749]) that the client + can use when requesting access tokens. The semantics of values in + this list are service specific. If omitted, an authorization + server MAY register a client with a default set of scopes. + """ + self._validate_claim_value('scope') + + def validate_contacts(self): + """Array of strings representing ways to contact people responsible + for this client, typically email addresses. The authorization + server MAY make these contact addresses available to end-users for + support requests for the client. See Section 6 for information on + Privacy Considerations. + """ + if 'contacts' in self and not isinstance(self['contacts'], list): + raise InvalidClaimError('contacts') + + def validate_tos_uri(self): + """URL string that points to a human-readable terms of service + document for the client that describes a contractual relationship + between the end-user and the client that the end-user accepts when + authorizing the client. The authorization server SHOULD display + this URL to the end-user if it is provided. The value of this + field MUST point to a valid web page. The value of this field MAY + be internationalized, as described in Section 2.2. + """ + self._validate_uri('tos_uri') + + def validate_policy_uri(self): + """URL string that points to a human-readable privacy policy document + that describes how the deployment organization collects, uses, + retains, and discloses personal data. The authorization server + SHOULD display this URL to the end-user if it is provided. The + value of this field MUST point to a valid web page. The value of + this field MAY be internationalized, as described in Section 2.2. + """ + self._validate_uri('policy_uri') + + def validate_jwks_uri(self): + """URL string referencing the client's JSON Web Key (JWK) Set + [RFC7517] document, which contains the client's public keys. The + value of this field MUST point to a valid JWK Set document. These + keys can be used by higher-level protocols that use signing or + encryption. For instance, these keys might be used by some + applications for validating signed requests made to the token + endpoint when using JWTs for client authentication [RFC7523]. Use + of this parameter is preferred over the "jwks" parameter, as it + allows for easier key rotation. The "jwks_uri" and "jwks" + parameters MUST NOT both be present in the same request or + response. + """ + # TODO: use real HTTP library + self._validate_uri('jwks_uri') + + def validate_jwks(self): + """Client's JSON Web Key Set [RFC7517] document value, which contains + the client's public keys. The value of this field MUST be a JSON + object containing a valid JWK Set. These keys can be used by + higher-level protocols that use signing or encryption. This + parameter is intended to be used by clients that cannot use the + "jwks_uri" parameter, such as native clients that cannot host + public URLs. The "jwks_uri" and "jwks" parameters MUST NOT both + be present in the same request or response. + """ + if 'jwks' in self: + if 'jwks_uri' in self: + # The "jwks_uri" and "jwks" parameters MUST NOT both be present + raise InvalidClaimError('jwks') + + jwks = self['jwks'] + try: + key_set = JsonWebKey.import_key_set(jwks) + if not key_set: + raise InvalidClaimError('jwks') + except ValueError: + raise InvalidClaimError('jwks') + + def validate_software_id(self): + """A unique identifier string (e.g., a Universally Unique Identifier + (UUID)) assigned by the client developer or software publisher + used by registration endpoints to identify the client software to + be dynamically registered. Unlike "client_id", which is issued by + the authorization server and SHOULD vary between instances, the + "software_id" SHOULD remain the same for all instances of the + client software. The "software_id" SHOULD remain the same across + multiple updates or versions of the same piece of software. The + value of this field is not intended to be human readable and is + usually opaque to the client and authorization server. + """ + + def validate_software_version(self): + """A version identifier string for the client software identified by + "software_id". The value of the "software_version" SHOULD change + on any update to the client software identified by the same + "software_id". The value of this field is intended to be compared + using string equality matching and no other comparison semantics + are defined by this specification. The value of this field is + outside the scope of this specification, but it is not intended to + be human readable and is usually opaque to the client and + authorization server. The definition of what constitutes an + update to client software that would trigger a change to this + value is specific to the software itself and is outside the scope + of this specification. + """ + + def _validate_uri(self, key, uri=None): + if uri is None: + uri = self.get(key) + if uri and not is_valid_url(uri): + raise InvalidClaimError(key) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/endpoint.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/endpoint.py new file mode 100644 index 00000000..d26e0614 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/endpoint.py @@ -0,0 +1,211 @@ +import os +import time +import binascii +from authlib.consts import default_json_headers +from authlib.common.security import generate_token +from authlib.jose import JsonWebToken, JoseError +from ..rfc6749 import AccessDeniedError, InvalidRequestError +from ..rfc6749 import scope_to_list +from .claims import ClientMetadataClaims +from .errors import ( + InvalidClientMetadataError, + UnapprovedSoftwareStatementError, + InvalidSoftwareStatementError, +) + + +class ClientRegistrationEndpoint: + """The client registration endpoint is an OAuth 2.0 endpoint designed to + allow a client to be registered with the authorization server. + """ + ENDPOINT_NAME = 'client_registration' + + #: The claims validation class + claims_class = ClientMetadataClaims + + #: Rewrite this value with a list to support ``software_statement`` + #: e.g. ``software_statement_alg_values_supported = ['RS256']`` + software_statement_alg_values_supported = None + + def __init__(self, server): + self.server = server + + def __call__(self, request): + return self.create_registration_response(request) + + def create_registration_response(self, request): + token = self.authenticate_token(request) + if not token: + raise AccessDeniedError() + + request.credential = token + + client_metadata = self.extract_client_metadata(request) + client_info = self.generate_client_info() + body = {} + body.update(client_metadata) + body.update(client_info) + client = self.save_client(client_info, client_metadata, request) + registration_info = self.generate_client_registration_info(client, request) + if registration_info: + body.update(registration_info) + return 201, body, default_json_headers + + def extract_client_metadata(self, request): + if not request.data: + raise InvalidRequestError() + + json_data = request.data.copy() + software_statement = json_data.pop('software_statement', None) + if software_statement and self.software_statement_alg_values_supported: + data = self.extract_software_statement(software_statement, request) + json_data.update(data) + + options = self.get_claims_options() + claims = self.claims_class(json_data, {}, options, self.get_server_metadata()) + try: + claims.validate() + except JoseError as error: + raise InvalidClientMetadataError(error.description) + return claims.get_registered_claims() + + def extract_software_statement(self, software_statement, request): + key = self.resolve_public_key(request) + if not key: + raise UnapprovedSoftwareStatementError() + + try: + jwt = JsonWebToken(self.software_statement_alg_values_supported) + claims = jwt.decode(software_statement, key) + # there is no need to validate claims + return claims + except JoseError: + raise InvalidSoftwareStatementError() + + def get_claims_options(self): + """Generate claims options validation from Authorization Server metadata.""" + metadata = self.get_server_metadata() + if not metadata: + return {} + + scopes_supported = metadata.get('scopes_supported') + response_types_supported = metadata.get('response_types_supported') + grant_types_supported = metadata.get('grant_types_supported') + auth_methods_supported = metadata.get('token_endpoint_auth_methods_supported') + options = {} + if scopes_supported is not None: + scopes_supported = set(scopes_supported) + + def _validate_scope(claims, value): + if not value: + return True + scopes = set(scope_to_list(value)) + return scopes_supported.issuperset(scopes) + + options['scope'] = {'validate': _validate_scope} + + if response_types_supported is not None: + response_types_supported = set(response_types_supported) + + def _validate_response_types(claims, value): + # If omitted, the default is that the client will use only the "code" + # response type. + response_types = set(value) if value else {"code"} + return response_types_supported.issuperset(response_types) + + options['response_types'] = {'validate': _validate_response_types} + + if grant_types_supported is not None: + grant_types_supported = set(grant_types_supported) + + def _validate_grant_types(claims, value): + # If omitted, the default behavior is that the client will use only + # the "authorization_code" Grant Type. + grant_types = set(value) if value else {"authorization_code"} + return grant_types_supported.issuperset(grant_types) + + options['grant_types'] = {'validate': _validate_grant_types} + + if auth_methods_supported is not None: + options['token_endpoint_auth_method'] = {'values': auth_methods_supported} + + return options + + def generate_client_info(self): + # https://tools.ietf.org/html/rfc7591#section-3.2.1 + client_id = self.generate_client_id() + client_secret = self.generate_client_secret() + client_id_issued_at = int(time.time()) + client_secret_expires_at = 0 + return dict( + client_id=client_id, + client_secret=client_secret, + client_id_issued_at=client_id_issued_at, + client_secret_expires_at=client_secret_expires_at, + ) + + def generate_client_registration_info(self, client, request): + """Generate ```registration_client_uri`` and ``registration_access_token`` + for RFC7592. This method returns ``None`` by default. Developers MAY rewrite + this method to return registration information.""" + return None + + def create_endpoint_request(self, request): + return self.server.create_json_request(request) + + def generate_client_id(self): + """Generate ``client_id`` value. Developers MAY rewrite this method + to use their own way to generate ``client_id``. + """ + return generate_token(42) + + def generate_client_secret(self): + """Generate ``client_secret`` value. Developers MAY rewrite this method + to use their own way to generate ``client_secret``. + """ + return binascii.hexlify(os.urandom(24)).decode('ascii') + + def get_server_metadata(self): + """Return server metadata which includes supported grant types, + response types and etc. + """ + raise NotImplementedError() + + def authenticate_token(self, request): + """Authenticate current credential who is requesting to register a client. + Developers MUST implement this method in subclass:: + + def authenticate_token(self, request): + auth = request.headers.get('Authorization') + return get_token_by_auth(auth) + + :return: token instance + """ + raise NotImplementedError() + + def resolve_public_key(self, request): + """Resolve a public key for decoding ``software_statement``. If + ``enable_software_statement=True``, developers MUST implement this + method in subclass:: + + def resolve_public_key(self, request): + return get_public_key_from_user(request.credential) + + :return: JWK or Key string + """ + raise NotImplementedError() + + def save_client(self, client_info, client_metadata, request): + """Save client into database. Developers MUST implement this method + in subclass:: + + def save_client(self, client_info, client_metadata, request): + client = OAuthClient( + client_id=client_info['client_id'], + client_secret=client_info['client_secret'], + ... + ) + client.save() + return client + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7591/errors.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/errors.py new file mode 100644 index 00000000..31693c04 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7591/errors.py @@ -0,0 +1,33 @@ +from ..rfc6749 import OAuth2Error + + +class InvalidRedirectURIError(OAuth2Error): + """The value of one or more redirection URIs is invalid. + https://tools.ietf.org/html/rfc7591#section-3.2.2 + """ + error = 'invalid_redirect_uri' + + +class InvalidClientMetadataError(OAuth2Error): + """The value of one of the client metadata fields is invalid and the + server has rejected this request. Note that an authorization + server MAY choose to substitute a valid value for any requested + parameter of a client's metadata. + https://tools.ietf.org/html/rfc7591#section-3.2.2 + """ + error = 'invalid_client_metadata' + + +class InvalidSoftwareStatementError(OAuth2Error): + """The software statement presented is invalid. + https://tools.ietf.org/html/rfc7591#section-3.2.2 + """ + error = 'invalid_software_statement' + + +class UnapprovedSoftwareStatementError(OAuth2Error): + """The software statement presented is not approved for use by this + authorization server. + https://tools.ietf.org/html/rfc7591#section-3.2.2 + """ + error = 'unapproved_software_statement' diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__init__.py new file mode 100644 index 00000000..6a6457be --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__init__.py @@ -0,0 +1,13 @@ +""" + authlib.oauth2.rfc7592 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + OAuth 2.0 Dynamic Client Registration Management Protocol. + + https://tools.ietf.org/html/rfc7592 +""" + +from .endpoint import ClientConfigurationEndpoint + +__all__ = ['ClientConfigurationEndpoint'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..8a198b04 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/endpoint.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/endpoint.cpython-311.pyc new file mode 100644 index 00000000..25605f94 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/__pycache__/endpoint.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7592/endpoint.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/endpoint.py new file mode 100644 index 00000000..cec9aad1 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7592/endpoint.py @@ -0,0 +1,256 @@ +from authlib.consts import default_json_headers +from authlib.jose import JoseError +from ..rfc7591.claims import ClientMetadataClaims +from ..rfc6749 import scope_to_list +from ..rfc6749 import AccessDeniedError +from ..rfc6749 import InvalidClientError +from ..rfc6749 import InvalidRequestError +from ..rfc6749 import UnauthorizedClientError +from ..rfc7591 import InvalidClientMetadataError + + +class ClientConfigurationEndpoint: + ENDPOINT_NAME = 'client_configuration' + + #: The claims validation class + claims_class = ClientMetadataClaims + + def __init__(self, server): + self.server = server + + def __call__(self, request): + return self.create_configuration_response(request) + + def create_configuration_response(self, request): + # This request is authenticated by the registration access token issued + # to the client. + token = self.authenticate_token(request) + if not token: + raise AccessDeniedError() + + request.credential = token + + client = self.authenticate_client(request) + if not client: + # If the client does not exist on this server, the server MUST respond + # with HTTP 401 Unauthorized and the registration access token used to + # make this request SHOULD be immediately revoked. + self.revoke_access_token(request, token) + raise InvalidClientError(status_code=401) + + if not self.check_permission(client, request): + # If the client does not have permission to read its record, the server + # MUST return an HTTP 403 Forbidden. + raise UnauthorizedClientError(status_code=403) + + request.client = client + + if request.method == 'GET': + return self.create_read_client_response(client, request) + elif request.method == 'DELETE': + return self.create_delete_client_response(client, request) + elif request.method == 'PUT': + return self.create_update_client_response(client, request) + + def create_endpoint_request(self, request): + return self.server.create_json_request(request) + + def create_read_client_response(self, client, request): + body = self.introspect_client(client) + body.update(self.generate_client_registration_info(client, request)) + return 200, body, default_json_headers + + def create_delete_client_response(self, client, request): + self.delete_client(client, request) + headers = [ + ('Cache-Control', 'no-store'), + ('Pragma', 'no-cache'), + ] + return 204, '', headers + + def create_update_client_response(self, client, request): + # The updated client metadata fields request MUST NOT include the + # 'registration_access_token', 'registration_client_uri', + # 'client_secret_expires_at', or 'client_id_issued_at' fields + must_not_include = ( + 'registration_access_token', + 'registration_client_uri', + 'client_secret_expires_at', + 'client_id_issued_at', + ) + for k in must_not_include: + if k in request.data: + raise InvalidRequestError() + + # The client MUST include its 'client_id' field in the request + client_id = request.data.get('client_id') + if not client_id: + raise InvalidRequestError() + if client_id != client.get_client_id(): + raise InvalidRequestError() + + # If the client includes the 'client_secret' field in the request, + # the value of this field MUST match the currently issued client + # secret for that client. + if 'client_secret' in request.data: + if not client.check_client_secret(request.data['client_secret']): + raise InvalidRequestError() + + client_metadata = self.extract_client_metadata(request) + client = self.update_client(client, client_metadata, request) + return self.create_read_client_response(client, request) + + def extract_client_metadata(self, request): + json_data = request.data.copy() + options = self.get_claims_options() + claims = self.claims_class(json_data, {}, options, self.get_server_metadata()) + + try: + claims.validate() + except JoseError as error: + raise InvalidClientMetadataError(error.description) + return claims.get_registered_claims() + + def get_claims_options(self): + metadata = self.get_server_metadata() + if not metadata: + return {} + + scopes_supported = metadata.get('scopes_supported') + response_types_supported = metadata.get('response_types_supported') + grant_types_supported = metadata.get('grant_types_supported') + auth_methods_supported = metadata.get('token_endpoint_auth_methods_supported') + options = {} + if scopes_supported is not None: + scopes_supported = set(scopes_supported) + + def _validate_scope(claims, value): + if not value: + return True + scopes = set(scope_to_list(value)) + return scopes_supported.issuperset(scopes) + + options['scope'] = {'validate': _validate_scope} + + if response_types_supported is not None: + response_types_supported = set(response_types_supported) + + def _validate_response_types(claims, value): + return response_types_supported.issuperset(set(value)) + + options['response_types'] = {'validate': _validate_response_types} + + if grant_types_supported is not None: + grant_types_supported = set(grant_types_supported) + + def _validate_grant_types(claims, value): + return grant_types_supported.issuperset(set(value)) + + options['grant_types'] = {'validate': _validate_grant_types} + + if auth_methods_supported is not None: + options['token_endpoint_auth_method'] = {'values': auth_methods_supported} + + return options + + def introspect_client(self, client): + return {**client.client_info, **client.client_metadata} + + def generate_client_registration_info(self, client, request): + """Generate ```registration_client_uri`` and ``registration_access_token`` + for RFC7592. By default this method returns the values sent in the current + request. Developers MUST rewrite this method to return different registration + information.:: + + def generate_client_registration_info(self, client, request):{ + access_token = request.headers['Authorization'].split(' ')[1] + return { + 'registration_client_uri': request.uri, + 'registration_access_token': access_token, + } + + :param client: the instance of OAuth client + :param request: formatted request instance + """ + raise NotImplementedError() + + def authenticate_token(self, request): + """Authenticate current credential who is requesting to register a client. + Developers MUST implement this method in subclass:: + + def authenticate_token(self, request): + auth = request.headers.get('Authorization') + return get_token_by_auth(auth) + + :return: token instance + """ + raise NotImplementedError() + + def authenticate_client(self, request): + """Read a client from the request payload. + Developers MUST implement this method in subclass:: + + def authenticate_client(self, request): + client_id = request.data.get('client_id') + return Client.get(client_id=client_id) + + :return: client instance + """ + raise NotImplementedError() + + def revoke_access_token(self, token, request): + """Revoke a token access in case an invalid client has been requested. + Developers MUST implement this method in subclass:: + + def revoke_access_token(self, token, request): + token.revoked = True + token.save() + + """ + raise NotImplementedError() + + def check_permission(self, client, request): + """Checks wether the current client is allowed to be accessed, edited + or deleted. Developers MUST implement it in subclass, e.g.:: + + def check_permission(self, client, request): + return client.editable + + :return: boolean + """ + raise NotImplementedError() + + def delete_client(self, client, request): + """Delete authorization code from database or cache. Developers MUST + implement it in subclass, e.g.:: + + def delete_client(self, client, request): + client.delete() + + :param client: the instance of OAuth client + :param request: formatted request instance + """ + raise NotImplementedError() + + def update_client(self, client, client_metadata, request): + """Update the client in the database. Developers MUST implement this method + in subclass:: + + def update_client(self, client, client_metadata, request): + client.set_client_metadata({**client.client_metadata, **client_metadata}) + client.save() + return client + + :param client: the instance of OAuth client + :param client_metadata: a dict of the client claims to update + :param request: formatted request instance + :return: client instance + """ + + raise NotImplementedError() + + def get_server_metadata(self): + """Return server metadata which includes supported grant types, + response types and etc. + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__init__.py new file mode 100644 index 00000000..c03043bd --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__init__.py @@ -0,0 +1,13 @@ +""" + authlib.oauth2.rfc7636 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + Proof Key for Code Exchange by OAuth Public Clients. + + https://tools.ietf.org/html/rfc7636 +""" + +from .challenge import CodeChallenge, create_s256_code_challenge + +__all__ = ['CodeChallenge', 'create_s256_code_challenge'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..c26cbbbb Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/challenge.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/challenge.cpython-311.pyc new file mode 100644 index 00000000..70355b35 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/__pycache__/challenge.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7636/challenge.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/challenge.py new file mode 100644 index 00000000..8303092e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7636/challenge.py @@ -0,0 +1,138 @@ +import re +import hashlib +from authlib.common.encoding import to_bytes, to_unicode, urlsafe_b64encode +from ..rfc6749 import ( + InvalidRequestError, + InvalidGrantError, + OAuth2Request, +) + + +CODE_VERIFIER_PATTERN = re.compile(r'^[a-zA-Z0-9\-._~]{43,128}$') + + +def create_s256_code_challenge(code_verifier): + """Create S256 code_challenge with the given code_verifier.""" + data = hashlib.sha256(to_bytes(code_verifier, 'ascii')).digest() + return to_unicode(urlsafe_b64encode(data)) + + +def compare_plain_code_challenge(code_verifier, code_challenge): + # If the "code_challenge_method" from Section 4.3 was "plain", + # they are compared directly + return code_verifier == code_challenge + + +def compare_s256_code_challenge(code_verifier, code_challenge): + # BASE64URL-ENCODE(SHA256(ASCII(code_verifier))) == code_challenge + return create_s256_code_challenge(code_verifier) == code_challenge + + +class CodeChallenge: + """CodeChallenge extension to Authorization Code Grant. It is used to + improve the security of Authorization Code flow for public clients by + sending extra "code_challenge" and "code_verifier" to the authorization + server. + + The AuthorizationCodeGrant SHOULD save the ``code_challenge`` and + ``code_challenge_method`` into database when ``save_authorization_code``. + Then register this extension via:: + + server.register_grant( + AuthorizationCodeGrant, + [CodeChallenge(required=True)] + ) + """ + #: defaults to "plain" if not present in the request + DEFAULT_CODE_CHALLENGE_METHOD = 'plain' + #: supported ``code_challenge_method`` + SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256'] + + CODE_CHALLENGE_METHODS = { + 'plain': compare_plain_code_challenge, + 'S256': compare_s256_code_challenge, + } + + def __init__(self, required=True): + self.required = required + + def __call__(self, grant): + grant.register_hook( + 'after_validate_authorization_request', + self.validate_code_challenge, + ) + grant.register_hook( + 'after_validate_token_request', + self.validate_code_verifier, + ) + + def validate_code_challenge(self, grant): + request: OAuth2Request = grant.request + challenge = request.data.get('code_challenge') + method = request.data.get('code_challenge_method') + if not challenge and not method: + return + + if not challenge: + raise InvalidRequestError('Missing "code_challenge"') + + if method and method not in self.SUPPORTED_CODE_CHALLENGE_METHOD: + raise InvalidRequestError('Unsupported "code_challenge_method"') + + def validate_code_verifier(self, grant): + request: OAuth2Request = grant.request + verifier = request.form.get('code_verifier') + + # public client MUST verify code challenge + if self.required and request.auth_method == 'none' and not verifier: + raise InvalidRequestError('Missing "code_verifier"') + + authorization_code = request.authorization_code + challenge = self.get_authorization_code_challenge(authorization_code) + + # ignore, it is the normal RFC6749 authorization_code request + if not challenge and not verifier: + return + + # challenge exists, code_verifier is required + if not verifier: + raise InvalidRequestError('Missing "code_verifier"') + + if not CODE_VERIFIER_PATTERN.match(verifier): + raise InvalidRequestError('Invalid "code_verifier"') + + # 4.6. Server Verifies code_verifier before Returning the Tokens + method = self.get_authorization_code_challenge_method(authorization_code) + if method is None: + method = self.DEFAULT_CODE_CHALLENGE_METHOD + + func = self.CODE_CHALLENGE_METHODS.get(method) + if not func: + raise RuntimeError(f'No verify method for "{method}"') + + # If the values are not equal, an error response indicating + # "invalid_grant" MUST be returned. + if not func(verifier, challenge): + raise InvalidGrantError(description='Code challenge failed.') + + def get_authorization_code_challenge(self, authorization_code): + """Get "code_challenge" associated with this authorization code. + Developers MAY re-implement it in subclass, the default logic:: + + def get_authorization_code_challenge(self, authorization_code): + return authorization_code.code_challenge + + :param authorization_code: the instance of authorization_code + """ + return authorization_code.code_challenge + + def get_authorization_code_challenge_method(self, authorization_code): + """Get "code_challenge_method" associated with this authorization code. + Developers MAY re-implement it in subclass, the default logic:: + + def get_authorization_code_challenge_method(self, authorization_code): + return authorization_code.code_challenge_method + + :param authorization_code: the instance of authorization_code + """ + return authorization_code.code_challenge_method diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__init__.py new file mode 100644 index 00000000..045aeda5 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__init__.py @@ -0,0 +1,15 @@ +""" + authlib.oauth2.rfc7662 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + OAuth 2.0 Token Introspection. + + https://tools.ietf.org/html/rfc7662 +""" + +from .introspection import IntrospectionEndpoint +from .models import IntrospectionToken +from .token_validator import IntrospectTokenValidator + +__all__ = ['IntrospectionEndpoint', 'IntrospectionToken', 'IntrospectTokenValidator'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..97214944 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/introspection.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/introspection.cpython-311.pyc new file mode 100644 index 00000000..343d0f57 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/introspection.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..b9e15300 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/token_validator.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/token_validator.cpython-311.pyc new file mode 100644 index 00000000..cba1cc69 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/__pycache__/token_validator.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/introspection.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/introspection.py new file mode 100644 index 00000000..515d6ca6 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/introspection.py @@ -0,0 +1,131 @@ +from authlib.consts import default_json_headers +from ..rfc6749 import ( + TokenEndpoint, + InvalidRequestError, + UnsupportedTokenTypeError, +) + + +class IntrospectionEndpoint(TokenEndpoint): + """Implementation of introspection endpoint which is described in + `RFC7662`_. + + .. _RFC7662: https://tools.ietf.org/html/rfc7662 + """ + #: Endpoint name to be registered + ENDPOINT_NAME = 'introspection' + + def authenticate_token(self, request, client): + """The protected resource calls the introspection endpoint using an HTTP + ``POST`` request with parameters sent as + "application/x-www-form-urlencoded" data. The protected resource sends a + parameter representing the token along with optional parameters + representing additional context that is known by the protected resource + to aid the authorization server in its response. + + token + **REQUIRED** The string value of the token. For access tokens, this + is the ``access_token`` value returned from the token endpoint + defined in OAuth 2.0. For refresh tokens, this is the + ``refresh_token`` value returned from the token endpoint as defined + in OAuth 2.0. + + token_type_hint + **OPTIONAL** A hint about the type of the token submitted for + introspection. + """ + + self.check_params(request, client) + token = self.query_token(request.form['token'], request.form.get('token_type_hint')) + if token and self.check_permission(token, client, request): + return token + + def check_params(self, request, client): + params = request.form + if 'token' not in params: + raise InvalidRequestError() + + hint = params.get('token_type_hint') + if hint and hint not in self.SUPPORTED_TOKEN_TYPES: + raise UnsupportedTokenTypeError() + + def create_endpoint_response(self, request): + """Validate introspection request and create the response. + + :returns: (status_code, body, headers) + """ + # The authorization server first validates the client credentials + client = self.authenticate_endpoint_client(request) + + # then verifies whether the token was issued to the client making + # the revocation request + token = self.authenticate_token(request, client) + + # the authorization server invalidates the token + body = self.create_introspection_payload(token) + return 200, body, default_json_headers + + def create_introspection_payload(self, token): + # the token is not active, does not exist on this server, or the + # protected resource is not allowed to introspect this particular + # token, then the authorization server MUST return an introspection + # response with the "active" field set to "false" + if not token: + return {'active': False} + if token.is_expired() or token.is_revoked(): + return {'active': False} + payload = self.introspect_token(token) + if 'active' not in payload: + payload['active'] = True + return payload + + def check_permission(self, token, client, request): + """Check if the request has permission to introspect the token. Developers + MUST implement this method:: + + def check_permission(self, token, client, request): + # only allow a special client to introspect the token + return client.client_id == 'introspection_client' + + :return: bool + """ + raise NotImplementedError() + + def query_token(self, token_string, token_type_hint): + """Get the token from database/storage by the given token string. + Developers should implement this method:: + + def query_token(self, token_string, token_type_hint): + if token_type_hint == 'access_token': + tok = Token.query_by_access_token(token_string) + elif token_type_hint == 'refresh_token': + tok = Token.query_by_refresh_token(token_string) + else: + tok = Token.query_by_access_token(token_string) + if not tok: + tok = Token.query_by_refresh_token(token_string) + return tok + """ + raise NotImplementedError() + + def introspect_token(self, token): + """Read given token and return its introspection metadata as a + dictionary following `Section 2.2`_:: + + def introspect_token(self, token): + return { + 'active': True, + 'client_id': token.client_id, + 'token_type': token.token_type, + 'username': get_token_username(token), + 'scope': token.get_scope(), + 'sub': get_token_user_sub(token), + 'aud': token.client_id, + 'iss': 'https://server.example.com/', + 'exp': token.expires_at, + 'iat': token.issued_at, + } + + .. _`Section 2.2`: https://tools.ietf.org/html/rfc7662#section-2.2 + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/models.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/models.py new file mode 100644 index 00000000..0f4f0c21 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/models.py @@ -0,0 +1,30 @@ +from ..rfc6749 import TokenMixin + + +class IntrospectionToken(dict, TokenMixin): + def get_client_id(self): + return self.get('client_id') + + def get_scope(self): + return self.get('scope') + + def get_expires_in(self): + # this method is only used in refresh token, + # no need to implement it + return 0 + + def get_expires_at(self): + return self.get('exp', 0) + + def __getattr__(self, key): + # https://tools.ietf.org/html/rfc7662#section-2.2 + available_keys = { + 'active', 'scope', 'client_id', 'username', 'token_type', + 'exp', 'iat', 'nbf', 'sub', 'aud', 'iss', 'jti' + } + try: + return object.__getattribute__(self, key) + except AttributeError as error: + if key in available_keys: + return self.get(key) + raise error diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc7662/token_validator.py b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/token_validator.py new file mode 100644 index 00000000..882c8d91 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc7662/token_validator.py @@ -0,0 +1,34 @@ +from ..rfc6749 import TokenValidator +from ..rfc6750 import ( + InvalidTokenError, + InsufficientScopeError +) + + +class IntrospectTokenValidator(TokenValidator): + TOKEN_TYPE = 'bearer' + + def introspect_token(self, token_string): + """Request introspection token endpoint with the given token string, + authorization server will return token information in JSON format. + Developers MUST implement this method before using it:: + + def introspect_token(self, token_string): + # for example, introspection token endpoint has limited + # internal IPs to access, so there is no need to add + # authentication. + url = 'https://example.com/oauth/introspect' + resp = requests.post(url, data={'token': token_string}) + resp.raise_for_status() + return resp.json() + """ + raise NotImplementedError() + + def authenticate_token(self, token_string): + return self.introspect_token(token_string) + + def validate_token(self, token, scopes, request): + if not token or not token['active']: + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if self.scope_insufficient(token.get('scope'), scopes): + raise InsufficientScopeError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__init__.py new file mode 100644 index 00000000..b1b151c5 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__init__.py @@ -0,0 +1,15 @@ +""" + authlib.oauth2.rfc8414 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents a direct implementation of + OAuth 2.0 Authorization Server Metadata. + + https://tools.ietf.org/html/rfc8414 +""" + +from .models import AuthorizationServerMetadata +from .well_known import get_well_known_url + + +__all__ = ['AuthorizationServerMetadata', 'get_well_known_url'] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..cb0faf28 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..e0f1c3b8 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/well_known.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/well_known.cpython-311.pyc new file mode 100644 index 00000000..ce2a0d96 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/__pycache__/well_known.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/models.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/models.py new file mode 100644 index 00000000..2dc790bd --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/models.py @@ -0,0 +1,368 @@ +from authlib.common.urls import urlparse, is_valid_url +from authlib.common.security import is_secure_transport + + +class AuthorizationServerMetadata(dict): + """Define Authorization Server Metadata via `Section 2`_ in RFC8414_. + + .. _RFC8414: https://tools.ietf.org/html/rfc8414 + .. _`Section 2`: https://tools.ietf.org/html/rfc8414#section-2 + """ + REGISTRY_KEYS = [ + 'issuer', 'authorization_endpoint', 'token_endpoint', + 'jwks_uri', 'registration_endpoint', 'scopes_supported', + 'response_types_supported', 'response_modes_supported', + 'grant_types_supported', 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'service_documentation', 'ui_locales_supported', + 'op_policy_uri', 'op_tos_uri', 'revocation_endpoint', + 'revocation_endpoint_auth_methods_supported', + 'revocation_endpoint_auth_signing_alg_values_supported', + 'introspection_endpoint', + 'introspection_endpoint_auth_methods_supported', + 'introspection_endpoint_auth_signing_alg_values_supported', + 'code_challenge_methods_supported', + ] + + def validate_issuer(self): + """REQUIRED. The authorization server's issuer identifier, which is + a URL that uses the "https" scheme and has no query or fragment + components. + """ + issuer = self.get('issuer') + + #: 1. REQUIRED + if not issuer: + raise ValueError('"issuer" is required') + + parsed = urlparse.urlparse(issuer) + + #: 2. uses the "https" scheme + if not is_secure_transport(issuer): + raise ValueError('"issuer" MUST use "https" scheme') + + #: 3. has no query or fragment + if parsed.query or parsed.fragment: + raise ValueError('"issuer" has no query or fragment') + + def validate_authorization_endpoint(self): + """URL of the authorization server's authorization endpoint + [RFC6749]. This is REQUIRED unless no grant types are supported + that use the authorization endpoint. + """ + url = self.get('authorization_endpoint') + if url: + if not is_secure_transport(url): + raise ValueError( + '"authorization_endpoint" MUST use "https" scheme') + return + + grant_types_supported = set(self.grant_types_supported) + authorization_grant_types = {'authorization_code', 'implicit'} + if grant_types_supported & authorization_grant_types: + raise ValueError('"authorization_endpoint" is required') + + def validate_token_endpoint(self): + """URL of the authorization server's token endpoint [RFC6749]. This + is REQUIRED unless only the implicit grant type is supported. + """ + grant_types_supported = self.get('grant_types_supported') + if grant_types_supported and len(grant_types_supported) == 1 and \ + grant_types_supported[0] == 'implicit': + return + + url = self.get('token_endpoint') + if not url: + raise ValueError('"token_endpoint" is required') + + if not is_secure_transport(url): + raise ValueError('"token_endpoint" MUST use "https" scheme') + + def validate_jwks_uri(self): + """OPTIONAL. URL of the authorization server's JWK Set [JWK] + document. The referenced document contains the signing key(s) the + client uses to validate signatures from the authorization server. + This URL MUST use the "https" scheme. The JWK Set MAY also + contain the server's encryption key or keys, which are used by + clients to encrypt requests to the server. When both signing and + encryption keys are made available, a "use" (public key use) + parameter value is REQUIRED for all keys in the referenced JWK Set + to indicate each key's intended usage. + """ + url = self.get('jwks_uri') + if url and not is_secure_transport(url): + raise ValueError('"jwks_uri" MUST use "https" scheme') + + def validate_registration_endpoint(self): + """OPTIONAL. URL of the authorization server's OAuth 2.0 Dynamic + Client Registration endpoint [RFC7591]. + """ + url = self.get('registration_endpoint') + if url and not is_secure_transport(url): + raise ValueError( + '"registration_endpoint" MUST use "https" scheme') + + def validate_scopes_supported(self): + """RECOMMENDED. JSON array containing a list of the OAuth 2.0 + [RFC6749] "scope" values that this authorization server supports. + Servers MAY choose not to advertise some supported scope values + even when this parameter is used. + """ + validate_array_value(self, 'scopes_supported') + + def validate_response_types_supported(self): + """REQUIRED. JSON array containing a list of the OAuth 2.0 + "response_type" values that this authorization server supports. + The array values used are the same as those used with the + "response_types" parameter defined by "OAuth 2.0 Dynamic Client + Registration Protocol" [RFC7591]. + """ + response_types_supported = self.get('response_types_supported') + if not response_types_supported: + raise ValueError('"response_types_supported" is required') + if not isinstance(response_types_supported, list): + raise ValueError('"response_types_supported" MUST be JSON array') + + def validate_response_modes_supported(self): + """OPTIONAL. JSON array containing a list of the OAuth 2.0 + "response_mode" values that this authorization server supports, as + specified in "OAuth 2.0 Multiple Response Type Encoding Practices" + [OAuth.Responses]. If omitted, the default is "["query", + "fragment"]". The response mode value "form_post" is also defined + in "OAuth 2.0 Form Post Response Mode" [OAuth.Post]. + """ + validate_array_value(self, 'response_modes_supported') + + def validate_grant_types_supported(self): + """OPTIONAL. JSON array containing a list of the OAuth 2.0 grant + type values that this authorization server supports. The array + values used are the same as those used with the "grant_types" + parameter defined by "OAuth 2.0 Dynamic Client Registration + Protocol" [RFC7591]. If omitted, the default value is + "["authorization_code", "implicit"]". + """ + validate_array_value(self, 'grant_types_supported') + + def validate_token_endpoint_auth_methods_supported(self): + """OPTIONAL. JSON array containing a list of client authentication + methods supported by this token endpoint. Client authentication + method values are used in the "token_endpoint_auth_method" + parameter defined in Section 2 of [RFC7591]. If omitted, the + default is "client_secret_basic" -- the HTTP Basic Authentication + Scheme specified in Section 2.3.1 of OAuth 2.0 [RFC6749]. + """ + validate_array_value(self, 'token_endpoint_auth_methods_supported') + + def validate_token_endpoint_auth_signing_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWS signing + algorithms ("alg" values) supported by the token endpoint for the + signature on the JWT [JWT] used to authenticate the client at the + token endpoint for the "private_key_jwt" and "client_secret_jwt" + authentication methods. This metadata entry MUST be present if + either of these authentication methods are specified in the + "token_endpoint_auth_methods_supported" entry. No default + algorithms are implied if this entry is omitted. Servers SHOULD + support "RS256". The value "none" MUST NOT be used. + """ + _validate_alg_values( + self, + 'token_endpoint_auth_signing_alg_values_supported', + self.token_endpoint_auth_methods_supported + ) + + def validate_service_documentation(self): + """OPTIONAL. URL of a page containing human-readable information + that developers might want or need to know when using the + authorization server. In particular, if the authorization server + does not support Dynamic Client Registration, then information on + how to register clients needs to be provided in this + documentation. + """ + value = self.get('service_documentation') + if value and not is_valid_url(value): + raise ValueError('"service_documentation" MUST be a URL') + + def validate_ui_locales_supported(self): + """OPTIONAL. Languages and scripts supported for the user interface, + represented as a JSON array of language tag values from BCP 47 + [RFC5646]. If omitted, the set of supported languages and scripts + is unspecified. + """ + validate_array_value(self, 'ui_locales_supported') + + def validate_op_policy_uri(self): + """OPTIONAL. URL that the authorization server provides to the + person registering the client to read about the authorization + server's requirements on how the client can use the data provided + by the authorization server. The registration process SHOULD + display this URL to the person registering the client if it is + given. As described in Section 5, despite the identifier + "op_policy_uri" appearing to be OpenID-specific, its usage in this + specification is actually referring to a general OAuth 2.0 feature + that is not specific to OpenID Connect. + """ + value = self.get('op_policy_uri') + if value and not is_valid_url(value): + raise ValueError('"op_policy_uri" MUST be a URL') + + def validate_op_tos_uri(self): + """OPTIONAL. URL that the authorization server provides to the + person registering the client to read about the authorization + server's terms of service. The registration process SHOULD + display this URL to the person registering the client if it is + given. As described in Section 5, despite the identifier + "op_tos_uri", appearing to be OpenID-specific, its usage in this + specification is actually referring to a general OAuth 2.0 feature + that is not specific to OpenID Connect. + """ + value = self.get('op_tos_uri') + if value and not is_valid_url(value): + raise ValueError('"op_tos_uri" MUST be a URL') + + def validate_revocation_endpoint(self): + """OPTIONAL. URL of the authorization server's OAuth 2.0 revocation + endpoint [RFC7009].""" + url = self.get('revocation_endpoint') + if url and not is_secure_transport(url): + raise ValueError('"revocation_endpoint" MUST use "https" scheme') + + def validate_revocation_endpoint_auth_methods_supported(self): + """OPTIONAL. JSON array containing a list of client authentication + methods supported by this revocation endpoint. The valid client + authentication method values are those registered in the IANA + "OAuth Token Endpoint Authentication Methods" registry + [IANA.OAuth.Parameters]. If omitted, the default is + "client_secret_basic" -- the HTTP Basic Authentication Scheme + specified in Section 2.3.1 of OAuth 2.0 [RFC6749]. + """ + validate_array_value(self, 'revocation_endpoint_auth_methods_supported') + + def validate_revocation_endpoint_auth_signing_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWS signing + algorithms ("alg" values) supported by the revocation endpoint for + the signature on the JWT [JWT] used to authenticate the client at + the revocation endpoint for the "private_key_jwt" and + "client_secret_jwt" authentication methods. This metadata entry + MUST be present if either of these authentication methods are + specified in the "revocation_endpoint_auth_methods_supported" + entry. No default algorithms are implied if this entry is + omitted. The value "none" MUST NOT be used. + """ + _validate_alg_values( + self, + 'revocation_endpoint_auth_signing_alg_values_supported', + self.revocation_endpoint_auth_methods_supported + ) + + def validate_introspection_endpoint(self): + """OPTIONAL. URL of the authorization server's OAuth 2.0 + introspection endpoint [RFC7662]. + """ + url = self.get('introspection_endpoint') + if url and not is_secure_transport(url): + raise ValueError( + '"introspection_endpoint" MUST use "https" scheme') + + def validate_introspection_endpoint_auth_methods_supported(self): + """OPTIONAL. JSON array containing a list of client authentication + methods supported by this introspection endpoint. The valid + client authentication method values are those registered in the + IANA "OAuth Token Endpoint Authentication Methods" registry + [IANA.OAuth.Parameters] or those registered in the IANA "OAuth + Access Token Types" registry [IANA.OAuth.Parameters]. (These + values are and will remain distinct, due to Section 7.2.) If + omitted, the set of supported authentication methods MUST be + determined by other means. + """ + validate_array_value(self, 'introspection_endpoint_auth_methods_supported') + + def validate_introspection_endpoint_auth_signing_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWS signing + algorithms ("alg" values) supported by the introspection endpoint + for the signature on the JWT [JWT] used to authenticate the client + at the introspection endpoint for the "private_key_jwt" and + "client_secret_jwt" authentication methods. This metadata entry + MUST be present if either of these authentication methods are + specified in the "introspection_endpoint_auth_methods_supported" + entry. No default algorithms are implied if this entry is + omitted. The value "none" MUST NOT be used. + """ + _validate_alg_values( + self, + 'introspection_endpoint_auth_signing_alg_values_supported', + self.introspection_endpoint_auth_methods_supported + ) + + def validate_code_challenge_methods_supported(self): + """OPTIONAL. JSON array containing a list of Proof Key for Code + Exchange (PKCE) [RFC7636] code challenge methods supported by this + authorization server. Code challenge method values are used in + the "code_challenge_method" parameter defined in Section 4.3 of + [RFC7636]. The valid code challenge method values are those + registered in the IANA "PKCE Code Challenge Methods" registry + [IANA.OAuth.Parameters]. If omitted, the authorization server + does not support PKCE. + """ + validate_array_value(self, 'code_challenge_methods_supported') + + @property + def response_modes_supported(self): + #: If omitted, the default is ["query", "fragment"] + return self.get('response_modes_supported', ["query", "fragment"]) + + @property + def grant_types_supported(self): + #: If omitted, the default value is ["authorization_code", "implicit"] + return self.get('grant_types_supported', ["authorization_code", "implicit"]) + + @property + def token_endpoint_auth_methods_supported(self): + #: If omitted, the default is "client_secret_basic" + return self.get('token_endpoint_auth_methods_supported', ["client_secret_basic"]) + + @property + def revocation_endpoint_auth_methods_supported(self): + #: If omitted, the default is "client_secret_basic" + return self.get('revocation_endpoint_auth_methods_supported', ["client_secret_basic"]) + + @property + def introspection_endpoint_auth_methods_supported(self): + #: If omitted, the set of supported authentication methods MUST be + #: determined by other means + #: here, we use "client_secret_basic" + return self.get('introspection_endpoint_auth_methods_supported', ["client_secret_basic"]) + + def validate(self): + """Validate all server metadata value.""" + for key in self.REGISTRY_KEYS: + object.__getattribute__(self, f'validate_{key}')() + + def __getattr__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError as error: + if key in self.REGISTRY_KEYS: + return self.get(key) + raise error + + +def _validate_alg_values(data, key, auth_methods_supported): + value = data.get(key) + if value and not isinstance(value, list): + raise ValueError(f'"{key}" MUST be JSON array') + + auth_methods = set(auth_methods_supported) + jwt_auth_methods = {'private_key_jwt', 'client_secret_jwt'} + if auth_methods & jwt_auth_methods: + if not value: + raise ValueError(f'"{key}" is required') + + if value and 'none' in value: + raise ValueError( + f'the value "none" MUST NOT be used in "{key}"') + + +def validate_array_value(metadata, key): + values = metadata.get(key) + if values is not None and not isinstance(values, list): + raise ValueError(f'"{key}" MUST be JSON array') diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8414/well_known.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/well_known.py new file mode 100644 index 00000000..42d70b3b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8414/well_known.py @@ -0,0 +1,22 @@ +from authlib.common.urls import urlparse + + +def get_well_known_url(issuer, external=False, suffix='oauth-authorization-server'): + """Get well-known URI with issuer via `Section 3.1`_. + + .. _`Section 3.1`: https://tools.ietf.org/html/rfc8414#section-3.1 + + :param issuer: URL of the issuer + :param external: return full external url or not + :param suffix: well-known URI suffix for RFC8414 + :return: URL + """ + parsed = urlparse.urlparse(issuer) + path = parsed.path + if path and path != '/': + url_path = f'/.well-known/{suffix}{path}' + else: + url_path = f'/.well-known/{suffix}' + if not external: + return url_path + return parsed.scheme + '://' + parsed.netloc + url_path diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__init__.py new file mode 100644 index 00000000..6ad59fdf --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__init__.py @@ -0,0 +1,22 @@ +""" + authlib.oauth2.rfc8628 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents an implementation of + OAuth 2.0 Device Authorization Grant. + + https://tools.ietf.org/html/rfc8628 +""" + +from .endpoint import DeviceAuthorizationEndpoint +from .device_code import DeviceCodeGrant, DEVICE_CODE_GRANT_TYPE +from .models import DeviceCredentialMixin, DeviceCredentialDict +from .errors import AuthorizationPendingError, SlowDownError, ExpiredTokenError + + +__all__ = [ + 'DeviceAuthorizationEndpoint', + 'DeviceCodeGrant', 'DEVICE_CODE_GRANT_TYPE', + 'DeviceCredentialMixin', 'DeviceCredentialDict', + 'AuthorizationPendingError', 'SlowDownError', 'ExpiredTokenError', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..52a278ee Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/device_code.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/device_code.cpython-311.pyc new file mode 100644 index 00000000..4b50a809 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/device_code.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/endpoint.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/endpoint.cpython-311.pyc new file mode 100644 index 00000000..c23a35c1 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/endpoint.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..a9976abc Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..77c56c9b Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/device_code.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/device_code.py new file mode 100644 index 00000000..68209170 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/device_code.py @@ -0,0 +1,183 @@ +import logging +from ..rfc6749.errors import ( + InvalidRequestError, + UnauthorizedClientError, + AccessDeniedError, +) +from ..rfc6749 import BaseGrant, TokenEndpointMixin +from .errors import ( + AuthorizationPendingError, + ExpiredTokenError, + SlowDownError, +) + +log = logging.getLogger(__name__) +DEVICE_CODE_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + + +class DeviceCodeGrant(BaseGrant, TokenEndpointMixin): + """This OAuth 2.0 [RFC6749] protocol extension enables OAuth clients to + request user authorization from applications on devices that have + limited input capabilities or lack a suitable browser. Such devices + include smart TVs, media consoles, picture frames, and printers, + which lack an easy input method or a suitable browser required for + traditional OAuth interactions. Here is the authorization flow:: + + +----------+ +----------------+ + | |>---(A)-- Client Identifier --->| | + | | | | + | |<---(B)-- Device Code, ---<| | + | | User Code, | | + | Device | & Verification URI | | + | Client | | | + | | [polling] | | + | |>---(E)-- Device Code --->| | + | | & Client Identifier | | + | | | Authorization | + | |<---(F)-- Access Token ---<| Server | + +----------+ (& Optional Refresh Token) | | + v | | + : | | + (C) User Code & Verification URI | | + : | | + v | | + +----------+ | | + | End User | | | + | at |<---(D)-- End user reviews --->| | + | Browser | authorization request | | + +----------+ +----------------+ + + This DeviceCodeGrant is the implementation of step (E) and (F). + + (E) While the end user reviews the client's request (step D), the + client repeatedly polls the authorization server to find out if + the user completed the user authorization step. The client + includes the device code and its client identifier. + + (F) The authorization server validates the device code provided by + the client and responds with the access token if the client is + granted access, an error if they are denied access, or an + indication that the client should continue to poll. + """ + GRANT_TYPE = DEVICE_CODE_GRANT_TYPE + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def validate_token_request(self): + """After displaying instructions to the user, the client creates an + access token request and sends it to the token endpoint with the + following parameters: + + grant_type + REQUIRED. Value MUST be set to + "urn:ietf:params:oauth:grant-type:device_code". + + device_code + REQUIRED. The device verification code, "device_code" from the + device authorization response. + + client_id + REQUIRED if the client is not authenticating with the + authorization server as described in Section 3.2.1. of [RFC6749]. + The client identifier as described in Section 2.2 of [RFC6749]. + + For example, the client makes the following HTTPS request:: + + POST /token HTTP/1.1 + Host: server.example.com + Content-Type: application/x-www-form-urlencoded + + grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code + &device_code=GmRhmhcxhwAzkoEqiMEg_DnyEysNkuNhszIySk9eS + &client_id=1406020730 + """ + device_code = self.request.data.get('device_code') + if not device_code: + raise InvalidRequestError('Missing "device_code" in payload') + + client = self.authenticate_token_endpoint_client() + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() + + credential = self.query_device_credential(device_code) + if not credential: + raise InvalidRequestError('Invalid "device_code" in payload') + + if credential.get_client_id() != client.get_client_id(): + raise UnauthorizedClientError() + + user = self.validate_device_credential(credential) + self.request.user = user + self.request.client = client + self.request.credential = credential + + def create_token_response(self): + """If the access token request is valid and authorized, the + authorization server issues an access token and optional refresh + token. + """ + client = self.request.client + scope = self.request.credential.get_scope() + token = self.generate_token( + user=self.request.user, + scope=scope, + include_refresh_token=client.check_grant_type('refresh_token'), + ) + log.debug('Issue token %r to %r', token, client) + self.save_token(token) + self.execute_hook('process_token', token=token) + return 200, token, self.TOKEN_RESPONSE_HEADER + + def validate_device_credential(self, credential): + if credential.is_expired(): + raise ExpiredTokenError() + + user_code = credential.get_user_code() + user_grant = self.query_user_grant(user_code) + + if user_grant is not None: + user, approved = user_grant + if not approved: + raise AccessDeniedError() + return user + + if self.should_slow_down(credential): + raise SlowDownError() + + raise AuthorizationPendingError() + + def query_device_credential(self, device_code): + """Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. + Developers MUST implement it in subclass:: + + def query_device_credential(self, device_code): + return DeviceCredential.get(device_code) + + :param device_code: a string represent the code. + :return: DeviceCredential instance + """ + raise NotImplementedError() + + def query_user_grant(self, user_code): + """Get user and grant via the given user code. Developers MUST + implement it in subclass:: + + def query_user_grant(self, user_code): + # e.g. we saved user grant info in redis + data = redis.get('oauth_user_grant:' + user_code) + if not data: + return None + + user_id, allowed = data.split() + user = User.get(user_id) + return user, bool(allowed) + + Note, user grant information is saved by verification endpoint. + """ + raise NotImplementedError() + + def should_slow_down(self, credential): + """The authorization request is still pending and polling should + continue, but the interval MUST be increased by 5 seconds for this + and all subsequent requests. + """ + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/endpoint.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/endpoint.py new file mode 100644 index 00000000..49221f09 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/endpoint.py @@ -0,0 +1,170 @@ +from authlib.consts import default_json_headers +from authlib.common.security import generate_token +from authlib.common.urls import add_params_to_uri + + +class DeviceAuthorizationEndpoint: + """This OAuth 2.0 [RFC6749] protocol extension enables OAuth clients to + request user authorization from applications on devices that have + limited input capabilities or lack a suitable browser. Such devices + include smart TVs, media consoles, picture frames, and printers, + which lack an easy input method or a suitable browser required for + traditional OAuth interactions. Here is the authorization flow:: + + +----------+ +----------------+ + | |>---(A)-- Client Identifier --->| | + | | | | + | |<---(B)-- Device Code, ---<| | + | | User Code, | | + | Device | & Verification URI | | + | Client | | | + | | [polling] | | + | |>---(E)-- Device Code --->| | + | | & Client Identifier | | + | | | Authorization | + | |<---(F)-- Access Token ---<| Server | + +----------+ (& Optional Refresh Token) | | + v | | + : | | + (C) User Code & Verification URI | | + : | | + v | | + +----------+ | | + | End User | | | + | at |<---(D)-- End user reviews --->| | + | Browser | authorization request | | + +----------+ +----------------+ + + This DeviceAuthorizationEndpoint is the implementation of step (A) and (B). + + (A) The client requests access from the authorization server and + includes its client identifier in the request. + + (B) The authorization server issues a device code and an end-user + code and provides the end-user verification URI. + """ + + ENDPOINT_NAME = 'device_authorization' + CLIENT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + #: customize "user_code" type, string or digital + USER_CODE_TYPE = 'string' + + #: The lifetime in seconds of the "device_code" and "user_code" + EXPIRES_IN = 1800 + + #: The minimum amount of time in seconds that the client SHOULD + #: wait between polling requests to the token endpoint. + INTERVAL = 5 + + def __init__(self, server): + self.server = server + + def __call__(self, request): + # make it callable for authorization server + # ``create_endpoint_response`` + return self.create_endpoint_response(request) + + def create_endpoint_request(self, request): + return self.server.create_oauth2_request(request) + + def authenticate_client(self, request): + """client_id is REQUIRED **if the client is not** authenticating with the + authorization server as described in Section 3.2.1. of [RFC6749]. + + This means the endpoint support "none" authentication method. In this case, + this endpoint's auth methods are: + + - client_secret_basic + - client_secret_post + - none + + Developers change the value of ``CLIENT_AUTH_METHODS`` in subclass. For + instance:: + + class MyDeviceAuthorizationEndpoint(DeviceAuthorizationEndpoint): + # only support ``client_secret_basic`` auth method + CLIENT_AUTH_METHODS = ['client_secret_basic'] + """ + client = self.server.authenticate_client( + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME) + request.client = client + return client + + def create_endpoint_response(self, request): + # https://tools.ietf.org/html/rfc8628#section-3.1 + + self.authenticate_client(request) + self.server.validate_requested_scope(request.scope) + + device_code = self.generate_device_code() + user_code = self.generate_user_code() + verification_uri = self.get_verification_uri() + verification_uri_complete = add_params_to_uri( + verification_uri, [('user_code', user_code)]) + + data = { + 'device_code': device_code, + 'user_code': user_code, + 'verification_uri': verification_uri, + 'verification_uri_complete': verification_uri_complete, + 'expires_in': self.EXPIRES_IN, + 'interval': self.INTERVAL, + } + + self.save_device_credential(request.client_id, request.scope, data) + return 200, data, default_json_headers + + def generate_user_code(self): + """A method to generate ``user_code`` value for device authorization + endpoint. This method will generate a random string like MQNA-JPOZ. + Developers can rewrite this method to create their own ``user_code``. + """ + # https://tools.ietf.org/html/rfc8628#section-6.1 + if self.USER_CODE_TYPE == 'digital': + return create_digital_user_code() + return create_string_user_code() + + def generate_device_code(self): + """A method to generate ``device_code`` value for device authorization + endpoint. This method will generate a random string of 42 characters. + Developers can rewrite this method to create their own ``device_code``. + """ + return generate_token(42) + + def get_verification_uri(self): + """Define the ``verification_uri`` of device authorization endpoint. + Developers MUST implement this method in subclass:: + + def get_verification_uri(self): + return 'https://your-company.com/active' + """ + raise NotImplementedError() + + def save_device_credential(self, client_id, scope, data): + """Save device token into database for later use. Developers MUST + implement this method in subclass:: + + def save_device_credential(self, client_id, scope, data): + item = DeviceCredential( + client_id=client_id, + scope=scope, + **data + ) + item.save() + """ + raise NotImplementedError() + + +def create_string_user_code(): + base = 'BCDFGHJKLMNPQRSTVWXZ' + return '-'.join([generate_token(4, base), generate_token(4, base)]) + + +def create_digital_user_code(): + base = '0123456789' + return '-'.join([ + generate_token(3, base), + generate_token(3, base), + generate_token(3, base), + ]) diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/errors.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/errors.py new file mode 100644 index 00000000..4a63db82 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/errors.py @@ -0,0 +1,27 @@ +from ..rfc6749.errors import OAuth2Error + +# https://tools.ietf.org/html/rfc8628#section-3.5 + + +class AuthorizationPendingError(OAuth2Error): + """The authorization request is still pending as the end user hasn't + yet completed the user-interaction steps (Section 3.3). + """ + error = 'authorization_pending' + + +class SlowDownError(OAuth2Error): + """A variant of "authorization_pending", the authorization request is + still pending and polling should continue, but the interval MUST + be increased by 5 seconds for this and all subsequent requests. + """ + error = 'slow_down' + + +class ExpiredTokenError(OAuth2Error): + """The "device_code" has expired, and the device authorization + session has concluded. The client MAY commence a new device + authorization request but SHOULD wait for user interaction before + restarting to avoid unnecessary polling. + """ + error = 'expired_token' diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8628/models.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/models.py new file mode 100644 index 00000000..39eb9a13 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8628/models.py @@ -0,0 +1,38 @@ +import time + + +class DeviceCredentialMixin: + def get_client_id(self): + raise NotImplementedError() + + def get_scope(self): + raise NotImplementedError() + + def get_user_code(self): + raise NotImplementedError() + + def is_expired(self): + raise NotImplementedError() + + +class DeviceCredentialDict(dict, DeviceCredentialMixin): + def get_client_id(self): + return self['client_id'] + + def get_scope(self): + return self.get('scope') + + def get_user_code(self): + return self['user_code'] + + def get_nonce(self): + return self.get('nonce') + + def get_auth_time(self): + return self.get('auth_time') + + def is_expired(self): + expires_at = self.get('expires_at') + if expires_at: + return expires_at < time.time() + return False diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__init__.py new file mode 100644 index 00000000..1a74f856 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__init__.py @@ -0,0 +1,9 @@ +""" + authlib.oauth2.rfc8693 + ~~~~~~~~~~~~~~~~~~~~~~ + + This module represents an implementation of + OAuth 2.0 Token Exchange. + + https://tools.ietf.org/html/rfc8693 +""" diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..94dcdbd9 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc8693/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__init__.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__init__.py new file mode 100644 index 00000000..b914509a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__init__.py @@ -0,0 +1,11 @@ +from .introspection import JWTIntrospectionEndpoint +from .revocation import JWTRevocationEndpoint +from .token import JWTBearerTokenGenerator +from .token_validator import JWTBearerTokenValidator + +__all__ = [ + 'JWTBearerTokenGenerator', + 'JWTBearerTokenValidator', + 'JWTIntrospectionEndpoint', + 'JWTRevocationEndpoint', +] diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..296599ac Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/claims.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/claims.cpython-311.pyc new file mode 100644 index 00000000..778e7ab7 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/claims.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/introspection.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/introspection.cpython-311.pyc new file mode 100644 index 00000000..7635c856 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/introspection.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/revocation.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/revocation.cpython-311.pyc new file mode 100644 index 00000000..e9bd350c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/revocation.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token.cpython-311.pyc new file mode 100644 index 00000000..54c38c2c Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token_validator.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token_validator.cpython-311.pyc new file mode 100644 index 00000000..e666ee26 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/__pycache__/token_validator.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/claims.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/claims.py new file mode 100644 index 00000000..4dcfea8e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/claims.py @@ -0,0 +1,62 @@ +from authlib.jose.errors import InvalidClaimError +from authlib.jose.rfc7519 import JWTClaims + + +class JWTAccessTokenClaims(JWTClaims): + REGISTERED_CLAIMS = JWTClaims.REGISTERED_CLAIMS + [ + 'client_id', + 'auth_time', + 'acr', + 'amr', + 'scope', + 'groups', + 'roles', + 'entitlements', + ] + + def validate(self, **kwargs): + self.validate_typ() + + super().validate(**kwargs) + self.validate_client_id() + self.validate_auth_time() + self.validate_acr() + self.validate_amr() + self.validate_scope() + self.validate_groups() + self.validate_roles() + self.validate_entitlements() + + def validate_typ(self): + # The resource server MUST verify that the 'typ' header value is 'at+jwt' + # or 'application/at+jwt' and reject tokens carrying any other value. + if self.header['typ'].lower() not in ('at+jwt', 'application/at+jwt'): + raise InvalidClaimError('typ') + + def validate_client_id(self): + return self._validate_claim_value('client_id') + + def validate_auth_time(self): + auth_time = self.get('auth_time') + if auth_time and not isinstance(auth_time, (int, float)): + raise InvalidClaimError('auth_time') + + def validate_acr(self): + return self._validate_claim_value('acr') + + def validate_amr(self): + amr = self.get('amr') + if amr and not isinstance(self['amr'], list): + raise InvalidClaimError('amr') + + def validate_scope(self): + return self._validate_claim_value('scope') + + def validate_groups(self): + return self._validate_claim_value('groups') + + def validate_roles(self): + return self._validate_claim_value('roles') + + def validate_entitlements(self): + return self._validate_claim_value('entitlements') diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/introspection.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/introspection.py new file mode 100644 index 00000000..17b5eb5a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/introspection.py @@ -0,0 +1,126 @@ +from ..rfc7662 import IntrospectionEndpoint +from authlib.common.errors import ContinueIteration +from authlib.consts import default_json_headers +from authlib.jose.errors import ExpiredTokenError +from authlib.jose.errors import InvalidClaimError +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc9068.token_validator import JWTBearerTokenValidator + + +class JWTIntrospectionEndpoint(IntrospectionEndpoint): + ''' + JWTIntrospectionEndpoint inherits from :ref:`specs/rfc7662` + :class:`~authlib.oauth2.rfc7662.IntrospectionEndpoint` and implements the machinery + to automatically process the JWT access tokens. + + :param issuer: The issuer identifier for which tokens will be introspected. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc7662.introspection.IntrospectionEndpoint`. + + :: + + class MyJWTAccessTokenIntrospectionEndpoint(JWTRevocationEndpoint): + def get_jwks(self): + ... + + def get_username(self, user_id): + ... + + authorization_server.register_endpoint( + MyJWTAccessTokenIntrospectionEndpoint( + issuer="https://authorization-server.example.org", + ) + ) + authorization_server.register_endpoint(MyRefreshTokenIntrospectionEndpoint) + + ''' + + #: Endpoint name to be registered + ENDPOINT_NAME = 'introspection' + + def __init__(self, issuer, server=None, *args, **kwargs): + super().__init__(*args, server=server, **kwargs) + self.issuer = issuer + + def create_endpoint_response(self, request): + '''''' + # The authorization server first validates the client credentials + client = self.authenticate_endpoint_client(request) + + # then verifies whether the token was issued to the client making + # the revocation request + token = self.authenticate_token(request, client) + + # the authorization server invalidates the token + body = self.create_introspection_payload(token) + return 200, body, default_json_headers + + def authenticate_token(self, request, client): + '''''' + self.check_params(request, client) + + # do not attempt to decode refresh_tokens + if request.form.get('token_type_hint') not in ('access_token', None): + raise ContinueIteration() + + validator = JWTBearerTokenValidator(issuer=self.issuer, resource_server=None) + validator.get_jwks = self.get_jwks + try: + token = validator.authenticate_token(request.form['token']) + + # if the token is not a JWT, fall back to the regular flow + except InvalidTokenError: + raise ContinueIteration() + + if token and self.check_permission(token, client, request): + return token + + def create_introspection_payload(self, token): + if not token: + return {'active': False} + + try: + token.validate() + except ExpiredTokenError: + return {'active': False} + except InvalidClaimError as exc: + if exc.claim_name == 'iss': + raise ContinueIteration() + raise InvalidTokenError() + + + payload = { + 'active': True, + 'token_type': 'Bearer', + 'client_id': token['client_id'], + 'scope': token['scope'], + 'sub': token['sub'], + 'aud': token['aud'], + 'iss': token['iss'], + 'exp': token['exp'], + 'iat': token['iat'], + } + + if username := self.get_username(token['sub']): + payload['username'] = username + + return payload + + def get_jwks(self): + '''Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + ''' + raise NotImplementedError() + + def get_username(self, user_id: str) -> str: + '''Returns an username from a user ID. + Developers MAY re-implement this method:: + + def get_username(self, user_id): + return User.get(id=user_id).username + ''' + return None diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/revocation.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/revocation.py new file mode 100644 index 00000000..9453c79a --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/revocation.py @@ -0,0 +1,70 @@ +from ..rfc6749 import UnsupportedTokenTypeError +from ..rfc7009 import RevocationEndpoint +from authlib.common.errors import ContinueIteration +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc9068.token_validator import JWTBearerTokenValidator + + +class JWTRevocationEndpoint(RevocationEndpoint): + '''JWTRevocationEndpoint inherits from `RFC7009`_ + :class:`~authlib.oauth2.rfc7009.RevocationEndpoint`. + + The JWT access tokens cannot be revoked. + If the submitted token is a JWT access token, then revocation returns + a `invalid_token_error`. + + :param issuer: The issuer identifier. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc7009.RevocationEndpoint`. + + Plain text access tokens and other kind of tokens such as refresh_tokens + will be ignored by this endpoint and passed to the next revocation endpoint:: + + class MyJWTAccessTokenRevocationEndpoint(JWTRevocationEndpoint): + def get_jwks(self): + ... + + authorization_server.register_endpoint( + MyJWTAccessTokenRevocationEndpoint( + issuer="https://authorization-server.example.org", + ) + ) + authorization_server.register_endpoint(MyRefreshTokenRevocationEndpoint) + + .. _RFC7009: https://tools.ietf.org/html/rfc7009 + ''' + + def __init__(self, issuer, server=None, *args, **kwargs): + super().__init__(*args, server=server, **kwargs) + self.issuer = issuer + + def authenticate_token(self, request, client): + '''''' + self.check_params(request, client) + + # do not attempt to revoke refresh_tokens + if request.form.get('token_type_hint') not in ('access_token', None): + raise ContinueIteration() + + validator = JWTBearerTokenValidator(issuer=self.issuer, resource_server=None) + validator.get_jwks = self.get_jwks + + try: + validator.authenticate_token(request.form['token']) + + # if the token is not a JWT, fall back to the regular flow + except InvalidTokenError: + raise ContinueIteration() + + # JWT access token cannot be revoked + raise UnsupportedTokenTypeError() + + def get_jwks(self): + '''Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + ''' + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token.py new file mode 100644 index 00000000..6751b88e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token.py @@ -0,0 +1,218 @@ +import time +from typing import List +from typing import Optional +from typing import Union + +from authlib.common.security import generate_token +from authlib.jose import jwt +from authlib.oauth2.rfc6750.token import BearerTokenGenerator + + +class JWTBearerTokenGenerator(BearerTokenGenerator): + '''A JWT formatted access token generator. + + :param issuer: The issuer identifier. Will appear in the JWT ``iss`` claim. + + :param \\*\\*kwargs: Other parameters are inherited from + :class:`~authlib.oauth2.rfc6750.token.BearerTokenGenerator`. + + This token generator can be registered into the authorization server:: + + class MyJWTBearerTokenGenerator(JWTBearerTokenGenerator): + def get_jwks(self): + ... + + def get_extra_claims(self, client, grant_type, user, scope): + ... + + authorization_server.register_token_generator( + 'default', + MyJWTBearerTokenGenerator(issuer='https://authorization-server.example.org'), + ) + ''' + + def __init__( + self, + issuer, + alg='RS256', + refresh_token_generator=None, + expires_generator=None, + ): + super().__init__( + self.access_token_generator, refresh_token_generator, expires_generator + ) + self.issuer = issuer + self.alg = alg + + def get_jwks(self): + '''Return the JWKs that will be used to sign the JWT access token. + Developers MUST re-implement this method:: + + def get_jwks(self): + return load_jwks("jwks.json") + ''' + raise NotImplementedError() + + def get_extra_claims(self, client, grant_type, user, scope): + '''Return extra claims to add in the JWT access token. Developers MAY + re-implement this method to add identity claims like the ones in + :ref:`specs/oidc` ID Token, or any other arbitrary claims:: + + def get_extra_claims(self, client, grant_type, user, scope): + return generate_user_info(user, scope) + ''' + return {} + + def get_audiences(self, client, user, scope) -> Union[str, List[str]]: + '''Return the audience for the token. By default this simply returns + the client ID. Developpers MAY re-implement this method to add extra + audiences:: + + def get_audiences(self, client, user, scope): + return [ + client.get_client_id(), + resource_server.get_id(), + ] + ''' + return client.get_client_id() + + def get_acr(self, user) -> Optional[str]: + '''Authentication Context Class Reference. + Returns a user-defined case sensitive string indicating the class of + authentication the used performed. Token audience may refuse to give access to + some resources if some ACR criterias are not met. + :ref:`specs/oidc` defines one special value: ``0`` means that the user + authentication did not respect `ISO29115`_ level 1, and will be refused monetary + operations. Developers MAY re-implement this method:: + + def get_acr(self, user): + if user.insecure_session(): + return '0' + return 'urn:mace:incommon:iap:silver' + + .. _ISO29115: https://www.iso.org/standard/45138.html + ''' + return None + + def get_auth_time(self, user) -> Optional[int]: + '''User authentication time. + Time when the End-User authentication occurred. Its value is a JSON number + representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC + until the date/time. Developers MAY re-implement this method:: + + def get_auth_time(self, user): + return datetime.timestamp(user.get_auth_time()) + ''' + return None + + def get_amr(self, user) -> Optional[List[str]]: + '''Authentication Methods References. + Defined by :ref:`specs/oidc` as an option list of user-defined case-sensitive + strings indication which authentication methods have been used to authenticate + the user. Developers MAY re-implement this method:: + + def get_amr(self, user): + return ['2FA'] if user.has_2fa_enabled() else [] + ''' + return None + + def get_jti(self, client, grant_type, user, scope) -> str: + '''JWT ID. + Create an unique identifier for the token. Developers MAY re-implement + this method:: + + def get_jti(self, client, grant_type, user scope): + return generate_random_string(16) + ''' + return generate_token(16) + + def access_token_generator(self, client, grant_type, user, scope): + now = int(time.time()) + expires_in = now + self._get_expires_in(client, grant_type) + + token_data = { + 'iss': self.issuer, + 'exp': expires_in, + 'client_id': client.get_client_id(), + 'iat': now, + 'jti': self.get_jti(client, grant_type, user, scope), + 'scope': scope, + } + + # In cases of access tokens obtained through grants where a resource owner is + # involved, such as the authorization code grant, the value of 'sub' SHOULD + # correspond to the subject identifier of the resource owner. + + if user: + token_data['sub'] = user.get_user_id() + + # In cases of access tokens obtained through grants where no resource owner is + # involved, such as the client credentials grant, the value of 'sub' SHOULD + # correspond to an identifier the authorization server uses to indicate the + # client application. + + else: + token_data['sub'] = client.get_client_id() + + # If the request includes a 'resource' parameter (as defined in [RFC8707]), the + # resulting JWT access token 'aud' claim SHOULD have the same value as the + # 'resource' parameter in the request. + + # TODO: Implement this with RFC8707 + if False: # pragma: no cover + ... + + # If the request does not include a 'resource' parameter, the authorization + # server MUST use a default resource indicator in the 'aud' claim. If a 'scope' + # parameter is present in the request, the authorization server SHOULD use it to + # infer the value of the default resource indicator to be used in the 'aud' + # claim. The mechanism through which scopes are associated with default resource + # indicator values is outside the scope of this specification. + + else: + token_data['aud'] = self.get_audiences(client, user, scope) + + # If the values in the 'scope' parameter refer to different default resource + # indicator values, the authorization server SHOULD reject the request with + # 'invalid_scope' as described in Section 4.1.2.1 of [RFC6749]. + # TODO: Implement this with RFC8707 + + if auth_time := self.get_auth_time(user): + token_data['auth_time'] = auth_time + + # The meaning and processing of acr Claim Values is out of scope for this + # specification. + + if acr := self.get_acr(user): + token_data['acr'] = acr + + # The definition of particular values to be used in the amr Claim is beyond the + # scope of this specification. + + if amr := self.get_amr(user): + token_data['amr'] = amr + + # Authorization servers MAY return arbitrary attributes not defined in any + # existing specification, as long as the corresponding claim names are collision + # resistant or the access tokens are meant to be used only within a private + # subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details. + + token_data.update(self.get_extra_claims(client, grant_type, user, scope)) + + # This specification registers the 'application/at+jwt' media type, which can + # be used to indicate that the content is a JWT access token. JWT access tokens + # MUST include this media type in the 'typ' header parameter to explicitly + # declare that the JWT represents an access token complying with this profile. + # Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED + # that the 'application/' prefix be omitted. Therefore, the 'typ' value used + # SHOULD be 'at+jwt'. + + header = {'alg': self.alg, 'typ': 'at+jwt'} + + access_token = jwt.encode( + header, + token_data, + key=self.get_jwks(), + check=False, + ) + return access_token.decode() diff --git a/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token_validator.py b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token_validator.py new file mode 100644 index 00000000..dc152e28 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oauth2/rfc9068/token_validator.py @@ -0,0 +1,163 @@ +''' + authlib.oauth2.rfc9068.token_validator + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Implementation of Validating JWT Access Tokens per `Section 4`_. + + .. _`Section 7`: https://www.rfc-editor.org/rfc/rfc9068.html#name-validating-jwt-access-token +''' +from authlib.jose import jwt +from authlib.jose.errors import DecodeError +from authlib.jose.errors import JoseError +from authlib.oauth2.rfc6750.errors import InsufficientScopeError +from authlib.oauth2.rfc6750.errors import InvalidTokenError +from authlib.oauth2.rfc6750.validator import BearerTokenValidator +from .claims import JWTAccessTokenClaims + + +class JWTBearerTokenValidator(BearerTokenValidator): + '''JWTBearerTokenValidator can protect your resource server endpoints. + + :param issuer: The issuer from which tokens will be accepted. + :param resource_server: An identifier for the current resource server, + which must appear in the JWT ``aud`` claim. + + Developers needs to implement the missing methods:: + + class MyJWTBearerTokenValidator(JWTBearerTokenValidator): + def get_jwks(self): + ... + + require_oauth = ResourceProtector() + require_oauth.register_token_validator( + MyJWTBearerTokenValidator( + issuer='https://authorization-server.example.org', + resource_server='https://resource-server.example.org', + ) + ) + + You can then protect resources depending on the JWT `scope`, `groups`, + `roles` or `entitlements` claims:: + + @require_oauth( + scope='profile', + groups='admins', + roles='student', + entitlements='captain', + ) + def resource_endpoint(): + ... + ''' + + def __init__(self, issuer, resource_server, *args, **kwargs): + self.issuer = issuer + self.resource_server = resource_server + super().__init__(*args, **kwargs) + + def get_jwks(self): + '''Return the JWKs that will be used to check the JWT access token signature. + Developers MUST re-implement this method. Typically the JWKs are statically + stored in the resource server configuration, or dynamically downloaded and + cached using :ref:`specs/rfc8414`:: + + def get_jwks(self): + if 'jwks' in cache: + return cache.get('jwks') + + server_metadata = get_server_metadata(self.issuer) + jwks_uri = server_metadata.get('jwks_uri') + cache['jwks'] = requests.get(jwks_uri).json() + return cache['jwks'] + ''' + raise NotImplementedError() + + def validate_iss(self, claims, iss: 'str') -> bool: + # The issuer identifier for the authorization server (which is typically + # obtained during discovery) MUST exactly match the value of the 'iss' + # claim. + return iss == self.issuer + + def authenticate_token(self, token_string): + '''''' + # empty docstring avoids to display the irrelevant parent docstring + + claims_options = { + 'iss': {'essential': True, 'validate': self.validate_iss}, + 'exp': {'essential': True}, + 'aud': {'essential': True, 'value': self.resource_server}, + 'sub': {'essential': True}, + 'client_id': {'essential': True}, + 'iat': {'essential': True}, + 'jti': {'essential': True}, + 'auth_time': {'essential': False}, + 'acr': {'essential': False}, + 'amr': {'essential': False}, + 'scope': {'essential': False}, + 'groups': {'essential': False}, + 'roles': {'essential': False}, + 'entitlements': {'essential': False}, + } + jwks = self.get_jwks() + + # If the JWT access token is encrypted, decrypt it using the keys and algorithms + # that the resource server specified during registration. If encryption was + # negotiated with the authorization server at registration time and the incoming + # JWT access token is not encrypted, the resource server SHOULD reject it. + + # The resource server MUST validate the signature of all incoming JWT access + # tokens according to [RFC7515] using the algorithm specified in the JWT 'alg' + # Header Parameter. The resource server MUST reject any JWT in which the value + # of 'alg' is 'none'. The resource server MUST use the keys provided by the + # authorization server. + try: + return jwt.decode( + token_string, + key=jwks, + claims_cls=JWTAccessTokenClaims, + claims_options=claims_options, + ) + except DecodeError: + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) + + def validate_token( + self, token, scopes, request, groups=None, roles=None, entitlements=None + ): + '''''' + # empty docstring avoids to display the irrelevant parent docstring + try: + token.validate() + except JoseError as exc: + raise InvalidTokenError( + realm=self.realm, extra_attributes=self.extra_attributes + ) from exc + + # If an authorization request includes a scope parameter, the corresponding + # issued JWT access token SHOULD include a 'scope' claim as defined in Section + # 4.2 of [RFC8693]. All the individual scope strings in the 'scope' claim MUST + # have meaning for the resources indicated in the 'aud' claim. See Section 5 for + # more considerations about the relationship between scope strings and resources + # indicated by the 'aud' claim. + + if self.scope_insufficient(token.get('scope', []), scopes): + raise InsufficientScopeError() + + # Many authorization servers embed authorization attributes that go beyond the + # delegated scenarios described by [RFC7519] in the access tokens they issue. + # Typical examples include resource owner memberships in roles and groups that + # are relevant to the resource being accessed, entitlements assigned to the + # resource owner for the targeted resource that the authorization server knows + # about, and so on. An authorization server wanting to include such attributes + # in a JWT access token SHOULD use the 'groups', 'roles', and 'entitlements' + # attributes of the 'User' resource schema defined by Section 4.1.2 of + # [RFC7643]) as claim types. + + if self.scope_insufficient(token.get('groups'), groups): + raise InvalidTokenError() + + if self.scope_insufficient(token.get('roles'), roles): + raise InvalidTokenError() + + if self.scope_insufficient(token.get('entitlements'), entitlements): + raise InvalidTokenError() diff --git a/.venv/Lib/site-packages/authlib/oidc/__init__.py b/.venv/Lib/site-packages/authlib/oidc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/authlib/oidc/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..bc3fa9ee Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__init__.py b/.venv/Lib/site-packages/authlib/oidc/core/__init__.py new file mode 100644 index 00000000..212ebc03 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/__init__.py @@ -0,0 +1,23 @@ +""" + authlib.oidc.core + ~~~~~~~~~~~~~~~~~ + + OpenID Connect Core 1.0 Implementation. + + http://openid.net/specs/openid-connect-core-1_0.html +""" + +from .models import AuthorizationCodeMixin +from .claims import ( + IDToken, CodeIDToken, ImplicitIDToken, HybridIDToken, + UserInfo, get_claim_cls_by_response_type, +) +from .grants import OpenIDToken, OpenIDCode, OpenIDHybridGrant, OpenIDImplicitGrant + + +__all__ = [ + 'AuthorizationCodeMixin', + 'IDToken', 'CodeIDToken', 'ImplicitIDToken', 'HybridIDToken', + 'UserInfo', 'get_claim_cls_by_response_type', + 'OpenIDToken', 'OpenIDCode', 'OpenIDHybridGrant', 'OpenIDImplicitGrant', +] diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6bea7d75 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/claims.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/claims.cpython-311.pyc new file mode 100644 index 00000000..95e8a169 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/claims.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/errors.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/errors.cpython-311.pyc new file mode 100644 index 00000000..5cb176ca Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/errors.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..39748e60 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..0bdd3370 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/claims.py b/.venv/Lib/site-packages/authlib/oidc/core/claims.py new file mode 100644 index 00000000..f8674585 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/claims.py @@ -0,0 +1,242 @@ +import time +import hmac +from authlib.common.encoding import to_bytes +from authlib.jose import JWTClaims +from authlib.jose.errors import ( + MissingClaimError, + InvalidClaimError, +) +from .util import create_half_hash + +__all__ = [ + 'IDToken', 'CodeIDToken', 'ImplicitIDToken', 'HybridIDToken', + 'UserInfo', 'get_claim_cls_by_response_type' +] + +_REGISTERED_CLAIMS = [ + 'iss', 'sub', 'aud', 'exp', 'nbf', 'iat', + 'auth_time', 'nonce', 'acr', 'amr', 'azp', + 'at_hash', +] + + +class IDToken(JWTClaims): + ESSENTIAL_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'iat'] + + def validate(self, now=None, leeway=0): + for k in self.ESSENTIAL_CLAIMS: + if k not in self: + raise MissingClaimError(k) + + self._validate_essential_claims() + if now is None: + now = int(time.time()) + + self.validate_iss() + self.validate_sub() + self.validate_aud() + self.validate_exp(now, leeway) + self.validate_nbf(now, leeway) + self.validate_iat(now, leeway) + self.validate_auth_time() + self.validate_nonce() + self.validate_acr() + self.validate_amr() + self.validate_azp() + self.validate_at_hash() + + def validate_auth_time(self): + """Time when the End-User authentication occurred. Its value is a JSON + number representing the number of seconds from 1970-01-01T0:0:0Z as + measured in UTC until the date/time. When a max_age request is made or + when auth_time is requested as an Essential Claim, then this Claim is + REQUIRED; otherwise, its inclusion is OPTIONAL. + """ + auth_time = self.get('auth_time') + if self.params.get('max_age') and not auth_time: + raise MissingClaimError('auth_time') + + if auth_time and not isinstance(auth_time, (int, float)): + raise InvalidClaimError('auth_time') + + def validate_nonce(self): + """String value used to associate a Client session with an ID Token, + and to mitigate replay attacks. The value is passed through unmodified + from the Authentication Request to the ID Token. If present in the ID + Token, Clients MUST verify that the nonce Claim Value is equal to the + value of the nonce parameter sent in the Authentication Request. If + present in the Authentication Request, Authorization Servers MUST + include a nonce Claim in the ID Token with the Claim Value being the + nonce value sent in the Authentication Request. Authorization Servers + SHOULD perform no other processing on nonce values used. The nonce + value is a case sensitive string. + """ + nonce_value = self.params.get('nonce') + if nonce_value: + if 'nonce' not in self: + raise MissingClaimError('nonce') + if nonce_value != self['nonce']: + raise InvalidClaimError('nonce') + + def validate_acr(self): + """OPTIONAL. Authentication Context Class Reference. String specifying + an Authentication Context Class Reference value that identifies the + Authentication Context Class that the authentication performed + satisfied. The value "0" indicates the End-User authentication did not + meet the requirements of `ISO/IEC 29115`_ level 1. Authentication + using a long-lived browser cookie, for instance, is one example where + the use of "level 0" is appropriate. Authentications with level 0 + SHOULD NOT be used to authorize access to any resource of any monetary + value. An absolute URI or an `RFC 6711`_ registered name SHOULD be + used as the acr value; registered names MUST NOT be used with a + different meaning than that which is registered. Parties using this + claim will need to agree upon the meanings of the values used, which + may be context-specific. The acr value is a case sensitive string. + + .. _`ISO/IEC 29115`: https://www.iso.org/standard/45138.html + .. _`RFC 6711`: https://tools.ietf.org/html/rfc6711 + """ + return self._validate_claim_value('acr') + + def validate_amr(self): + """OPTIONAL. Authentication Methods References. JSON array of strings + that are identifiers for authentication methods used in the + authentication. For instance, values might indicate that both password + and OTP authentication methods were used. The definition of particular + values to be used in the amr Claim is beyond the scope of this + specification. Parties using this claim will need to agree upon the + meanings of the values used, which may be context-specific. The amr + value is an array of case sensitive strings. + """ + amr = self.get('amr') + if amr and not isinstance(self['amr'], list): + raise InvalidClaimError('amr') + + def validate_azp(self): + """OPTIONAL. Authorized party - the party to which the ID Token was + issued. If present, it MUST contain the OAuth 2.0 Client ID of this + party. This Claim is only needed when the ID Token has a single + audience value and that audience is different than the authorized + party. It MAY be included even when the authorized party is the same + as the sole audience. The azp value is a case sensitive string + containing a StringOrURI value. + """ + aud = self.get('aud') + client_id = self.params.get('client_id') + required = False + if aud and client_id: + if isinstance(aud, list) and len(aud) == 1: + aud = aud[0] + if aud != client_id: + required = True + + azp = self.get('azp') + if required and not azp: + raise MissingClaimError('azp') + + if azp and client_id and azp != client_id: + raise InvalidClaimError('azp') + + def validate_at_hash(self): + """OPTIONAL. Access Token hash value. Its value is the base64url + encoding of the left-most half of the hash of the octets of the ASCII + representation of the access_token value, where the hash algorithm + used is the hash algorithm used in the alg Header Parameter of the + ID Token's JOSE Header. For instance, if the alg is RS256, hash the + access_token value with SHA-256, then take the left-most 128 bits and + base64url encode them. The at_hash value is a case sensitive string. + """ + access_token = self.params.get('access_token') + at_hash = self.get('at_hash') + if at_hash and access_token: + if not _verify_hash(at_hash, access_token, self.header['alg']): + raise InvalidClaimError('at_hash') + + +class CodeIDToken(IDToken): + RESPONSE_TYPES = ('code',) + REGISTERED_CLAIMS = _REGISTERED_CLAIMS + + +class ImplicitIDToken(IDToken): + RESPONSE_TYPES = ('id_token', 'id_token token') + ESSENTIAL_CLAIMS = ['iss', 'sub', 'aud', 'exp', 'iat', 'nonce'] + REGISTERED_CLAIMS = _REGISTERED_CLAIMS + + def validate_at_hash(self): + """If the ID Token is issued from the Authorization Endpoint with an + access_token value, which is the case for the response_type value + id_token token, this is REQUIRED; it MAY NOT be used when no Access + Token is issued, which is the case for the response_type value + id_token. + """ + access_token = self.params.get('access_token') + if access_token and 'at_hash' not in self: + raise MissingClaimError('at_hash') + super().validate_at_hash() + + +class HybridIDToken(ImplicitIDToken): + RESPONSE_TYPES = ('code id_token', 'code token', 'code id_token token') + REGISTERED_CLAIMS = _REGISTERED_CLAIMS + ['c_hash'] + + def validate(self, now=None, leeway=0): + super().validate(now=now, leeway=leeway) + self.validate_c_hash() + + def validate_c_hash(self): + """Code hash value. Its value is the base64url encoding of the + left-most half of the hash of the octets of the ASCII representation + of the code value, where the hash algorithm used is the hash algorithm + used in the alg Header Parameter of the ID Token's JOSE Header. For + instance, if the alg is HS512, hash the code value with SHA-512, then + take the left-most 256 bits and base64url encode them. The c_hash + value is a case sensitive string. + If the ID Token is issued from the Authorization Endpoint with a code, + which is the case for the response_type values code id_token and code + id_token token, this is REQUIRED; otherwise, its inclusion is OPTIONAL. + """ + code = self.params.get('code') + c_hash = self.get('c_hash') + if code: + if not c_hash: + raise MissingClaimError('c_hash') + if not _verify_hash(c_hash, code, self.header['alg']): + raise InvalidClaimError('c_hash') + + +class UserInfo(dict): + """The standard claims of a UserInfo object. Defined per `Section 5.1`_. + + .. _`Section 5.1`: http://openid.net/specs/openid-connect-core-1_0.html#StandardClaims + """ + + #: registered claims that UserInfo supports + REGISTERED_CLAIMS = [ + 'sub', 'name', 'given_name', 'family_name', 'middle_name', 'nickname', + 'preferred_username', 'profile', 'picture', 'website', 'email', + 'email_verified', 'gender', 'birthdate', 'zoneinfo', 'locale', + 'phone_number', 'phone_number_verified', 'address', 'updated_at', + ] + + def __getattr__(self, key): + try: + return object.__getattribute__(self, key) + except AttributeError as error: + if key in self.REGISTERED_CLAIMS: + return self.get(key) + raise error + + +def get_claim_cls_by_response_type(response_type): + claims_classes = (CodeIDToken, ImplicitIDToken, HybridIDToken) + for claims_cls in claims_classes: + if response_type in claims_cls.RESPONSE_TYPES: + return claims_cls + + +def _verify_hash(signature, s, alg): + hash_value = create_half_hash(s, alg) + if not hash_value: + return True + return hmac.compare_digest(hash_value, to_bytes(signature)) diff --git a/.venv/Lib/site-packages/authlib/oidc/core/errors.py b/.venv/Lib/site-packages/authlib/oidc/core/errors.py new file mode 100644 index 00000000..e5fb630e --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/errors.py @@ -0,0 +1,78 @@ +from authlib.oauth2 import OAuth2Error + + +class InteractionRequiredError(OAuth2Error): + """The Authorization Server requires End-User interaction of some form + to proceed. This error MAY be returned when the prompt parameter value + in the Authentication Request is none, but the Authentication Request + cannot be completed without displaying a user interface for End-User + interaction. + + http://openid.net/specs/openid-connect-core-1_0.html#AuthError + """ + error = 'interaction_required' + + +class LoginRequiredError(OAuth2Error): + """The Authorization Server requires End-User authentication. This error + MAY be returned when the prompt parameter value in the Authentication + Request is none, but the Authentication Request cannot be completed + without displaying a user interface for End-User authentication. + + http://openid.net/specs/openid-connect-core-1_0.html#AuthError + """ + error = 'login_required' + + +class AccountSelectionRequiredError(OAuth2Error): + """The End-User is REQUIRED to select a session at the Authorization + Server. The End-User MAY be authenticated at the Authorization Server + with different associated accounts, but the End-User did not select a + session. This error MAY be returned when the prompt parameter value in + the Authentication Request is none, but the Authentication Request cannot + be completed without displaying a user interface to prompt for a session + to use. + + http://openid.net/specs/openid-connect-core-1_0.html#AuthError + """ + error = 'account_selection_required' + + +class ConsentRequiredError(OAuth2Error): + """The Authorization Server requires End-User consent. This error MAY be + returned when the prompt parameter value in the Authentication Request is + none, but the Authentication Request cannot be completed without + displaying a user interface for End-User consent. + + http://openid.net/specs/openid-connect-core-1_0.html#AuthError + """ + error = 'consent_required' + + +class InvalidRequestURIError(OAuth2Error): + """The request_uri in the Authorization Request returns an error or + contains invalid data. + + http://openid.net/specs/openid-connect-core-1_0.html#AuthError + """ + error = 'invalid_request_uri' + + +class InvalidRequestObjectError(OAuth2Error): + """The request parameter contains an invalid Request Object.""" + error = 'invalid_request_object' + + +class RequestNotSupportedError(OAuth2Error): + """The OP does not support use of the request parameter.""" + error = 'request_not_supported' + + +class RequestURINotSupportedError(OAuth2Error): + """The OP does not support use of the request_uri parameter.""" + error = 'request_uri_not_supported' + + +class RegistrationNotSupportedError(OAuth2Error): + """The OP does not support use of the registration parameter.""" + error = 'registration_not_supported' diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__init__.py b/.venv/Lib/site-packages/authlib/oidc/core/grants/__init__.py new file mode 100644 index 00000000..8b4b0025 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/grants/__init__.py @@ -0,0 +1,10 @@ +from .code import OpenIDToken, OpenIDCode +from .implicit import OpenIDImplicitGrant +from .hybrid import OpenIDHybridGrant + +__all__ = [ + 'OpenIDToken', + 'OpenIDCode', + 'OpenIDImplicitGrant', + 'OpenIDHybridGrant', +] diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..86b26167 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/code.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/code.cpython-311.pyc new file mode 100644 index 00000000..4afd65f5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/code.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/hybrid.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/hybrid.cpython-311.pyc new file mode 100644 index 00000000..47fb4bd5 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/hybrid.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/implicit.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/implicit.cpython-311.pyc new file mode 100644 index 00000000..ba45e950 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/implicit.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..2ebb6425 Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/core/grants/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/code.py b/.venv/Lib/site-packages/authlib/oidc/core/grants/code.py new file mode 100644 index 00000000..9ac3bfbb --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/grants/code.py @@ -0,0 +1,142 @@ +""" + authlib.oidc.core.grants.code + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Implementation of Authentication using the Authorization Code Flow + per `Section 3.1`_. + + .. _`Section 3.1`: http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth +""" + +import logging +from authlib.oauth2.rfc6749 import OAuth2Request +from .util import ( + is_openid_scope, + validate_nonce, + validate_request_prompt, + generate_id_token, +) + +log = logging.getLogger(__name__) + + +class OpenIDToken: + def get_jwt_config(self, grant): # pragma: no cover + """Get the JWT configuration for OpenIDCode extension. The JWT + configuration will be used to generate ``id_token``. Developers + MUST implement this method in subclass, e.g.:: + + def get_jwt_config(self, grant): + return { + 'key': read_private_key_file(key_path), + 'alg': 'RS256', + 'iss': 'issuer-identity', + 'exp': 3600 + } + + :param grant: AuthorizationCodeGrant instance + :return: dict + """ + raise NotImplementedError() + + def generate_user_info(self, user, scope): + """Provide user information for the given scope. Developers + MUST implement this method in subclass, e.g.:: + + from authlib.oidc.core import UserInfo + + def generate_user_info(self, user, scope): + user_info = UserInfo(sub=user.id, name=user.name) + if 'email' in scope: + user_info['email'] = user.email + return user_info + + :param user: user instance + :param scope: scope of the token + :return: ``authlib.oidc.core.UserInfo`` instance + """ + raise NotImplementedError() + + def get_audiences(self, request): + """Parse `aud` value for id_token, default value is client id. Developers + MAY rewrite this method to provide a customized audience value. + """ + client = request.client + return [client.get_client_id()] + + def process_token(self, grant, token): + scope = token.get('scope') + if not scope or not is_openid_scope(scope): + # standard authorization code flow + return token + + request: OAuth2Request = grant.request + authorization_code = request.authorization_code + + config = self.get_jwt_config(grant) + config['aud'] = self.get_audiences(request) + + if authorization_code: + config['nonce'] = authorization_code.get_nonce() + config['auth_time'] = authorization_code.get_auth_time() + + user_info = self.generate_user_info(request.user, token['scope']) + id_token = generate_id_token(token, user_info, **config) + token['id_token'] = id_token + return token + + def __call__(self, grant): + grant.register_hook('process_token', self.process_token) + + +class OpenIDCode(OpenIDToken): + """An extension from OpenID Connect for "grant_type=code" request. Developers + MUST implement the missing methods:: + + class MyOpenIDCode(OpenIDCode): + def get_jwt_config(self, grant): + return {...} + + def exists_nonce(self, nonce, request): + return check_if_nonce_in_cache(request.client_id, nonce) + + def generate_user_info(self, user, scope): + return {...} + + The register this extension with AuthorizationCodeGrant:: + + authorization_server.register_grant(AuthorizationCodeGrant, extensions=[MyOpenIDCode()]) + """ + def __init__(self, require_nonce=False): + self.require_nonce = require_nonce + + def exists_nonce(self, nonce, request): + """Check if the given nonce is existing in your database. Developers + MUST implement this method in subclass, e.g.:: + + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.client_id, nonce=nonce + ).first() + return bool(exists) + + :param nonce: A string of "nonce" parameter in request + :param request: OAuth2Request instance + :return: Boolean + """ + raise NotImplementedError() + + def validate_openid_authorization_request(self, grant): + validate_nonce(grant.request, self.exists_nonce, self.require_nonce) + + def __call__(self, grant): + grant.register_hook('process_token', self.process_token) + if is_openid_scope(grant.request.scope): + grant.register_hook( + 'after_validate_authorization_request', + self.validate_openid_authorization_request + ) + grant.register_hook( + 'after_validate_consent_request', + validate_request_prompt + ) diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/hybrid.py b/.venv/Lib/site-packages/authlib/oidc/core/grants/hybrid.py new file mode 100644 index 00000000..384c8673 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/grants/hybrid.py @@ -0,0 +1,90 @@ +import logging +from authlib.common.security import generate_token +from authlib.oauth2.rfc6749 import InvalidScopeError +from authlib.oauth2.rfc6749.grants.authorization_code import ( + validate_code_authorization_request +) +from .implicit import OpenIDImplicitGrant +from .util import is_openid_scope, validate_nonce + +log = logging.getLogger(__name__) + + +class OpenIDHybridGrant(OpenIDImplicitGrant): + #: Generated "code" length + AUTHORIZATION_CODE_LENGTH = 48 + + RESPONSE_TYPES = {'code id_token', 'code token', 'code id_token token'} + GRANT_TYPE = 'code' + DEFAULT_RESPONSE_MODE = 'fragment' + + def generate_authorization_code(self): + """"The method to generate "code" value for authorization code data. + Developers may rewrite this method, or customize the code length with:: + + class MyAuthorizationCodeGrant(AuthorizationCodeGrant): + AUTHORIZATION_CODE_LENGTH = 32 # default is 48 + """ + return generate_token(self.AUTHORIZATION_CODE_LENGTH) + + def save_authorization_code(self, code, request): + """Save authorization_code for later use. Developers MUST implement + it in subclass. Here is an example:: + + def save_authorization_code(self, code, request): + client = request.client + auth_code = AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + nonce=request.data.get('nonce'), + user_id=request.user.id, + ) + auth_code.save() + """ + raise NotImplementedError() + + def validate_authorization_request(self): + if not is_openid_scope(self.request.scope): + raise InvalidScopeError( + 'Missing "openid" scope', + redirect_uri=self.request.redirect_uri, + redirect_fragment=True, + ) + self.register_hook( + 'after_validate_authorization_request', + lambda grant: validate_nonce( + grant.request, grant.exists_nonce, required=True) + ) + return validate_code_authorization_request(self) + + def create_granted_params(self, grant_user): + self.request.user = grant_user + client = self.request.client + code = self.generate_authorization_code() + self.save_authorization_code(code, self.request) + params = [('code', code)] + token = self.generate_token( + grant_type='implicit', + user=grant_user, + scope=self.request.scope, + include_refresh_token=False + ) + + response_types = self.request.response_type.split() + if 'token' in response_types: + log.debug('Grant token %r to %r', token, client) + self.server.save_token(token, self.request) + if 'id_token' in response_types: + token = self.process_implicit_token(token, code) + else: + # response_type is "code id_token" + token = { + 'expires_in': token['expires_in'], + 'scope': token['scope'] + } + token = self.process_implicit_token(token, code) + + params.extend([(k, token[k]) for k in token]) + return params diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/implicit.py b/.venv/Lib/site-packages/authlib/oidc/core/grants/implicit.py new file mode 100644 index 00000000..15bc1fac --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/grants/implicit.py @@ -0,0 +1,150 @@ +import logging +from authlib.oauth2.rfc6749 import ( + OAuth2Error, + InvalidScopeError, + AccessDeniedError, + ImplicitGrant, +) +from .util import ( + is_openid_scope, + validate_nonce, + validate_request_prompt, + create_response_mode_response, + generate_id_token, +) + +log = logging.getLogger(__name__) + + +class OpenIDImplicitGrant(ImplicitGrant): + RESPONSE_TYPES = {'id_token token', 'id_token'} + DEFAULT_RESPONSE_MODE = 'fragment' + + def exists_nonce(self, nonce, request): + """Check if the given nonce is existing in your database. Developers + should implement this method in subclass, e.g.:: + + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.client_id, nonce=nonce + ).first() + return bool(exists) + + :param nonce: A string of "nonce" parameter in request + :param request: OAuth2Request instance + :return: Boolean + """ + raise NotImplementedError() + + def get_jwt_config(self): + """Get the JWT configuration for OpenIDImplicitGrant. The JWT + configuration will be used to generate ``id_token``. Developers + MUST implement this method in subclass, e.g.:: + + def get_jwt_config(self): + return { + 'key': read_private_key_file(key_path), + 'alg': 'RS256', + 'iss': 'issuer-identity', + 'exp': 3600 + } + + :return: dict + """ + raise NotImplementedError() + + def generate_user_info(self, user, scope): + """Provide user information for the given scope. Developers + MUST implement this method in subclass, e.g.:: + + from authlib.oidc.core import UserInfo + + def generate_user_info(self, user, scope): + user_info = UserInfo(sub=user.id, name=user.name) + if 'email' in scope: + user_info['email'] = user.email + return user_info + + :param user: user instance + :param scope: scope of the token + :return: ``authlib.oidc.core.UserInfo`` instance + """ + raise NotImplementedError() + + def get_audiences(self, request): + """Parse `aud` value for id_token, default value is client id. Developers + MAY rewrite this method to provide a customized audience value. + """ + client = request.client + return [client.get_client_id()] + + def validate_authorization_request(self): + if not is_openid_scope(self.request.scope): + raise InvalidScopeError( + 'Missing "openid" scope', + redirect_uri=self.request.redirect_uri, + redirect_fragment=True, + ) + redirect_uri = super().validate_authorization_request() + try: + validate_nonce(self.request, self.exists_nonce, required=True) + except OAuth2Error as error: + error.redirect_uri = redirect_uri + error.redirect_fragment = True + raise error + return redirect_uri + + def validate_consent_request(self): + redirect_uri = self.validate_authorization_request() + validate_request_prompt(self, redirect_uri, redirect_fragment=True) + + def create_authorization_response(self, redirect_uri, grant_user): + state = self.request.state + if grant_user: + params = self.create_granted_params(grant_user) + if state: + params.append(('state', state)) + else: + error = AccessDeniedError(state=state) + params = error.get_body() + + # http://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#ResponseModes + response_mode = self.request.data.get('response_mode', self.DEFAULT_RESPONSE_MODE) + return create_response_mode_response( + redirect_uri=redirect_uri, + params=params, + response_mode=response_mode, + ) + + def create_granted_params(self, grant_user): + self.request.user = grant_user + client = self.request.client + token = self.generate_token( + user=grant_user, + scope=self.request.scope, + include_refresh_token=False + ) + if self.request.response_type == 'id_token': + token = { + 'expires_in': token['expires_in'], + 'scope': token['scope'], + } + token = self.process_implicit_token(token) + else: + log.debug('Grant token %r to %r', token, client) + self.server.save_token(token, self.request) + token = self.process_implicit_token(token) + params = [(k, token[k]) for k in token] + return params + + def process_implicit_token(self, token, code=None): + config = self.get_jwt_config() + config['aud'] = self.get_audiences(self.request) + config['nonce'] = self.request.data.get('nonce') + if code is not None: + config['code'] = code + + user_info = self.generate_user_info(self.request.user, token['scope']) + id_token = generate_id_token(token, user_info, **config) + token['id_token'] = id_token + return token diff --git a/.venv/Lib/site-packages/authlib/oidc/core/grants/util.py b/.venv/Lib/site-packages/authlib/oidc/core/grants/util.py new file mode 100644 index 00000000..3b57dbe8 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/grants/util.py @@ -0,0 +1,131 @@ +import time +from authlib.oauth2.rfc6749 import InvalidRequestError +from authlib.oauth2.rfc6749 import scope_to_list +from authlib.jose import jwt +from authlib.common.encoding import to_native +from authlib.common.urls import add_params_to_uri, quote_url +from ..util import create_half_hash +from ..errors import ( + LoginRequiredError, + AccountSelectionRequiredError, + ConsentRequiredError, +) + + +def is_openid_scope(scope): + scopes = scope_to_list(scope) + return scopes and 'openid' in scopes + + +def validate_request_prompt(grant, redirect_uri, redirect_fragment=False): + prompt = grant.request.data.get('prompt') + end_user = grant.request.user + if not prompt: + if not end_user: + grant.prompt = 'login' + return grant + + if prompt == 'none' and not end_user: + raise LoginRequiredError( + redirect_uri=redirect_uri, + redirect_fragment=redirect_fragment) + + prompts = prompt.split() + if 'none' in prompts and len(prompts) > 1: + # If this parameter contains none with any other value, + # an error is returned + raise InvalidRequestError( + 'Invalid "prompt" parameter.', + redirect_uri=redirect_uri, + redirect_fragment=redirect_fragment) + + prompt = _guess_prompt_value( + end_user, prompts, redirect_uri, redirect_fragment=redirect_fragment) + if prompt: + grant.prompt = prompt + return grant + + +def validate_nonce(request, exists_nonce, required=False): + nonce = request.data.get('nonce') + if not nonce: + if required: + raise InvalidRequestError('Missing "nonce" in request.') + return True + + if exists_nonce(nonce, request): + raise InvalidRequestError('Replay attack') + + +def generate_id_token( + token, user_info, key, iss, aud, alg='RS256', exp=3600, + nonce=None, auth_time=None, code=None): + + now = int(time.time()) + if auth_time is None: + auth_time = now + + payload = { + 'iss': iss, + 'aud': aud, + 'iat': now, + 'exp': now + exp, + 'auth_time': auth_time, + } + if nonce: + payload['nonce'] = nonce + + if code: + payload['c_hash'] = to_native(create_half_hash(code, alg)) + + access_token = token.get('access_token') + if access_token: + payload['at_hash'] = to_native(create_half_hash(access_token, alg)) + + payload.update(user_info) + return to_native(jwt.encode({'alg': alg}, payload, key)) + + +def create_response_mode_response(redirect_uri, params, response_mode): + if response_mode == 'form_post': + tpl = ( + 'Redirecting' + '' + '
{}
' + ) + inputs = ''.join([ + ''.format( + quote_url(k), quote_url(v)) + for k, v in params + ]) + body = tpl.format(quote_url(redirect_uri), inputs) + return 200, body, [('Content-Type', 'text/html; charset=utf-8')] + + if response_mode == 'query': + uri = add_params_to_uri(redirect_uri, params, fragment=False) + elif response_mode == 'fragment': + uri = add_params_to_uri(redirect_uri, params, fragment=True) + else: + raise InvalidRequestError('Invalid "response_mode" value') + + return 302, '', [('Location', uri)] + + +def _guess_prompt_value(end_user, prompts, redirect_uri, redirect_fragment): + # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest + + if not end_user and 'login' in prompts: + return 'login' + + if 'consent' in prompts: + if not end_user: + raise ConsentRequiredError( + redirect_uri=redirect_uri, + redirect_fragment=redirect_fragment) + return 'consent' + elif 'select_account' in prompts: + if not end_user: + raise AccountSelectionRequiredError( + redirect_uri=redirect_uri, + redirect_fragment=redirect_fragment) + return 'select_account' diff --git a/.venv/Lib/site-packages/authlib/oidc/core/models.py b/.venv/Lib/site-packages/authlib/oidc/core/models.py new file mode 100644 index 00000000..5f414050 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/models.py @@ -0,0 +1,13 @@ +from authlib.oauth2.rfc6749 import ( + AuthorizationCodeMixin as _AuthorizationCodeMixin +) + + +class AuthorizationCodeMixin(_AuthorizationCodeMixin): + def get_nonce(self): + """Get "nonce" value of the authorization code object.""" + raise NotImplementedError() + + def get_auth_time(self): + """Get "auth_time" value of the authorization code object.""" + raise NotImplementedError() diff --git a/.venv/Lib/site-packages/authlib/oidc/core/util.py b/.venv/Lib/site-packages/authlib/oidc/core/util.py new file mode 100644 index 00000000..6df005d2 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/core/util.py @@ -0,0 +1,12 @@ +import hashlib +from authlib.common.encoding import to_bytes, urlsafe_b64encode + + +def create_half_hash(s, alg): + hash_type = f'sha{alg[2:]}' + hash_alg = getattr(hashlib, hash_type, None) + if not hash_alg: + return None + data_digest = hash_alg(to_bytes(s)).digest() + slice_index = int(len(data_digest) / 2) + return urlsafe_b64encode(data_digest[:slice_index]) diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/__init__.py b/.venv/Lib/site-packages/authlib/oidc/discovery/__init__.py new file mode 100644 index 00000000..1e76401b --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/discovery/__init__.py @@ -0,0 +1,13 @@ +""" + authlib.oidc.discover + ~~~~~~~~~~~~~~~~~~~~~ + + OpenID Connect Discovery 1.0 Implementation. + + https://openid.net/specs/openid-connect-discovery-1_0.html +""" + +from .models import OpenIDProviderMetadata +from .well_known import get_well_known_url + +__all__ = ['OpenIDProviderMetadata', 'get_well_known_url'] diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6be2443b Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/models.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/models.cpython-311.pyc new file mode 100644 index 00000000..21ae82cc Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/models.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/well_known.cpython-311.pyc b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/well_known.cpython-311.pyc new file mode 100644 index 00000000..fcb7bbdc Binary files /dev/null and b/.venv/Lib/site-packages/authlib/oidc/discovery/__pycache__/well_known.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/models.py b/.venv/Lib/site-packages/authlib/oidc/discovery/models.py new file mode 100644 index 00000000..d9329efd --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/discovery/models.py @@ -0,0 +1,283 @@ +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.oauth2.rfc8414.models import validate_array_value + + +class OpenIDProviderMetadata(AuthorizationServerMetadata): + REGISTRY_KEYS = [ + 'issuer', 'authorization_endpoint', 'token_endpoint', + 'jwks_uri', 'registration_endpoint', 'scopes_supported', + 'response_types_supported', 'response_modes_supported', + 'grant_types_supported', + 'token_endpoint_auth_methods_supported', + 'token_endpoint_auth_signing_alg_values_supported', + 'service_documentation', 'ui_locales_supported', + 'op_policy_uri', 'op_tos_uri', + + # added by OpenID + 'acr_values_supported', 'subject_types_supported', + 'id_token_signing_alg_values_supported', + 'id_token_encryption_alg_values_supported', + 'id_token_encryption_enc_values_supported', + 'userinfo_signing_alg_values_supported', + 'userinfo_encryption_alg_values_supported', + 'userinfo_encryption_enc_values_supported', + 'request_object_signing_alg_values_supported', + 'request_object_encryption_alg_values_supported', + 'request_object_encryption_enc_values_supported', + 'display_values_supported', + 'claim_types_supported', + 'claims_supported', + 'claims_locales_supported', + 'claims_parameter_supported', + 'request_parameter_supported', + 'request_uri_parameter_supported', + 'require_request_uri_registration', + + # not defined by OpenID + # 'revocation_endpoint', + # 'revocation_endpoint_auth_methods_supported', + # 'revocation_endpoint_auth_signing_alg_values_supported', + # 'introspection_endpoint', + # 'introspection_endpoint_auth_methods_supported', + # 'introspection_endpoint_auth_signing_alg_values_supported', + # 'code_challenge_methods_supported', + ] + + def validate_jwks_uri(self): + # REQUIRED in OpenID Connect + jwks_uri = self.get('jwks_uri') + if jwks_uri is None: + raise ValueError('"jwks_uri" is required') + return super().validate_jwks_uri() + + def validate_acr_values_supported(self): + """OPTIONAL. JSON array containing a list of the Authentication + Context Class References that this OP supports. + """ + validate_array_value(self, 'acr_values_supported') + + def validate_subject_types_supported(self): + """REQUIRED. JSON array containing a list of the Subject Identifier + types that this OP supports. Valid types include pairwise and public. + """ + # 1. REQUIRED + values = self.get('subject_types_supported') + if values is None: + raise ValueError('"subject_types_supported" is required') + + # 2. JSON array + if not isinstance(values, list): + raise ValueError('"subject_types_supported" MUST be JSON array') + + # 3. Valid types include pairwise and public + valid_types = {'pairwise', 'public'} + if not valid_types.issuperset(set(values)): + raise ValueError( + '"subject_types_supported" contains invalid values') + + def validate_id_token_signing_alg_values_supported(self): + """REQUIRED. JSON array containing a list of the JWS signing + algorithms (alg values) supported by the OP for the ID Token to + encode the Claims in a JWT [JWT]. The algorithm RS256 MUST be + included. The value none MAY be supported, but MUST NOT be used + unless the Response Type used returns no ID Token from the + Authorization Endpoint (such as when using the Authorization + Code Flow). + """ + # 1. REQUIRED + values = self.get('id_token_signing_alg_values_supported') + if values is None: + raise ValueError('"id_token_signing_alg_values_supported" is required') + + # 2. JSON array + if not isinstance(values, list): + raise ValueError('"id_token_signing_alg_values_supported" MUST be JSON array') + + # 3. The algorithm RS256 MUST be included + if 'RS256' not in values: + raise ValueError( + '"RS256" MUST be included in "id_token_signing_alg_values_supported"') + + def validate_id_token_encryption_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (alg values) supported by the OP for the ID Token to + encode the Claims in a JWT. + """ + validate_array_value(self, 'id_token_encryption_alg_values_supported') + + def validate_id_token_encryption_enc_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (enc values) supported by the OP for the ID Token to + encode the Claims in a JWT. + """ + validate_array_value(self, 'id_token_encryption_enc_values_supported') + + def validate_userinfo_signing_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWS signing + algorithms (alg values) [JWA] supported by the UserInfo Endpoint + to encode the Claims in a JWT. The value none MAY be included. + """ + validate_array_value(self, 'userinfo_signing_alg_values_supported') + + def validate_userinfo_encryption_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (alg values) [JWA] supported by the UserInfo Endpoint + to encode the Claims in a JWT. + """ + validate_array_value(self, 'userinfo_encryption_alg_values_supported') + + def validate_userinfo_encryption_enc_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (enc values) [JWA] supported by the UserInfo Endpoint + to encode the Claims in a JWT. + """ + validate_array_value(self, 'userinfo_encryption_enc_values_supported') + + def validate_request_object_signing_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWS signing + algorithms (alg values) supported by the OP for Request Objects, + which are described in Section 6.1 of OpenID Connect Core 1.0. + These algorithms are used both when the Request Object is passed + by value (using the request parameter) and when it is passed by + reference (using the request_uri parameter). Servers SHOULD support + none and RS256. + """ + values = self.get('request_object_signing_alg_values_supported') + if not values: + return + + if not isinstance(values, list): + raise ValueError('"request_object_signing_alg_values_supported" MUST be JSON array') + + # Servers SHOULD support none and RS256 + if 'none' not in values or 'RS256' not in values: + raise ValueError( + '"request_object_signing_alg_values_supported" ' + 'SHOULD support none and RS256') + + def validate_request_object_encryption_alg_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (alg values) supported by the OP for Request Objects. + These algorithms are used both when the Request Object is passed + by value and when it is passed by reference. + """ + validate_array_value(self, 'request_object_encryption_alg_values_supported') + + def validate_request_object_encryption_enc_values_supported(self): + """OPTIONAL. JSON array containing a list of the JWE encryption + algorithms (enc values) supported by the OP for Request Objects. + These algorithms are used both when the Request Object is passed + by value and when it is passed by reference. + """ + validate_array_value(self, 'request_object_encryption_enc_values_supported') + + def validate_display_values_supported(self): + """OPTIONAL. JSON array containing a list of the display parameter + values that the OpenID Provider supports. These values are described + in Section 3.1.2.1 of OpenID Connect Core 1.0. + """ + values = self.get('display_values_supported') + if not values: + return + + if not isinstance(values, list): + raise ValueError('"display_values_supported" MUST be JSON array') + + valid_values = {'page', 'popup', 'touch', 'wap'} + if not valid_values.issuperset(set(values)): + raise ValueError('"display_values_supported" contains invalid values') + + def validate_claim_types_supported(self): + """OPTIONAL. JSON array containing a list of the Claim Types that + the OpenID Provider supports. These Claim Types are described in + Section 5.6 of OpenID Connect Core 1.0. Values defined by this + specification are normal, aggregated, and distributed. If omitted, + the implementation supports only normal Claims. + """ + values = self.get('claim_types_supported') + if not values: + return + + if not isinstance(values, list): + raise ValueError('"claim_types_supported" MUST be JSON array') + + valid_values = {'normal', 'aggregated', 'distributed'} + if not valid_values.issuperset(set(values)): + raise ValueError('"claim_types_supported" contains invalid values') + + def validate_claims_supported(self): + """RECOMMENDED. JSON array containing a list of the Claim Names + of the Claims that the OpenID Provider MAY be able to supply values + for. Note that for privacy or other reasons, this might not be an + exhaustive list. + """ + validate_array_value(self, 'claims_supported') + + def validate_claims_locales_supported(self): + """OPTIONAL. Languages and scripts supported for values in Claims + being returned, represented as a JSON array of BCP47 [RFC5646] + language tag values. Not all languages and scripts are necessarily + supported for all Claim values. + """ + validate_array_value(self, 'claims_locales_supported') + + def validate_claims_parameter_supported(self): + """OPTIONAL. Boolean value specifying whether the OP supports use of + the claims parameter, with true indicating support. If omitted, the + default value is false. + """ + _validate_boolean_value(self, 'claims_parameter_supported') + + def validate_request_parameter_supported(self): + """OPTIONAL. Boolean value specifying whether the OP supports use of + the request parameter, with true indicating support. If omitted, the + default value is false. + """ + _validate_boolean_value(self, 'request_parameter_supported') + + def validate_request_uri_parameter_supported(self): + """OPTIONAL. Boolean value specifying whether the OP supports use of + the request_uri parameter, with true indicating support. If omitted, + the default value is true. + """ + _validate_boolean_value(self, 'request_uri_parameter_supported') + + def validate_require_request_uri_registration(self): + """OPTIONAL. Boolean value specifying whether the OP requires any + request_uri values used to be pre-registered using the request_uris + registration parameter. Pre-registration is REQUIRED when the value + is true. If omitted, the default value is false. + """ + _validate_boolean_value(self, 'require_request_uri_registration') + + @property + def claim_types_supported(self): + # If omitted, the implementation supports only normal Claims + return self.get('claim_types_supported', ['normal']) + + @property + def claims_parameter_supported(self): + # If omitted, the default value is false. + return self.get('claims_parameter_supported', False) + + @property + def request_parameter_supported(self): + # If omitted, the default value is false. + return self.get('request_parameter_supported', False) + + @property + def request_uri_parameter_supported(self): + # If omitted, the default value is true. + return self.get('request_uri_parameter_supported', True) + + @property + def require_request_uri_registration(self): + # If omitted, the default value is false. + return self.get('require_request_uri_registration', False) + + +def _validate_boolean_value(metadata, key): + if key not in metadata: + return + if metadata[key] not in (True, False): + raise ValueError(f'"{key}" MUST be boolean') diff --git a/.venv/Lib/site-packages/authlib/oidc/discovery/well_known.py b/.venv/Lib/site-packages/authlib/oidc/discovery/well_known.py new file mode 100644 index 00000000..e3087a14 --- /dev/null +++ b/.venv/Lib/site-packages/authlib/oidc/discovery/well_known.py @@ -0,0 +1,17 @@ +from authlib.common.urls import urlparse + + +def get_well_known_url(issuer, external=False): + """Get well-known URI with issuer via Section 4.1. + + :param issuer: URL of the issuer + :param external: return full external url or not + :return: URL + """ + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest + if external: + return issuer.rstrip('/') + '/.well-known/openid-configuration' + + parsed = urlparse.urlparse(issuer) + path = parsed.path + return path.rstrip('/') + '/.well-known/openid-configuration' diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/INSTALLER b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/LICENSE b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/LICENSE new file mode 100644 index 00000000..29225eee --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/LICENSE @@ -0,0 +1,26 @@ + +Except when otherwise stated (look for LICENSE files in directories or +information at the beginning of each file) all software and +documentation is licensed as follows: + + The MIT License + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or + sell copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/METADATA b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/METADATA new file mode 100644 index 00000000..f582bfbb --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/METADATA @@ -0,0 +1,39 @@ +Metadata-Version: 2.1 +Name: cffi +Version: 1.16.0 +Summary: Foreign Function Interface for Python calling C code. +Home-page: http://cffi.readthedocs.org +Author: Armin Rigo, Maciej Fijalkowski +Author-email: python-cffi@googlegroups.com +License: MIT +Project-URL: Documentation, http://cffi.readthedocs.org/ +Project-URL: Source Code, https://github.com/python-cffi/cffi +Project-URL: Issue Tracker, https://github.com/python-cffi/cffi/issues +Project-URL: Changelog, https://cffi.readthedocs.io/en/latest/whatsnew.html +Project-URL: Downloads, https://github.com/python-cffi/cffi/releases +Project-URL: Contact, https://groups.google.com/forum/#!forum/python-cffi +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: License :: OSI Approved :: MIT License +Requires-Python: >=3.8 +License-File: LICENSE +Requires-Dist: pycparser + + +CFFI +==== + +Foreign Function Interface for Python calling C code. +Please see the `Documentation `_. + +Contact +------- + +`Mailing list `_ diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/RECORD b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/RECORD new file mode 100644 index 00000000..053bb7f5 --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/RECORD @@ -0,0 +1,48 @@ +_cffi_backend.cp311-win_amd64.pyd,sha256=WXZ7CRiFm-3fKKfWalBDFBH_2UDDKz6DR-bZOLYPrN8,181248 +cffi-1.16.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +cffi-1.16.0.dist-info/LICENSE,sha256=esEZUOct9bRcUXFqeyLnuzSzJNZ_Bl4pOBUt1HLEgV8,1320 +cffi-1.16.0.dist-info/METADATA,sha256=JSVpnYGQLm_ch1eupKWpcINOEYPZdnejqtsSbz-LeJg,1519 +cffi-1.16.0.dist-info/RECORD,, +cffi-1.16.0.dist-info/WHEEL,sha256=badvNS-y9fEq0X-qzdZYvql_JFjI7Xfw-wR8FsjoK0I,102 +cffi-1.16.0.dist-info/entry_points.txt,sha256=y6jTxnyeuLnL-XJcDv8uML3n6wyYiGRg8MTp_QGJ9Ho,75 +cffi-1.16.0.dist-info/top_level.txt,sha256=rE7WR3rZfNKxWI9-jn6hsHCAl7MDkB-FmuQbxWjFehQ,19 +cffi/__init__.py,sha256=3woJiTGr2RG4_jOJxsYUhkS6p6N_oz8OBfXTiegoUwQ,527 +cffi/__pycache__/__init__.cpython-311.pyc,, +cffi/__pycache__/_imp_emulation.cpython-311.pyc,, +cffi/__pycache__/_shimmed_dist_utils.cpython-311.pyc,, +cffi/__pycache__/api.cpython-311.pyc,, +cffi/__pycache__/backend_ctypes.cpython-311.pyc,, +cffi/__pycache__/cffi_opcode.cpython-311.pyc,, +cffi/__pycache__/commontypes.cpython-311.pyc,, +cffi/__pycache__/cparser.cpython-311.pyc,, +cffi/__pycache__/error.cpython-311.pyc,, +cffi/__pycache__/ffiplatform.cpython-311.pyc,, +cffi/__pycache__/lock.cpython-311.pyc,, +cffi/__pycache__/model.cpython-311.pyc,, +cffi/__pycache__/pkgconfig.cpython-311.pyc,, +cffi/__pycache__/recompiler.cpython-311.pyc,, +cffi/__pycache__/setuptools_ext.cpython-311.pyc,, +cffi/__pycache__/vengine_cpy.cpython-311.pyc,, +cffi/__pycache__/vengine_gen.cpython-311.pyc,, +cffi/__pycache__/verifier.cpython-311.pyc,, +cffi/_cffi_errors.h,sha256=G0bGOb-6SNIO0UY8KEN3cM40Yd1JuR5bETQ8Ni5PxWY,4057 +cffi/_cffi_include.h,sha256=H7cgdZR-POwmUFrIup4jOGzmje8YoQHhN99gVFg7w08,15185 +cffi/_embedding.h,sha256=hWxZXkHEqIGELOPKKYCspuZS0su3M6Tc0sL0M2g7vzI,19337 +cffi/_imp_emulation.py,sha256=pGPNO0osgce1iFDCmVQgTEOTs4IJtJsiE8nGrFafjVA,3043 +cffi/_shimmed_dist_utils.py,sha256=39wmD8jFLis74FzCFesJ__UOPfWu4WiECozCarZcAXE,2048 +cffi/api.py,sha256=tqyTU_x_WPPFaYof4MC8PrlewL15nW9GGsr6eDJ1vsk,43050 +cffi/backend_ctypes.py,sha256=BHN3q2giL2_Y8wMDST2CIcc_qoMrs65qV9Ob5JvxBZ4,43575 +cffi/cffi_opcode.py,sha256=57P2NHLZkuTWueZybu5iosWljb6ocQmUXzGrCplrnyE,5911 +cffi/commontypes.py,sha256=mEZD4g0qtadnv6O6CEXvMQaJ1K6SRbG5S1h4YvVZHOU,2769 +cffi/cparser.py,sha256=CwVk2V3ATYlCoywG6zN35w6UQ7zj2EWX68KjoJp2Mzk,45237 +cffi/error.py,sha256=Bka7fSV22aIglTQDPIDfpnxTc1aWZLMQdQOJY-h_PUA,908 +cffi/ffiplatform.py,sha256=n1FZYFwdmzQ8bmBQWTvCWWGYu82alBTjZVc3gXjBv38,3697 +cffi/lock.py,sha256=vnbsel7392Ib8gGBifIfAfc7MHteSwd3nP725pvc25Q,777 +cffi/model.py,sha256=cafvJJEePx0PXq29IQNKpBsFabNqZk7DTkeDMZSz3nc,22408 +cffi/parse_c_type.h,sha256=fKYNqWNX5f9kZNNhbXcRLTOlpRGRhh8eCLyHmTXIZnQ,6157 +cffi/pkgconfig.py,sha256=9zDcDf0XKIJaxFHLg7e-W8-Xb8Yq5hdhqH7kLg-ugRo,4495 +cffi/recompiler.py,sha256=alZMkdmqd0dCgUQ0bxEfDNTlPUe4iGV57KUNij-vPUs,66182 +cffi/setuptools_ext.py,sha256=60wpMrriB-lC7RPG90Pc6dIbE2nVo6u1ye_S6Ei2_sM,9087 +cffi/vengine_cpy.py,sha256=vbMYENE5jlItSjb2k62AIoWbS66uBjLWOxzisRfUg5Q,44428 +cffi/vengine_gen.py,sha256=mykUhLFJIcV6AyQ5cMJ3n_7dbqw0a9WEjXW0E-WfgiI,27359 +cffi/verifier.py,sha256=Tt_mMNvykn3cFSAM7tSqaON2oaD1fbWyFqbZ2ugP98Y,11488 diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/WHEEL b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/WHEEL new file mode 100644 index 00000000..6d160455 --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.2) +Root-Is-Purelib: false +Tag: cp311-cp311-win_amd64 + diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/entry_points.txt b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/entry_points.txt new file mode 100644 index 00000000..4b0274f2 --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[distutils.setup_keywords] +cffi_modules = cffi.setuptools_ext:cffi_modules diff --git a/.venv/Lib/site-packages/cffi-1.16.0.dist-info/top_level.txt b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/top_level.txt new file mode 100644 index 00000000..f6457795 --- /dev/null +++ b/.venv/Lib/site-packages/cffi-1.16.0.dist-info/top_level.txt @@ -0,0 +1,2 @@ +_cffi_backend +cffi diff --git a/.venv/Lib/site-packages/cffi/__init__.py b/.venv/Lib/site-packages/cffi/__init__.py new file mode 100644 index 00000000..90dedf43 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/__init__.py @@ -0,0 +1,14 @@ +__all__ = ['FFI', 'VerificationError', 'VerificationMissing', 'CDefError', + 'FFIError'] + +from .api import FFI +from .error import CDefError, FFIError, VerificationError, VerificationMissing +from .error import PkgConfigError + +__version__ = "1.16.0" +__version_info__ = (1, 16, 0) + +# The verifier module file names are based on the CRC32 of a string that +# contains the following version number. It may be older than __version__ +# if nothing is clearly incompatible. +__version_verifier_modules__ = "0.8.6" diff --git a/.venv/Lib/site-packages/cffi/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..f3afbb9e Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/_imp_emulation.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/_imp_emulation.cpython-311.pyc new file mode 100644 index 00000000..745ebf8d Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/_imp_emulation.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/_shimmed_dist_utils.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/_shimmed_dist_utils.cpython-311.pyc new file mode 100644 index 00000000..73e4dea7 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/_shimmed_dist_utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/api.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/api.cpython-311.pyc new file mode 100644 index 00000000..6a00ae1e Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/api.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/backend_ctypes.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/backend_ctypes.cpython-311.pyc new file mode 100644 index 00000000..8ad26337 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/backend_ctypes.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/cffi_opcode.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/cffi_opcode.cpython-311.pyc new file mode 100644 index 00000000..228d8615 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/cffi_opcode.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/commontypes.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/commontypes.cpython-311.pyc new file mode 100644 index 00000000..6153d1b6 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/commontypes.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/cparser.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/cparser.cpython-311.pyc new file mode 100644 index 00000000..a36775ed Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/cparser.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/error.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/error.cpython-311.pyc new file mode 100644 index 00000000..e9a1a7d0 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/error.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/ffiplatform.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/ffiplatform.cpython-311.pyc new file mode 100644 index 00000000..d0209cbf Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/ffiplatform.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/lock.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/lock.cpython-311.pyc new file mode 100644 index 00000000..2a8ce440 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/lock.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/model.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/model.cpython-311.pyc new file mode 100644 index 00000000..1e55ce6f Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/model.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/pkgconfig.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/pkgconfig.cpython-311.pyc new file mode 100644 index 00000000..53b03fdd Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/pkgconfig.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/recompiler.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/recompiler.cpython-311.pyc new file mode 100644 index 00000000..2feb4a1b Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/recompiler.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/setuptools_ext.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/setuptools_ext.cpython-311.pyc new file mode 100644 index 00000000..b0729848 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/setuptools_ext.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/vengine_cpy.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/vengine_cpy.cpython-311.pyc new file mode 100644 index 00000000..cfa8c03d Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/vengine_cpy.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/vengine_gen.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/vengine_gen.cpython-311.pyc new file mode 100644 index 00000000..606a6020 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/vengine_gen.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/__pycache__/verifier.cpython-311.pyc b/.venv/Lib/site-packages/cffi/__pycache__/verifier.cpython-311.pyc new file mode 100644 index 00000000..5e670094 Binary files /dev/null and b/.venv/Lib/site-packages/cffi/__pycache__/verifier.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cffi/_cffi_errors.h b/.venv/Lib/site-packages/cffi/_cffi_errors.h new file mode 100644 index 00000000..158e0590 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/_cffi_errors.h @@ -0,0 +1,149 @@ +#ifndef CFFI_MESSAGEBOX +# ifdef _MSC_VER +# define CFFI_MESSAGEBOX 1 +# else +# define CFFI_MESSAGEBOX 0 +# endif +#endif + + +#if CFFI_MESSAGEBOX +/* Windows only: logic to take the Python-CFFI embedding logic + initialization errors and display them in a background thread + with MessageBox. The idea is that if the whole program closes + as a result of this problem, then likely it is already a console + program and you can read the stderr output in the console too. + If it is not a console program, then it will likely show its own + dialog to complain, or generally not abruptly close, and for this + case the background thread should stay alive. +*/ +static void *volatile _cffi_bootstrap_text; + +static PyObject *_cffi_start_error_capture(void) +{ + PyObject *result = NULL; + PyObject *x, *m, *bi; + + if (InterlockedCompareExchangePointer(&_cffi_bootstrap_text, + (void *)1, NULL) != NULL) + return (PyObject *)1; + + m = PyImport_AddModule("_cffi_error_capture"); + if (m == NULL) + goto error; + + result = PyModule_GetDict(m); + if (result == NULL) + goto error; + +#if PY_MAJOR_VERSION >= 3 + bi = PyImport_ImportModule("builtins"); +#else + bi = PyImport_ImportModule("__builtin__"); +#endif + if (bi == NULL) + goto error; + PyDict_SetItemString(result, "__builtins__", bi); + Py_DECREF(bi); + + x = PyRun_String( + "import sys\n" + "class FileLike:\n" + " def write(self, x):\n" + " try:\n" + " of.write(x)\n" + " except: pass\n" + " self.buf += x\n" + " def flush(self):\n" + " pass\n" + "fl = FileLike()\n" + "fl.buf = ''\n" + "of = sys.stderr\n" + "sys.stderr = fl\n" + "def done():\n" + " sys.stderr = of\n" + " return fl.buf\n", /* make sure the returned value stays alive */ + Py_file_input, + result, result); + Py_XDECREF(x); + + error: + if (PyErr_Occurred()) + { + PyErr_WriteUnraisable(Py_None); + PyErr_Clear(); + } + return result; +} + +#pragma comment(lib, "user32.lib") + +static DWORD WINAPI _cffi_bootstrap_dialog(LPVOID ignored) +{ + Sleep(666); /* may be interrupted if the whole process is closing */ +#if PY_MAJOR_VERSION >= 3 + MessageBoxW(NULL, (wchar_t *)_cffi_bootstrap_text, + L"Python-CFFI error", + MB_OK | MB_ICONERROR); +#else + MessageBoxA(NULL, (char *)_cffi_bootstrap_text, + "Python-CFFI error", + MB_OK | MB_ICONERROR); +#endif + _cffi_bootstrap_text = NULL; + return 0; +} + +static void _cffi_stop_error_capture(PyObject *ecap) +{ + PyObject *s; + void *text; + + if (ecap == (PyObject *)1) + return; + + if (ecap == NULL) + goto error; + + s = PyRun_String("done()", Py_eval_input, ecap, ecap); + if (s == NULL) + goto error; + + /* Show a dialog box, but in a background thread, and + never show multiple dialog boxes at once. */ +#if PY_MAJOR_VERSION >= 3 + text = PyUnicode_AsWideCharString(s, NULL); +#else + text = PyString_AsString(s); +#endif + + _cffi_bootstrap_text = text; + + if (text != NULL) + { + HANDLE h; + h = CreateThread(NULL, 0, _cffi_bootstrap_dialog, + NULL, 0, NULL); + if (h != NULL) + CloseHandle(h); + } + /* decref the string, but it should stay alive as 'fl.buf' + in the small module above. It will really be freed only if + we later get another similar error. So it's a leak of at + most one copy of the small module. That's fine for this + situation which is usually a "fatal error" anyway. */ + Py_DECREF(s); + PyErr_Clear(); + return; + + error: + _cffi_bootstrap_text = NULL; + PyErr_Clear(); +} + +#else + +static PyObject *_cffi_start_error_capture(void) { return NULL; } +static void _cffi_stop_error_capture(PyObject *ecap) { } + +#endif diff --git a/.venv/Lib/site-packages/cffi/_cffi_include.h b/.venv/Lib/site-packages/cffi/_cffi_include.h new file mode 100644 index 00000000..e4c0a672 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/_cffi_include.h @@ -0,0 +1,385 @@ +#define _CFFI_ + +/* We try to define Py_LIMITED_API before including Python.h. + + Mess: we can only define it if Py_DEBUG, Py_TRACE_REFS and + Py_REF_DEBUG are not defined. This is a best-effort approximation: + we can learn about Py_DEBUG from pyconfig.h, but it is unclear if + the same works for the other two macros. Py_DEBUG implies them, + but not the other way around. + + The implementation is messy (issue #350): on Windows, with _MSC_VER, + we have to define Py_LIMITED_API even before including pyconfig.h. + In that case, we guess what pyconfig.h will do to the macros above, + and check our guess after the #include. + + Note that on Windows, with CPython 3.x, you need >= 3.5 and virtualenv + version >= 16.0.0. With older versions of either, you don't get a + copy of PYTHON3.DLL in the virtualenv. We can't check the version of + CPython *before* we even include pyconfig.h. ffi.set_source() puts + a ``#define _CFFI_NO_LIMITED_API'' at the start of this file if it is + running on Windows < 3.5, as an attempt at fixing it, but that's + arguably wrong because it may not be the target version of Python. + Still better than nothing I guess. As another workaround, you can + remove the definition of Py_LIMITED_API here. + + See also 'py_limited_api' in cffi/setuptools_ext.py. +*/ +#if !defined(_CFFI_USE_EMBEDDING) && !defined(Py_LIMITED_API) +# ifdef _MSC_VER +# if !defined(_DEBUG) && !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG) && !defined(_CFFI_NO_LIMITED_API) +# define Py_LIMITED_API +# endif +# include + /* sanity-check: Py_LIMITED_API will cause crashes if any of these + are also defined. Normally, the Python file PC/pyconfig.h does not + cause any of these to be defined, with the exception that _DEBUG + causes Py_DEBUG. Double-check that. */ +# ifdef Py_LIMITED_API +# if defined(Py_DEBUG) +# error "pyconfig.h unexpectedly defines Py_DEBUG, but Py_LIMITED_API is set" +# endif +# if defined(Py_TRACE_REFS) +# error "pyconfig.h unexpectedly defines Py_TRACE_REFS, but Py_LIMITED_API is set" +# endif +# if defined(Py_REF_DEBUG) +# error "pyconfig.h unexpectedly defines Py_REF_DEBUG, but Py_LIMITED_API is set" +# endif +# endif +# else +# include +# if !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG) && !defined(_CFFI_NO_LIMITED_API) +# define Py_LIMITED_API +# endif +# endif +#endif + +#include +#ifdef __cplusplus +extern "C" { +#endif +#include +#include "parse_c_type.h" + +/* this block of #ifs should be kept exactly identical between + c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py + and cffi/_cffi_include.h */ +#if defined(_MSC_VER) +# include /* for alloca() */ +# if _MSC_VER < 1600 /* MSVC < 2010 */ + typedef __int8 int8_t; + typedef __int16 int16_t; + typedef __int32 int32_t; + typedef __int64 int64_t; + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + typedef unsigned __int64 uint64_t; + typedef __int8 int_least8_t; + typedef __int16 int_least16_t; + typedef __int32 int_least32_t; + typedef __int64 int_least64_t; + typedef unsigned __int8 uint_least8_t; + typedef unsigned __int16 uint_least16_t; + typedef unsigned __int32 uint_least32_t; + typedef unsigned __int64 uint_least64_t; + typedef __int8 int_fast8_t; + typedef __int16 int_fast16_t; + typedef __int32 int_fast32_t; + typedef __int64 int_fast64_t; + typedef unsigned __int8 uint_fast8_t; + typedef unsigned __int16 uint_fast16_t; + typedef unsigned __int32 uint_fast32_t; + typedef unsigned __int64 uint_fast64_t; + typedef __int64 intmax_t; + typedef unsigned __int64 uintmax_t; +# else +# include +# endif +# if _MSC_VER < 1800 /* MSVC < 2013 */ +# ifndef __cplusplus + typedef unsigned char _Bool; +# endif +# endif +#else +# include +# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux) +# include +# endif +#endif + +#ifdef __GNUC__ +# define _CFFI_UNUSED_FN __attribute__((unused)) +#else +# define _CFFI_UNUSED_FN /* nothing */ +#endif + +#ifdef __cplusplus +# ifndef _Bool + typedef bool _Bool; /* semi-hackish: C++ has no _Bool; bool is builtin */ +# endif +#endif + +/********** CPython-specific section **********/ +#ifndef PYPY_VERSION + + +#if PY_MAJOR_VERSION >= 3 +# define PyInt_FromLong PyLong_FromLong +#endif + +#define _cffi_from_c_double PyFloat_FromDouble +#define _cffi_from_c_float PyFloat_FromDouble +#define _cffi_from_c_long PyInt_FromLong +#define _cffi_from_c_ulong PyLong_FromUnsignedLong +#define _cffi_from_c_longlong PyLong_FromLongLong +#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong +#define _cffi_from_c__Bool PyBool_FromLong + +#define _cffi_to_c_double PyFloat_AsDouble +#define _cffi_to_c_float PyFloat_AsDouble + +#define _cffi_from_c_int(x, type) \ + (((type)-1) > 0 ? /* unsigned */ \ + (sizeof(type) < sizeof(long) ? \ + PyInt_FromLong((long)x) : \ + sizeof(type) == sizeof(long) ? \ + PyLong_FromUnsignedLong((unsigned long)x) : \ + PyLong_FromUnsignedLongLong((unsigned long long)x)) : \ + (sizeof(type) <= sizeof(long) ? \ + PyInt_FromLong((long)x) : \ + PyLong_FromLongLong((long long)x))) + +#define _cffi_to_c_int(o, type) \ + ((type)( \ + sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \ + : (type)_cffi_to_c_i8(o)) : \ + sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \ + : (type)_cffi_to_c_i16(o)) : \ + sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \ + : (type)_cffi_to_c_i32(o)) : \ + sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \ + : (type)_cffi_to_c_i64(o)) : \ + (Py_FatalError("unsupported size for type " #type), (type)0))) + +#define _cffi_to_c_i8 \ + ((int(*)(PyObject *))_cffi_exports[1]) +#define _cffi_to_c_u8 \ + ((int(*)(PyObject *))_cffi_exports[2]) +#define _cffi_to_c_i16 \ + ((int(*)(PyObject *))_cffi_exports[3]) +#define _cffi_to_c_u16 \ + ((int(*)(PyObject *))_cffi_exports[4]) +#define _cffi_to_c_i32 \ + ((int(*)(PyObject *))_cffi_exports[5]) +#define _cffi_to_c_u32 \ + ((unsigned int(*)(PyObject *))_cffi_exports[6]) +#define _cffi_to_c_i64 \ + ((long long(*)(PyObject *))_cffi_exports[7]) +#define _cffi_to_c_u64 \ + ((unsigned long long(*)(PyObject *))_cffi_exports[8]) +#define _cffi_to_c_char \ + ((int(*)(PyObject *))_cffi_exports[9]) +#define _cffi_from_c_pointer \ + ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[10]) +#define _cffi_to_c_pointer \ + ((char *(*)(PyObject *, struct _cffi_ctypedescr *))_cffi_exports[11]) +#define _cffi_get_struct_layout \ + not used any more +#define _cffi_restore_errno \ + ((void(*)(void))_cffi_exports[13]) +#define _cffi_save_errno \ + ((void(*)(void))_cffi_exports[14]) +#define _cffi_from_c_char \ + ((PyObject *(*)(char))_cffi_exports[15]) +#define _cffi_from_c_deref \ + ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[16]) +#define _cffi_to_c \ + ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[17]) +#define _cffi_from_c_struct \ + ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[18]) +#define _cffi_to_c_wchar_t \ + ((_cffi_wchar_t(*)(PyObject *))_cffi_exports[19]) +#define _cffi_from_c_wchar_t \ + ((PyObject *(*)(_cffi_wchar_t))_cffi_exports[20]) +#define _cffi_to_c_long_double \ + ((long double(*)(PyObject *))_cffi_exports[21]) +#define _cffi_to_c__Bool \ + ((_Bool(*)(PyObject *))_cffi_exports[22]) +#define _cffi_prepare_pointer_call_argument \ + ((Py_ssize_t(*)(struct _cffi_ctypedescr *, \ + PyObject *, char **))_cffi_exports[23]) +#define _cffi_convert_array_from_object \ + ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[24]) +#define _CFFI_CPIDX 25 +#define _cffi_call_python \ + ((void(*)(struct _cffi_externpy_s *, char *))_cffi_exports[_CFFI_CPIDX]) +#define _cffi_to_c_wchar3216_t \ + ((int(*)(PyObject *))_cffi_exports[26]) +#define _cffi_from_c_wchar3216_t \ + ((PyObject *(*)(int))_cffi_exports[27]) +#define _CFFI_NUM_EXPORTS 28 + +struct _cffi_ctypedescr; + +static void *_cffi_exports[_CFFI_NUM_EXPORTS]; + +#define _cffi_type(index) ( \ + assert((((uintptr_t)_cffi_types[index]) & 1) == 0), \ + (struct _cffi_ctypedescr *)_cffi_types[index]) + +static PyObject *_cffi_init(const char *module_name, Py_ssize_t version, + const struct _cffi_type_context_s *ctx) +{ + PyObject *module, *o_arg, *new_module; + void *raw[] = { + (void *)module_name, + (void *)version, + (void *)_cffi_exports, + (void *)ctx, + }; + + module = PyImport_ImportModule("_cffi_backend"); + if (module == NULL) + goto failure; + + o_arg = PyLong_FromVoidPtr((void *)raw); + if (o_arg == NULL) + goto failure; + + new_module = PyObject_CallMethod( + module, (char *)"_init_cffi_1_0_external_module", (char *)"O", o_arg); + + Py_DECREF(o_arg); + Py_DECREF(module); + return new_module; + + failure: + Py_XDECREF(module); + return NULL; +} + + +#ifdef HAVE_WCHAR_H +typedef wchar_t _cffi_wchar_t; +#else +typedef uint16_t _cffi_wchar_t; /* same random pick as _cffi_backend.c */ +#endif + +_CFFI_UNUSED_FN static uint16_t _cffi_to_c_char16_t(PyObject *o) +{ + if (sizeof(_cffi_wchar_t) == 2) + return (uint16_t)_cffi_to_c_wchar_t(o); + else + return (uint16_t)_cffi_to_c_wchar3216_t(o); +} + +_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char16_t(uint16_t x) +{ + if (sizeof(_cffi_wchar_t) == 2) + return _cffi_from_c_wchar_t((_cffi_wchar_t)x); + else + return _cffi_from_c_wchar3216_t((int)x); +} + +_CFFI_UNUSED_FN static int _cffi_to_c_char32_t(PyObject *o) +{ + if (sizeof(_cffi_wchar_t) == 4) + return (int)_cffi_to_c_wchar_t(o); + else + return (int)_cffi_to_c_wchar3216_t(o); +} + +_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char32_t(unsigned int x) +{ + if (sizeof(_cffi_wchar_t) == 4) + return _cffi_from_c_wchar_t((_cffi_wchar_t)x); + else + return _cffi_from_c_wchar3216_t((int)x); +} + +union _cffi_union_alignment_u { + unsigned char m_char; + unsigned short m_short; + unsigned int m_int; + unsigned long m_long; + unsigned long long m_longlong; + float m_float; + double m_double; + long double m_longdouble; +}; + +struct _cffi_freeme_s { + struct _cffi_freeme_s *next; + union _cffi_union_alignment_u alignment; +}; + +_CFFI_UNUSED_FN static int +_cffi_convert_array_argument(struct _cffi_ctypedescr *ctptr, PyObject *arg, + char **output_data, Py_ssize_t datasize, + struct _cffi_freeme_s **freeme) +{ + char *p; + if (datasize < 0) + return -1; + + p = *output_data; + if (p == NULL) { + struct _cffi_freeme_s *fp = (struct _cffi_freeme_s *)PyObject_Malloc( + offsetof(struct _cffi_freeme_s, alignment) + (size_t)datasize); + if (fp == NULL) + return -1; + fp->next = *freeme; + *freeme = fp; + p = *output_data = (char *)&fp->alignment; + } + memset((void *)p, 0, (size_t)datasize); + return _cffi_convert_array_from_object(p, ctptr, arg); +} + +_CFFI_UNUSED_FN static void +_cffi_free_array_arguments(struct _cffi_freeme_s *freeme) +{ + do { + void *p = (void *)freeme; + freeme = freeme->next; + PyObject_Free(p); + } while (freeme != NULL); +} + +/********** end CPython-specific section **********/ +#else +_CFFI_UNUSED_FN +static void (*_cffi_call_python_org)(struct _cffi_externpy_s *, char *); +# define _cffi_call_python _cffi_call_python_org +#endif + + +#define _cffi_array_len(array) (sizeof(array) / sizeof((array)[0])) + +#define _cffi_prim_int(size, sign) \ + ((size) == 1 ? ((sign) ? _CFFI_PRIM_INT8 : _CFFI_PRIM_UINT8) : \ + (size) == 2 ? ((sign) ? _CFFI_PRIM_INT16 : _CFFI_PRIM_UINT16) : \ + (size) == 4 ? ((sign) ? _CFFI_PRIM_INT32 : _CFFI_PRIM_UINT32) : \ + (size) == 8 ? ((sign) ? _CFFI_PRIM_INT64 : _CFFI_PRIM_UINT64) : \ + _CFFI__UNKNOWN_PRIM) + +#define _cffi_prim_float(size) \ + ((size) == sizeof(float) ? _CFFI_PRIM_FLOAT : \ + (size) == sizeof(double) ? _CFFI_PRIM_DOUBLE : \ + (size) == sizeof(long double) ? _CFFI__UNKNOWN_LONG_DOUBLE : \ + _CFFI__UNKNOWN_FLOAT_PRIM) + +#define _cffi_check_int(got, got_nonpos, expected) \ + ((got_nonpos) == (expected <= 0) && \ + (got) == (unsigned long long)expected) + +#ifdef MS_WIN32 +# define _cffi_stdcall __stdcall +#else +# define _cffi_stdcall /* nothing */ +#endif + +#ifdef __cplusplus +} +#endif diff --git a/.venv/Lib/site-packages/cffi/_embedding.h b/.venv/Lib/site-packages/cffi/_embedding.h new file mode 100644 index 00000000..1cb66f23 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/_embedding.h @@ -0,0 +1,550 @@ + +/***** Support code for embedding *****/ + +#ifdef __cplusplus +extern "C" { +#endif + + +#if defined(_WIN32) +# define CFFI_DLLEXPORT __declspec(dllexport) +#elif defined(__GNUC__) +# define CFFI_DLLEXPORT __attribute__((visibility("default"))) +#else +# define CFFI_DLLEXPORT /* nothing */ +#endif + + +/* There are two global variables of type _cffi_call_python_fnptr: + + * _cffi_call_python, which we declare just below, is the one called + by ``extern "Python"`` implementations. + + * _cffi_call_python_org, which on CPython is actually part of the + _cffi_exports[] array, is the function pointer copied from + _cffi_backend. If _cffi_start_python() fails, then this is set + to NULL; otherwise, it should never be NULL. + + After initialization is complete, both are equal. However, the + first one remains equal to &_cffi_start_and_call_python until the + very end of initialization, when we are (or should be) sure that + concurrent threads also see a completely initialized world, and + only then is it changed. +*/ +#undef _cffi_call_python +typedef void (*_cffi_call_python_fnptr)(struct _cffi_externpy_s *, char *); +static void _cffi_start_and_call_python(struct _cffi_externpy_s *, char *); +static _cffi_call_python_fnptr _cffi_call_python = &_cffi_start_and_call_python; + + +#ifndef _MSC_VER + /* --- Assuming a GCC not infinitely old --- */ +# define cffi_compare_and_swap(l,o,n) __sync_bool_compare_and_swap(l,o,n) +# define cffi_write_barrier() __sync_synchronize() +# if !defined(__amd64__) && !defined(__x86_64__) && \ + !defined(__i386__) && !defined(__i386) +# define cffi_read_barrier() __sync_synchronize() +# else +# define cffi_read_barrier() (void)0 +# endif +#else + /* --- Windows threads version --- */ +# include +# define cffi_compare_and_swap(l,o,n) \ + (InterlockedCompareExchangePointer(l,n,o) == (o)) +# define cffi_write_barrier() InterlockedCompareExchange(&_cffi_dummy,0,0) +# define cffi_read_barrier() (void)0 +static volatile LONG _cffi_dummy; +#endif + +#ifdef WITH_THREAD +# ifndef _MSC_VER +# include + static pthread_mutex_t _cffi_embed_startup_lock; +# else + static CRITICAL_SECTION _cffi_embed_startup_lock; +# endif + static char _cffi_embed_startup_lock_ready = 0; +#endif + +static void _cffi_acquire_reentrant_mutex(void) +{ + static void *volatile lock = NULL; + + while (!cffi_compare_and_swap(&lock, NULL, (void *)1)) { + /* should ideally do a spin loop instruction here, but + hard to do it portably and doesn't really matter I + think: pthread_mutex_init() should be very fast, and + this is only run at start-up anyway. */ + } + +#ifdef WITH_THREAD + if (!_cffi_embed_startup_lock_ready) { +# ifndef _MSC_VER + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(&_cffi_embed_startup_lock, &attr); +# else + InitializeCriticalSection(&_cffi_embed_startup_lock); +# endif + _cffi_embed_startup_lock_ready = 1; + } +#endif + + while (!cffi_compare_and_swap(&lock, (void *)1, NULL)) + ; + +#ifndef _MSC_VER + pthread_mutex_lock(&_cffi_embed_startup_lock); +#else + EnterCriticalSection(&_cffi_embed_startup_lock); +#endif +} + +static void _cffi_release_reentrant_mutex(void) +{ +#ifndef _MSC_VER + pthread_mutex_unlock(&_cffi_embed_startup_lock); +#else + LeaveCriticalSection(&_cffi_embed_startup_lock); +#endif +} + + +/********** CPython-specific section **********/ +#ifndef PYPY_VERSION + +#include "_cffi_errors.h" + + +#define _cffi_call_python_org _cffi_exports[_CFFI_CPIDX] + +PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(void); /* forward */ + +static void _cffi_py_initialize(void) +{ + /* XXX use initsigs=0, which "skips initialization registration of + signal handlers, which might be useful when Python is + embedded" according to the Python docs. But review and think + if it should be a user-controllable setting. + + XXX we should also give a way to write errors to a buffer + instead of to stderr. + + XXX if importing 'site' fails, CPython (any version) calls + exit(). Should we try to work around this behavior here? + */ + Py_InitializeEx(0); +} + +static int _cffi_initialize_python(void) +{ + /* This initializes Python, imports _cffi_backend, and then the + present .dll/.so is set up as a CPython C extension module. + */ + int result; + PyGILState_STATE state; + PyObject *pycode=NULL, *global_dict=NULL, *x; + PyObject *builtins; + + state = PyGILState_Ensure(); + + /* Call the initxxx() function from the present module. It will + create and initialize us as a CPython extension module, instead + of letting the startup Python code do it---it might reimport + the same .dll/.so and get maybe confused on some platforms. + It might also have troubles locating the .dll/.so again for all + I know. + */ + (void)_CFFI_PYTHON_STARTUP_FUNC(); + if (PyErr_Occurred()) + goto error; + + /* Now run the Python code provided to ffi.embedding_init_code(). + */ + pycode = Py_CompileString(_CFFI_PYTHON_STARTUP_CODE, + "", + Py_file_input); + if (pycode == NULL) + goto error; + global_dict = PyDict_New(); + if (global_dict == NULL) + goto error; + builtins = PyEval_GetBuiltins(); + if (builtins == NULL) + goto error; + if (PyDict_SetItemString(global_dict, "__builtins__", builtins) < 0) + goto error; + x = PyEval_EvalCode( +#if PY_MAJOR_VERSION < 3 + (PyCodeObject *) +#endif + pycode, global_dict, global_dict); + if (x == NULL) + goto error; + Py_DECREF(x); + + /* Done! Now if we've been called from + _cffi_start_and_call_python() in an ``extern "Python"``, we can + only hope that the Python code did correctly set up the + corresponding @ffi.def_extern() function. Otherwise, the + general logic of ``extern "Python"`` functions (inside the + _cffi_backend module) will find that the reference is still + missing and print an error. + */ + result = 0; + done: + Py_XDECREF(pycode); + Py_XDECREF(global_dict); + PyGILState_Release(state); + return result; + + error:; + { + /* Print as much information as potentially useful. + Debugging load-time failures with embedding is not fun + */ + PyObject *ecap; + PyObject *exception, *v, *tb, *f, *modules, *mod; + PyErr_Fetch(&exception, &v, &tb); + ecap = _cffi_start_error_capture(); + f = PySys_GetObject((char *)"stderr"); + if (f != NULL && f != Py_None) { + PyFile_WriteString( + "Failed to initialize the Python-CFFI embedding logic:\n\n", f); + } + + if (exception != NULL) { + PyErr_NormalizeException(&exception, &v, &tb); + PyErr_Display(exception, v, tb); + } + Py_XDECREF(exception); + Py_XDECREF(v); + Py_XDECREF(tb); + + if (f != NULL && f != Py_None) { + PyFile_WriteString("\nFrom: " _CFFI_MODULE_NAME + "\ncompiled with cffi version: 1.16.0" + "\n_cffi_backend module: ", f); + modules = PyImport_GetModuleDict(); + mod = PyDict_GetItemString(modules, "_cffi_backend"); + if (mod == NULL) { + PyFile_WriteString("not loaded", f); + } + else { + v = PyObject_GetAttrString(mod, "__file__"); + PyFile_WriteObject(v, f, 0); + Py_XDECREF(v); + } + PyFile_WriteString("\nsys.path: ", f); + PyFile_WriteObject(PySys_GetObject((char *)"path"), f, 0); + PyFile_WriteString("\n\n", f); + } + _cffi_stop_error_capture(ecap); + } + result = -1; + goto done; +} + +#if PY_VERSION_HEX < 0x03080000 +PyAPI_DATA(char *) _PyParser_TokenNames[]; /* from CPython */ +#endif + +static int _cffi_carefully_make_gil(void) +{ + /* This does the basic initialization of Python. It can be called + completely concurrently from unrelated threads. It assumes + that we don't hold the GIL before (if it exists), and we don't + hold it afterwards. + + (What it really does used to be completely different in Python 2 + and Python 3, with the Python 2 solution avoiding the spin-lock + around the Py_InitializeEx() call. However, after recent changes + to CPython 2.7 (issue #358) it no longer works. So we use the + Python 3 solution everywhere.) + + This initializes Python by calling Py_InitializeEx(). + Important: this must not be called concurrently at all. + So we use a global variable as a simple spin lock. This global + variable must be from 'libpythonX.Y.so', not from this + cffi-based extension module, because it must be shared from + different cffi-based extension modules. + + In Python < 3.8, we choose + _PyParser_TokenNames[0] as a completely arbitrary pointer value + that is never written to. The default is to point to the + string "ENDMARKER". We change it temporarily to point to the + next character in that string. (Yes, I know it's REALLY + obscure.) + + In Python >= 3.8, this string array is no longer writable, so + instead we pick PyCapsuleType.tp_version_tag. We can't change + Python < 3.8 because someone might use a mixture of cffi + embedded modules, some of which were compiled before this file + changed. + + In Python >= 3.12, this stopped working because that particular + tp_version_tag gets modified during interpreter startup. It's + arguably a bad idea before 3.12 too, but again we can't change + that because someone might use a mixture of cffi embedded + modules, and no-one reported a bug so far. In Python >= 3.12 + we go instead for PyCapsuleType.tp_as_buffer, which is supposed + to always be NULL. We write to it temporarily a pointer to + a struct full of NULLs, which is semantically the same. + */ + +#ifdef WITH_THREAD +# if PY_VERSION_HEX < 0x03080000 + char *volatile *lock = (char *volatile *)_PyParser_TokenNames; + char *old_value, *locked_value; + + while (1) { /* spin loop */ + old_value = *lock; + locked_value = old_value + 1; + if (old_value[0] == 'E') { + assert(old_value[1] == 'N'); + if (cffi_compare_and_swap(lock, old_value, locked_value)) + break; + } + else { + assert(old_value[0] == 'N'); + /* should ideally do a spin loop instruction here, but + hard to do it portably and doesn't really matter I + think: PyEval_InitThreads() should be very fast, and + this is only run at start-up anyway. */ + } + } +# else +# if PY_VERSION_HEX < 0x030C0000 + int volatile *lock = (int volatile *)&PyCapsule_Type.tp_version_tag; + int old_value, locked_value = -42; + assert(!(PyCapsule_Type.tp_flags & Py_TPFLAGS_HAVE_VERSION_TAG)); +# else + static struct ebp_s { PyBufferProcs buf; int mark; } empty_buffer_procs; + empty_buffer_procs.mark = -42; + PyBufferProcs *volatile *lock = (PyBufferProcs *volatile *) + &PyCapsule_Type.tp_as_buffer; + PyBufferProcs *old_value, *locked_value = &empty_buffer_procs.buf; +# endif + + while (1) { /* spin loop */ + old_value = *lock; + if (old_value == 0) { + if (cffi_compare_and_swap(lock, old_value, locked_value)) + break; + } + else { +# if PY_VERSION_HEX < 0x030C0000 + assert(old_value == locked_value); +# else + /* The pointer should point to a possibly different + empty_buffer_procs from another C extension module */ + assert(((struct ebp_s *)old_value)->mark == -42); +# endif + /* should ideally do a spin loop instruction here, but + hard to do it portably and doesn't really matter I + think: PyEval_InitThreads() should be very fast, and + this is only run at start-up anyway. */ + } + } +# endif +#endif + + /* call Py_InitializeEx() */ + if (!Py_IsInitialized()) { + _cffi_py_initialize(); +#if PY_VERSION_HEX < 0x03070000 + PyEval_InitThreads(); +#endif + PyEval_SaveThread(); /* release the GIL */ + /* the returned tstate must be the one that has been stored into the + autoTLSkey by _PyGILState_Init() called from Py_Initialize(). */ + } + else { +#if PY_VERSION_HEX < 0x03070000 + /* PyEval_InitThreads() is always a no-op from CPython 3.7 */ + PyGILState_STATE state = PyGILState_Ensure(); + PyEval_InitThreads(); + PyGILState_Release(state); +#endif + } + +#ifdef WITH_THREAD + /* release the lock */ + while (!cffi_compare_and_swap(lock, locked_value, old_value)) + ; +#endif + + return 0; +} + +/********** end CPython-specific section **********/ + + +#else + + +/********** PyPy-specific section **********/ + +PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(const void *[]); /* forward */ + +static struct _cffi_pypy_init_s { + const char *name; + void *func; /* function pointer */ + const char *code; +} _cffi_pypy_init = { + _CFFI_MODULE_NAME, + _CFFI_PYTHON_STARTUP_FUNC, + _CFFI_PYTHON_STARTUP_CODE, +}; + +extern int pypy_carefully_make_gil(const char *); +extern int pypy_init_embedded_cffi_module(int, struct _cffi_pypy_init_s *); + +static int _cffi_carefully_make_gil(void) +{ + return pypy_carefully_make_gil(_CFFI_MODULE_NAME); +} + +static int _cffi_initialize_python(void) +{ + return pypy_init_embedded_cffi_module(0xB011, &_cffi_pypy_init); +} + +/********** end PyPy-specific section **********/ + + +#endif + + +#ifdef __GNUC__ +__attribute__((noinline)) +#endif +static _cffi_call_python_fnptr _cffi_start_python(void) +{ + /* Delicate logic to initialize Python. This function can be + called multiple times concurrently, e.g. when the process calls + its first ``extern "Python"`` functions in multiple threads at + once. It can also be called recursively, in which case we must + ignore it. We also have to consider what occurs if several + different cffi-based extensions reach this code in parallel + threads---it is a different copy of the code, then, and we + can't have any shared global variable unless it comes from + 'libpythonX.Y.so'. + + Idea: + + * _cffi_carefully_make_gil(): "carefully" call + PyEval_InitThreads() (possibly with Py_InitializeEx() first). + + * then we use a (local) custom lock to make sure that a call to this + cffi-based extension will wait if another call to the *same* + extension is running the initialization in another thread. + It is reentrant, so that a recursive call will not block, but + only one from a different thread. + + * then we grab the GIL and (Python 2) we call Py_InitializeEx(). + At this point, concurrent calls to Py_InitializeEx() are not + possible: we have the GIL. + + * do the rest of the specific initialization, which may + temporarily release the GIL but not the custom lock. + Only release the custom lock when we are done. + */ + static char called = 0; + + if (_cffi_carefully_make_gil() != 0) + return NULL; + + _cffi_acquire_reentrant_mutex(); + + /* Here the GIL exists, but we don't have it. We're only protected + from concurrency by the reentrant mutex. */ + + /* This file only initializes the embedded module once, the first + time this is called, even if there are subinterpreters. */ + if (!called) { + called = 1; /* invoke _cffi_initialize_python() only once, + but don't set '_cffi_call_python' right now, + otherwise concurrent threads won't call + this function at all (we need them to wait) */ + if (_cffi_initialize_python() == 0) { + /* now initialization is finished. Switch to the fast-path. */ + + /* We would like nobody to see the new value of + '_cffi_call_python' without also seeing the rest of the + data initialized. However, this is not possible. But + the new value of '_cffi_call_python' is the function + 'cffi_call_python()' from _cffi_backend. So: */ + cffi_write_barrier(); + /* ^^^ we put a write barrier here, and a corresponding + read barrier at the start of cffi_call_python(). This + ensures that after that read barrier, we see everything + done here before the write barrier. + */ + + assert(_cffi_call_python_org != NULL); + _cffi_call_python = (_cffi_call_python_fnptr)_cffi_call_python_org; + } + else { + /* initialization failed. Reset this to NULL, even if it was + already set to some other value. Future calls to + _cffi_start_python() are still forced to occur, and will + always return NULL from now on. */ + _cffi_call_python_org = NULL; + } + } + + _cffi_release_reentrant_mutex(); + + return (_cffi_call_python_fnptr)_cffi_call_python_org; +} + +static +void _cffi_start_and_call_python(struct _cffi_externpy_s *externpy, char *args) +{ + _cffi_call_python_fnptr fnptr; + int current_err = errno; +#ifdef _MSC_VER + int current_lasterr = GetLastError(); +#endif + fnptr = _cffi_start_python(); + if (fnptr == NULL) { + fprintf(stderr, "function %s() called, but initialization code " + "failed. Returning 0.\n", externpy->name); + memset(args, 0, externpy->size_of_result); + } +#ifdef _MSC_VER + SetLastError(current_lasterr); +#endif + errno = current_err; + + if (fnptr != NULL) + fnptr(externpy, args); +} + + +/* The cffi_start_python() function makes sure Python is initialized + and our cffi module is set up. It can be called manually from the + user C code. The same effect is obtained automatically from any + dll-exported ``extern "Python"`` function. This function returns + -1 if initialization failed, 0 if all is OK. */ +_CFFI_UNUSED_FN +static int cffi_start_python(void) +{ + if (_cffi_call_python == &_cffi_start_and_call_python) { + if (_cffi_start_python() == NULL) + return -1; + } + cffi_read_barrier(); + return 0; +} + +#undef cffi_compare_and_swap +#undef cffi_write_barrier +#undef cffi_read_barrier + +#ifdef __cplusplus +} +#endif diff --git a/.venv/Lib/site-packages/cffi/_imp_emulation.py b/.venv/Lib/site-packages/cffi/_imp_emulation.py new file mode 100644 index 00000000..136abddd --- /dev/null +++ b/.venv/Lib/site-packages/cffi/_imp_emulation.py @@ -0,0 +1,83 @@ + +try: + # this works on Python < 3.12 + from imp import * + +except ImportError: + # this is a limited emulation for Python >= 3.12. + # Note that this is used only for tests or for the old ffi.verify(). + # This is copied from the source code of Python 3.11. + + from _imp import (acquire_lock, release_lock, + is_builtin, is_frozen) + + from importlib._bootstrap import _load + + from importlib import machinery + import os + import sys + import tokenize + + SEARCH_ERROR = 0 + PY_SOURCE = 1 + PY_COMPILED = 2 + C_EXTENSION = 3 + PY_RESOURCE = 4 + PKG_DIRECTORY = 5 + C_BUILTIN = 6 + PY_FROZEN = 7 + PY_CODERESOURCE = 8 + IMP_HOOK = 9 + + def get_suffixes(): + extensions = [(s, 'rb', C_EXTENSION) + for s in machinery.EXTENSION_SUFFIXES] + source = [(s, 'r', PY_SOURCE) for s in machinery.SOURCE_SUFFIXES] + bytecode = [(s, 'rb', PY_COMPILED) for s in machinery.BYTECODE_SUFFIXES] + return extensions + source + bytecode + + def find_module(name, path=None): + if not isinstance(name, str): + raise TypeError("'name' must be a str, not {}".format(type(name))) + elif not isinstance(path, (type(None), list)): + # Backwards-compatibility + raise RuntimeError("'path' must be None or a list, " + "not {}".format(type(path))) + + if path is None: + if is_builtin(name): + return None, None, ('', '', C_BUILTIN) + elif is_frozen(name): + return None, None, ('', '', PY_FROZEN) + else: + path = sys.path + + for entry in path: + package_directory = os.path.join(entry, name) + for suffix in ['.py', machinery.BYTECODE_SUFFIXES[0]]: + package_file_name = '__init__' + suffix + file_path = os.path.join(package_directory, package_file_name) + if os.path.isfile(file_path): + return None, package_directory, ('', '', PKG_DIRECTORY) + for suffix, mode, type_ in get_suffixes(): + file_name = name + suffix + file_path = os.path.join(entry, file_name) + if os.path.isfile(file_path): + break + else: + continue + break # Break out of outer loop when breaking out of inner loop. + else: + raise ImportError(name, name=name) + + encoding = None + if 'b' not in mode: + with open(file_path, 'rb') as file: + encoding = tokenize.detect_encoding(file.readline)[0] + file = open(file_path, mode, encoding=encoding) + return file, file_path, (suffix, mode, type_) + + def load_dynamic(name, path, file=None): + loader = machinery.ExtensionFileLoader(name, path) + spec = machinery.ModuleSpec(name=name, loader=loader, origin=path) + return _load(spec) diff --git a/.venv/Lib/site-packages/cffi/_shimmed_dist_utils.py b/.venv/Lib/site-packages/cffi/_shimmed_dist_utils.py new file mode 100644 index 00000000..611bf40f --- /dev/null +++ b/.venv/Lib/site-packages/cffi/_shimmed_dist_utils.py @@ -0,0 +1,41 @@ +""" +Temporary shim module to indirect the bits of distutils we need from setuptools/distutils while providing useful +error messages beyond `No module named 'distutils' on Python >= 3.12, or when setuptools' vendored distutils is broken. + +This is a compromise to avoid a hard-dep on setuptools for Python >= 3.12, since many users don't need runtime compilation support from CFFI. +""" +import sys + +try: + # import setuptools first; this is the most robust way to ensure its embedded distutils is available + # (the .pth shim should usually work, but this is even more robust) + import setuptools +except Exception as ex: + if sys.version_info >= (3, 12): + # Python 3.12 has no built-in distutils to fall back on, so any import problem is fatal + raise Exception("This CFFI feature requires setuptools on Python >= 3.12. The setuptools module is missing or non-functional.") from ex + + # silently ignore on older Pythons (support fallback to stdlib distutils where available) +else: + del setuptools + +try: + # bring in just the bits of distutils we need, whether they really came from setuptools or stdlib-embedded distutils + from distutils import log, sysconfig + from distutils.ccompiler import CCompiler + from distutils.command.build_ext import build_ext + from distutils.core import Distribution, Extension + from distutils.dir_util import mkpath + from distutils.errors import DistutilsSetupError, CompileError, LinkError + from distutils.log import set_threshold, set_verbosity + + if sys.platform == 'win32': + from distutils.msvc9compiler import MSVCCompiler +except Exception as ex: + if sys.version_info >= (3, 12): + raise Exception("This CFFI feature requires setuptools on Python >= 3.12. Please install the setuptools package.") from ex + + # anything older, just let the underlying distutils import error fly + raise Exception("This CFFI feature requires distutils. Please install the distutils or setuptools package.") from ex + +del sys diff --git a/.venv/Lib/site-packages/cffi/api.py b/.venv/Lib/site-packages/cffi/api.py new file mode 100644 index 00000000..edeb7928 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/api.py @@ -0,0 +1,965 @@ +import sys, types +from .lock import allocate_lock +from .error import CDefError +from . import model + +try: + callable +except NameError: + # Python 3.1 + from collections import Callable + callable = lambda x: isinstance(x, Callable) + +try: + basestring +except NameError: + # Python 3.x + basestring = str + +_unspecified = object() + + + +class FFI(object): + r''' + The main top-level class that you instantiate once, or once per module. + + Example usage: + + ffi = FFI() + ffi.cdef(""" + int printf(const char *, ...); + """) + + C = ffi.dlopen(None) # standard library + -or- + C = ffi.verify() # use a C compiler: verify the decl above is right + + C.printf("hello, %s!\n", ffi.new("char[]", "world")) + ''' + + def __init__(self, backend=None): + """Create an FFI instance. The 'backend' argument is used to + select a non-default backend, mostly for tests. + """ + if backend is None: + # You need PyPy (>= 2.0 beta), or a CPython (>= 2.6) with + # _cffi_backend.so compiled. + import _cffi_backend as backend + from . import __version__ + if backend.__version__ != __version__: + # bad version! Try to be as explicit as possible. + if hasattr(backend, '__file__'): + # CPython + raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. When we import the top-level '_cffi_backend' extension module, we get version %s, located in %r. The two versions should be equal; check your installation." % ( + __version__, __file__, + backend.__version__, backend.__file__)) + else: + # PyPy + raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. This interpreter comes with a built-in '_cffi_backend' module, which is version %s. The two versions should be equal; check your installation." % ( + __version__, __file__, backend.__version__)) + # (If you insist you can also try to pass the option + # 'backend=backend_ctypes.CTypesBackend()', but don't + # rely on it! It's probably not going to work well.) + + from . import cparser + self._backend = backend + self._lock = allocate_lock() + self._parser = cparser.Parser() + self._cached_btypes = {} + self._parsed_types = types.ModuleType('parsed_types').__dict__ + self._new_types = types.ModuleType('new_types').__dict__ + self._function_caches = [] + self._libraries = [] + self._cdefsources = [] + self._included_ffis = [] + self._windows_unicode = None + self._init_once_cache = {} + self._cdef_version = None + self._embedding = None + self._typecache = model.get_typecache(backend) + if hasattr(backend, 'set_ffi'): + backend.set_ffi(self) + for name in list(backend.__dict__): + if name.startswith('RTLD_'): + setattr(self, name, getattr(backend, name)) + # + with self._lock: + self.BVoidP = self._get_cached_btype(model.voidp_type) + self.BCharA = self._get_cached_btype(model.char_array_type) + if isinstance(backend, types.ModuleType): + # _cffi_backend: attach these constants to the class + if not hasattr(FFI, 'NULL'): + FFI.NULL = self.cast(self.BVoidP, 0) + FFI.CData, FFI.CType = backend._get_types() + else: + # ctypes backend: attach these constants to the instance + self.NULL = self.cast(self.BVoidP, 0) + self.CData, self.CType = backend._get_types() + self.buffer = backend.buffer + + def cdef(self, csource, override=False, packed=False, pack=None): + """Parse the given C source. This registers all declared functions, + types, and global variables. The functions and global variables can + then be accessed via either 'ffi.dlopen()' or 'ffi.verify()'. + The types can be used in 'ffi.new()' and other functions. + If 'packed' is specified as True, all structs declared inside this + cdef are packed, i.e. laid out without any field alignment at all. + Alternatively, 'pack' can be a small integer, and requests for + alignment greater than that are ignored (pack=1 is equivalent to + packed=True). + """ + self._cdef(csource, override=override, packed=packed, pack=pack) + + def embedding_api(self, csource, packed=False, pack=None): + self._cdef(csource, packed=packed, pack=pack, dllexport=True) + if self._embedding is None: + self._embedding = '' + + def _cdef(self, csource, override=False, **options): + if not isinstance(csource, str): # unicode, on Python 2 + if not isinstance(csource, basestring): + raise TypeError("cdef() argument must be a string") + csource = csource.encode('ascii') + with self._lock: + self._cdef_version = object() + self._parser.parse(csource, override=override, **options) + self._cdefsources.append(csource) + if override: + for cache in self._function_caches: + cache.clear() + finishlist = self._parser._recomplete + if finishlist: + self._parser._recomplete = [] + for tp in finishlist: + tp.finish_backend_type(self, finishlist) + + def dlopen(self, name, flags=0): + """Load and return a dynamic library identified by 'name'. + The standard C library can be loaded by passing None. + Note that functions and types declared by 'ffi.cdef()' are not + linked to a particular library, just like C headers; in the + library we only look for the actual (untyped) symbols. + """ + if not (isinstance(name, basestring) or + name is None or + isinstance(name, self.CData)): + raise TypeError("dlopen(name): name must be a file name, None, " + "or an already-opened 'void *' handle") + with self._lock: + lib, function_cache = _make_ffi_library(self, name, flags) + self._function_caches.append(function_cache) + self._libraries.append(lib) + return lib + + def dlclose(self, lib): + """Close a library obtained with ffi.dlopen(). After this call, + access to functions or variables from the library will fail + (possibly with a segmentation fault). + """ + type(lib).__cffi_close__(lib) + + def _typeof_locked(self, cdecl): + # call me with the lock! + key = cdecl + if key in self._parsed_types: + return self._parsed_types[key] + # + if not isinstance(cdecl, str): # unicode, on Python 2 + cdecl = cdecl.encode('ascii') + # + type = self._parser.parse_type(cdecl) + really_a_function_type = type.is_raw_function + if really_a_function_type: + type = type.as_function_pointer() + btype = self._get_cached_btype(type) + result = btype, really_a_function_type + self._parsed_types[key] = result + return result + + def _typeof(self, cdecl, consider_function_as_funcptr=False): + # string -> ctype object + try: + result = self._parsed_types[cdecl] + except KeyError: + with self._lock: + result = self._typeof_locked(cdecl) + # + btype, really_a_function_type = result + if really_a_function_type and not consider_function_as_funcptr: + raise CDefError("the type %r is a function type, not a " + "pointer-to-function type" % (cdecl,)) + return btype + + def typeof(self, cdecl): + """Parse the C type given as a string and return the + corresponding object. + It can also be used on 'cdata' instance to get its C type. + """ + if isinstance(cdecl, basestring): + return self._typeof(cdecl) + if isinstance(cdecl, self.CData): + return self._backend.typeof(cdecl) + if isinstance(cdecl, types.BuiltinFunctionType): + res = _builtin_function_type(cdecl) + if res is not None: + return res + if (isinstance(cdecl, types.FunctionType) + and hasattr(cdecl, '_cffi_base_type')): + with self._lock: + return self._get_cached_btype(cdecl._cffi_base_type) + raise TypeError(type(cdecl)) + + def sizeof(self, cdecl): + """Return the size in bytes of the argument. It can be a + string naming a C type, or a 'cdata' instance. + """ + if isinstance(cdecl, basestring): + BType = self._typeof(cdecl) + return self._backend.sizeof(BType) + else: + return self._backend.sizeof(cdecl) + + def alignof(self, cdecl): + """Return the natural alignment size in bytes of the C type + given as a string. + """ + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return self._backend.alignof(cdecl) + + def offsetof(self, cdecl, *fields_or_indexes): + """Return the offset of the named field inside the given + structure or array, which must be given as a C type name. + You can give several field names in case of nested structures. + You can also give numeric values which correspond to array + items, in case of an array type. + """ + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return self._typeoffsetof(cdecl, *fields_or_indexes)[1] + + def new(self, cdecl, init=None): + """Allocate an instance according to the specified C type and + return a pointer to it. The specified C type must be either a + pointer or an array: ``new('X *')`` allocates an X and returns + a pointer to it, whereas ``new('X[n]')`` allocates an array of + n X'es and returns an array referencing it (which works + mostly like a pointer, like in C). You can also use + ``new('X[]', n)`` to allocate an array of a non-constant + length n. + + The memory is initialized following the rules of declaring a + global variable in C: by default it is zero-initialized, but + an explicit initializer can be given which can be used to + fill all or part of the memory. + + When the returned object goes out of scope, the memory + is freed. In other words the returned object has + ownership of the value of type 'cdecl' that it points to. This + means that the raw data can be used as long as this object is + kept alive, but must not be used for a longer time. Be careful + about that when copying the pointer to the memory somewhere + else, e.g. into another structure. + """ + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return self._backend.newp(cdecl, init) + + def new_allocator(self, alloc=None, free=None, + should_clear_after_alloc=True): + """Return a new allocator, i.e. a function that behaves like ffi.new() + but uses the provided low-level 'alloc' and 'free' functions. + + 'alloc' is called with the size as argument. If it returns NULL, a + MemoryError is raised. 'free' is called with the result of 'alloc' + as argument. Both can be either Python function or directly C + functions. If 'free' is None, then no free function is called. + If both 'alloc' and 'free' are None, the default is used. + + If 'should_clear_after_alloc' is set to False, then the memory + returned by 'alloc' is assumed to be already cleared (or you are + fine with garbage); otherwise CFFI will clear it. + """ + compiled_ffi = self._backend.FFI() + allocator = compiled_ffi.new_allocator(alloc, free, + should_clear_after_alloc) + def allocate(cdecl, init=None): + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return allocator(cdecl, init) + return allocate + + def cast(self, cdecl, source): + """Similar to a C cast: returns an instance of the named C + type initialized with the given 'source'. The source is + casted between integers or pointers of any type. + """ + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return self._backend.cast(cdecl, source) + + def string(self, cdata, maxlen=-1): + """Return a Python string (or unicode string) from the 'cdata'. + If 'cdata' is a pointer or array of characters or bytes, returns + the null-terminated string. The returned string extends until + the first null character, or at most 'maxlen' characters. If + 'cdata' is an array then 'maxlen' defaults to its length. + + If 'cdata' is a pointer or array of wchar_t, returns a unicode + string following the same rules. + + If 'cdata' is a single character or byte or a wchar_t, returns + it as a string or unicode string. + + If 'cdata' is an enum, returns the value of the enumerator as a + string, or 'NUMBER' if the value is out of range. + """ + return self._backend.string(cdata, maxlen) + + def unpack(self, cdata, length): + """Unpack an array of C data of the given length, + returning a Python string/unicode/list. + + If 'cdata' is a pointer to 'char', returns a byte string. + It does not stop at the first null. This is equivalent to: + ffi.buffer(cdata, length)[:] + + If 'cdata' is a pointer to 'wchar_t', returns a unicode string. + 'length' is measured in wchar_t's; it is not the size in bytes. + + If 'cdata' is a pointer to anything else, returns a list of + 'length' items. This is a faster equivalent to: + [cdata[i] for i in range(length)] + """ + return self._backend.unpack(cdata, length) + + #def buffer(self, cdata, size=-1): + # """Return a read-write buffer object that references the raw C data + # pointed to by the given 'cdata'. The 'cdata' must be a pointer or + # an array. Can be passed to functions expecting a buffer, or directly + # manipulated with: + # + # buf[:] get a copy of it in a regular string, or + # buf[idx] as a single character + # buf[:] = ... + # buf[idx] = ... change the content + # """ + # note that 'buffer' is a type, set on this instance by __init__ + + def from_buffer(self, cdecl, python_buffer=_unspecified, + require_writable=False): + """Return a cdata of the given type pointing to the data of the + given Python object, which must support the buffer interface. + Note that this is not meant to be used on the built-in types + str or unicode (you can build 'char[]' arrays explicitly) + but only on objects containing large quantities of raw data + in some other format, like 'array.array' or numpy arrays. + + The first argument is optional and default to 'char[]'. + """ + if python_buffer is _unspecified: + cdecl, python_buffer = self.BCharA, cdecl + elif isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + return self._backend.from_buffer(cdecl, python_buffer, + require_writable) + + def memmove(self, dest, src, n): + """ffi.memmove(dest, src, n) copies n bytes of memory from src to dest. + + Like the C function memmove(), the memory areas may overlap; + apart from that it behaves like the C function memcpy(). + + 'src' can be any cdata ptr or array, or any Python buffer object. + 'dest' can be any cdata ptr or array, or a writable Python buffer + object. The size to copy, 'n', is always measured in bytes. + + Unlike other methods, this one supports all Python buffer including + byte strings and bytearrays---but it still does not support + non-contiguous buffers. + """ + return self._backend.memmove(dest, src, n) + + def callback(self, cdecl, python_callable=None, error=None, onerror=None): + """Return a callback object or a decorator making such a + callback object. 'cdecl' must name a C function pointer type. + The callback invokes the specified 'python_callable' (which may + be provided either directly or via a decorator). Important: the + callback object must be manually kept alive for as long as the + callback may be invoked from the C level. + """ + def callback_decorator_wrap(python_callable): + if not callable(python_callable): + raise TypeError("the 'python_callable' argument " + "is not callable") + return self._backend.callback(cdecl, python_callable, + error, onerror) + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl, consider_function_as_funcptr=True) + if python_callable is None: + return callback_decorator_wrap # decorator mode + else: + return callback_decorator_wrap(python_callable) # direct mode + + def getctype(self, cdecl, replace_with=''): + """Return a string giving the C type 'cdecl', which may be itself + a string or a object. If 'replace_with' is given, it gives + extra text to append (or insert for more complicated C types), like + a variable name, or '*' to get actually the C type 'pointer-to-cdecl'. + """ + if isinstance(cdecl, basestring): + cdecl = self._typeof(cdecl) + replace_with = replace_with.strip() + if (replace_with.startswith('*') + and '&[' in self._backend.getcname(cdecl, '&')): + replace_with = '(%s)' % replace_with + elif replace_with and not replace_with[0] in '[(': + replace_with = ' ' + replace_with + return self._backend.getcname(cdecl, replace_with) + + def gc(self, cdata, destructor, size=0): + """Return a new cdata object that points to the same + data. Later, when this new cdata object is garbage-collected, + 'destructor(old_cdata_object)' will be called. + + The optional 'size' gives an estimate of the size, used to + trigger the garbage collection more eagerly. So far only used + on PyPy. It tells the GC that the returned object keeps alive + roughly 'size' bytes of external memory. + """ + return self._backend.gcp(cdata, destructor, size) + + def _get_cached_btype(self, type): + assert self._lock.acquire(False) is False + # call me with the lock! + try: + BType = self._cached_btypes[type] + except KeyError: + finishlist = [] + BType = type.get_cached_btype(self, finishlist) + for type in finishlist: + type.finish_backend_type(self, finishlist) + return BType + + def verify(self, source='', tmpdir=None, **kwargs): + """Verify that the current ffi signatures compile on this + machine, and return a dynamic library object. The dynamic + library can be used to call functions and access global + variables declared in this 'ffi'. The library is compiled + by the C compiler: it gives you C-level API compatibility + (including calling macros). This is unlike 'ffi.dlopen()', + which requires binary compatibility in the signatures. + """ + from .verifier import Verifier, _caller_dir_pycache + # + # If set_unicode(True) was called, insert the UNICODE and + # _UNICODE macro declarations + if self._windows_unicode: + self._apply_windows_unicode(kwargs) + # + # Set the tmpdir here, and not in Verifier.__init__: it picks + # up the caller's directory, which we want to be the caller of + # ffi.verify(), as opposed to the caller of Veritier(). + tmpdir = tmpdir or _caller_dir_pycache() + # + # Make a Verifier() and use it to load the library. + self.verifier = Verifier(self, source, tmpdir, **kwargs) + lib = self.verifier.load_library() + # + # Save the loaded library for keep-alive purposes, even + # if the caller doesn't keep it alive itself (it should). + self._libraries.append(lib) + return lib + + def _get_errno(self): + return self._backend.get_errno() + def _set_errno(self, errno): + self._backend.set_errno(errno) + errno = property(_get_errno, _set_errno, None, + "the value of 'errno' from/to the C calls") + + def getwinerror(self, code=-1): + return self._backend.getwinerror(code) + + def _pointer_to(self, ctype): + with self._lock: + return model.pointer_cache(self, ctype) + + def addressof(self, cdata, *fields_or_indexes): + """Return the address of a . + If 'fields_or_indexes' are given, returns the address of that + field or array item in the structure or array, recursively in + case of nested structures. + """ + try: + ctype = self._backend.typeof(cdata) + except TypeError: + if '__addressof__' in type(cdata).__dict__: + return type(cdata).__addressof__(cdata, *fields_or_indexes) + raise + if fields_or_indexes: + ctype, offset = self._typeoffsetof(ctype, *fields_or_indexes) + else: + if ctype.kind == "pointer": + raise TypeError("addressof(pointer)") + offset = 0 + ctypeptr = self._pointer_to(ctype) + return self._backend.rawaddressof(ctypeptr, cdata, offset) + + def _typeoffsetof(self, ctype, field_or_index, *fields_or_indexes): + ctype, offset = self._backend.typeoffsetof(ctype, field_or_index) + for field1 in fields_or_indexes: + ctype, offset1 = self._backend.typeoffsetof(ctype, field1, 1) + offset += offset1 + return ctype, offset + + def include(self, ffi_to_include): + """Includes the typedefs, structs, unions and enums defined + in another FFI instance. Usage is similar to a #include in C, + where a part of the program might include types defined in + another part for its own usage. Note that the include() + method has no effect on functions, constants and global + variables, which must anyway be accessed directly from the + lib object returned by the original FFI instance. + """ + if not isinstance(ffi_to_include, FFI): + raise TypeError("ffi.include() expects an argument that is also of" + " type cffi.FFI, not %r" % ( + type(ffi_to_include).__name__,)) + if ffi_to_include is self: + raise ValueError("self.include(self)") + with ffi_to_include._lock: + with self._lock: + self._parser.include(ffi_to_include._parser) + self._cdefsources.append('[') + self._cdefsources.extend(ffi_to_include._cdefsources) + self._cdefsources.append(']') + self._included_ffis.append(ffi_to_include) + + def new_handle(self, x): + return self._backend.newp_handle(self.BVoidP, x) + + def from_handle(self, x): + return self._backend.from_handle(x) + + def release(self, x): + self._backend.release(x) + + def set_unicode(self, enabled_flag): + """Windows: if 'enabled_flag' is True, enable the UNICODE and + _UNICODE defines in C, and declare the types like TCHAR and LPTCSTR + to be (pointers to) wchar_t. If 'enabled_flag' is False, + declare these types to be (pointers to) plain 8-bit characters. + This is mostly for backward compatibility; you usually want True. + """ + if self._windows_unicode is not None: + raise ValueError("set_unicode() can only be called once") + enabled_flag = bool(enabled_flag) + if enabled_flag: + self.cdef("typedef wchar_t TBYTE;" + "typedef wchar_t TCHAR;" + "typedef const wchar_t *LPCTSTR;" + "typedef const wchar_t *PCTSTR;" + "typedef wchar_t *LPTSTR;" + "typedef wchar_t *PTSTR;" + "typedef TBYTE *PTBYTE;" + "typedef TCHAR *PTCHAR;") + else: + self.cdef("typedef char TBYTE;" + "typedef char TCHAR;" + "typedef const char *LPCTSTR;" + "typedef const char *PCTSTR;" + "typedef char *LPTSTR;" + "typedef char *PTSTR;" + "typedef TBYTE *PTBYTE;" + "typedef TCHAR *PTCHAR;") + self._windows_unicode = enabled_flag + + def _apply_windows_unicode(self, kwds): + defmacros = kwds.get('define_macros', ()) + if not isinstance(defmacros, (list, tuple)): + raise TypeError("'define_macros' must be a list or tuple") + defmacros = list(defmacros) + [('UNICODE', '1'), + ('_UNICODE', '1')] + kwds['define_macros'] = defmacros + + def _apply_embedding_fix(self, kwds): + # must include an argument like "-lpython2.7" for the compiler + def ensure(key, value): + lst = kwds.setdefault(key, []) + if value not in lst: + lst.append(value) + # + if '__pypy__' in sys.builtin_module_names: + import os + if sys.platform == "win32": + # we need 'libpypy-c.lib'. Current distributions of + # pypy (>= 4.1) contain it as 'libs/python27.lib'. + pythonlib = "python{0[0]}{0[1]}".format(sys.version_info) + if hasattr(sys, 'prefix'): + ensure('library_dirs', os.path.join(sys.prefix, 'libs')) + else: + # we need 'libpypy-c.{so,dylib}', which should be by + # default located in 'sys.prefix/bin' for installed + # systems. + if sys.version_info < (3,): + pythonlib = "pypy-c" + else: + pythonlib = "pypy3-c" + if hasattr(sys, 'prefix'): + ensure('library_dirs', os.path.join(sys.prefix, 'bin')) + # On uninstalled pypy's, the libpypy-c is typically found in + # .../pypy/goal/. + if hasattr(sys, 'prefix'): + ensure('library_dirs', os.path.join(sys.prefix, 'pypy', 'goal')) + else: + if sys.platform == "win32": + template = "python%d%d" + if hasattr(sys, 'gettotalrefcount'): + template += '_d' + else: + try: + import sysconfig + except ImportError: # 2.6 + from cffi._shimmed_dist_utils import sysconfig + template = "python%d.%d" + if sysconfig.get_config_var('DEBUG_EXT'): + template += sysconfig.get_config_var('DEBUG_EXT') + pythonlib = (template % + (sys.hexversion >> 24, (sys.hexversion >> 16) & 0xff)) + if hasattr(sys, 'abiflags'): + pythonlib += sys.abiflags + ensure('libraries', pythonlib) + if sys.platform == "win32": + ensure('extra_link_args', '/MANIFEST') + + def set_source(self, module_name, source, source_extension='.c', **kwds): + import os + if hasattr(self, '_assigned_source'): + raise ValueError("set_source() cannot be called several times " + "per ffi object") + if not isinstance(module_name, basestring): + raise TypeError("'module_name' must be a string") + if os.sep in module_name or (os.altsep and os.altsep in module_name): + raise ValueError("'module_name' must not contain '/': use a dotted " + "name to make a 'package.module' location") + self._assigned_source = (str(module_name), source, + source_extension, kwds) + + def set_source_pkgconfig(self, module_name, pkgconfig_libs, source, + source_extension='.c', **kwds): + from . import pkgconfig + if not isinstance(pkgconfig_libs, list): + raise TypeError("the pkgconfig_libs argument must be a list " + "of package names") + kwds2 = pkgconfig.flags_from_pkgconfig(pkgconfig_libs) + pkgconfig.merge_flags(kwds, kwds2) + self.set_source(module_name, source, source_extension, **kwds) + + def distutils_extension(self, tmpdir='build', verbose=True): + from cffi._shimmed_dist_utils import mkpath + from .recompiler import recompile + # + if not hasattr(self, '_assigned_source'): + if hasattr(self, 'verifier'): # fallback, 'tmpdir' ignored + return self.verifier.get_extension() + raise ValueError("set_source() must be called before" + " distutils_extension()") + module_name, source, source_extension, kwds = self._assigned_source + if source is None: + raise TypeError("distutils_extension() is only for C extension " + "modules, not for dlopen()-style pure Python " + "modules") + mkpath(tmpdir) + ext, updated = recompile(self, module_name, + source, tmpdir=tmpdir, extradir=tmpdir, + source_extension=source_extension, + call_c_compiler=False, **kwds) + if verbose: + if updated: + sys.stderr.write("regenerated: %r\n" % (ext.sources[0],)) + else: + sys.stderr.write("not modified: %r\n" % (ext.sources[0],)) + return ext + + def emit_c_code(self, filename): + from .recompiler import recompile + # + if not hasattr(self, '_assigned_source'): + raise ValueError("set_source() must be called before emit_c_code()") + module_name, source, source_extension, kwds = self._assigned_source + if source is None: + raise TypeError("emit_c_code() is only for C extension modules, " + "not for dlopen()-style pure Python modules") + recompile(self, module_name, source, + c_file=filename, call_c_compiler=False, **kwds) + + def emit_python_code(self, filename): + from .recompiler import recompile + # + if not hasattr(self, '_assigned_source'): + raise ValueError("set_source() must be called before emit_c_code()") + module_name, source, source_extension, kwds = self._assigned_source + if source is not None: + raise TypeError("emit_python_code() is only for dlopen()-style " + "pure Python modules, not for C extension modules") + recompile(self, module_name, source, + c_file=filename, call_c_compiler=False, **kwds) + + def compile(self, tmpdir='.', verbose=0, target=None, debug=None): + """The 'target' argument gives the final file name of the + compiled DLL. Use '*' to force distutils' choice, suitable for + regular CPython C API modules. Use a file name ending in '.*' + to ask for the system's default extension for dynamic libraries + (.so/.dll/.dylib). + + The default is '*' when building a non-embedded C API extension, + and (module_name + '.*') when building an embedded library. + """ + from .recompiler import recompile + # + if not hasattr(self, '_assigned_source'): + raise ValueError("set_source() must be called before compile()") + module_name, source, source_extension, kwds = self._assigned_source + return recompile(self, module_name, source, tmpdir=tmpdir, + target=target, source_extension=source_extension, + compiler_verbose=verbose, debug=debug, **kwds) + + def init_once(self, func, tag): + # Read _init_once_cache[tag], which is either (False, lock) if + # we're calling the function now in some thread, or (True, result). + # Don't call setdefault() in most cases, to avoid allocating and + # immediately freeing a lock; but still use setdefaut() to avoid + # races. + try: + x = self._init_once_cache[tag] + except KeyError: + x = self._init_once_cache.setdefault(tag, (False, allocate_lock())) + # Common case: we got (True, result), so we return the result. + if x[0]: + return x[1] + # Else, it's a lock. Acquire it to serialize the following tests. + with x[1]: + # Read again from _init_once_cache the current status. + x = self._init_once_cache[tag] + if x[0]: + return x[1] + # Call the function and store the result back. + result = func() + self._init_once_cache[tag] = (True, result) + return result + + def embedding_init_code(self, pysource): + if self._embedding: + raise ValueError("embedding_init_code() can only be called once") + # fix 'pysource' before it gets dumped into the C file: + # - remove empty lines at the beginning, so it starts at "line 1" + # - dedent, if all non-empty lines are indented + # - check for SyntaxErrors + import re + match = re.match(r'\s*\n', pysource) + if match: + pysource = pysource[match.end():] + lines = pysource.splitlines() or [''] + prefix = re.match(r'\s*', lines[0]).group() + for i in range(1, len(lines)): + line = lines[i] + if line.rstrip(): + while not line.startswith(prefix): + prefix = prefix[:-1] + i = len(prefix) + lines = [line[i:]+'\n' for line in lines] + pysource = ''.join(lines) + # + compile(pysource, "cffi_init", "exec") + # + self._embedding = pysource + + def def_extern(self, *args, **kwds): + raise ValueError("ffi.def_extern() is only available on API-mode FFI " + "objects") + + def list_types(self): + """Returns the user type names known to this FFI instance. + This returns a tuple containing three lists of names: + (typedef_names, names_of_structs, names_of_unions) + """ + typedefs = [] + structs = [] + unions = [] + for key in self._parser._declarations: + if key.startswith('typedef '): + typedefs.append(key[8:]) + elif key.startswith('struct '): + structs.append(key[7:]) + elif key.startswith('union '): + unions.append(key[6:]) + typedefs.sort() + structs.sort() + unions.sort() + return (typedefs, structs, unions) + + +def _load_backend_lib(backend, name, flags): + import os + if not isinstance(name, basestring): + if sys.platform != "win32" or name is not None: + return backend.load_library(name, flags) + name = "c" # Windows: load_library(None) fails, but this works + # on Python 2 (backward compatibility hack only) + first_error = None + if '.' in name or '/' in name or os.sep in name: + try: + return backend.load_library(name, flags) + except OSError as e: + first_error = e + import ctypes.util + path = ctypes.util.find_library(name) + if path is None: + if name == "c" and sys.platform == "win32" and sys.version_info >= (3,): + raise OSError("dlopen(None) cannot work on Windows for Python 3 " + "(see http://bugs.python.org/issue23606)") + msg = ("ctypes.util.find_library() did not manage " + "to locate a library called %r" % (name,)) + if first_error is not None: + msg = "%s. Additionally, %s" % (first_error, msg) + raise OSError(msg) + return backend.load_library(path, flags) + +def _make_ffi_library(ffi, libname, flags): + backend = ffi._backend + backendlib = _load_backend_lib(backend, libname, flags) + # + def accessor_function(name): + key = 'function ' + name + tp, _ = ffi._parser._declarations[key] + BType = ffi._get_cached_btype(tp) + value = backendlib.load_function(BType, name) + library.__dict__[name] = value + # + def accessor_variable(name): + key = 'variable ' + name + tp, _ = ffi._parser._declarations[key] + BType = ffi._get_cached_btype(tp) + read_variable = backendlib.read_variable + write_variable = backendlib.write_variable + setattr(FFILibrary, name, property( + lambda self: read_variable(BType, name), + lambda self, value: write_variable(BType, name, value))) + # + def addressof_var(name): + try: + return addr_variables[name] + except KeyError: + with ffi._lock: + if name not in addr_variables: + key = 'variable ' + name + tp, _ = ffi._parser._declarations[key] + BType = ffi._get_cached_btype(tp) + if BType.kind != 'array': + BType = model.pointer_cache(ffi, BType) + p = backendlib.load_function(BType, name) + addr_variables[name] = p + return addr_variables[name] + # + def accessor_constant(name): + raise NotImplementedError("non-integer constant '%s' cannot be " + "accessed from a dlopen() library" % (name,)) + # + def accessor_int_constant(name): + library.__dict__[name] = ffi._parser._int_constants[name] + # + accessors = {} + accessors_version = [False] + addr_variables = {} + # + def update_accessors(): + if accessors_version[0] is ffi._cdef_version: + return + # + for key, (tp, _) in ffi._parser._declarations.items(): + if not isinstance(tp, model.EnumType): + tag, name = key.split(' ', 1) + if tag == 'function': + accessors[name] = accessor_function + elif tag == 'variable': + accessors[name] = accessor_variable + elif tag == 'constant': + accessors[name] = accessor_constant + else: + for i, enumname in enumerate(tp.enumerators): + def accessor_enum(name, tp=tp, i=i): + tp.check_not_partial() + library.__dict__[name] = tp.enumvalues[i] + accessors[enumname] = accessor_enum + for name in ffi._parser._int_constants: + accessors.setdefault(name, accessor_int_constant) + accessors_version[0] = ffi._cdef_version + # + def make_accessor(name): + with ffi._lock: + if name in library.__dict__ or name in FFILibrary.__dict__: + return # added by another thread while waiting for the lock + if name not in accessors: + update_accessors() + if name not in accessors: + raise AttributeError(name) + accessors[name](name) + # + class FFILibrary(object): + def __getattr__(self, name): + make_accessor(name) + return getattr(self, name) + def __setattr__(self, name, value): + try: + property = getattr(self.__class__, name) + except AttributeError: + make_accessor(name) + setattr(self, name, value) + else: + property.__set__(self, value) + def __dir__(self): + with ffi._lock: + update_accessors() + return accessors.keys() + def __addressof__(self, name): + if name in library.__dict__: + return library.__dict__[name] + if name in FFILibrary.__dict__: + return addressof_var(name) + make_accessor(name) + if name in library.__dict__: + return library.__dict__[name] + if name in FFILibrary.__dict__: + return addressof_var(name) + raise AttributeError("cffi library has no function or " + "global variable named '%s'" % (name,)) + def __cffi_close__(self): + backendlib.close_lib() + self.__dict__.clear() + # + if isinstance(libname, basestring): + try: + if not isinstance(libname, str): # unicode, on Python 2 + libname = libname.encode('utf-8') + FFILibrary.__name__ = 'FFILibrary_%s' % libname + except UnicodeError: + pass + library = FFILibrary() + return library, library.__dict__ + +def _builtin_function_type(func): + # a hack to make at least ffi.typeof(builtin_function) work, + # if the builtin function was obtained by 'vengine_cpy'. + import sys + try: + module = sys.modules[func.__module__] + ffi = module._cffi_original_ffi + types_of_builtin_funcs = module._cffi_types_of_builtin_funcs + tp = types_of_builtin_funcs[func] + except (KeyError, AttributeError, TypeError): + return None + else: + with ffi._lock: + return ffi._get_cached_btype(tp) diff --git a/.venv/Lib/site-packages/cffi/backend_ctypes.py b/.venv/Lib/site-packages/cffi/backend_ctypes.py new file mode 100644 index 00000000..e7956a79 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/backend_ctypes.py @@ -0,0 +1,1121 @@ +import ctypes, ctypes.util, operator, sys +from . import model + +if sys.version_info < (3,): + bytechr = chr +else: + unicode = str + long = int + xrange = range + bytechr = lambda num: bytes([num]) + +class CTypesType(type): + pass + +class CTypesData(object): + __metaclass__ = CTypesType + __slots__ = ['__weakref__'] + __name__ = '' + + def __init__(self, *args): + raise TypeError("cannot instantiate %r" % (self.__class__,)) + + @classmethod + def _newp(cls, init): + raise TypeError("expected a pointer or array ctype, got '%s'" + % (cls._get_c_name(),)) + + @staticmethod + def _to_ctypes(value): + raise TypeError + + @classmethod + def _arg_to_ctypes(cls, *value): + try: + ctype = cls._ctype + except AttributeError: + raise TypeError("cannot create an instance of %r" % (cls,)) + if value: + res = cls._to_ctypes(*value) + if not isinstance(res, ctype): + res = cls._ctype(res) + else: + res = cls._ctype() + return res + + @classmethod + def _create_ctype_obj(cls, init): + if init is None: + return cls._arg_to_ctypes() + else: + return cls._arg_to_ctypes(init) + + @staticmethod + def _from_ctypes(ctypes_value): + raise TypeError + + @classmethod + def _get_c_name(cls, replace_with=''): + return cls._reftypename.replace(' &', replace_with) + + @classmethod + def _fix_class(cls): + cls.__name__ = 'CData<%s>' % (cls._get_c_name(),) + cls.__qualname__ = 'CData<%s>' % (cls._get_c_name(),) + cls.__module__ = 'ffi' + + def _get_own_repr(self): + raise NotImplementedError + + def _addr_repr(self, address): + if address == 0: + return 'NULL' + else: + if address < 0: + address += 1 << (8*ctypes.sizeof(ctypes.c_void_p)) + return '0x%x' % address + + def __repr__(self, c_name=None): + own = self._get_own_repr() + return '' % (c_name or self._get_c_name(), own) + + def _convert_to_address(self, BClass): + if BClass is None: + raise TypeError("cannot convert %r to an address" % ( + self._get_c_name(),)) + else: + raise TypeError("cannot convert %r to %r" % ( + self._get_c_name(), BClass._get_c_name())) + + @classmethod + def _get_size(cls): + return ctypes.sizeof(cls._ctype) + + def _get_size_of_instance(self): + return ctypes.sizeof(self._ctype) + + @classmethod + def _cast_from(cls, source): + raise TypeError("cannot cast to %r" % (cls._get_c_name(),)) + + def _cast_to_integer(self): + return self._convert_to_address(None) + + @classmethod + def _alignment(cls): + return ctypes.alignment(cls._ctype) + + def __iter__(self): + raise TypeError("cdata %r does not support iteration" % ( + self._get_c_name()),) + + def _make_cmp(name): + cmpfunc = getattr(operator, name) + def cmp(self, other): + v_is_ptr = not isinstance(self, CTypesGenericPrimitive) + w_is_ptr = (isinstance(other, CTypesData) and + not isinstance(other, CTypesGenericPrimitive)) + if v_is_ptr and w_is_ptr: + return cmpfunc(self._convert_to_address(None), + other._convert_to_address(None)) + elif v_is_ptr or w_is_ptr: + return NotImplemented + else: + if isinstance(self, CTypesGenericPrimitive): + self = self._value + if isinstance(other, CTypesGenericPrimitive): + other = other._value + return cmpfunc(self, other) + cmp.func_name = name + return cmp + + __eq__ = _make_cmp('__eq__') + __ne__ = _make_cmp('__ne__') + __lt__ = _make_cmp('__lt__') + __le__ = _make_cmp('__le__') + __gt__ = _make_cmp('__gt__') + __ge__ = _make_cmp('__ge__') + + def __hash__(self): + return hash(self._convert_to_address(None)) + + def _to_string(self, maxlen): + raise TypeError("string(): %r" % (self,)) + + +class CTypesGenericPrimitive(CTypesData): + __slots__ = [] + + def __hash__(self): + return hash(self._value) + + def _get_own_repr(self): + return repr(self._from_ctypes(self._value)) + + +class CTypesGenericArray(CTypesData): + __slots__ = [] + + @classmethod + def _newp(cls, init): + return cls(init) + + def __iter__(self): + for i in xrange(len(self)): + yield self[i] + + def _get_own_repr(self): + return self._addr_repr(ctypes.addressof(self._blob)) + + +class CTypesGenericPtr(CTypesData): + __slots__ = ['_address', '_as_ctype_ptr'] + _automatic_casts = False + kind = "pointer" + + @classmethod + def _newp(cls, init): + return cls(init) + + @classmethod + def _cast_from(cls, source): + if source is None: + address = 0 + elif isinstance(source, CTypesData): + address = source._cast_to_integer() + elif isinstance(source, (int, long)): + address = source + else: + raise TypeError("bad type for cast to %r: %r" % + (cls, type(source).__name__)) + return cls._new_pointer_at(address) + + @classmethod + def _new_pointer_at(cls, address): + self = cls.__new__(cls) + self._address = address + self._as_ctype_ptr = ctypes.cast(address, cls._ctype) + return self + + def _get_own_repr(self): + try: + return self._addr_repr(self._address) + except AttributeError: + return '???' + + def _cast_to_integer(self): + return self._address + + def __nonzero__(self): + return bool(self._address) + __bool__ = __nonzero__ + + @classmethod + def _to_ctypes(cls, value): + if not isinstance(value, CTypesData): + raise TypeError("unexpected %s object" % type(value).__name__) + address = value._convert_to_address(cls) + return ctypes.cast(address, cls._ctype) + + @classmethod + def _from_ctypes(cls, ctypes_ptr): + address = ctypes.cast(ctypes_ptr, ctypes.c_void_p).value or 0 + return cls._new_pointer_at(address) + + @classmethod + def _initialize(cls, ctypes_ptr, value): + if value: + ctypes_ptr.contents = cls._to_ctypes(value).contents + + def _convert_to_address(self, BClass): + if (BClass in (self.__class__, None) or BClass._automatic_casts + or self._automatic_casts): + return self._address + else: + return CTypesData._convert_to_address(self, BClass) + + +class CTypesBaseStructOrUnion(CTypesData): + __slots__ = ['_blob'] + + @classmethod + def _create_ctype_obj(cls, init): + # may be overridden + raise TypeError("cannot instantiate opaque type %s" % (cls,)) + + def _get_own_repr(self): + return self._addr_repr(ctypes.addressof(self._blob)) + + @classmethod + def _offsetof(cls, fieldname): + return getattr(cls._ctype, fieldname).offset + + def _convert_to_address(self, BClass): + if getattr(BClass, '_BItem', None) is self.__class__: + return ctypes.addressof(self._blob) + else: + return CTypesData._convert_to_address(self, BClass) + + @classmethod + def _from_ctypes(cls, ctypes_struct_or_union): + self = cls.__new__(cls) + self._blob = ctypes_struct_or_union + return self + + @classmethod + def _to_ctypes(cls, value): + return value._blob + + def __repr__(self, c_name=None): + return CTypesData.__repr__(self, c_name or self._get_c_name(' &')) + + +class CTypesBackend(object): + + PRIMITIVE_TYPES = { + 'char': ctypes.c_char, + 'short': ctypes.c_short, + 'int': ctypes.c_int, + 'long': ctypes.c_long, + 'long long': ctypes.c_longlong, + 'signed char': ctypes.c_byte, + 'unsigned char': ctypes.c_ubyte, + 'unsigned short': ctypes.c_ushort, + 'unsigned int': ctypes.c_uint, + 'unsigned long': ctypes.c_ulong, + 'unsigned long long': ctypes.c_ulonglong, + 'float': ctypes.c_float, + 'double': ctypes.c_double, + '_Bool': ctypes.c_bool, + } + + for _name in ['unsigned long long', 'unsigned long', + 'unsigned int', 'unsigned short', 'unsigned char']: + _size = ctypes.sizeof(PRIMITIVE_TYPES[_name]) + PRIMITIVE_TYPES['uint%d_t' % (8*_size)] = PRIMITIVE_TYPES[_name] + if _size == ctypes.sizeof(ctypes.c_void_p): + PRIMITIVE_TYPES['uintptr_t'] = PRIMITIVE_TYPES[_name] + if _size == ctypes.sizeof(ctypes.c_size_t): + PRIMITIVE_TYPES['size_t'] = PRIMITIVE_TYPES[_name] + + for _name in ['long long', 'long', 'int', 'short', 'signed char']: + _size = ctypes.sizeof(PRIMITIVE_TYPES[_name]) + PRIMITIVE_TYPES['int%d_t' % (8*_size)] = PRIMITIVE_TYPES[_name] + if _size == ctypes.sizeof(ctypes.c_void_p): + PRIMITIVE_TYPES['intptr_t'] = PRIMITIVE_TYPES[_name] + PRIMITIVE_TYPES['ptrdiff_t'] = PRIMITIVE_TYPES[_name] + if _size == ctypes.sizeof(ctypes.c_size_t): + PRIMITIVE_TYPES['ssize_t'] = PRIMITIVE_TYPES[_name] + + + def __init__(self): + self.RTLD_LAZY = 0 # not supported anyway by ctypes + self.RTLD_NOW = 0 + self.RTLD_GLOBAL = ctypes.RTLD_GLOBAL + self.RTLD_LOCAL = ctypes.RTLD_LOCAL + + def set_ffi(self, ffi): + self.ffi = ffi + + def _get_types(self): + return CTypesData, CTypesType + + def load_library(self, path, flags=0): + cdll = ctypes.CDLL(path, flags) + return CTypesLibrary(self, cdll) + + def new_void_type(self): + class CTypesVoid(CTypesData): + __slots__ = [] + _reftypename = 'void &' + @staticmethod + def _from_ctypes(novalue): + return None + @staticmethod + def _to_ctypes(novalue): + if novalue is not None: + raise TypeError("None expected, got %s object" % + (type(novalue).__name__,)) + return None + CTypesVoid._fix_class() + return CTypesVoid + + def new_primitive_type(self, name): + if name == 'wchar_t': + raise NotImplementedError(name) + ctype = self.PRIMITIVE_TYPES[name] + if name == 'char': + kind = 'char' + elif name in ('float', 'double'): + kind = 'float' + else: + if name in ('signed char', 'unsigned char'): + kind = 'byte' + elif name == '_Bool': + kind = 'bool' + else: + kind = 'int' + is_signed = (ctype(-1).value == -1) + # + def _cast_source_to_int(source): + if isinstance(source, (int, long, float)): + source = int(source) + elif isinstance(source, CTypesData): + source = source._cast_to_integer() + elif isinstance(source, bytes): + source = ord(source) + elif source is None: + source = 0 + else: + raise TypeError("bad type for cast to %r: %r" % + (CTypesPrimitive, type(source).__name__)) + return source + # + kind1 = kind + class CTypesPrimitive(CTypesGenericPrimitive): + __slots__ = ['_value'] + _ctype = ctype + _reftypename = '%s &' % name + kind = kind1 + + def __init__(self, value): + self._value = value + + @staticmethod + def _create_ctype_obj(init): + if init is None: + return ctype() + return ctype(CTypesPrimitive._to_ctypes(init)) + + if kind == 'int' or kind == 'byte': + @classmethod + def _cast_from(cls, source): + source = _cast_source_to_int(source) + source = ctype(source).value # cast within range + return cls(source) + def __int__(self): + return self._value + + if kind == 'bool': + @classmethod + def _cast_from(cls, source): + if not isinstance(source, (int, long, float)): + source = _cast_source_to_int(source) + return cls(bool(source)) + def __int__(self): + return int(self._value) + + if kind == 'char': + @classmethod + def _cast_from(cls, source): + source = _cast_source_to_int(source) + source = bytechr(source & 0xFF) + return cls(source) + def __int__(self): + return ord(self._value) + + if kind == 'float': + @classmethod + def _cast_from(cls, source): + if isinstance(source, float): + pass + elif isinstance(source, CTypesGenericPrimitive): + if hasattr(source, '__float__'): + source = float(source) + else: + source = int(source) + else: + source = _cast_source_to_int(source) + source = ctype(source).value # fix precision + return cls(source) + def __int__(self): + return int(self._value) + def __float__(self): + return self._value + + _cast_to_integer = __int__ + + if kind == 'int' or kind == 'byte' or kind == 'bool': + @staticmethod + def _to_ctypes(x): + if not isinstance(x, (int, long)): + if isinstance(x, CTypesData): + x = int(x) + else: + raise TypeError("integer expected, got %s" % + type(x).__name__) + if ctype(x).value != x: + if not is_signed and x < 0: + raise OverflowError("%s: negative integer" % name) + else: + raise OverflowError("%s: integer out of bounds" + % name) + return x + + if kind == 'char': + @staticmethod + def _to_ctypes(x): + if isinstance(x, bytes) and len(x) == 1: + return x + if isinstance(x, CTypesPrimitive): # > + return x._value + raise TypeError("character expected, got %s" % + type(x).__name__) + def __nonzero__(self): + return ord(self._value) != 0 + else: + def __nonzero__(self): + return self._value != 0 + __bool__ = __nonzero__ + + if kind == 'float': + @staticmethod + def _to_ctypes(x): + if not isinstance(x, (int, long, float, CTypesData)): + raise TypeError("float expected, got %s" % + type(x).__name__) + return ctype(x).value + + @staticmethod + def _from_ctypes(value): + return getattr(value, 'value', value) + + @staticmethod + def _initialize(blob, init): + blob.value = CTypesPrimitive._to_ctypes(init) + + if kind == 'char': + def _to_string(self, maxlen): + return self._value + if kind == 'byte': + def _to_string(self, maxlen): + return chr(self._value & 0xff) + # + CTypesPrimitive._fix_class() + return CTypesPrimitive + + def new_pointer_type(self, BItem): + getbtype = self.ffi._get_cached_btype + if BItem is getbtype(model.PrimitiveType('char')): + kind = 'charp' + elif BItem in (getbtype(model.PrimitiveType('signed char')), + getbtype(model.PrimitiveType('unsigned char'))): + kind = 'bytep' + elif BItem is getbtype(model.void_type): + kind = 'voidp' + else: + kind = 'generic' + # + class CTypesPtr(CTypesGenericPtr): + __slots__ = ['_own'] + if kind == 'charp': + __slots__ += ['__as_strbuf'] + _BItem = BItem + if hasattr(BItem, '_ctype'): + _ctype = ctypes.POINTER(BItem._ctype) + _bitem_size = ctypes.sizeof(BItem._ctype) + else: + _ctype = ctypes.c_void_p + if issubclass(BItem, CTypesGenericArray): + _reftypename = BItem._get_c_name('(* &)') + else: + _reftypename = BItem._get_c_name(' * &') + + def __init__(self, init): + ctypeobj = BItem._create_ctype_obj(init) + if kind == 'charp': + self.__as_strbuf = ctypes.create_string_buffer( + ctypeobj.value + b'\x00') + self._as_ctype_ptr = ctypes.cast( + self.__as_strbuf, self._ctype) + else: + self._as_ctype_ptr = ctypes.pointer(ctypeobj) + self._address = ctypes.cast(self._as_ctype_ptr, + ctypes.c_void_p).value + self._own = True + + def __add__(self, other): + if isinstance(other, (int, long)): + return self._new_pointer_at(self._address + + other * self._bitem_size) + else: + return NotImplemented + + def __sub__(self, other): + if isinstance(other, (int, long)): + return self._new_pointer_at(self._address - + other * self._bitem_size) + elif type(self) is type(other): + return (self._address - other._address) // self._bitem_size + else: + return NotImplemented + + def __getitem__(self, index): + if getattr(self, '_own', False) and index != 0: + raise IndexError + return BItem._from_ctypes(self._as_ctype_ptr[index]) + + def __setitem__(self, index, value): + self._as_ctype_ptr[index] = BItem._to_ctypes(value) + + if kind == 'charp' or kind == 'voidp': + @classmethod + def _arg_to_ctypes(cls, *value): + if value and isinstance(value[0], bytes): + return ctypes.c_char_p(value[0]) + else: + return super(CTypesPtr, cls)._arg_to_ctypes(*value) + + if kind == 'charp' or kind == 'bytep': + def _to_string(self, maxlen): + if maxlen < 0: + maxlen = sys.maxsize + p = ctypes.cast(self._as_ctype_ptr, + ctypes.POINTER(ctypes.c_char)) + n = 0 + while n < maxlen and p[n] != b'\x00': + n += 1 + return b''.join([p[i] for i in range(n)]) + + def _get_own_repr(self): + if getattr(self, '_own', False): + return 'owning %d bytes' % ( + ctypes.sizeof(self._as_ctype_ptr.contents),) + return super(CTypesPtr, self)._get_own_repr() + # + if (BItem is self.ffi._get_cached_btype(model.void_type) or + BItem is self.ffi._get_cached_btype(model.PrimitiveType('char'))): + CTypesPtr._automatic_casts = True + # + CTypesPtr._fix_class() + return CTypesPtr + + def new_array_type(self, CTypesPtr, length): + if length is None: + brackets = ' &[]' + else: + brackets = ' &[%d]' % length + BItem = CTypesPtr._BItem + getbtype = self.ffi._get_cached_btype + if BItem is getbtype(model.PrimitiveType('char')): + kind = 'char' + elif BItem in (getbtype(model.PrimitiveType('signed char')), + getbtype(model.PrimitiveType('unsigned char'))): + kind = 'byte' + else: + kind = 'generic' + # + class CTypesArray(CTypesGenericArray): + __slots__ = ['_blob', '_own'] + if length is not None: + _ctype = BItem._ctype * length + else: + __slots__.append('_ctype') + _reftypename = BItem._get_c_name(brackets) + _declared_length = length + _CTPtr = CTypesPtr + + def __init__(self, init): + if length is None: + if isinstance(init, (int, long)): + len1 = init + init = None + elif kind == 'char' and isinstance(init, bytes): + len1 = len(init) + 1 # extra null + else: + init = tuple(init) + len1 = len(init) + self._ctype = BItem._ctype * len1 + self._blob = self._ctype() + self._own = True + if init is not None: + self._initialize(self._blob, init) + + @staticmethod + def _initialize(blob, init): + if isinstance(init, bytes): + init = [init[i:i+1] for i in range(len(init))] + else: + if isinstance(init, CTypesGenericArray): + if (len(init) != len(blob) or + not isinstance(init, CTypesArray)): + raise TypeError("length/type mismatch: %s" % (init,)) + init = tuple(init) + if len(init) > len(blob): + raise IndexError("too many initializers") + addr = ctypes.cast(blob, ctypes.c_void_p).value + PTR = ctypes.POINTER(BItem._ctype) + itemsize = ctypes.sizeof(BItem._ctype) + for i, value in enumerate(init): + p = ctypes.cast(addr + i * itemsize, PTR) + BItem._initialize(p.contents, value) + + def __len__(self): + return len(self._blob) + + def __getitem__(self, index): + if not (0 <= index < len(self._blob)): + raise IndexError + return BItem._from_ctypes(self._blob[index]) + + def __setitem__(self, index, value): + if not (0 <= index < len(self._blob)): + raise IndexError + self._blob[index] = BItem._to_ctypes(value) + + if kind == 'char' or kind == 'byte': + def _to_string(self, maxlen): + if maxlen < 0: + maxlen = len(self._blob) + p = ctypes.cast(self._blob, + ctypes.POINTER(ctypes.c_char)) + n = 0 + while n < maxlen and p[n] != b'\x00': + n += 1 + return b''.join([p[i] for i in range(n)]) + + def _get_own_repr(self): + if getattr(self, '_own', False): + return 'owning %d bytes' % (ctypes.sizeof(self._blob),) + return super(CTypesArray, self)._get_own_repr() + + def _convert_to_address(self, BClass): + if BClass in (CTypesPtr, None) or BClass._automatic_casts: + return ctypes.addressof(self._blob) + else: + return CTypesData._convert_to_address(self, BClass) + + @staticmethod + def _from_ctypes(ctypes_array): + self = CTypesArray.__new__(CTypesArray) + self._blob = ctypes_array + return self + + @staticmethod + def _arg_to_ctypes(value): + return CTypesPtr._arg_to_ctypes(value) + + def __add__(self, other): + if isinstance(other, (int, long)): + return CTypesPtr._new_pointer_at( + ctypes.addressof(self._blob) + + other * ctypes.sizeof(BItem._ctype)) + else: + return NotImplemented + + @classmethod + def _cast_from(cls, source): + raise NotImplementedError("casting to %r" % ( + cls._get_c_name(),)) + # + CTypesArray._fix_class() + return CTypesArray + + def _new_struct_or_union(self, kind, name, base_ctypes_class): + # + class struct_or_union(base_ctypes_class): + pass + struct_or_union.__name__ = '%s_%s' % (kind, name) + kind1 = kind + # + class CTypesStructOrUnion(CTypesBaseStructOrUnion): + __slots__ = ['_blob'] + _ctype = struct_or_union + _reftypename = '%s &' % (name,) + _kind = kind = kind1 + # + CTypesStructOrUnion._fix_class() + return CTypesStructOrUnion + + def new_struct_type(self, name): + return self._new_struct_or_union('struct', name, ctypes.Structure) + + def new_union_type(self, name): + return self._new_struct_or_union('union', name, ctypes.Union) + + def complete_struct_or_union(self, CTypesStructOrUnion, fields, tp, + totalsize=-1, totalalignment=-1, sflags=0, + pack=0): + if totalsize >= 0 or totalalignment >= 0: + raise NotImplementedError("the ctypes backend of CFFI does not support " + "structures completed by verify(); please " + "compile and install the _cffi_backend module.") + struct_or_union = CTypesStructOrUnion._ctype + fnames = [fname for (fname, BField, bitsize) in fields] + btypes = [BField for (fname, BField, bitsize) in fields] + bitfields = [bitsize for (fname, BField, bitsize) in fields] + # + bfield_types = {} + cfields = [] + for (fname, BField, bitsize) in fields: + if bitsize < 0: + cfields.append((fname, BField._ctype)) + bfield_types[fname] = BField + else: + cfields.append((fname, BField._ctype, bitsize)) + bfield_types[fname] = Ellipsis + if sflags & 8: + struct_or_union._pack_ = 1 + elif pack: + struct_or_union._pack_ = pack + struct_or_union._fields_ = cfields + CTypesStructOrUnion._bfield_types = bfield_types + # + @staticmethod + def _create_ctype_obj(init): + result = struct_or_union() + if init is not None: + initialize(result, init) + return result + CTypesStructOrUnion._create_ctype_obj = _create_ctype_obj + # + def initialize(blob, init): + if is_union: + if len(init) > 1: + raise ValueError("union initializer: %d items given, but " + "only one supported (use a dict if needed)" + % (len(init),)) + if not isinstance(init, dict): + if isinstance(init, (bytes, unicode)): + raise TypeError("union initializer: got a str") + init = tuple(init) + if len(init) > len(fnames): + raise ValueError("too many values for %s initializer" % + CTypesStructOrUnion._get_c_name()) + init = dict(zip(fnames, init)) + addr = ctypes.addressof(blob) + for fname, value in init.items(): + BField, bitsize = name2fieldtype[fname] + assert bitsize < 0, \ + "not implemented: initializer with bit fields" + offset = CTypesStructOrUnion._offsetof(fname) + PTR = ctypes.POINTER(BField._ctype) + p = ctypes.cast(addr + offset, PTR) + BField._initialize(p.contents, value) + is_union = CTypesStructOrUnion._kind == 'union' + name2fieldtype = dict(zip(fnames, zip(btypes, bitfields))) + # + for fname, BField, bitsize in fields: + if fname == '': + raise NotImplementedError("nested anonymous structs/unions") + if hasattr(CTypesStructOrUnion, fname): + raise ValueError("the field name %r conflicts in " + "the ctypes backend" % fname) + if bitsize < 0: + def getter(self, fname=fname, BField=BField, + offset=CTypesStructOrUnion._offsetof(fname), + PTR=ctypes.POINTER(BField._ctype)): + addr = ctypes.addressof(self._blob) + p = ctypes.cast(addr + offset, PTR) + return BField._from_ctypes(p.contents) + def setter(self, value, fname=fname, BField=BField): + setattr(self._blob, fname, BField._to_ctypes(value)) + # + if issubclass(BField, CTypesGenericArray): + setter = None + if BField._declared_length == 0: + def getter(self, fname=fname, BFieldPtr=BField._CTPtr, + offset=CTypesStructOrUnion._offsetof(fname), + PTR=ctypes.POINTER(BField._ctype)): + addr = ctypes.addressof(self._blob) + p = ctypes.cast(addr + offset, PTR) + return BFieldPtr._from_ctypes(p) + # + else: + def getter(self, fname=fname, BField=BField): + return BField._from_ctypes(getattr(self._blob, fname)) + def setter(self, value, fname=fname, BField=BField): + # xxx obscure workaround + value = BField._to_ctypes(value) + oldvalue = getattr(self._blob, fname) + setattr(self._blob, fname, value) + if value != getattr(self._blob, fname): + setattr(self._blob, fname, oldvalue) + raise OverflowError("value too large for bitfield") + setattr(CTypesStructOrUnion, fname, property(getter, setter)) + # + CTypesPtr = self.ffi._get_cached_btype(model.PointerType(tp)) + for fname in fnames: + if hasattr(CTypesPtr, fname): + raise ValueError("the field name %r conflicts in " + "the ctypes backend" % fname) + def getter(self, fname=fname): + return getattr(self[0], fname) + def setter(self, value, fname=fname): + setattr(self[0], fname, value) + setattr(CTypesPtr, fname, property(getter, setter)) + + def new_function_type(self, BArgs, BResult, has_varargs): + nameargs = [BArg._get_c_name() for BArg in BArgs] + if has_varargs: + nameargs.append('...') + nameargs = ', '.join(nameargs) + # + class CTypesFunctionPtr(CTypesGenericPtr): + __slots__ = ['_own_callback', '_name'] + _ctype = ctypes.CFUNCTYPE(getattr(BResult, '_ctype', None), + *[BArg._ctype for BArg in BArgs], + use_errno=True) + _reftypename = BResult._get_c_name('(* &)(%s)' % (nameargs,)) + + def __init__(self, init, error=None): + # create a callback to the Python callable init() + import traceback + assert not has_varargs, "varargs not supported for callbacks" + if getattr(BResult, '_ctype', None) is not None: + error = BResult._from_ctypes( + BResult._create_ctype_obj(error)) + else: + error = None + def callback(*args): + args2 = [] + for arg, BArg in zip(args, BArgs): + args2.append(BArg._from_ctypes(arg)) + try: + res2 = init(*args2) + res2 = BResult._to_ctypes(res2) + except: + traceback.print_exc() + res2 = error + if issubclass(BResult, CTypesGenericPtr): + if res2: + res2 = ctypes.cast(res2, ctypes.c_void_p).value + # .value: http://bugs.python.org/issue1574593 + else: + res2 = None + #print repr(res2) + return res2 + if issubclass(BResult, CTypesGenericPtr): + # The only pointers callbacks can return are void*s: + # http://bugs.python.org/issue5710 + callback_ctype = ctypes.CFUNCTYPE( + ctypes.c_void_p, + *[BArg._ctype for BArg in BArgs], + use_errno=True) + else: + callback_ctype = CTypesFunctionPtr._ctype + self._as_ctype_ptr = callback_ctype(callback) + self._address = ctypes.cast(self._as_ctype_ptr, + ctypes.c_void_p).value + self._own_callback = init + + @staticmethod + def _initialize(ctypes_ptr, value): + if value: + raise NotImplementedError("ctypes backend: not supported: " + "initializers for function pointers") + + def __repr__(self): + c_name = getattr(self, '_name', None) + if c_name: + i = self._reftypename.index('(* &)') + if self._reftypename[i-1] not in ' )*': + c_name = ' ' + c_name + c_name = self._reftypename.replace('(* &)', c_name) + return CTypesData.__repr__(self, c_name) + + def _get_own_repr(self): + if getattr(self, '_own_callback', None) is not None: + return 'calling %r' % (self._own_callback,) + return super(CTypesFunctionPtr, self)._get_own_repr() + + def __call__(self, *args): + if has_varargs: + assert len(args) >= len(BArgs) + extraargs = args[len(BArgs):] + args = args[:len(BArgs)] + else: + assert len(args) == len(BArgs) + ctypes_args = [] + for arg, BArg in zip(args, BArgs): + ctypes_args.append(BArg._arg_to_ctypes(arg)) + if has_varargs: + for i, arg in enumerate(extraargs): + if arg is None: + ctypes_args.append(ctypes.c_void_p(0)) # NULL + continue + if not isinstance(arg, CTypesData): + raise TypeError( + "argument %d passed in the variadic part " + "needs to be a cdata object (got %s)" % + (1 + len(BArgs) + i, type(arg).__name__)) + ctypes_args.append(arg._arg_to_ctypes(arg)) + result = self._as_ctype_ptr(*ctypes_args) + return BResult._from_ctypes(result) + # + CTypesFunctionPtr._fix_class() + return CTypesFunctionPtr + + def new_enum_type(self, name, enumerators, enumvalues, CTypesInt): + assert isinstance(name, str) + reverse_mapping = dict(zip(reversed(enumvalues), + reversed(enumerators))) + # + class CTypesEnum(CTypesInt): + __slots__ = [] + _reftypename = '%s &' % name + + def _get_own_repr(self): + value = self._value + try: + return '%d: %s' % (value, reverse_mapping[value]) + except KeyError: + return str(value) + + def _to_string(self, maxlen): + value = self._value + try: + return reverse_mapping[value] + except KeyError: + return str(value) + # + CTypesEnum._fix_class() + return CTypesEnum + + def get_errno(self): + return ctypes.get_errno() + + def set_errno(self, value): + ctypes.set_errno(value) + + def string(self, b, maxlen=-1): + return b._to_string(maxlen) + + def buffer(self, bptr, size=-1): + raise NotImplementedError("buffer() with ctypes backend") + + def sizeof(self, cdata_or_BType): + if isinstance(cdata_or_BType, CTypesData): + return cdata_or_BType._get_size_of_instance() + else: + assert issubclass(cdata_or_BType, CTypesData) + return cdata_or_BType._get_size() + + def alignof(self, BType): + assert issubclass(BType, CTypesData) + return BType._alignment() + + def newp(self, BType, source): + if not issubclass(BType, CTypesData): + raise TypeError + return BType._newp(source) + + def cast(self, BType, source): + return BType._cast_from(source) + + def callback(self, BType, source, error, onerror): + assert onerror is None # XXX not implemented + return BType(source, error) + + _weakref_cache_ref = None + + def gcp(self, cdata, destructor, size=0): + if self._weakref_cache_ref is None: + import weakref + class MyRef(weakref.ref): + def __eq__(self, other): + myref = self() + return self is other or ( + myref is not None and myref is other()) + def __ne__(self, other): + return not (self == other) + def __hash__(self): + try: + return self._hash + except AttributeError: + self._hash = hash(self()) + return self._hash + self._weakref_cache_ref = {}, MyRef + weak_cache, MyRef = self._weakref_cache_ref + + if destructor is None: + try: + del weak_cache[MyRef(cdata)] + except KeyError: + raise TypeError("Can remove destructor only on a object " + "previously returned by ffi.gc()") + return None + + def remove(k): + cdata, destructor = weak_cache.pop(k, (None, None)) + if destructor is not None: + destructor(cdata) + + new_cdata = self.cast(self.typeof(cdata), cdata) + assert new_cdata is not cdata + weak_cache[MyRef(new_cdata, remove)] = (cdata, destructor) + return new_cdata + + typeof = type + + def getcname(self, BType, replace_with): + return BType._get_c_name(replace_with) + + def typeoffsetof(self, BType, fieldname, num=0): + if isinstance(fieldname, str): + if num == 0 and issubclass(BType, CTypesGenericPtr): + BType = BType._BItem + if not issubclass(BType, CTypesBaseStructOrUnion): + raise TypeError("expected a struct or union ctype") + BField = BType._bfield_types[fieldname] + if BField is Ellipsis: + raise TypeError("not supported for bitfields") + return (BField, BType._offsetof(fieldname)) + elif isinstance(fieldname, (int, long)): + if issubclass(BType, CTypesGenericArray): + BType = BType._CTPtr + if not issubclass(BType, CTypesGenericPtr): + raise TypeError("expected an array or ptr ctype") + BItem = BType._BItem + offset = BItem._get_size() * fieldname + if offset > sys.maxsize: + raise OverflowError + return (BItem, offset) + else: + raise TypeError(type(fieldname)) + + def rawaddressof(self, BTypePtr, cdata, offset=None): + if isinstance(cdata, CTypesBaseStructOrUnion): + ptr = ctypes.pointer(type(cdata)._to_ctypes(cdata)) + elif isinstance(cdata, CTypesGenericPtr): + if offset is None or not issubclass(type(cdata)._BItem, + CTypesBaseStructOrUnion): + raise TypeError("unexpected cdata type") + ptr = type(cdata)._to_ctypes(cdata) + elif isinstance(cdata, CTypesGenericArray): + ptr = type(cdata)._to_ctypes(cdata) + else: + raise TypeError("expected a ") + if offset: + ptr = ctypes.cast( + ctypes.c_void_p( + ctypes.cast(ptr, ctypes.c_void_p).value + offset), + type(ptr)) + return BTypePtr._from_ctypes(ptr) + + +class CTypesLibrary(object): + + def __init__(self, backend, cdll): + self.backend = backend + self.cdll = cdll + + def load_function(self, BType, name): + c_func = getattr(self.cdll, name) + funcobj = BType._from_ctypes(c_func) + funcobj._name = name + return funcobj + + def read_variable(self, BType, name): + try: + ctypes_obj = BType._ctype.in_dll(self.cdll, name) + except AttributeError as e: + raise NotImplementedError(e) + return BType._from_ctypes(ctypes_obj) + + def write_variable(self, BType, name, value): + new_ctypes_obj = BType._to_ctypes(value) + ctypes_obj = BType._ctype.in_dll(self.cdll, name) + ctypes.memmove(ctypes.addressof(ctypes_obj), + ctypes.addressof(new_ctypes_obj), + ctypes.sizeof(BType._ctype)) diff --git a/.venv/Lib/site-packages/cffi/cffi_opcode.py b/.venv/Lib/site-packages/cffi/cffi_opcode.py new file mode 100644 index 00000000..a0df98d1 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/cffi_opcode.py @@ -0,0 +1,187 @@ +from .error import VerificationError + +class CffiOp(object): + def __init__(self, op, arg): + self.op = op + self.arg = arg + + def as_c_expr(self): + if self.op is None: + assert isinstance(self.arg, str) + return '(_cffi_opcode_t)(%s)' % (self.arg,) + classname = CLASS_NAME[self.op] + return '_CFFI_OP(_CFFI_OP_%s, %s)' % (classname, self.arg) + + def as_python_bytes(self): + if self.op is None and self.arg.isdigit(): + value = int(self.arg) # non-negative: '-' not in self.arg + if value >= 2**31: + raise OverflowError("cannot emit %r: limited to 2**31-1" + % (self.arg,)) + return format_four_bytes(value) + if isinstance(self.arg, str): + raise VerificationError("cannot emit to Python: %r" % (self.arg,)) + return format_four_bytes((self.arg << 8) | self.op) + + def __str__(self): + classname = CLASS_NAME.get(self.op, self.op) + return '(%s %s)' % (classname, self.arg) + +def format_four_bytes(num): + return '\\x%02X\\x%02X\\x%02X\\x%02X' % ( + (num >> 24) & 0xFF, + (num >> 16) & 0xFF, + (num >> 8) & 0xFF, + (num ) & 0xFF) + +OP_PRIMITIVE = 1 +OP_POINTER = 3 +OP_ARRAY = 5 +OP_OPEN_ARRAY = 7 +OP_STRUCT_UNION = 9 +OP_ENUM = 11 +OP_FUNCTION = 13 +OP_FUNCTION_END = 15 +OP_NOOP = 17 +OP_BITFIELD = 19 +OP_TYPENAME = 21 +OP_CPYTHON_BLTN_V = 23 # varargs +OP_CPYTHON_BLTN_N = 25 # noargs +OP_CPYTHON_BLTN_O = 27 # O (i.e. a single arg) +OP_CONSTANT = 29 +OP_CONSTANT_INT = 31 +OP_GLOBAL_VAR = 33 +OP_DLOPEN_FUNC = 35 +OP_DLOPEN_CONST = 37 +OP_GLOBAL_VAR_F = 39 +OP_EXTERN_PYTHON = 41 + +PRIM_VOID = 0 +PRIM_BOOL = 1 +PRIM_CHAR = 2 +PRIM_SCHAR = 3 +PRIM_UCHAR = 4 +PRIM_SHORT = 5 +PRIM_USHORT = 6 +PRIM_INT = 7 +PRIM_UINT = 8 +PRIM_LONG = 9 +PRIM_ULONG = 10 +PRIM_LONGLONG = 11 +PRIM_ULONGLONG = 12 +PRIM_FLOAT = 13 +PRIM_DOUBLE = 14 +PRIM_LONGDOUBLE = 15 + +PRIM_WCHAR = 16 +PRIM_INT8 = 17 +PRIM_UINT8 = 18 +PRIM_INT16 = 19 +PRIM_UINT16 = 20 +PRIM_INT32 = 21 +PRIM_UINT32 = 22 +PRIM_INT64 = 23 +PRIM_UINT64 = 24 +PRIM_INTPTR = 25 +PRIM_UINTPTR = 26 +PRIM_PTRDIFF = 27 +PRIM_SIZE = 28 +PRIM_SSIZE = 29 +PRIM_INT_LEAST8 = 30 +PRIM_UINT_LEAST8 = 31 +PRIM_INT_LEAST16 = 32 +PRIM_UINT_LEAST16 = 33 +PRIM_INT_LEAST32 = 34 +PRIM_UINT_LEAST32 = 35 +PRIM_INT_LEAST64 = 36 +PRIM_UINT_LEAST64 = 37 +PRIM_INT_FAST8 = 38 +PRIM_UINT_FAST8 = 39 +PRIM_INT_FAST16 = 40 +PRIM_UINT_FAST16 = 41 +PRIM_INT_FAST32 = 42 +PRIM_UINT_FAST32 = 43 +PRIM_INT_FAST64 = 44 +PRIM_UINT_FAST64 = 45 +PRIM_INTMAX = 46 +PRIM_UINTMAX = 47 +PRIM_FLOATCOMPLEX = 48 +PRIM_DOUBLECOMPLEX = 49 +PRIM_CHAR16 = 50 +PRIM_CHAR32 = 51 + +_NUM_PRIM = 52 +_UNKNOWN_PRIM = -1 +_UNKNOWN_FLOAT_PRIM = -2 +_UNKNOWN_LONG_DOUBLE = -3 + +_IO_FILE_STRUCT = -1 + +PRIMITIVE_TO_INDEX = { + 'char': PRIM_CHAR, + 'short': PRIM_SHORT, + 'int': PRIM_INT, + 'long': PRIM_LONG, + 'long long': PRIM_LONGLONG, + 'signed char': PRIM_SCHAR, + 'unsigned char': PRIM_UCHAR, + 'unsigned short': PRIM_USHORT, + 'unsigned int': PRIM_UINT, + 'unsigned long': PRIM_ULONG, + 'unsigned long long': PRIM_ULONGLONG, + 'float': PRIM_FLOAT, + 'double': PRIM_DOUBLE, + 'long double': PRIM_LONGDOUBLE, + 'float _Complex': PRIM_FLOATCOMPLEX, + 'double _Complex': PRIM_DOUBLECOMPLEX, + '_Bool': PRIM_BOOL, + 'wchar_t': PRIM_WCHAR, + 'char16_t': PRIM_CHAR16, + 'char32_t': PRIM_CHAR32, + 'int8_t': PRIM_INT8, + 'uint8_t': PRIM_UINT8, + 'int16_t': PRIM_INT16, + 'uint16_t': PRIM_UINT16, + 'int32_t': PRIM_INT32, + 'uint32_t': PRIM_UINT32, + 'int64_t': PRIM_INT64, + 'uint64_t': PRIM_UINT64, + 'intptr_t': PRIM_INTPTR, + 'uintptr_t': PRIM_UINTPTR, + 'ptrdiff_t': PRIM_PTRDIFF, + 'size_t': PRIM_SIZE, + 'ssize_t': PRIM_SSIZE, + 'int_least8_t': PRIM_INT_LEAST8, + 'uint_least8_t': PRIM_UINT_LEAST8, + 'int_least16_t': PRIM_INT_LEAST16, + 'uint_least16_t': PRIM_UINT_LEAST16, + 'int_least32_t': PRIM_INT_LEAST32, + 'uint_least32_t': PRIM_UINT_LEAST32, + 'int_least64_t': PRIM_INT_LEAST64, + 'uint_least64_t': PRIM_UINT_LEAST64, + 'int_fast8_t': PRIM_INT_FAST8, + 'uint_fast8_t': PRIM_UINT_FAST8, + 'int_fast16_t': PRIM_INT_FAST16, + 'uint_fast16_t': PRIM_UINT_FAST16, + 'int_fast32_t': PRIM_INT_FAST32, + 'uint_fast32_t': PRIM_UINT_FAST32, + 'int_fast64_t': PRIM_INT_FAST64, + 'uint_fast64_t': PRIM_UINT_FAST64, + 'intmax_t': PRIM_INTMAX, + 'uintmax_t': PRIM_UINTMAX, + } + +F_UNION = 0x01 +F_CHECK_FIELDS = 0x02 +F_PACKED = 0x04 +F_EXTERNAL = 0x08 +F_OPAQUE = 0x10 + +G_FLAGS = dict([('_CFFI_' + _key, globals()[_key]) + for _key in ['F_UNION', 'F_CHECK_FIELDS', 'F_PACKED', + 'F_EXTERNAL', 'F_OPAQUE']]) + +CLASS_NAME = {} +for _name, _value in list(globals().items()): + if _name.startswith('OP_') and isinstance(_value, int): + CLASS_NAME[_value] = _name[3:] diff --git a/.venv/Lib/site-packages/cffi/commontypes.py b/.venv/Lib/site-packages/cffi/commontypes.py new file mode 100644 index 00000000..8ec97c75 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/commontypes.py @@ -0,0 +1,80 @@ +import sys +from . import model +from .error import FFIError + + +COMMON_TYPES = {} + +try: + # fetch "bool" and all simple Windows types + from _cffi_backend import _get_common_types + _get_common_types(COMMON_TYPES) +except ImportError: + pass + +COMMON_TYPES['FILE'] = model.unknown_type('FILE', '_IO_FILE') +COMMON_TYPES['bool'] = '_Bool' # in case we got ImportError above + +for _type in model.PrimitiveType.ALL_PRIMITIVE_TYPES: + if _type.endswith('_t'): + COMMON_TYPES[_type] = _type +del _type + +_CACHE = {} + +def resolve_common_type(parser, commontype): + try: + return _CACHE[commontype] + except KeyError: + cdecl = COMMON_TYPES.get(commontype, commontype) + if not isinstance(cdecl, str): + result, quals = cdecl, 0 # cdecl is already a BaseType + elif cdecl in model.PrimitiveType.ALL_PRIMITIVE_TYPES: + result, quals = model.PrimitiveType(cdecl), 0 + elif cdecl == 'set-unicode-needed': + raise FFIError("The Windows type %r is only available after " + "you call ffi.set_unicode()" % (commontype,)) + else: + if commontype == cdecl: + raise FFIError( + "Unsupported type: %r. Please look at " + "http://cffi.readthedocs.io/en/latest/cdef.html#ffi-cdef-limitations " + "and file an issue if you think this type should really " + "be supported." % (commontype,)) + result, quals = parser.parse_type_and_quals(cdecl) # recursive + + assert isinstance(result, model.BaseTypeByIdentity) + _CACHE[commontype] = result, quals + return result, quals + + +# ____________________________________________________________ +# extra types for Windows (most of them are in commontypes.c) + + +def win_common_types(): + return { + "UNICODE_STRING": model.StructType( + "_UNICODE_STRING", + ["Length", + "MaximumLength", + "Buffer"], + [model.PrimitiveType("unsigned short"), + model.PrimitiveType("unsigned short"), + model.PointerType(model.PrimitiveType("wchar_t"))], + [-1, -1, -1]), + "PUNICODE_STRING": "UNICODE_STRING *", + "PCUNICODE_STRING": "const UNICODE_STRING *", + + "TBYTE": "set-unicode-needed", + "TCHAR": "set-unicode-needed", + "LPCTSTR": "set-unicode-needed", + "PCTSTR": "set-unicode-needed", + "LPTSTR": "set-unicode-needed", + "PTSTR": "set-unicode-needed", + "PTBYTE": "set-unicode-needed", + "PTCHAR": "set-unicode-needed", + } + +if sys.platform == 'win32': + COMMON_TYPES.update(win_common_types()) diff --git a/.venv/Lib/site-packages/cffi/cparser.py b/.venv/Lib/site-packages/cffi/cparser.py new file mode 100644 index 00000000..74830e91 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/cparser.py @@ -0,0 +1,1006 @@ +from . import model +from .commontypes import COMMON_TYPES, resolve_common_type +from .error import FFIError, CDefError +try: + from . import _pycparser as pycparser +except ImportError: + import pycparser +import weakref, re, sys + +try: + if sys.version_info < (3,): + import thread as _thread + else: + import _thread + lock = _thread.allocate_lock() +except ImportError: + lock = None + +def _workaround_for_static_import_finders(): + # Issue #392: packaging tools like cx_Freeze can not find these + # because pycparser uses exec dynamic import. This is an obscure + # workaround. This function is never called. + import pycparser.yacctab + import pycparser.lextab + +CDEF_SOURCE_STRING = "" +_r_comment = re.compile(r"/\*.*?\*/|//([^\n\\]|\\.)*?$", + re.DOTALL | re.MULTILINE) +_r_define = re.compile(r"^\s*#\s*define\s+([A-Za-z_][A-Za-z_0-9]*)" + r"\b((?:[^\n\\]|\\.)*?)$", + re.DOTALL | re.MULTILINE) +_r_line_directive = re.compile(r"^[ \t]*#[ \t]*(?:line|\d+)\b.*$", re.MULTILINE) +_r_partial_enum = re.compile(r"=\s*\.\.\.\s*[,}]|\.\.\.\s*\}") +_r_enum_dotdotdot = re.compile(r"__dotdotdot\d+__$") +_r_partial_array = re.compile(r"\[\s*\.\.\.\s*\]") +_r_words = re.compile(r"\w+|\S") +_parser_cache = None +_r_int_literal = re.compile(r"-?0?x?[0-9a-f]+[lu]*$", re.IGNORECASE) +_r_stdcall1 = re.compile(r"\b(__stdcall|WINAPI)\b") +_r_stdcall2 = re.compile(r"[(]\s*(__stdcall|WINAPI)\b") +_r_cdecl = re.compile(r"\b__cdecl\b") +_r_extern_python = re.compile(r'\bextern\s*"' + r'(Python|Python\s*\+\s*C|C\s*\+\s*Python)"\s*.') +_r_star_const_space = re.compile( # matches "* const " + r"[*]\s*((const|volatile|restrict)\b\s*)+") +_r_int_dotdotdot = re.compile(r"(\b(int|long|short|signed|unsigned|char)\s*)+" + r"\.\.\.") +_r_float_dotdotdot = re.compile(r"\b(double|float)\s*\.\.\.") + +def _get_parser(): + global _parser_cache + if _parser_cache is None: + _parser_cache = pycparser.CParser() + return _parser_cache + +def _workaround_for_old_pycparser(csource): + # Workaround for a pycparser issue (fixed between pycparser 2.10 and + # 2.14): "char*const***" gives us a wrong syntax tree, the same as + # for "char***(*const)". This means we can't tell the difference + # afterwards. But "char(*const(***))" gives us the right syntax + # tree. The issue only occurs if there are several stars in + # sequence with no parenthesis inbetween, just possibly qualifiers. + # Attempt to fix it by adding some parentheses in the source: each + # time we see "* const" or "* const *", we add an opening + # parenthesis before each star---the hard part is figuring out where + # to close them. + parts = [] + while True: + match = _r_star_const_space.search(csource) + if not match: + break + #print repr(''.join(parts)+csource), '=>', + parts.append(csource[:match.start()]) + parts.append('('); closing = ')' + parts.append(match.group()) # e.g. "* const " + endpos = match.end() + if csource.startswith('*', endpos): + parts.append('('); closing += ')' + level = 0 + i = endpos + while i < len(csource): + c = csource[i] + if c == '(': + level += 1 + elif c == ')': + if level == 0: + break + level -= 1 + elif c in ',;=': + if level == 0: + break + i += 1 + csource = csource[endpos:i] + closing + csource[i:] + #print repr(''.join(parts)+csource) + parts.append(csource) + return ''.join(parts) + +def _preprocess_extern_python(csource): + # input: `extern "Python" int foo(int);` or + # `extern "Python" { int foo(int); }` + # output: + # void __cffi_extern_python_start; + # int foo(int); + # void __cffi_extern_python_stop; + # + # input: `extern "Python+C" int foo(int);` + # output: + # void __cffi_extern_python_plus_c_start; + # int foo(int); + # void __cffi_extern_python_stop; + parts = [] + while True: + match = _r_extern_python.search(csource) + if not match: + break + endpos = match.end() - 1 + #print + #print ''.join(parts)+csource + #print '=>' + parts.append(csource[:match.start()]) + if 'C' in match.group(1): + parts.append('void __cffi_extern_python_plus_c_start; ') + else: + parts.append('void __cffi_extern_python_start; ') + if csource[endpos] == '{': + # grouping variant + closing = csource.find('}', endpos) + if closing < 0: + raise CDefError("'extern \"Python\" {': no '}' found") + if csource.find('{', endpos + 1, closing) >= 0: + raise NotImplementedError("cannot use { } inside a block " + "'extern \"Python\" { ... }'") + parts.append(csource[endpos+1:closing]) + csource = csource[closing+1:] + else: + # non-grouping variant + semicolon = csource.find(';', endpos) + if semicolon < 0: + raise CDefError("'extern \"Python\": no ';' found") + parts.append(csource[endpos:semicolon+1]) + csource = csource[semicolon+1:] + parts.append(' void __cffi_extern_python_stop;') + #print ''.join(parts)+csource + #print + parts.append(csource) + return ''.join(parts) + +def _warn_for_string_literal(csource): + if '"' not in csource: + return + for line in csource.splitlines(): + if '"' in line and not line.lstrip().startswith('#'): + import warnings + warnings.warn("String literal found in cdef() or type source. " + "String literals are ignored here, but you should " + "remove them anyway because some character sequences " + "confuse pre-parsing.") + break + +def _warn_for_non_extern_non_static_global_variable(decl): + if not decl.storage: + import warnings + warnings.warn("Global variable '%s' in cdef(): for consistency " + "with C it should have a storage class specifier " + "(usually 'extern')" % (decl.name,)) + +def _remove_line_directives(csource): + # _r_line_directive matches whole lines, without the final \n, if they + # start with '#line' with some spacing allowed, or '#NUMBER'. This + # function stores them away and replaces them with exactly the string + # '#line@N', where N is the index in the list 'line_directives'. + line_directives = [] + def replace(m): + i = len(line_directives) + line_directives.append(m.group()) + return '#line@%d' % i + csource = _r_line_directive.sub(replace, csource) + return csource, line_directives + +def _put_back_line_directives(csource, line_directives): + def replace(m): + s = m.group() + if not s.startswith('#line@'): + raise AssertionError("unexpected #line directive " + "(should have been processed and removed") + return line_directives[int(s[6:])] + return _r_line_directive.sub(replace, csource) + +def _preprocess(csource): + # First, remove the lines of the form '#line N "filename"' because + # the "filename" part could confuse the rest + csource, line_directives = _remove_line_directives(csource) + # Remove comments. NOTE: this only work because the cdef() section + # should not contain any string literals (except in line directives)! + def replace_keeping_newlines(m): + return ' ' + m.group().count('\n') * '\n' + csource = _r_comment.sub(replace_keeping_newlines, csource) + # Remove the "#define FOO x" lines + macros = {} + for match in _r_define.finditer(csource): + macroname, macrovalue = match.groups() + macrovalue = macrovalue.replace('\\\n', '').strip() + macros[macroname] = macrovalue + csource = _r_define.sub('', csource) + # + if pycparser.__version__ < '2.14': + csource = _workaround_for_old_pycparser(csource) + # + # BIG HACK: replace WINAPI or __stdcall with "volatile const". + # It doesn't make sense for the return type of a function to be + # "volatile volatile const", so we abuse it to detect __stdcall... + # Hack number 2 is that "int(volatile *fptr)();" is not valid C + # syntax, so we place the "volatile" before the opening parenthesis. + csource = _r_stdcall2.sub(' volatile volatile const(', csource) + csource = _r_stdcall1.sub(' volatile volatile const ', csource) + csource = _r_cdecl.sub(' ', csource) + # + # Replace `extern "Python"` with start/end markers + csource = _preprocess_extern_python(csource) + # + # Now there should not be any string literal left; warn if we get one + _warn_for_string_literal(csource) + # + # Replace "[...]" with "[__dotdotdotarray__]" + csource = _r_partial_array.sub('[__dotdotdotarray__]', csource) + # + # Replace "...}" with "__dotdotdotNUM__}". This construction should + # occur only at the end of enums; at the end of structs we have "...;}" + # and at the end of vararg functions "...);". Also replace "=...[,}]" + # with ",__dotdotdotNUM__[,}]": this occurs in the enums too, when + # giving an unknown value. + matches = list(_r_partial_enum.finditer(csource)) + for number, match in enumerate(reversed(matches)): + p = match.start() + if csource[p] == '=': + p2 = csource.find('...', p, match.end()) + assert p2 > p + csource = '%s,__dotdotdot%d__ %s' % (csource[:p], number, + csource[p2+3:]) + else: + assert csource[p:p+3] == '...' + csource = '%s __dotdotdot%d__ %s' % (csource[:p], number, + csource[p+3:]) + # Replace "int ..." or "unsigned long int..." with "__dotdotdotint__" + csource = _r_int_dotdotdot.sub(' __dotdotdotint__ ', csource) + # Replace "float ..." or "double..." with "__dotdotdotfloat__" + csource = _r_float_dotdotdot.sub(' __dotdotdotfloat__ ', csource) + # Replace all remaining "..." with the same name, "__dotdotdot__", + # which is declared with a typedef for the purpose of C parsing. + csource = csource.replace('...', ' __dotdotdot__ ') + # Finally, put back the line directives + csource = _put_back_line_directives(csource, line_directives) + return csource, macros + +def _common_type_names(csource): + # Look in the source for what looks like usages of types from the + # list of common types. A "usage" is approximated here as the + # appearance of the word, minus a "definition" of the type, which + # is the last word in a "typedef" statement. Approximative only + # but should be fine for all the common types. + look_for_words = set(COMMON_TYPES) + look_for_words.add(';') + look_for_words.add(',') + look_for_words.add('(') + look_for_words.add(')') + look_for_words.add('typedef') + words_used = set() + is_typedef = False + paren = 0 + previous_word = '' + for word in _r_words.findall(csource): + if word in look_for_words: + if word == ';': + if is_typedef: + words_used.discard(previous_word) + look_for_words.discard(previous_word) + is_typedef = False + elif word == 'typedef': + is_typedef = True + paren = 0 + elif word == '(': + paren += 1 + elif word == ')': + paren -= 1 + elif word == ',': + if is_typedef and paren == 0: + words_used.discard(previous_word) + look_for_words.discard(previous_word) + else: # word in COMMON_TYPES + words_used.add(word) + previous_word = word + return words_used + + +class Parser(object): + + def __init__(self): + self._declarations = {} + self._included_declarations = set() + self._anonymous_counter = 0 + self._structnode2type = weakref.WeakKeyDictionary() + self._options = {} + self._int_constants = {} + self._recomplete = [] + self._uses_new_feature = None + + def _parse(self, csource): + csource, macros = _preprocess(csource) + # XXX: for more efficiency we would need to poke into the + # internals of CParser... the following registers the + # typedefs, because their presence or absence influences the + # parsing itself (but what they are typedef'ed to plays no role) + ctn = _common_type_names(csource) + typenames = [] + for name in sorted(self._declarations): + if name.startswith('typedef '): + name = name[8:] + typenames.append(name) + ctn.discard(name) + typenames += sorted(ctn) + # + csourcelines = [] + csourcelines.append('# 1 ""') + for typename in typenames: + csourcelines.append('typedef int %s;' % typename) + csourcelines.append('typedef int __dotdotdotint__, __dotdotdotfloat__,' + ' __dotdotdot__;') + # this forces pycparser to consider the following in the file + # called from line 1 + csourcelines.append('# 1 "%s"' % (CDEF_SOURCE_STRING,)) + csourcelines.append(csource) + fullcsource = '\n'.join(csourcelines) + if lock is not None: + lock.acquire() # pycparser is not thread-safe... + try: + ast = _get_parser().parse(fullcsource) + except pycparser.c_parser.ParseError as e: + self.convert_pycparser_error(e, csource) + finally: + if lock is not None: + lock.release() + # csource will be used to find buggy source text + return ast, macros, csource + + def _convert_pycparser_error(self, e, csource): + # xxx look for ":NUM:" at the start of str(e) + # and interpret that as a line number. This will not work if + # the user gives explicit ``# NUM "FILE"`` directives. + line = None + msg = str(e) + match = re.match(r"%s:(\d+):" % (CDEF_SOURCE_STRING,), msg) + if match: + linenum = int(match.group(1), 10) + csourcelines = csource.splitlines() + if 1 <= linenum <= len(csourcelines): + line = csourcelines[linenum-1] + return line + + def convert_pycparser_error(self, e, csource): + line = self._convert_pycparser_error(e, csource) + + msg = str(e) + if line: + msg = 'cannot parse "%s"\n%s' % (line.strip(), msg) + else: + msg = 'parse error\n%s' % (msg,) + raise CDefError(msg) + + def parse(self, csource, override=False, packed=False, pack=None, + dllexport=False): + if packed: + if packed != True: + raise ValueError("'packed' should be False or True; use " + "'pack' to give another value") + if pack: + raise ValueError("cannot give both 'pack' and 'packed'") + pack = 1 + elif pack: + if pack & (pack - 1): + raise ValueError("'pack' must be a power of two, not %r" % + (pack,)) + else: + pack = 0 + prev_options = self._options + try: + self._options = {'override': override, + 'packed': pack, + 'dllexport': dllexport} + self._internal_parse(csource) + finally: + self._options = prev_options + + def _internal_parse(self, csource): + ast, macros, csource = self._parse(csource) + # add the macros + self._process_macros(macros) + # find the first "__dotdotdot__" and use that as a separator + # between the repeated typedefs and the real csource + iterator = iter(ast.ext) + for decl in iterator: + if decl.name == '__dotdotdot__': + break + else: + assert 0 + current_decl = None + # + try: + self._inside_extern_python = '__cffi_extern_python_stop' + for decl in iterator: + current_decl = decl + if isinstance(decl, pycparser.c_ast.Decl): + self._parse_decl(decl) + elif isinstance(decl, pycparser.c_ast.Typedef): + if not decl.name: + raise CDefError("typedef does not declare any name", + decl) + quals = 0 + if (isinstance(decl.type.type, pycparser.c_ast.IdentifierType) and + decl.type.type.names[-1].startswith('__dotdotdot')): + realtype = self._get_unknown_type(decl) + elif (isinstance(decl.type, pycparser.c_ast.PtrDecl) and + isinstance(decl.type.type, pycparser.c_ast.TypeDecl) and + isinstance(decl.type.type.type, + pycparser.c_ast.IdentifierType) and + decl.type.type.type.names[-1].startswith('__dotdotdot')): + realtype = self._get_unknown_ptr_type(decl) + else: + realtype, quals = self._get_type_and_quals( + decl.type, name=decl.name, partial_length_ok=True, + typedef_example="*(%s *)0" % (decl.name,)) + self._declare('typedef ' + decl.name, realtype, quals=quals) + elif decl.__class__.__name__ == 'Pragma': + pass # skip pragma, only in pycparser 2.15 + else: + raise CDefError("unexpected <%s>: this construct is valid " + "C but not valid in cdef()" % + decl.__class__.__name__, decl) + except CDefError as e: + if len(e.args) == 1: + e.args = e.args + (current_decl,) + raise + except FFIError as e: + msg = self._convert_pycparser_error(e, csource) + if msg: + e.args = (e.args[0] + "\n *** Err: %s" % msg,) + raise + + def _add_constants(self, key, val): + if key in self._int_constants: + if self._int_constants[key] == val: + return # ignore identical double declarations + raise FFIError( + "multiple declarations of constant: %s" % (key,)) + self._int_constants[key] = val + + def _add_integer_constant(self, name, int_str): + int_str = int_str.lower().rstrip("ul") + neg = int_str.startswith('-') + if neg: + int_str = int_str[1:] + # "010" is not valid oct in py3 + if (int_str.startswith("0") and int_str != '0' + and not int_str.startswith("0x")): + int_str = "0o" + int_str[1:] + pyvalue = int(int_str, 0) + if neg: + pyvalue = -pyvalue + self._add_constants(name, pyvalue) + self._declare('macro ' + name, pyvalue) + + def _process_macros(self, macros): + for key, value in macros.items(): + value = value.strip() + if _r_int_literal.match(value): + self._add_integer_constant(key, value) + elif value == '...': + self._declare('macro ' + key, value) + else: + raise CDefError( + 'only supports one of the following syntax:\n' + ' #define %s ... (literally dot-dot-dot)\n' + ' #define %s NUMBER (with NUMBER an integer' + ' constant, decimal/hex/octal)\n' + 'got:\n' + ' #define %s %s' + % (key, key, key, value)) + + def _declare_function(self, tp, quals, decl): + tp = self._get_type_pointer(tp, quals) + if self._options.get('dllexport'): + tag = 'dllexport_python ' + elif self._inside_extern_python == '__cffi_extern_python_start': + tag = 'extern_python ' + elif self._inside_extern_python == '__cffi_extern_python_plus_c_start': + tag = 'extern_python_plus_c ' + else: + tag = 'function ' + self._declare(tag + decl.name, tp) + + def _parse_decl(self, decl): + node = decl.type + if isinstance(node, pycparser.c_ast.FuncDecl): + tp, quals = self._get_type_and_quals(node, name=decl.name) + assert isinstance(tp, model.RawFunctionType) + self._declare_function(tp, quals, decl) + else: + if isinstance(node, pycparser.c_ast.Struct): + self._get_struct_union_enum_type('struct', node) + elif isinstance(node, pycparser.c_ast.Union): + self._get_struct_union_enum_type('union', node) + elif isinstance(node, pycparser.c_ast.Enum): + self._get_struct_union_enum_type('enum', node) + elif not decl.name: + raise CDefError("construct does not declare any variable", + decl) + # + if decl.name: + tp, quals = self._get_type_and_quals(node, + partial_length_ok=True) + if tp.is_raw_function: + self._declare_function(tp, quals, decl) + elif (tp.is_integer_type() and + hasattr(decl, 'init') and + hasattr(decl.init, 'value') and + _r_int_literal.match(decl.init.value)): + self._add_integer_constant(decl.name, decl.init.value) + elif (tp.is_integer_type() and + isinstance(decl.init, pycparser.c_ast.UnaryOp) and + decl.init.op == '-' and + hasattr(decl.init.expr, 'value') and + _r_int_literal.match(decl.init.expr.value)): + self._add_integer_constant(decl.name, + '-' + decl.init.expr.value) + elif (tp is model.void_type and + decl.name.startswith('__cffi_extern_python_')): + # hack: `extern "Python"` in the C source is replaced + # with "void __cffi_extern_python_start;" and + # "void __cffi_extern_python_stop;" + self._inside_extern_python = decl.name + else: + if self._inside_extern_python !='__cffi_extern_python_stop': + raise CDefError( + "cannot declare constants or " + "variables with 'extern \"Python\"'") + if (quals & model.Q_CONST) and not tp.is_array_type: + self._declare('constant ' + decl.name, tp, quals=quals) + else: + _warn_for_non_extern_non_static_global_variable(decl) + self._declare('variable ' + decl.name, tp, quals=quals) + + def parse_type(self, cdecl): + return self.parse_type_and_quals(cdecl)[0] + + def parse_type_and_quals(self, cdecl): + ast, macros = self._parse('void __dummy(\n%s\n);' % cdecl)[:2] + assert not macros + exprnode = ast.ext[-1].type.args.params[0] + if isinstance(exprnode, pycparser.c_ast.ID): + raise CDefError("unknown identifier '%s'" % (exprnode.name,)) + return self._get_type_and_quals(exprnode.type) + + def _declare(self, name, obj, included=False, quals=0): + if name in self._declarations: + prevobj, prevquals = self._declarations[name] + if prevobj is obj and prevquals == quals: + return + if not self._options.get('override'): + raise FFIError( + "multiple declarations of %s (for interactive usage, " + "try cdef(xx, override=True))" % (name,)) + assert '__dotdotdot__' not in name.split() + self._declarations[name] = (obj, quals) + if included: + self._included_declarations.add(obj) + + def _extract_quals(self, type): + quals = 0 + if isinstance(type, (pycparser.c_ast.TypeDecl, + pycparser.c_ast.PtrDecl)): + if 'const' in type.quals: + quals |= model.Q_CONST + if 'volatile' in type.quals: + quals |= model.Q_VOLATILE + if 'restrict' in type.quals: + quals |= model.Q_RESTRICT + return quals + + def _get_type_pointer(self, type, quals, declname=None): + if isinstance(type, model.RawFunctionType): + return type.as_function_pointer() + if (isinstance(type, model.StructOrUnionOrEnum) and + type.name.startswith('$') and type.name[1:].isdigit() and + type.forcename is None and declname is not None): + return model.NamedPointerType(type, declname, quals) + return model.PointerType(type, quals) + + def _get_type_and_quals(self, typenode, name=None, partial_length_ok=False, + typedef_example=None): + # first, dereference typedefs, if we have it already parsed, we're good + if (isinstance(typenode, pycparser.c_ast.TypeDecl) and + isinstance(typenode.type, pycparser.c_ast.IdentifierType) and + len(typenode.type.names) == 1 and + ('typedef ' + typenode.type.names[0]) in self._declarations): + tp, quals = self._declarations['typedef ' + typenode.type.names[0]] + quals |= self._extract_quals(typenode) + return tp, quals + # + if isinstance(typenode, pycparser.c_ast.ArrayDecl): + # array type + if typenode.dim is None: + length = None + else: + length = self._parse_constant( + typenode.dim, partial_length_ok=partial_length_ok) + # a hack: in 'typedef int foo_t[...][...];', don't use '...' as + # the length but use directly the C expression that would be + # generated by recompiler.py. This lets the typedef be used in + # many more places within recompiler.py + if typedef_example is not None: + if length == '...': + length = '_cffi_array_len(%s)' % (typedef_example,) + typedef_example = "*" + typedef_example + # + tp, quals = self._get_type_and_quals(typenode.type, + partial_length_ok=partial_length_ok, + typedef_example=typedef_example) + return model.ArrayType(tp, length), quals + # + if isinstance(typenode, pycparser.c_ast.PtrDecl): + # pointer type + itemtype, itemquals = self._get_type_and_quals(typenode.type) + tp = self._get_type_pointer(itemtype, itemquals, declname=name) + quals = self._extract_quals(typenode) + return tp, quals + # + if isinstance(typenode, pycparser.c_ast.TypeDecl): + quals = self._extract_quals(typenode) + type = typenode.type + if isinstance(type, pycparser.c_ast.IdentifierType): + # assume a primitive type. get it from .names, but reduce + # synonyms to a single chosen combination + names = list(type.names) + if names != ['signed', 'char']: # keep this unmodified + prefixes = {} + while names: + name = names[0] + if name in ('short', 'long', 'signed', 'unsigned'): + prefixes[name] = prefixes.get(name, 0) + 1 + del names[0] + else: + break + # ignore the 'signed' prefix below, and reorder the others + newnames = [] + for prefix in ('unsigned', 'short', 'long'): + for i in range(prefixes.get(prefix, 0)): + newnames.append(prefix) + if not names: + names = ['int'] # implicitly + if names == ['int']: # but kill it if 'short' or 'long' + if 'short' in prefixes or 'long' in prefixes: + names = [] + names = newnames + names + ident = ' '.join(names) + if ident == 'void': + return model.void_type, quals + if ident == '__dotdotdot__': + raise FFIError(':%d: bad usage of "..."' % + typenode.coord.line) + tp0, quals0 = resolve_common_type(self, ident) + return tp0, (quals | quals0) + # + if isinstance(type, pycparser.c_ast.Struct): + # 'struct foobar' + tp = self._get_struct_union_enum_type('struct', type, name) + return tp, quals + # + if isinstance(type, pycparser.c_ast.Union): + # 'union foobar' + tp = self._get_struct_union_enum_type('union', type, name) + return tp, quals + # + if isinstance(type, pycparser.c_ast.Enum): + # 'enum foobar' + tp = self._get_struct_union_enum_type('enum', type, name) + return tp, quals + # + if isinstance(typenode, pycparser.c_ast.FuncDecl): + # a function type + return self._parse_function_type(typenode, name), 0 + # + # nested anonymous structs or unions end up here + if isinstance(typenode, pycparser.c_ast.Struct): + return self._get_struct_union_enum_type('struct', typenode, name, + nested=True), 0 + if isinstance(typenode, pycparser.c_ast.Union): + return self._get_struct_union_enum_type('union', typenode, name, + nested=True), 0 + # + raise FFIError(":%d: bad or unsupported type declaration" % + typenode.coord.line) + + def _parse_function_type(self, typenode, funcname=None): + params = list(getattr(typenode.args, 'params', [])) + for i, arg in enumerate(params): + if not hasattr(arg, 'type'): + raise CDefError("%s arg %d: unknown type '%s'" + " (if you meant to use the old C syntax of giving" + " untyped arguments, it is not supported)" + % (funcname or 'in expression', i + 1, + getattr(arg, 'name', '?'))) + ellipsis = ( + len(params) > 0 and + isinstance(params[-1].type, pycparser.c_ast.TypeDecl) and + isinstance(params[-1].type.type, + pycparser.c_ast.IdentifierType) and + params[-1].type.type.names == ['__dotdotdot__']) + if ellipsis: + params.pop() + if not params: + raise CDefError( + "%s: a function with only '(...)' as argument" + " is not correct C" % (funcname or 'in expression')) + args = [self._as_func_arg(*self._get_type_and_quals(argdeclnode.type)) + for argdeclnode in params] + if not ellipsis and args == [model.void_type]: + args = [] + result, quals = self._get_type_and_quals(typenode.type) + # the 'quals' on the result type are ignored. HACK: we absure them + # to detect __stdcall functions: we textually replace "__stdcall" + # with "volatile volatile const" above. + abi = None + if hasattr(typenode.type, 'quals'): # else, probable syntax error anyway + if typenode.type.quals[-3:] == ['volatile', 'volatile', 'const']: + abi = '__stdcall' + return model.RawFunctionType(tuple(args), result, ellipsis, abi) + + def _as_func_arg(self, type, quals): + if isinstance(type, model.ArrayType): + return model.PointerType(type.item, quals) + elif isinstance(type, model.RawFunctionType): + return type.as_function_pointer() + else: + return type + + def _get_struct_union_enum_type(self, kind, type, name=None, nested=False): + # First, a level of caching on the exact 'type' node of the AST. + # This is obscure, but needed because pycparser "unrolls" declarations + # such as "typedef struct { } foo_t, *foo_p" and we end up with + # an AST that is not a tree, but a DAG, with the "type" node of the + # two branches foo_t and foo_p of the trees being the same node. + # It's a bit silly but detecting "DAG-ness" in the AST tree seems + # to be the only way to distinguish this case from two independent + # structs. See test_struct_with_two_usages. + try: + return self._structnode2type[type] + except KeyError: + pass + # + # Note that this must handle parsing "struct foo" any number of + # times and always return the same StructType object. Additionally, + # one of these times (not necessarily the first), the fields of + # the struct can be specified with "struct foo { ...fields... }". + # If no name is given, then we have to create a new anonymous struct + # with no caching; in this case, the fields are either specified + # right now or never. + # + force_name = name + name = type.name + # + # get the type or create it if needed + if name is None: + # 'force_name' is used to guess a more readable name for + # anonymous structs, for the common case "typedef struct { } foo". + if force_name is not None: + explicit_name = '$%s' % force_name + else: + self._anonymous_counter += 1 + explicit_name = '$%d' % self._anonymous_counter + tp = None + else: + explicit_name = name + key = '%s %s' % (kind, name) + tp, _ = self._declarations.get(key, (None, None)) + # + if tp is None: + if kind == 'struct': + tp = model.StructType(explicit_name, None, None, None) + elif kind == 'union': + tp = model.UnionType(explicit_name, None, None, None) + elif kind == 'enum': + if explicit_name == '__dotdotdot__': + raise CDefError("Enums cannot be declared with ...") + tp = self._build_enum_type(explicit_name, type.values) + else: + raise AssertionError("kind = %r" % (kind,)) + if name is not None: + self._declare(key, tp) + else: + if kind == 'enum' and type.values is not None: + raise NotImplementedError( + "enum %s: the '{}' declaration should appear on the first " + "time the enum is mentioned, not later" % explicit_name) + if not tp.forcename: + tp.force_the_name(force_name) + if tp.forcename and '$' in tp.name: + self._declare('anonymous %s' % tp.forcename, tp) + # + self._structnode2type[type] = tp + # + # enums: done here + if kind == 'enum': + return tp + # + # is there a 'type.decls'? If yes, then this is the place in the + # C sources that declare the fields. If no, then just return the + # existing type, possibly still incomplete. + if type.decls is None: + return tp + # + if tp.fldnames is not None: + raise CDefError("duplicate declaration of struct %s" % name) + fldnames = [] + fldtypes = [] + fldbitsize = [] + fldquals = [] + for decl in type.decls: + if (isinstance(decl.type, pycparser.c_ast.IdentifierType) and + ''.join(decl.type.names) == '__dotdotdot__'): + # XXX pycparser is inconsistent: 'names' should be a list + # of strings, but is sometimes just one string. Use + # str.join() as a way to cope with both. + self._make_partial(tp, nested) + continue + if decl.bitsize is None: + bitsize = -1 + else: + bitsize = self._parse_constant(decl.bitsize) + self._partial_length = False + type, fqual = self._get_type_and_quals(decl.type, + partial_length_ok=True) + if self._partial_length: + self._make_partial(tp, nested) + if isinstance(type, model.StructType) and type.partial: + self._make_partial(tp, nested) + fldnames.append(decl.name or '') + fldtypes.append(type) + fldbitsize.append(bitsize) + fldquals.append(fqual) + tp.fldnames = tuple(fldnames) + tp.fldtypes = tuple(fldtypes) + tp.fldbitsize = tuple(fldbitsize) + tp.fldquals = tuple(fldquals) + if fldbitsize != [-1] * len(fldbitsize): + if isinstance(tp, model.StructType) and tp.partial: + raise NotImplementedError("%s: using both bitfields and '...;'" + % (tp,)) + tp.packed = self._options.get('packed') + if tp.completed: # must be re-completed: it is not opaque any more + tp.completed = 0 + self._recomplete.append(tp) + return tp + + def _make_partial(self, tp, nested): + if not isinstance(tp, model.StructOrUnion): + raise CDefError("%s cannot be partial" % (tp,)) + if not tp.has_c_name() and not nested: + raise NotImplementedError("%s is partial but has no C name" %(tp,)) + tp.partial = True + + def _parse_constant(self, exprnode, partial_length_ok=False): + # for now, limited to expressions that are an immediate number + # or positive/negative number + if isinstance(exprnode, pycparser.c_ast.Constant): + s = exprnode.value + if '0' <= s[0] <= '9': + s = s.rstrip('uUlL') + try: + if s.startswith('0'): + return int(s, 8) + else: + return int(s, 10) + except ValueError: + if len(s) > 1: + if s.lower()[0:2] == '0x': + return int(s, 16) + elif s.lower()[0:2] == '0b': + return int(s, 2) + raise CDefError("invalid constant %r" % (s,)) + elif s[0] == "'" and s[-1] == "'" and ( + len(s) == 3 or (len(s) == 4 and s[1] == "\\")): + return ord(s[-2]) + else: + raise CDefError("invalid constant %r" % (s,)) + # + if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and + exprnode.op == '+'): + return self._parse_constant(exprnode.expr) + # + if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and + exprnode.op == '-'): + return -self._parse_constant(exprnode.expr) + # load previously defined int constant + if (isinstance(exprnode, pycparser.c_ast.ID) and + exprnode.name in self._int_constants): + return self._int_constants[exprnode.name] + # + if (isinstance(exprnode, pycparser.c_ast.ID) and + exprnode.name == '__dotdotdotarray__'): + if partial_length_ok: + self._partial_length = True + return '...' + raise FFIError(":%d: unsupported '[...]' here, cannot derive " + "the actual array length in this context" + % exprnode.coord.line) + # + if isinstance(exprnode, pycparser.c_ast.BinaryOp): + left = self._parse_constant(exprnode.left) + right = self._parse_constant(exprnode.right) + if exprnode.op == '+': + return left + right + elif exprnode.op == '-': + return left - right + elif exprnode.op == '*': + return left * right + elif exprnode.op == '/': + return self._c_div(left, right) + elif exprnode.op == '%': + return left - self._c_div(left, right) * right + elif exprnode.op == '<<': + return left << right + elif exprnode.op == '>>': + return left >> right + elif exprnode.op == '&': + return left & right + elif exprnode.op == '|': + return left | right + elif exprnode.op == '^': + return left ^ right + # + raise FFIError(":%d: unsupported expression: expected a " + "simple numeric constant" % exprnode.coord.line) + + def _c_div(self, a, b): + result = a // b + if ((a < 0) ^ (b < 0)) and (a % b) != 0: + result += 1 + return result + + def _build_enum_type(self, explicit_name, decls): + if decls is not None: + partial = False + enumerators = [] + enumvalues = [] + nextenumvalue = 0 + for enum in decls.enumerators: + if _r_enum_dotdotdot.match(enum.name): + partial = True + continue + if enum.value is not None: + nextenumvalue = self._parse_constant(enum.value) + enumerators.append(enum.name) + enumvalues.append(nextenumvalue) + self._add_constants(enum.name, nextenumvalue) + nextenumvalue += 1 + enumerators = tuple(enumerators) + enumvalues = tuple(enumvalues) + tp = model.EnumType(explicit_name, enumerators, enumvalues) + tp.partial = partial + else: # opaque enum + tp = model.EnumType(explicit_name, (), ()) + return tp + + def include(self, other): + for name, (tp, quals) in other._declarations.items(): + if name.startswith('anonymous $enum_$'): + continue # fix for test_anonymous_enum_include + kind = name.split(' ', 1)[0] + if kind in ('struct', 'union', 'enum', 'anonymous', 'typedef'): + self._declare(name, tp, included=True, quals=quals) + for k, v in other._int_constants.items(): + self._add_constants(k, v) + + def _get_unknown_type(self, decl): + typenames = decl.type.type.names + if typenames == ['__dotdotdot__']: + return model.unknown_type(decl.name) + + if typenames == ['__dotdotdotint__']: + if self._uses_new_feature is None: + self._uses_new_feature = "'typedef int... %s'" % decl.name + return model.UnknownIntegerType(decl.name) + + if typenames == ['__dotdotdotfloat__']: + # note: not for 'long double' so far + if self._uses_new_feature is None: + self._uses_new_feature = "'typedef float... %s'" % decl.name + return model.UnknownFloatType(decl.name) + + raise FFIError(':%d: unsupported usage of "..." in typedef' + % decl.coord.line) + + def _get_unknown_ptr_type(self, decl): + if decl.type.type.type.names == ['__dotdotdot__']: + return model.unknown_ptr_type(decl.name) + raise FFIError(':%d: unsupported usage of "..." in typedef' + % decl.coord.line) diff --git a/.venv/Lib/site-packages/cffi/error.py b/.venv/Lib/site-packages/cffi/error.py new file mode 100644 index 00000000..0a27247c --- /dev/null +++ b/.venv/Lib/site-packages/cffi/error.py @@ -0,0 +1,31 @@ + +class FFIError(Exception): + __module__ = 'cffi' + +class CDefError(Exception): + __module__ = 'cffi' + def __str__(self): + try: + current_decl = self.args[1] + filename = current_decl.coord.file + linenum = current_decl.coord.line + prefix = '%s:%d: ' % (filename, linenum) + except (AttributeError, TypeError, IndexError): + prefix = '' + return '%s%s' % (prefix, self.args[0]) + +class VerificationError(Exception): + """ An error raised when verification fails + """ + __module__ = 'cffi' + +class VerificationMissing(Exception): + """ An error raised when incomplete structures are passed into + cdef, but no verification has been done + """ + __module__ = 'cffi' + +class PkgConfigError(Exception): + """ An error raised for missing modules in pkg-config + """ + __module__ = 'cffi' diff --git a/.venv/Lib/site-packages/cffi/ffiplatform.py b/.venv/Lib/site-packages/cffi/ffiplatform.py new file mode 100644 index 00000000..adca28f1 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/ffiplatform.py @@ -0,0 +1,113 @@ +import sys, os +from .error import VerificationError + + +LIST_OF_FILE_NAMES = ['sources', 'include_dirs', 'library_dirs', + 'extra_objects', 'depends'] + +def get_extension(srcfilename, modname, sources=(), **kwds): + from cffi._shimmed_dist_utils import Extension + allsources = [srcfilename] + for src in sources: + allsources.append(os.path.normpath(src)) + return Extension(name=modname, sources=allsources, **kwds) + +def compile(tmpdir, ext, compiler_verbose=0, debug=None): + """Compile a C extension module using distutils.""" + + saved_environ = os.environ.copy() + try: + outputfilename = _build(tmpdir, ext, compiler_verbose, debug) + outputfilename = os.path.abspath(outputfilename) + finally: + # workaround for a distutils bugs where some env vars can + # become longer and longer every time it is used + for key, value in saved_environ.items(): + if os.environ.get(key) != value: + os.environ[key] = value + return outputfilename + +def _build(tmpdir, ext, compiler_verbose=0, debug=None): + # XXX compact but horrible :-( + from cffi._shimmed_dist_utils import Distribution, CompileError, LinkError, set_threshold, set_verbosity + + dist = Distribution({'ext_modules': [ext]}) + dist.parse_config_files() + options = dist.get_option_dict('build_ext') + if debug is None: + debug = sys.flags.debug + options['debug'] = ('ffiplatform', debug) + options['force'] = ('ffiplatform', True) + options['build_lib'] = ('ffiplatform', tmpdir) + options['build_temp'] = ('ffiplatform', tmpdir) + # + try: + old_level = set_threshold(0) or 0 + try: + set_verbosity(compiler_verbose) + dist.run_command('build_ext') + cmd_obj = dist.get_command_obj('build_ext') + [soname] = cmd_obj.get_outputs() + finally: + set_threshold(old_level) + except (CompileError, LinkError) as e: + raise VerificationError('%s: %s' % (e.__class__.__name__, e)) + # + return soname + +try: + from os.path import samefile +except ImportError: + def samefile(f1, f2): + return os.path.abspath(f1) == os.path.abspath(f2) + +def maybe_relative_path(path): + if not os.path.isabs(path): + return path # already relative + dir = path + names = [] + while True: + prevdir = dir + dir, name = os.path.split(prevdir) + if dir == prevdir or not dir: + return path # failed to make it relative + names.append(name) + try: + if samefile(dir, os.curdir): + names.reverse() + return os.path.join(*names) + except OSError: + pass + +# ____________________________________________________________ + +try: + int_or_long = (int, long) + import cStringIO +except NameError: + int_or_long = int # Python 3 + import io as cStringIO + +def _flatten(x, f): + if isinstance(x, str): + f.write('%ds%s' % (len(x), x)) + elif isinstance(x, dict): + keys = sorted(x.keys()) + f.write('%dd' % len(keys)) + for key in keys: + _flatten(key, f) + _flatten(x[key], f) + elif isinstance(x, (list, tuple)): + f.write('%dl' % len(x)) + for value in x: + _flatten(value, f) + elif isinstance(x, int_or_long): + f.write('%di' % (x,)) + else: + raise TypeError( + "the keywords to verify() contains unsupported object %r" % (x,)) + +def flatten(x): + f = cStringIO.StringIO() + _flatten(x, f) + return f.getvalue() diff --git a/.venv/Lib/site-packages/cffi/lock.py b/.venv/Lib/site-packages/cffi/lock.py new file mode 100644 index 00000000..db91b715 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/lock.py @@ -0,0 +1,30 @@ +import sys + +if sys.version_info < (3,): + try: + from thread import allocate_lock + except ImportError: + from dummy_thread import allocate_lock +else: + try: + from _thread import allocate_lock + except ImportError: + from _dummy_thread import allocate_lock + + +##import sys +##l1 = allocate_lock + +##class allocate_lock(object): +## def __init__(self): +## self._real = l1() +## def __enter__(self): +## for i in range(4, 0, -1): +## print sys._getframe(i).f_code +## print +## return self._real.__enter__() +## def __exit__(self, *args): +## return self._real.__exit__(*args) +## def acquire(self, f): +## assert f is False +## return self._real.acquire(f) diff --git a/.venv/Lib/site-packages/cffi/model.py b/.venv/Lib/site-packages/cffi/model.py new file mode 100644 index 00000000..1708f43d --- /dev/null +++ b/.venv/Lib/site-packages/cffi/model.py @@ -0,0 +1,618 @@ +import types +import weakref + +from .lock import allocate_lock +from .error import CDefError, VerificationError, VerificationMissing + +# type qualifiers +Q_CONST = 0x01 +Q_RESTRICT = 0x02 +Q_VOLATILE = 0x04 + +def qualify(quals, replace_with): + if quals & Q_CONST: + replace_with = ' const ' + replace_with.lstrip() + if quals & Q_VOLATILE: + replace_with = ' volatile ' + replace_with.lstrip() + if quals & Q_RESTRICT: + # It seems that __restrict is supported by gcc and msvc. + # If you hit some different compiler, add a #define in + # _cffi_include.h for it (and in its copies, documented there) + replace_with = ' __restrict ' + replace_with.lstrip() + return replace_with + + +class BaseTypeByIdentity(object): + is_array_type = False + is_raw_function = False + + def get_c_name(self, replace_with='', context='a C file', quals=0): + result = self.c_name_with_marker + assert result.count('&') == 1 + # some logic duplication with ffi.getctype()... :-( + replace_with = replace_with.strip() + if replace_with: + if replace_with.startswith('*') and '&[' in result: + replace_with = '(%s)' % replace_with + elif not replace_with[0] in '[(': + replace_with = ' ' + replace_with + replace_with = qualify(quals, replace_with) + result = result.replace('&', replace_with) + if '$' in result: + raise VerificationError( + "cannot generate '%s' in %s: unknown type name" + % (self._get_c_name(), context)) + return result + + def _get_c_name(self): + return self.c_name_with_marker.replace('&', '') + + def has_c_name(self): + return '$' not in self._get_c_name() + + def is_integer_type(self): + return False + + def get_cached_btype(self, ffi, finishlist, can_delay=False): + try: + BType = ffi._cached_btypes[self] + except KeyError: + BType = self.build_backend_type(ffi, finishlist) + BType2 = ffi._cached_btypes.setdefault(self, BType) + assert BType2 is BType + return BType + + def __repr__(self): + return '<%s>' % (self._get_c_name(),) + + def _get_items(self): + return [(name, getattr(self, name)) for name in self._attrs_] + + +class BaseType(BaseTypeByIdentity): + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self._get_items() == other._get_items()) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((self.__class__, tuple(self._get_items()))) + + +class VoidType(BaseType): + _attrs_ = () + + def __init__(self): + self.c_name_with_marker = 'void&' + + def build_backend_type(self, ffi, finishlist): + return global_cache(self, ffi, 'new_void_type') + +void_type = VoidType() + + +class BasePrimitiveType(BaseType): + def is_complex_type(self): + return False + + +class PrimitiveType(BasePrimitiveType): + _attrs_ = ('name',) + + ALL_PRIMITIVE_TYPES = { + 'char': 'c', + 'short': 'i', + 'int': 'i', + 'long': 'i', + 'long long': 'i', + 'signed char': 'i', + 'unsigned char': 'i', + 'unsigned short': 'i', + 'unsigned int': 'i', + 'unsigned long': 'i', + 'unsigned long long': 'i', + 'float': 'f', + 'double': 'f', + 'long double': 'f', + 'float _Complex': 'j', + 'double _Complex': 'j', + '_Bool': 'i', + # the following types are not primitive in the C sense + 'wchar_t': 'c', + 'char16_t': 'c', + 'char32_t': 'c', + 'int8_t': 'i', + 'uint8_t': 'i', + 'int16_t': 'i', + 'uint16_t': 'i', + 'int32_t': 'i', + 'uint32_t': 'i', + 'int64_t': 'i', + 'uint64_t': 'i', + 'int_least8_t': 'i', + 'uint_least8_t': 'i', + 'int_least16_t': 'i', + 'uint_least16_t': 'i', + 'int_least32_t': 'i', + 'uint_least32_t': 'i', + 'int_least64_t': 'i', + 'uint_least64_t': 'i', + 'int_fast8_t': 'i', + 'uint_fast8_t': 'i', + 'int_fast16_t': 'i', + 'uint_fast16_t': 'i', + 'int_fast32_t': 'i', + 'uint_fast32_t': 'i', + 'int_fast64_t': 'i', + 'uint_fast64_t': 'i', + 'intptr_t': 'i', + 'uintptr_t': 'i', + 'intmax_t': 'i', + 'uintmax_t': 'i', + 'ptrdiff_t': 'i', + 'size_t': 'i', + 'ssize_t': 'i', + } + + def __init__(self, name): + assert name in self.ALL_PRIMITIVE_TYPES + self.name = name + self.c_name_with_marker = name + '&' + + def is_char_type(self): + return self.ALL_PRIMITIVE_TYPES[self.name] == 'c' + def is_integer_type(self): + return self.ALL_PRIMITIVE_TYPES[self.name] == 'i' + def is_float_type(self): + return self.ALL_PRIMITIVE_TYPES[self.name] == 'f' + def is_complex_type(self): + return self.ALL_PRIMITIVE_TYPES[self.name] == 'j' + + def build_backend_type(self, ffi, finishlist): + return global_cache(self, ffi, 'new_primitive_type', self.name) + + +class UnknownIntegerType(BasePrimitiveType): + _attrs_ = ('name',) + + def __init__(self, name): + self.name = name + self.c_name_with_marker = name + '&' + + def is_integer_type(self): + return True + + def build_backend_type(self, ffi, finishlist): + raise NotImplementedError("integer type '%s' can only be used after " + "compilation" % self.name) + +class UnknownFloatType(BasePrimitiveType): + _attrs_ = ('name', ) + + def __init__(self, name): + self.name = name + self.c_name_with_marker = name + '&' + + def build_backend_type(self, ffi, finishlist): + raise NotImplementedError("float type '%s' can only be used after " + "compilation" % self.name) + + +class BaseFunctionType(BaseType): + _attrs_ = ('args', 'result', 'ellipsis', 'abi') + + def __init__(self, args, result, ellipsis, abi=None): + self.args = args + self.result = result + self.ellipsis = ellipsis + self.abi = abi + # + reprargs = [arg._get_c_name() for arg in self.args] + if self.ellipsis: + reprargs.append('...') + reprargs = reprargs or ['void'] + replace_with = self._base_pattern % (', '.join(reprargs),) + if abi is not None: + replace_with = replace_with[:1] + abi + ' ' + replace_with[1:] + self.c_name_with_marker = ( + self.result.c_name_with_marker.replace('&', replace_with)) + + +class RawFunctionType(BaseFunctionType): + # Corresponds to a C type like 'int(int)', which is the C type of + # a function, but not a pointer-to-function. The backend has no + # notion of such a type; it's used temporarily by parsing. + _base_pattern = '(&)(%s)' + is_raw_function = True + + def build_backend_type(self, ffi, finishlist): + raise CDefError("cannot render the type %r: it is a function " + "type, not a pointer-to-function type" % (self,)) + + def as_function_pointer(self): + return FunctionPtrType(self.args, self.result, self.ellipsis, self.abi) + + +class FunctionPtrType(BaseFunctionType): + _base_pattern = '(*&)(%s)' + + def build_backend_type(self, ffi, finishlist): + result = self.result.get_cached_btype(ffi, finishlist) + args = [] + for tp in self.args: + args.append(tp.get_cached_btype(ffi, finishlist)) + abi_args = () + if self.abi == "__stdcall": + if not self.ellipsis: # __stdcall ignored for variadic funcs + try: + abi_args = (ffi._backend.FFI_STDCALL,) + except AttributeError: + pass + return global_cache(self, ffi, 'new_function_type', + tuple(args), result, self.ellipsis, *abi_args) + + def as_raw_function(self): + return RawFunctionType(self.args, self.result, self.ellipsis, self.abi) + + +class PointerType(BaseType): + _attrs_ = ('totype', 'quals') + + def __init__(self, totype, quals=0): + self.totype = totype + self.quals = quals + extra = " *&" + if totype.is_array_type: + extra = "(%s)" % (extra.lstrip(),) + extra = qualify(quals, extra) + self.c_name_with_marker = totype.c_name_with_marker.replace('&', extra) + + def build_backend_type(self, ffi, finishlist): + BItem = self.totype.get_cached_btype(ffi, finishlist, can_delay=True) + return global_cache(self, ffi, 'new_pointer_type', BItem) + +voidp_type = PointerType(void_type) + +def ConstPointerType(totype): + return PointerType(totype, Q_CONST) + +const_voidp_type = ConstPointerType(void_type) + + +class NamedPointerType(PointerType): + _attrs_ = ('totype', 'name') + + def __init__(self, totype, name, quals=0): + PointerType.__init__(self, totype, quals) + self.name = name + self.c_name_with_marker = name + '&' + + +class ArrayType(BaseType): + _attrs_ = ('item', 'length') + is_array_type = True + + def __init__(self, item, length): + self.item = item + self.length = length + # + if length is None: + brackets = '&[]' + elif length == '...': + brackets = '&[/*...*/]' + else: + brackets = '&[%s]' % length + self.c_name_with_marker = ( + self.item.c_name_with_marker.replace('&', brackets)) + + def length_is_unknown(self): + return isinstance(self.length, str) + + def resolve_length(self, newlength): + return ArrayType(self.item, newlength) + + def build_backend_type(self, ffi, finishlist): + if self.length_is_unknown(): + raise CDefError("cannot render the type %r: unknown length" % + (self,)) + self.item.get_cached_btype(ffi, finishlist) # force the item BType + BPtrItem = PointerType(self.item).get_cached_btype(ffi, finishlist) + return global_cache(self, ffi, 'new_array_type', BPtrItem, self.length) + +char_array_type = ArrayType(PrimitiveType('char'), None) + + +class StructOrUnionOrEnum(BaseTypeByIdentity): + _attrs_ = ('name',) + forcename = None + + def build_c_name_with_marker(self): + name = self.forcename or '%s %s' % (self.kind, self.name) + self.c_name_with_marker = name + '&' + + def force_the_name(self, forcename): + self.forcename = forcename + self.build_c_name_with_marker() + + def get_official_name(self): + assert self.c_name_with_marker.endswith('&') + return self.c_name_with_marker[:-1] + + +class StructOrUnion(StructOrUnionOrEnum): + fixedlayout = None + completed = 0 + partial = False + packed = 0 + + def __init__(self, name, fldnames, fldtypes, fldbitsize, fldquals=None): + self.name = name + self.fldnames = fldnames + self.fldtypes = fldtypes + self.fldbitsize = fldbitsize + self.fldquals = fldquals + self.build_c_name_with_marker() + + def anonymous_struct_fields(self): + if self.fldtypes is not None: + for name, type in zip(self.fldnames, self.fldtypes): + if name == '' and isinstance(type, StructOrUnion): + yield type + + def enumfields(self, expand_anonymous_struct_union=True): + fldquals = self.fldquals + if fldquals is None: + fldquals = (0,) * len(self.fldnames) + for name, type, bitsize, quals in zip(self.fldnames, self.fldtypes, + self.fldbitsize, fldquals): + if (name == '' and isinstance(type, StructOrUnion) + and expand_anonymous_struct_union): + # nested anonymous struct/union + for result in type.enumfields(): + yield result + else: + yield (name, type, bitsize, quals) + + def force_flatten(self): + # force the struct or union to have a declaration that lists + # directly all fields returned by enumfields(), flattening + # nested anonymous structs/unions. + names = [] + types = [] + bitsizes = [] + fldquals = [] + for name, type, bitsize, quals in self.enumfields(): + names.append(name) + types.append(type) + bitsizes.append(bitsize) + fldquals.append(quals) + self.fldnames = tuple(names) + self.fldtypes = tuple(types) + self.fldbitsize = tuple(bitsizes) + self.fldquals = tuple(fldquals) + + def get_cached_btype(self, ffi, finishlist, can_delay=False): + BType = StructOrUnionOrEnum.get_cached_btype(self, ffi, finishlist, + can_delay) + if not can_delay: + self.finish_backend_type(ffi, finishlist) + return BType + + def finish_backend_type(self, ffi, finishlist): + if self.completed: + if self.completed != 2: + raise NotImplementedError("recursive structure declaration " + "for '%s'" % (self.name,)) + return + BType = ffi._cached_btypes[self] + # + self.completed = 1 + # + if self.fldtypes is None: + pass # not completing it: it's an opaque struct + # + elif self.fixedlayout is None: + fldtypes = [tp.get_cached_btype(ffi, finishlist) + for tp in self.fldtypes] + lst = list(zip(self.fldnames, fldtypes, self.fldbitsize)) + extra_flags = () + if self.packed: + if self.packed == 1: + extra_flags = (8,) # SF_PACKED + else: + extra_flags = (0, self.packed) + ffi._backend.complete_struct_or_union(BType, lst, self, + -1, -1, *extra_flags) + # + else: + fldtypes = [] + fieldofs, fieldsize, totalsize, totalalignment = self.fixedlayout + for i in range(len(self.fldnames)): + fsize = fieldsize[i] + ftype = self.fldtypes[i] + # + if isinstance(ftype, ArrayType) and ftype.length_is_unknown(): + # fix the length to match the total size + BItemType = ftype.item.get_cached_btype(ffi, finishlist) + nlen, nrest = divmod(fsize, ffi.sizeof(BItemType)) + if nrest != 0: + self._verification_error( + "field '%s.%s' has a bogus size?" % ( + self.name, self.fldnames[i] or '{}')) + ftype = ftype.resolve_length(nlen) + self.fldtypes = (self.fldtypes[:i] + (ftype,) + + self.fldtypes[i+1:]) + # + BFieldType = ftype.get_cached_btype(ffi, finishlist) + if isinstance(ftype, ArrayType) and ftype.length is None: + assert fsize == 0 + else: + bitemsize = ffi.sizeof(BFieldType) + if bitemsize != fsize: + self._verification_error( + "field '%s.%s' is declared as %d bytes, but is " + "really %d bytes" % (self.name, + self.fldnames[i] or '{}', + bitemsize, fsize)) + fldtypes.append(BFieldType) + # + lst = list(zip(self.fldnames, fldtypes, self.fldbitsize, fieldofs)) + ffi._backend.complete_struct_or_union(BType, lst, self, + totalsize, totalalignment) + self.completed = 2 + + def _verification_error(self, msg): + raise VerificationError(msg) + + def check_not_partial(self): + if self.partial and self.fixedlayout is None: + raise VerificationMissing(self._get_c_name()) + + def build_backend_type(self, ffi, finishlist): + self.check_not_partial() + finishlist.append(self) + # + return global_cache(self, ffi, 'new_%s_type' % self.kind, + self.get_official_name(), key=self) + + +class StructType(StructOrUnion): + kind = 'struct' + + +class UnionType(StructOrUnion): + kind = 'union' + + +class EnumType(StructOrUnionOrEnum): + kind = 'enum' + partial = False + partial_resolved = False + + def __init__(self, name, enumerators, enumvalues, baseinttype=None): + self.name = name + self.enumerators = enumerators + self.enumvalues = enumvalues + self.baseinttype = baseinttype + self.build_c_name_with_marker() + + def force_the_name(self, forcename): + StructOrUnionOrEnum.force_the_name(self, forcename) + if self.forcename is None: + name = self.get_official_name() + self.forcename = '$' + name.replace(' ', '_') + + def check_not_partial(self): + if self.partial and not self.partial_resolved: + raise VerificationMissing(self._get_c_name()) + + def build_backend_type(self, ffi, finishlist): + self.check_not_partial() + base_btype = self.build_baseinttype(ffi, finishlist) + return global_cache(self, ffi, 'new_enum_type', + self.get_official_name(), + self.enumerators, self.enumvalues, + base_btype, key=self) + + def build_baseinttype(self, ffi, finishlist): + if self.baseinttype is not None: + return self.baseinttype.get_cached_btype(ffi, finishlist) + # + if self.enumvalues: + smallest_value = min(self.enumvalues) + largest_value = max(self.enumvalues) + else: + import warnings + try: + # XXX! The goal is to ensure that the warnings.warn() + # will not suppress the warning. We want to get it + # several times if we reach this point several times. + __warningregistry__.clear() + except NameError: + pass + warnings.warn("%r has no values explicitly defined; " + "guessing that it is equivalent to 'unsigned int'" + % self._get_c_name()) + smallest_value = largest_value = 0 + if smallest_value < 0: # needs a signed type + sign = 1 + candidate1 = PrimitiveType("int") + candidate2 = PrimitiveType("long") + else: + sign = 0 + candidate1 = PrimitiveType("unsigned int") + candidate2 = PrimitiveType("unsigned long") + btype1 = candidate1.get_cached_btype(ffi, finishlist) + btype2 = candidate2.get_cached_btype(ffi, finishlist) + size1 = ffi.sizeof(btype1) + size2 = ffi.sizeof(btype2) + if (smallest_value >= ((-1) << (8*size1-1)) and + largest_value < (1 << (8*size1-sign))): + return btype1 + if (smallest_value >= ((-1) << (8*size2-1)) and + largest_value < (1 << (8*size2-sign))): + return btype2 + raise CDefError("%s values don't all fit into either 'long' " + "or 'unsigned long'" % self._get_c_name()) + +def unknown_type(name, structname=None): + if structname is None: + structname = '$%s' % name + tp = StructType(structname, None, None, None) + tp.force_the_name(name) + tp.origin = "unknown_type" + return tp + +def unknown_ptr_type(name, structname=None): + if structname is None: + structname = '$$%s' % name + tp = StructType(structname, None, None, None) + return NamedPointerType(tp, name) + + +global_lock = allocate_lock() +_typecache_cffi_backend = weakref.WeakValueDictionary() + +def get_typecache(backend): + # returns _typecache_cffi_backend if backend is the _cffi_backend + # module, or type(backend).__typecache if backend is an instance of + # CTypesBackend (or some FakeBackend class during tests) + if isinstance(backend, types.ModuleType): + return _typecache_cffi_backend + with global_lock: + if not hasattr(type(backend), '__typecache'): + type(backend).__typecache = weakref.WeakValueDictionary() + return type(backend).__typecache + +def global_cache(srctype, ffi, funcname, *args, **kwds): + key = kwds.pop('key', (funcname, args)) + assert not kwds + try: + return ffi._typecache[key] + except KeyError: + pass + try: + res = getattr(ffi._backend, funcname)(*args) + except NotImplementedError as e: + raise NotImplementedError("%s: %r: %s" % (funcname, srctype, e)) + # note that setdefault() on WeakValueDictionary is not atomic + # and contains a rare bug (http://bugs.python.org/issue19542); + # we have to use a lock and do it ourselves + cache = ffi._typecache + with global_lock: + res1 = cache.get(key) + if res1 is None: + cache[key] = res + return res + else: + return res1 + +def pointer_cache(ffi, BType): + return global_cache('?', ffi, 'new_pointer_type', BType) + +def attach_exception_info(e, name): + if e.args and type(e.args[0]) is str: + e.args = ('%s: %s' % (name, e.args[0]),) + e.args[1:] diff --git a/.venv/Lib/site-packages/cffi/parse_c_type.h b/.venv/Lib/site-packages/cffi/parse_c_type.h new file mode 100644 index 00000000..84e4ef85 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/parse_c_type.h @@ -0,0 +1,181 @@ + +/* This part is from file 'cffi/parse_c_type.h'. It is copied at the + beginning of C sources generated by CFFI's ffi.set_source(). */ + +typedef void *_cffi_opcode_t; + +#define _CFFI_OP(opcode, arg) (_cffi_opcode_t)(opcode | (((uintptr_t)(arg)) << 8)) +#define _CFFI_GETOP(cffi_opcode) ((unsigned char)(uintptr_t)cffi_opcode) +#define _CFFI_GETARG(cffi_opcode) (((intptr_t)cffi_opcode) >> 8) + +#define _CFFI_OP_PRIMITIVE 1 +#define _CFFI_OP_POINTER 3 +#define _CFFI_OP_ARRAY 5 +#define _CFFI_OP_OPEN_ARRAY 7 +#define _CFFI_OP_STRUCT_UNION 9 +#define _CFFI_OP_ENUM 11 +#define _CFFI_OP_FUNCTION 13 +#define _CFFI_OP_FUNCTION_END 15 +#define _CFFI_OP_NOOP 17 +#define _CFFI_OP_BITFIELD 19 +#define _CFFI_OP_TYPENAME 21 +#define _CFFI_OP_CPYTHON_BLTN_V 23 // varargs +#define _CFFI_OP_CPYTHON_BLTN_N 25 // noargs +#define _CFFI_OP_CPYTHON_BLTN_O 27 // O (i.e. a single arg) +#define _CFFI_OP_CONSTANT 29 +#define _CFFI_OP_CONSTANT_INT 31 +#define _CFFI_OP_GLOBAL_VAR 33 +#define _CFFI_OP_DLOPEN_FUNC 35 +#define _CFFI_OP_DLOPEN_CONST 37 +#define _CFFI_OP_GLOBAL_VAR_F 39 +#define _CFFI_OP_EXTERN_PYTHON 41 + +#define _CFFI_PRIM_VOID 0 +#define _CFFI_PRIM_BOOL 1 +#define _CFFI_PRIM_CHAR 2 +#define _CFFI_PRIM_SCHAR 3 +#define _CFFI_PRIM_UCHAR 4 +#define _CFFI_PRIM_SHORT 5 +#define _CFFI_PRIM_USHORT 6 +#define _CFFI_PRIM_INT 7 +#define _CFFI_PRIM_UINT 8 +#define _CFFI_PRIM_LONG 9 +#define _CFFI_PRIM_ULONG 10 +#define _CFFI_PRIM_LONGLONG 11 +#define _CFFI_PRIM_ULONGLONG 12 +#define _CFFI_PRIM_FLOAT 13 +#define _CFFI_PRIM_DOUBLE 14 +#define _CFFI_PRIM_LONGDOUBLE 15 + +#define _CFFI_PRIM_WCHAR 16 +#define _CFFI_PRIM_INT8 17 +#define _CFFI_PRIM_UINT8 18 +#define _CFFI_PRIM_INT16 19 +#define _CFFI_PRIM_UINT16 20 +#define _CFFI_PRIM_INT32 21 +#define _CFFI_PRIM_UINT32 22 +#define _CFFI_PRIM_INT64 23 +#define _CFFI_PRIM_UINT64 24 +#define _CFFI_PRIM_INTPTR 25 +#define _CFFI_PRIM_UINTPTR 26 +#define _CFFI_PRIM_PTRDIFF 27 +#define _CFFI_PRIM_SIZE 28 +#define _CFFI_PRIM_SSIZE 29 +#define _CFFI_PRIM_INT_LEAST8 30 +#define _CFFI_PRIM_UINT_LEAST8 31 +#define _CFFI_PRIM_INT_LEAST16 32 +#define _CFFI_PRIM_UINT_LEAST16 33 +#define _CFFI_PRIM_INT_LEAST32 34 +#define _CFFI_PRIM_UINT_LEAST32 35 +#define _CFFI_PRIM_INT_LEAST64 36 +#define _CFFI_PRIM_UINT_LEAST64 37 +#define _CFFI_PRIM_INT_FAST8 38 +#define _CFFI_PRIM_UINT_FAST8 39 +#define _CFFI_PRIM_INT_FAST16 40 +#define _CFFI_PRIM_UINT_FAST16 41 +#define _CFFI_PRIM_INT_FAST32 42 +#define _CFFI_PRIM_UINT_FAST32 43 +#define _CFFI_PRIM_INT_FAST64 44 +#define _CFFI_PRIM_UINT_FAST64 45 +#define _CFFI_PRIM_INTMAX 46 +#define _CFFI_PRIM_UINTMAX 47 +#define _CFFI_PRIM_FLOATCOMPLEX 48 +#define _CFFI_PRIM_DOUBLECOMPLEX 49 +#define _CFFI_PRIM_CHAR16 50 +#define _CFFI_PRIM_CHAR32 51 + +#define _CFFI__NUM_PRIM 52 +#define _CFFI__UNKNOWN_PRIM (-1) +#define _CFFI__UNKNOWN_FLOAT_PRIM (-2) +#define _CFFI__UNKNOWN_LONG_DOUBLE (-3) + +#define _CFFI__IO_FILE_STRUCT (-1) + + +struct _cffi_global_s { + const char *name; + void *address; + _cffi_opcode_t type_op; + void *size_or_direct_fn; // OP_GLOBAL_VAR: size, or 0 if unknown + // OP_CPYTHON_BLTN_*: addr of direct function +}; + +struct _cffi_getconst_s { + unsigned long long value; + const struct _cffi_type_context_s *ctx; + int gindex; +}; + +struct _cffi_struct_union_s { + const char *name; + int type_index; // -> _cffi_types, on a OP_STRUCT_UNION + int flags; // _CFFI_F_* flags below + size_t size; + int alignment; + int first_field_index; // -> _cffi_fields array + int num_fields; +}; +#define _CFFI_F_UNION 0x01 // is a union, not a struct +#define _CFFI_F_CHECK_FIELDS 0x02 // complain if fields are not in the + // "standard layout" or if some are missing +#define _CFFI_F_PACKED 0x04 // for CHECK_FIELDS, assume a packed struct +#define _CFFI_F_EXTERNAL 0x08 // in some other ffi.include() +#define _CFFI_F_OPAQUE 0x10 // opaque + +struct _cffi_field_s { + const char *name; + size_t field_offset; + size_t field_size; + _cffi_opcode_t field_type_op; +}; + +struct _cffi_enum_s { + const char *name; + int type_index; // -> _cffi_types, on a OP_ENUM + int type_prim; // _CFFI_PRIM_xxx + const char *enumerators; // comma-delimited string +}; + +struct _cffi_typename_s { + const char *name; + int type_index; /* if opaque, points to a possibly artificial + OP_STRUCT which is itself opaque */ +}; + +struct _cffi_type_context_s { + _cffi_opcode_t *types; + const struct _cffi_global_s *globals; + const struct _cffi_field_s *fields; + const struct _cffi_struct_union_s *struct_unions; + const struct _cffi_enum_s *enums; + const struct _cffi_typename_s *typenames; + int num_globals; + int num_struct_unions; + int num_enums; + int num_typenames; + const char *const *includes; + int num_types; + int flags; /* future extension */ +}; + +struct _cffi_parse_info_s { + const struct _cffi_type_context_s *ctx; + _cffi_opcode_t *output; + unsigned int output_size; + size_t error_location; + const char *error_message; +}; + +struct _cffi_externpy_s { + const char *name; + size_t size_of_result; + void *reserved1, *reserved2; +}; + +#ifdef _CFFI_INTERNAL +static int parse_c_type(struct _cffi_parse_info_s *info, const char *input); +static int search_in_globals(const struct _cffi_type_context_s *ctx, + const char *search, size_t search_len); +static int search_in_struct_unions(const struct _cffi_type_context_s *ctx, + const char *search, size_t search_len); +#endif diff --git a/.venv/Lib/site-packages/cffi/pkgconfig.py b/.venv/Lib/site-packages/cffi/pkgconfig.py new file mode 100644 index 00000000..5c93f15a --- /dev/null +++ b/.venv/Lib/site-packages/cffi/pkgconfig.py @@ -0,0 +1,121 @@ +# pkg-config, https://www.freedesktop.org/wiki/Software/pkg-config/ integration for cffi +import sys, os, subprocess + +from .error import PkgConfigError + + +def merge_flags(cfg1, cfg2): + """Merge values from cffi config flags cfg2 to cf1 + + Example: + merge_flags({"libraries": ["one"]}, {"libraries": ["two"]}) + {"libraries": ["one", "two"]} + """ + for key, value in cfg2.items(): + if key not in cfg1: + cfg1[key] = value + else: + if not isinstance(cfg1[key], list): + raise TypeError("cfg1[%r] should be a list of strings" % (key,)) + if not isinstance(value, list): + raise TypeError("cfg2[%r] should be a list of strings" % (key,)) + cfg1[key].extend(value) + return cfg1 + + +def call(libname, flag, encoding=sys.getfilesystemencoding()): + """Calls pkg-config and returns the output if found + """ + a = ["pkg-config", "--print-errors"] + a.append(flag) + a.append(libname) + try: + pc = subprocess.Popen(a, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except EnvironmentError as e: + raise PkgConfigError("cannot run pkg-config: %s" % (str(e).strip(),)) + + bout, berr = pc.communicate() + if pc.returncode != 0: + try: + berr = berr.decode(encoding) + except Exception: + pass + raise PkgConfigError(berr.strip()) + + if sys.version_info >= (3,) and not isinstance(bout, str): # Python 3.x + try: + bout = bout.decode(encoding) + except UnicodeDecodeError: + raise PkgConfigError("pkg-config %s %s returned bytes that cannot " + "be decoded with encoding %r:\n%r" % + (flag, libname, encoding, bout)) + + if os.altsep != '\\' and '\\' in bout: + raise PkgConfigError("pkg-config %s %s returned an unsupported " + "backslash-escaped output:\n%r" % + (flag, libname, bout)) + return bout + + +def flags_from_pkgconfig(libs): + r"""Return compiler line flags for FFI.set_source based on pkg-config output + + Usage + ... + ffibuilder.set_source("_foo", pkgconfig = ["libfoo", "libbar >= 1.8.3"]) + + If pkg-config is installed on build machine, then arguments include_dirs, + library_dirs, libraries, define_macros, extra_compile_args and + extra_link_args are extended with an output of pkg-config for libfoo and + libbar. + + Raises PkgConfigError in case the pkg-config call fails. + """ + + def get_include_dirs(string): + return [x[2:] for x in string.split() if x.startswith("-I")] + + def get_library_dirs(string): + return [x[2:] for x in string.split() if x.startswith("-L")] + + def get_libraries(string): + return [x[2:] for x in string.split() if x.startswith("-l")] + + # convert -Dfoo=bar to list of tuples [("foo", "bar")] expected by distutils + def get_macros(string): + def _macro(x): + x = x[2:] # drop "-D" + if '=' in x: + return tuple(x.split("=", 1)) # "-Dfoo=bar" => ("foo", "bar") + else: + return (x, None) # "-Dfoo" => ("foo", None) + return [_macro(x) for x in string.split() if x.startswith("-D")] + + def get_other_cflags(string): + return [x for x in string.split() if not x.startswith("-I") and + not x.startswith("-D")] + + def get_other_libs(string): + return [x for x in string.split() if not x.startswith("-L") and + not x.startswith("-l")] + + # return kwargs for given libname + def kwargs(libname): + fse = sys.getfilesystemencoding() + all_cflags = call(libname, "--cflags") + all_libs = call(libname, "--libs") + return { + "include_dirs": get_include_dirs(all_cflags), + "library_dirs": get_library_dirs(all_libs), + "libraries": get_libraries(all_libs), + "define_macros": get_macros(all_cflags), + "extra_compile_args": get_other_cflags(all_cflags), + "extra_link_args": get_other_libs(all_libs), + } + + # merge all arguments together + ret = {} + for libname in libs: + lib_flags = kwargs(libname) + merge_flags(ret, lib_flags) + return ret diff --git a/.venv/Lib/site-packages/cffi/recompiler.py b/.venv/Lib/site-packages/cffi/recompiler.py new file mode 100644 index 00000000..4167bc05 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/recompiler.py @@ -0,0 +1,1581 @@ +import os, sys, io +from . import ffiplatform, model +from .error import VerificationError +from .cffi_opcode import * + +VERSION_BASE = 0x2601 +VERSION_EMBEDDED = 0x2701 +VERSION_CHAR16CHAR32 = 0x2801 + +USE_LIMITED_API = (sys.platform != 'win32' or sys.version_info < (3, 0) or + sys.version_info >= (3, 5)) + + +class GlobalExpr: + def __init__(self, name, address, type_op, size=0, check_value=0): + self.name = name + self.address = address + self.type_op = type_op + self.size = size + self.check_value = check_value + + def as_c_expr(self): + return ' { "%s", (void *)%s, %s, (void *)%s },' % ( + self.name, self.address, self.type_op.as_c_expr(), self.size) + + def as_python_expr(self): + return "b'%s%s',%d" % (self.type_op.as_python_bytes(), self.name, + self.check_value) + +class FieldExpr: + def __init__(self, name, field_offset, field_size, fbitsize, field_type_op): + self.name = name + self.field_offset = field_offset + self.field_size = field_size + self.fbitsize = fbitsize + self.field_type_op = field_type_op + + def as_c_expr(self): + spaces = " " * len(self.name) + return (' { "%s", %s,\n' % (self.name, self.field_offset) + + ' %s %s,\n' % (spaces, self.field_size) + + ' %s %s },' % (spaces, self.field_type_op.as_c_expr())) + + def as_python_expr(self): + raise NotImplementedError + + def as_field_python_expr(self): + if self.field_type_op.op == OP_NOOP: + size_expr = '' + elif self.field_type_op.op == OP_BITFIELD: + size_expr = format_four_bytes(self.fbitsize) + else: + raise NotImplementedError + return "b'%s%s%s'" % (self.field_type_op.as_python_bytes(), + size_expr, + self.name) + +class StructUnionExpr: + def __init__(self, name, type_index, flags, size, alignment, comment, + first_field_index, c_fields): + self.name = name + self.type_index = type_index + self.flags = flags + self.size = size + self.alignment = alignment + self.comment = comment + self.first_field_index = first_field_index + self.c_fields = c_fields + + def as_c_expr(self): + return (' { "%s", %d, %s,' % (self.name, self.type_index, self.flags) + + '\n %s, %s, ' % (self.size, self.alignment) + + '%d, %d ' % (self.first_field_index, len(self.c_fields)) + + ('/* %s */ ' % self.comment if self.comment else '') + + '},') + + def as_python_expr(self): + flags = eval(self.flags, G_FLAGS) + fields_expr = [c_field.as_field_python_expr() + for c_field in self.c_fields] + return "(b'%s%s%s',%s)" % ( + format_four_bytes(self.type_index), + format_four_bytes(flags), + self.name, + ','.join(fields_expr)) + +class EnumExpr: + def __init__(self, name, type_index, size, signed, allenums): + self.name = name + self.type_index = type_index + self.size = size + self.signed = signed + self.allenums = allenums + + def as_c_expr(self): + return (' { "%s", %d, _cffi_prim_int(%s, %s),\n' + ' "%s" },' % (self.name, self.type_index, + self.size, self.signed, self.allenums)) + + def as_python_expr(self): + prim_index = { + (1, 0): PRIM_UINT8, (1, 1): PRIM_INT8, + (2, 0): PRIM_UINT16, (2, 1): PRIM_INT16, + (4, 0): PRIM_UINT32, (4, 1): PRIM_INT32, + (8, 0): PRIM_UINT64, (8, 1): PRIM_INT64, + }[self.size, self.signed] + return "b'%s%s%s\\x00%s'" % (format_four_bytes(self.type_index), + format_four_bytes(prim_index), + self.name, self.allenums) + +class TypenameExpr: + def __init__(self, name, type_index): + self.name = name + self.type_index = type_index + + def as_c_expr(self): + return ' { "%s", %d },' % (self.name, self.type_index) + + def as_python_expr(self): + return "b'%s%s'" % (format_four_bytes(self.type_index), self.name) + + +# ____________________________________________________________ + + +class Recompiler: + _num_externpy = 0 + + def __init__(self, ffi, module_name, target_is_python=False): + self.ffi = ffi + self.module_name = module_name + self.target_is_python = target_is_python + self._version = VERSION_BASE + + def needs_version(self, ver): + self._version = max(self._version, ver) + + def collect_type_table(self): + self._typesdict = {} + self._generate("collecttype") + # + all_decls = sorted(self._typesdict, key=str) + # + # prepare all FUNCTION bytecode sequences first + self.cffi_types = [] + for tp in all_decls: + if tp.is_raw_function: + assert self._typesdict[tp] is None + self._typesdict[tp] = len(self.cffi_types) + self.cffi_types.append(tp) # placeholder + for tp1 in tp.args: + assert isinstance(tp1, (model.VoidType, + model.BasePrimitiveType, + model.PointerType, + model.StructOrUnionOrEnum, + model.FunctionPtrType)) + if self._typesdict[tp1] is None: + self._typesdict[tp1] = len(self.cffi_types) + self.cffi_types.append(tp1) # placeholder + self.cffi_types.append('END') # placeholder + # + # prepare all OTHER bytecode sequences + for tp in all_decls: + if not tp.is_raw_function and self._typesdict[tp] is None: + self._typesdict[tp] = len(self.cffi_types) + self.cffi_types.append(tp) # placeholder + if tp.is_array_type and tp.length is not None: + self.cffi_types.append('LEN') # placeholder + assert None not in self._typesdict.values() + # + # collect all structs and unions and enums + self._struct_unions = {} + self._enums = {} + for tp in all_decls: + if isinstance(tp, model.StructOrUnion): + self._struct_unions[tp] = None + elif isinstance(tp, model.EnumType): + self._enums[tp] = None + for i, tp in enumerate(sorted(self._struct_unions, + key=lambda tp: tp.name)): + self._struct_unions[tp] = i + for i, tp in enumerate(sorted(self._enums, + key=lambda tp: tp.name)): + self._enums[tp] = i + # + # emit all bytecode sequences now + for tp in all_decls: + method = getattr(self, '_emit_bytecode_' + tp.__class__.__name__) + method(tp, self._typesdict[tp]) + # + # consistency check + for op in self.cffi_types: + assert isinstance(op, CffiOp) + self.cffi_types = tuple(self.cffi_types) # don't change any more + + def _enum_fields(self, tp): + # When producing C, expand all anonymous struct/union fields. + # That's necessary to have C code checking the offsets of the + # individual fields contained in them. When producing Python, + # don't do it and instead write it like it is, with the + # corresponding fields having an empty name. Empty names are + # recognized at runtime when we import the generated Python + # file. + expand_anonymous_struct_union = not self.target_is_python + return tp.enumfields(expand_anonymous_struct_union) + + def _do_collect_type(self, tp): + if not isinstance(tp, model.BaseTypeByIdentity): + if isinstance(tp, tuple): + for x in tp: + self._do_collect_type(x) + return + if tp not in self._typesdict: + self._typesdict[tp] = None + if isinstance(tp, model.FunctionPtrType): + self._do_collect_type(tp.as_raw_function()) + elif isinstance(tp, model.StructOrUnion): + if tp.fldtypes is not None and ( + tp not in self.ffi._parser._included_declarations): + for name1, tp1, _, _ in self._enum_fields(tp): + self._do_collect_type(self._field_type(tp, name1, tp1)) + else: + for _, x in tp._get_items(): + self._do_collect_type(x) + + def _generate(self, step_name): + lst = self.ffi._parser._declarations.items() + for name, (tp, quals) in sorted(lst): + kind, realname = name.split(' ', 1) + try: + method = getattr(self, '_generate_cpy_%s_%s' % (kind, + step_name)) + except AttributeError: + raise VerificationError( + "not implemented in recompile(): %r" % name) + try: + self._current_quals = quals + method(tp, realname) + except Exception as e: + model.attach_exception_info(e, name) + raise + + # ---------- + + ALL_STEPS = ["global", "field", "struct_union", "enum", "typename"] + + def collect_step_tables(self): + # collect the declarations for '_cffi_globals', '_cffi_typenames', etc. + self._lsts = {} + for step_name in self.ALL_STEPS: + self._lsts[step_name] = [] + self._seen_struct_unions = set() + self._generate("ctx") + self._add_missing_struct_unions() + # + for step_name in self.ALL_STEPS: + lst = self._lsts[step_name] + if step_name != "field": + lst.sort(key=lambda entry: entry.name) + self._lsts[step_name] = tuple(lst) # don't change any more + # + # check for a possible internal inconsistency: _cffi_struct_unions + # should have been generated with exactly self._struct_unions + lst = self._lsts["struct_union"] + for tp, i in self._struct_unions.items(): + assert i < len(lst) + assert lst[i].name == tp.name + assert len(lst) == len(self._struct_unions) + # same with enums + lst = self._lsts["enum"] + for tp, i in self._enums.items(): + assert i < len(lst) + assert lst[i].name == tp.name + assert len(lst) == len(self._enums) + + # ---------- + + def _prnt(self, what=''): + self._f.write(what + '\n') + + def write_source_to_f(self, f, preamble): + if self.target_is_python: + assert preamble is None + self.write_py_source_to_f(f) + else: + assert preamble is not None + self.write_c_source_to_f(f, preamble) + + def _rel_readlines(self, filename): + g = open(os.path.join(os.path.dirname(__file__), filename), 'r') + lines = g.readlines() + g.close() + return lines + + def write_c_source_to_f(self, f, preamble): + self._f = f + prnt = self._prnt + if self.ffi._embedding is not None: + prnt('#define _CFFI_USE_EMBEDDING') + if not USE_LIMITED_API: + prnt('#define _CFFI_NO_LIMITED_API') + # + # first the '#include' (actually done by inlining the file's content) + lines = self._rel_readlines('_cffi_include.h') + i = lines.index('#include "parse_c_type.h"\n') + lines[i:i+1] = self._rel_readlines('parse_c_type.h') + prnt(''.join(lines)) + # + # if we have ffi._embedding != None, we give it here as a macro + # and include an extra file + base_module_name = self.module_name.split('.')[-1] + if self.ffi._embedding is not None: + prnt('#define _CFFI_MODULE_NAME "%s"' % (self.module_name,)) + prnt('static const char _CFFI_PYTHON_STARTUP_CODE[] = {') + self._print_string_literal_in_array(self.ffi._embedding) + prnt('0 };') + prnt('#ifdef PYPY_VERSION') + prnt('# define _CFFI_PYTHON_STARTUP_FUNC _cffi_pypyinit_%s' % ( + base_module_name,)) + prnt('#elif PY_MAJOR_VERSION >= 3') + prnt('# define _CFFI_PYTHON_STARTUP_FUNC PyInit_%s' % ( + base_module_name,)) + prnt('#else') + prnt('# define _CFFI_PYTHON_STARTUP_FUNC init%s' % ( + base_module_name,)) + prnt('#endif') + lines = self._rel_readlines('_embedding.h') + i = lines.index('#include "_cffi_errors.h"\n') + lines[i:i+1] = self._rel_readlines('_cffi_errors.h') + prnt(''.join(lines)) + self.needs_version(VERSION_EMBEDDED) + # + # then paste the C source given by the user, verbatim. + prnt('/************************************************************/') + prnt() + prnt(preamble) + prnt() + prnt('/************************************************************/') + prnt() + # + # the declaration of '_cffi_types' + prnt('static void *_cffi_types[] = {') + typeindex2type = dict([(i, tp) for (tp, i) in self._typesdict.items()]) + for i, op in enumerate(self.cffi_types): + comment = '' + if i in typeindex2type: + comment = ' // ' + typeindex2type[i]._get_c_name() + prnt('/* %2d */ %s,%s' % (i, op.as_c_expr(), comment)) + if not self.cffi_types: + prnt(' 0') + prnt('};') + prnt() + # + # call generate_cpy_xxx_decl(), for every xxx found from + # ffi._parser._declarations. This generates all the functions. + self._seen_constants = set() + self._generate("decl") + # + # the declaration of '_cffi_globals' and '_cffi_typenames' + nums = {} + for step_name in self.ALL_STEPS: + lst = self._lsts[step_name] + nums[step_name] = len(lst) + if nums[step_name] > 0: + prnt('static const struct _cffi_%s_s _cffi_%ss[] = {' % ( + step_name, step_name)) + for entry in lst: + prnt(entry.as_c_expr()) + prnt('};') + prnt() + # + # the declaration of '_cffi_includes' + if self.ffi._included_ffis: + prnt('static const char * const _cffi_includes[] = {') + for ffi_to_include in self.ffi._included_ffis: + try: + included_module_name, included_source = ( + ffi_to_include._assigned_source[:2]) + except AttributeError: + raise VerificationError( + "ffi object %r includes %r, but the latter has not " + "been prepared with set_source()" % ( + self.ffi, ffi_to_include,)) + if included_source is None: + raise VerificationError( + "not implemented yet: ffi.include() of a Python-based " + "ffi inside a C-based ffi") + prnt(' "%s",' % (included_module_name,)) + prnt(' NULL') + prnt('};') + prnt() + # + # the declaration of '_cffi_type_context' + prnt('static const struct _cffi_type_context_s _cffi_type_context = {') + prnt(' _cffi_types,') + for step_name in self.ALL_STEPS: + if nums[step_name] > 0: + prnt(' _cffi_%ss,' % step_name) + else: + prnt(' NULL, /* no %ss */' % step_name) + for step_name in self.ALL_STEPS: + if step_name != "field": + prnt(' %d, /* num_%ss */' % (nums[step_name], step_name)) + if self.ffi._included_ffis: + prnt(' _cffi_includes,') + else: + prnt(' NULL, /* no includes */') + prnt(' %d, /* num_types */' % (len(self.cffi_types),)) + flags = 0 + if self._num_externpy > 0 or self.ffi._embedding is not None: + flags |= 1 # set to mean that we use extern "Python" + prnt(' %d, /* flags */' % flags) + prnt('};') + prnt() + # + # the init function + prnt('#ifdef __GNUC__') + prnt('# pragma GCC visibility push(default) /* for -fvisibility= */') + prnt('#endif') + prnt() + prnt('#ifdef PYPY_VERSION') + prnt('PyMODINIT_FUNC') + prnt('_cffi_pypyinit_%s(const void *p[])' % (base_module_name,)) + prnt('{') + if flags & 1: + prnt(' if (((intptr_t)p[0]) >= 0x0A03) {') + prnt(' _cffi_call_python_org = ' + '(void(*)(struct _cffi_externpy_s *, char *))p[1];') + prnt(' }') + prnt(' p[0] = (const void *)0x%x;' % self._version) + prnt(' p[1] = &_cffi_type_context;') + prnt('#if PY_MAJOR_VERSION >= 3') + prnt(' return NULL;') + prnt('#endif') + prnt('}') + # on Windows, distutils insists on putting init_cffi_xyz in + # 'export_symbols', so instead of fighting it, just give up and + # give it one + prnt('# ifdef _MSC_VER') + prnt(' PyMODINIT_FUNC') + prnt('# if PY_MAJOR_VERSION >= 3') + prnt(' PyInit_%s(void) { return NULL; }' % (base_module_name,)) + prnt('# else') + prnt(' init%s(void) { }' % (base_module_name,)) + prnt('# endif') + prnt('# endif') + prnt('#elif PY_MAJOR_VERSION >= 3') + prnt('PyMODINIT_FUNC') + prnt('PyInit_%s(void)' % (base_module_name,)) + prnt('{') + prnt(' return _cffi_init("%s", 0x%x, &_cffi_type_context);' % ( + self.module_name, self._version)) + prnt('}') + prnt('#else') + prnt('PyMODINIT_FUNC') + prnt('init%s(void)' % (base_module_name,)) + prnt('{') + prnt(' _cffi_init("%s", 0x%x, &_cffi_type_context);' % ( + self.module_name, self._version)) + prnt('}') + prnt('#endif') + prnt() + prnt('#ifdef __GNUC__') + prnt('# pragma GCC visibility pop') + prnt('#endif') + self._version = None + + def _to_py(self, x): + if isinstance(x, str): + return "b'%s'" % (x,) + if isinstance(x, (list, tuple)): + rep = [self._to_py(item) for item in x] + if len(rep) == 1: + rep.append('') + return "(%s)" % (','.join(rep),) + return x.as_python_expr() # Py2: unicode unexpected; Py3: bytes unexp. + + def write_py_source_to_f(self, f): + self._f = f + prnt = self._prnt + # + # header + prnt("# auto-generated file") + prnt("import _cffi_backend") + # + # the 'import' of the included ffis + num_includes = len(self.ffi._included_ffis or ()) + for i in range(num_includes): + ffi_to_include = self.ffi._included_ffis[i] + try: + included_module_name, included_source = ( + ffi_to_include._assigned_source[:2]) + except AttributeError: + raise VerificationError( + "ffi object %r includes %r, but the latter has not " + "been prepared with set_source()" % ( + self.ffi, ffi_to_include,)) + if included_source is not None: + raise VerificationError( + "not implemented yet: ffi.include() of a C-based " + "ffi inside a Python-based ffi") + prnt('from %s import ffi as _ffi%d' % (included_module_name, i)) + prnt() + prnt("ffi = _cffi_backend.FFI('%s'," % (self.module_name,)) + prnt(" _version = 0x%x," % (self._version,)) + self._version = None + # + # the '_types' keyword argument + self.cffi_types = tuple(self.cffi_types) # don't change any more + types_lst = [op.as_python_bytes() for op in self.cffi_types] + prnt(' _types = %s,' % (self._to_py(''.join(types_lst)),)) + typeindex2type = dict([(i, tp) for (tp, i) in self._typesdict.items()]) + # + # the keyword arguments from ALL_STEPS + for step_name in self.ALL_STEPS: + lst = self._lsts[step_name] + if len(lst) > 0 and step_name != "field": + prnt(' _%ss = %s,' % (step_name, self._to_py(lst))) + # + # the '_includes' keyword argument + if num_includes > 0: + prnt(' _includes = (%s,),' % ( + ', '.join(['_ffi%d' % i for i in range(num_includes)]),)) + # + # the footer + prnt(')') + + # ---------- + + def _gettypenum(self, type): + # a KeyError here is a bug. please report it! :-) + return self._typesdict[type] + + def _convert_funcarg_to_c(self, tp, fromvar, tovar, errcode): + extraarg = '' + if isinstance(tp, model.BasePrimitiveType) and not tp.is_complex_type(): + if tp.is_integer_type() and tp.name != '_Bool': + converter = '_cffi_to_c_int' + extraarg = ', %s' % tp.name + elif isinstance(tp, model.UnknownFloatType): + # don't check with is_float_type(): it may be a 'long + # double' here, and _cffi_to_c_double would loose precision + converter = '(%s)_cffi_to_c_double' % (tp.get_c_name(''),) + else: + cname = tp.get_c_name('') + converter = '(%s)_cffi_to_c_%s' % (cname, + tp.name.replace(' ', '_')) + if cname in ('char16_t', 'char32_t'): + self.needs_version(VERSION_CHAR16CHAR32) + errvalue = '-1' + # + elif isinstance(tp, model.PointerType): + self._convert_funcarg_to_c_ptr_or_array(tp, fromvar, + tovar, errcode) + return + # + elif (isinstance(tp, model.StructOrUnionOrEnum) or + isinstance(tp, model.BasePrimitiveType)): + # a struct (not a struct pointer) as a function argument; + # or, a complex (the same code works) + self._prnt(' if (_cffi_to_c((char *)&%s, _cffi_type(%d), %s) < 0)' + % (tovar, self._gettypenum(tp), fromvar)) + self._prnt(' %s;' % errcode) + return + # + elif isinstance(tp, model.FunctionPtrType): + converter = '(%s)_cffi_to_c_pointer' % tp.get_c_name('') + extraarg = ', _cffi_type(%d)' % self._gettypenum(tp) + errvalue = 'NULL' + # + else: + raise NotImplementedError(tp) + # + self._prnt(' %s = %s(%s%s);' % (tovar, converter, fromvar, extraarg)) + self._prnt(' if (%s == (%s)%s && PyErr_Occurred())' % ( + tovar, tp.get_c_name(''), errvalue)) + self._prnt(' %s;' % errcode) + + def _extra_local_variables(self, tp, localvars, freelines): + if isinstance(tp, model.PointerType): + localvars.add('Py_ssize_t datasize') + localvars.add('struct _cffi_freeme_s *large_args_free = NULL') + freelines.add('if (large_args_free != NULL)' + ' _cffi_free_array_arguments(large_args_free);') + + def _convert_funcarg_to_c_ptr_or_array(self, tp, fromvar, tovar, errcode): + self._prnt(' datasize = _cffi_prepare_pointer_call_argument(') + self._prnt(' _cffi_type(%d), %s, (char **)&%s);' % ( + self._gettypenum(tp), fromvar, tovar)) + self._prnt(' if (datasize != 0) {') + self._prnt(' %s = ((size_t)datasize) <= 640 ? ' + '(%s)alloca((size_t)datasize) : NULL;' % ( + tovar, tp.get_c_name(''))) + self._prnt(' if (_cffi_convert_array_argument(_cffi_type(%d), %s, ' + '(char **)&%s,' % (self._gettypenum(tp), fromvar, tovar)) + self._prnt(' datasize, &large_args_free) < 0)') + self._prnt(' %s;' % errcode) + self._prnt(' }') + + def _convert_expr_from_c(self, tp, var, context): + if isinstance(tp, model.BasePrimitiveType): + if tp.is_integer_type() and tp.name != '_Bool': + return '_cffi_from_c_int(%s, %s)' % (var, tp.name) + elif isinstance(tp, model.UnknownFloatType): + return '_cffi_from_c_double(%s)' % (var,) + elif tp.name != 'long double' and not tp.is_complex_type(): + cname = tp.name.replace(' ', '_') + if cname in ('char16_t', 'char32_t'): + self.needs_version(VERSION_CHAR16CHAR32) + return '_cffi_from_c_%s(%s)' % (cname, var) + else: + return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, (model.PointerType, model.FunctionPtrType)): + return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, model.ArrayType): + return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % ( + var, self._gettypenum(model.PointerType(tp.item))) + elif isinstance(tp, model.StructOrUnion): + if tp.fldnames is None: + raise TypeError("'%s' is used as %s, but is opaque" % ( + tp._get_c_name(), context)) + return '_cffi_from_c_struct((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, model.EnumType): + return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + else: + raise NotImplementedError(tp) + + # ---------- + # typedefs + + def _typedef_type(self, tp, name): + return self._global_type(tp, "(*(%s *)0)" % (name,)) + + def _generate_cpy_typedef_collecttype(self, tp, name): + self._do_collect_type(self._typedef_type(tp, name)) + + def _generate_cpy_typedef_decl(self, tp, name): + pass + + def _typedef_ctx(self, tp, name): + type_index = self._typesdict[tp] + self._lsts["typename"].append(TypenameExpr(name, type_index)) + + def _generate_cpy_typedef_ctx(self, tp, name): + tp = self._typedef_type(tp, name) + self._typedef_ctx(tp, name) + if getattr(tp, "origin", None) == "unknown_type": + self._struct_ctx(tp, tp.name, approxname=None) + elif isinstance(tp, model.NamedPointerType): + self._struct_ctx(tp.totype, tp.totype.name, approxname=tp.name, + named_ptr=tp) + + # ---------- + # function declarations + + def _generate_cpy_function_collecttype(self, tp, name): + self._do_collect_type(tp.as_raw_function()) + if tp.ellipsis and not self.target_is_python: + self._do_collect_type(tp) + + def _generate_cpy_function_decl(self, tp, name): + assert not self.target_is_python + assert isinstance(tp, model.FunctionPtrType) + if tp.ellipsis: + # cannot support vararg functions better than this: check for its + # exact type (including the fixed arguments), and build it as a + # constant function pointer (no CPython wrapper) + self._generate_cpy_constant_decl(tp, name) + return + prnt = self._prnt + numargs = len(tp.args) + if numargs == 0: + argname = 'noarg' + elif numargs == 1: + argname = 'arg0' + else: + argname = 'args' + # + # ------------------------------ + # the 'd' version of the function, only for addressof(lib, 'func') + arguments = [] + call_arguments = [] + context = 'argument of %s' % name + for i, type in enumerate(tp.args): + arguments.append(type.get_c_name(' x%d' % i, context)) + call_arguments.append('x%d' % i) + repr_arguments = ', '.join(arguments) + repr_arguments = repr_arguments or 'void' + if tp.abi: + abi = tp.abi + ' ' + else: + abi = '' + name_and_arguments = '%s_cffi_d_%s(%s)' % (abi, name, repr_arguments) + prnt('static %s' % (tp.result.get_c_name(name_and_arguments),)) + prnt('{') + call_arguments = ', '.join(call_arguments) + result_code = 'return ' + if isinstance(tp.result, model.VoidType): + result_code = '' + prnt(' %s%s(%s);' % (result_code, name, call_arguments)) + prnt('}') + # + prnt('#ifndef PYPY_VERSION') # ------------------------------ + # + prnt('static PyObject *') + prnt('_cffi_f_%s(PyObject *self, PyObject *%s)' % (name, argname)) + prnt('{') + # + context = 'argument of %s' % name + for i, type in enumerate(tp.args): + arg = type.get_c_name(' x%d' % i, context) + prnt(' %s;' % arg) + # + localvars = set() + freelines = set() + for type in tp.args: + self._extra_local_variables(type, localvars, freelines) + for decl in sorted(localvars): + prnt(' %s;' % (decl,)) + # + if not isinstance(tp.result, model.VoidType): + result_code = 'result = ' + context = 'result of %s' % name + result_decl = ' %s;' % tp.result.get_c_name(' result', context) + prnt(result_decl) + prnt(' PyObject *pyresult;') + else: + result_decl = None + result_code = '' + # + if len(tp.args) > 1: + rng = range(len(tp.args)) + for i in rng: + prnt(' PyObject *arg%d;' % i) + prnt() + prnt(' if (!PyArg_UnpackTuple(args, "%s", %d, %d, %s))' % ( + name, len(rng), len(rng), + ', '.join(['&arg%d' % i for i in rng]))) + prnt(' return NULL;') + prnt() + # + for i, type in enumerate(tp.args): + self._convert_funcarg_to_c(type, 'arg%d' % i, 'x%d' % i, + 'return NULL') + prnt() + # + prnt(' Py_BEGIN_ALLOW_THREADS') + prnt(' _cffi_restore_errno();') + call_arguments = ['x%d' % i for i in range(len(tp.args))] + call_arguments = ', '.join(call_arguments) + prnt(' { %s%s(%s); }' % (result_code, name, call_arguments)) + prnt(' _cffi_save_errno();') + prnt(' Py_END_ALLOW_THREADS') + prnt() + # + prnt(' (void)self; /* unused */') + if numargs == 0: + prnt(' (void)noarg; /* unused */') + if result_code: + prnt(' pyresult = %s;' % + self._convert_expr_from_c(tp.result, 'result', 'result type')) + for freeline in freelines: + prnt(' ' + freeline) + prnt(' return pyresult;') + else: + for freeline in freelines: + prnt(' ' + freeline) + prnt(' Py_INCREF(Py_None);') + prnt(' return Py_None;') + prnt('}') + # + prnt('#else') # ------------------------------ + # + # the PyPy version: need to replace struct/union arguments with + # pointers, and if the result is a struct/union, insert a first + # arg that is a pointer to the result. We also do that for + # complex args and return type. + def need_indirection(type): + return (isinstance(type, model.StructOrUnion) or + (isinstance(type, model.PrimitiveType) and + type.is_complex_type())) + difference = False + arguments = [] + call_arguments = [] + context = 'argument of %s' % name + for i, type in enumerate(tp.args): + indirection = '' + if need_indirection(type): + indirection = '*' + difference = True + arg = type.get_c_name(' %sx%d' % (indirection, i), context) + arguments.append(arg) + call_arguments.append('%sx%d' % (indirection, i)) + tp_result = tp.result + if need_indirection(tp_result): + context = 'result of %s' % name + arg = tp_result.get_c_name(' *result', context) + arguments.insert(0, arg) + tp_result = model.void_type + result_decl = None + result_code = '*result = ' + difference = True + if difference: + repr_arguments = ', '.join(arguments) + repr_arguments = repr_arguments or 'void' + name_and_arguments = '%s_cffi_f_%s(%s)' % (abi, name, + repr_arguments) + prnt('static %s' % (tp_result.get_c_name(name_and_arguments),)) + prnt('{') + if result_decl: + prnt(result_decl) + call_arguments = ', '.join(call_arguments) + prnt(' { %s%s(%s); }' % (result_code, name, call_arguments)) + if result_decl: + prnt(' return result;') + prnt('}') + else: + prnt('# define _cffi_f_%s _cffi_d_%s' % (name, name)) + # + prnt('#endif') # ------------------------------ + prnt() + + def _generate_cpy_function_ctx(self, tp, name): + if tp.ellipsis and not self.target_is_python: + self._generate_cpy_constant_ctx(tp, name) + return + type_index = self._typesdict[tp.as_raw_function()] + numargs = len(tp.args) + if self.target_is_python: + meth_kind = OP_DLOPEN_FUNC + elif numargs == 0: + meth_kind = OP_CPYTHON_BLTN_N # 'METH_NOARGS' + elif numargs == 1: + meth_kind = OP_CPYTHON_BLTN_O # 'METH_O' + else: + meth_kind = OP_CPYTHON_BLTN_V # 'METH_VARARGS' + self._lsts["global"].append( + GlobalExpr(name, '_cffi_f_%s' % name, + CffiOp(meth_kind, type_index), + size='_cffi_d_%s' % name)) + + # ---------- + # named structs or unions + + def _field_type(self, tp_struct, field_name, tp_field): + if isinstance(tp_field, model.ArrayType): + actual_length = tp_field.length + if actual_length == '...': + ptr_struct_name = tp_struct.get_c_name('*') + actual_length = '_cffi_array_len(((%s)0)->%s)' % ( + ptr_struct_name, field_name) + tp_item = self._field_type(tp_struct, '%s[0]' % field_name, + tp_field.item) + tp_field = model.ArrayType(tp_item, actual_length) + return tp_field + + def _struct_collecttype(self, tp): + self._do_collect_type(tp) + if self.target_is_python: + # also requires nested anon struct/unions in ABI mode, recursively + for fldtype in tp.anonymous_struct_fields(): + self._struct_collecttype(fldtype) + + def _struct_decl(self, tp, cname, approxname): + if tp.fldtypes is None: + return + prnt = self._prnt + checkfuncname = '_cffi_checkfld_%s' % (approxname,) + prnt('_CFFI_UNUSED_FN') + prnt('static void %s(%s *p)' % (checkfuncname, cname)) + prnt('{') + prnt(' /* only to generate compile-time warnings or errors */') + prnt(' (void)p;') + for fname, ftype, fbitsize, fqual in self._enum_fields(tp): + try: + if ftype.is_integer_type() or fbitsize >= 0: + # accept all integers, but complain on float or double + if fname != '': + prnt(" (void)((p->%s) | 0); /* check that '%s.%s' is " + "an integer */" % (fname, cname, fname)) + continue + # only accept exactly the type declared, except that '[]' + # is interpreted as a '*' and so will match any array length. + # (It would also match '*', but that's harder to detect...) + while (isinstance(ftype, model.ArrayType) + and (ftype.length is None or ftype.length == '...')): + ftype = ftype.item + fname = fname + '[0]' + prnt(' { %s = &p->%s; (void)tmp; }' % ( + ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual), + fname)) + except VerificationError as e: + prnt(' /* %s */' % str(e)) # cannot verify it, ignore + prnt('}') + prnt('struct _cffi_align_%s { char x; %s y; };' % (approxname, cname)) + prnt() + + def _struct_ctx(self, tp, cname, approxname, named_ptr=None): + type_index = self._typesdict[tp] + reason_for_not_expanding = None + flags = [] + if isinstance(tp, model.UnionType): + flags.append("_CFFI_F_UNION") + if tp.fldtypes is None: + flags.append("_CFFI_F_OPAQUE") + reason_for_not_expanding = "opaque" + if (tp not in self.ffi._parser._included_declarations and + (named_ptr is None or + named_ptr not in self.ffi._parser._included_declarations)): + if tp.fldtypes is None: + pass # opaque + elif tp.partial or any(tp.anonymous_struct_fields()): + pass # field layout obtained silently from the C compiler + else: + flags.append("_CFFI_F_CHECK_FIELDS") + if tp.packed: + if tp.packed > 1: + raise NotImplementedError( + "%r is declared with 'pack=%r'; only 0 or 1 are " + "supported in API mode (try to use \"...;\", which " + "does not require a 'pack' declaration)" % + (tp, tp.packed)) + flags.append("_CFFI_F_PACKED") + else: + flags.append("_CFFI_F_EXTERNAL") + reason_for_not_expanding = "external" + flags = '|'.join(flags) or '0' + c_fields = [] + if reason_for_not_expanding is None: + enumfields = list(self._enum_fields(tp)) + for fldname, fldtype, fbitsize, fqual in enumfields: + fldtype = self._field_type(tp, fldname, fldtype) + self._check_not_opaque(fldtype, + "field '%s.%s'" % (tp.name, fldname)) + # cname is None for _add_missing_struct_unions() only + op = OP_NOOP + if fbitsize >= 0: + op = OP_BITFIELD + size = '%d /* bits */' % fbitsize + elif cname is None or ( + isinstance(fldtype, model.ArrayType) and + fldtype.length is None): + size = '(size_t)-1' + else: + size = 'sizeof(((%s)0)->%s)' % ( + tp.get_c_name('*') if named_ptr is None + else named_ptr.name, + fldname) + if cname is None or fbitsize >= 0: + offset = '(size_t)-1' + elif named_ptr is not None: + offset = '((char *)&((%s)0)->%s) - (char *)0' % ( + named_ptr.name, fldname) + else: + offset = 'offsetof(%s, %s)' % (tp.get_c_name(''), fldname) + c_fields.append( + FieldExpr(fldname, offset, size, fbitsize, + CffiOp(op, self._typesdict[fldtype]))) + first_field_index = len(self._lsts["field"]) + self._lsts["field"].extend(c_fields) + # + if cname is None: # unknown name, for _add_missing_struct_unions + size = '(size_t)-2' + align = -2 + comment = "unnamed" + else: + if named_ptr is not None: + size = 'sizeof(*(%s)0)' % (named_ptr.name,) + align = '-1 /* unknown alignment */' + else: + size = 'sizeof(%s)' % (cname,) + align = 'offsetof(struct _cffi_align_%s, y)' % (approxname,) + comment = None + else: + size = '(size_t)-1' + align = -1 + first_field_index = -1 + comment = reason_for_not_expanding + self._lsts["struct_union"].append( + StructUnionExpr(tp.name, type_index, flags, size, align, comment, + first_field_index, c_fields)) + self._seen_struct_unions.add(tp) + + def _check_not_opaque(self, tp, location): + while isinstance(tp, model.ArrayType): + tp = tp.item + if isinstance(tp, model.StructOrUnion) and tp.fldtypes is None: + raise TypeError( + "%s is of an opaque type (not declared in cdef())" % location) + + def _add_missing_struct_unions(self): + # not very nice, but some struct declarations might be missing + # because they don't have any known C name. Check that they are + # not partial (we can't complete or verify them!) and emit them + # anonymously. + lst = list(self._struct_unions.items()) + lst.sort(key=lambda tp_order: tp_order[1]) + for tp, order in lst: + if tp not in self._seen_struct_unions: + if tp.partial: + raise NotImplementedError("internal inconsistency: %r is " + "partial but was not seen at " + "this point" % (tp,)) + if tp.name.startswith('$') and tp.name[1:].isdigit(): + approxname = tp.name[1:] + elif tp.name == '_IO_FILE' and tp.forcename == 'FILE': + approxname = 'FILE' + self._typedef_ctx(tp, 'FILE') + else: + raise NotImplementedError("internal inconsistency: %r" % + (tp,)) + self._struct_ctx(tp, None, approxname) + + def _generate_cpy_struct_collecttype(self, tp, name): + self._struct_collecttype(tp) + _generate_cpy_union_collecttype = _generate_cpy_struct_collecttype + + def _struct_names(self, tp): + cname = tp.get_c_name('') + if ' ' in cname: + return cname, cname.replace(' ', '_') + else: + return cname, '_' + cname + + def _generate_cpy_struct_decl(self, tp, name): + self._struct_decl(tp, *self._struct_names(tp)) + _generate_cpy_union_decl = _generate_cpy_struct_decl + + def _generate_cpy_struct_ctx(self, tp, name): + self._struct_ctx(tp, *self._struct_names(tp)) + _generate_cpy_union_ctx = _generate_cpy_struct_ctx + + # ---------- + # 'anonymous' declarations. These are produced for anonymous structs + # or unions; the 'name' is obtained by a typedef. + + def _generate_cpy_anonymous_collecttype(self, tp, name): + if isinstance(tp, model.EnumType): + self._generate_cpy_enum_collecttype(tp, name) + else: + self._struct_collecttype(tp) + + def _generate_cpy_anonymous_decl(self, tp, name): + if isinstance(tp, model.EnumType): + self._generate_cpy_enum_decl(tp) + else: + self._struct_decl(tp, name, 'typedef_' + name) + + def _generate_cpy_anonymous_ctx(self, tp, name): + if isinstance(tp, model.EnumType): + self._enum_ctx(tp, name) + else: + self._struct_ctx(tp, name, 'typedef_' + name) + + # ---------- + # constants, declared with "static const ..." + + def _generate_cpy_const(self, is_int, name, tp=None, category='const', + check_value=None): + if (category, name) in self._seen_constants: + raise VerificationError( + "duplicate declaration of %s '%s'" % (category, name)) + self._seen_constants.add((category, name)) + # + prnt = self._prnt + funcname = '_cffi_%s_%s' % (category, name) + if is_int: + prnt('static int %s(unsigned long long *o)' % funcname) + prnt('{') + prnt(' int n = (%s) <= 0;' % (name,)) + prnt(' *o = (unsigned long long)((%s) | 0);' + ' /* check that %s is an integer */' % (name, name)) + if check_value is not None: + if check_value > 0: + check_value = '%dU' % (check_value,) + prnt(' if (!_cffi_check_int(*o, n, %s))' % (check_value,)) + prnt(' n |= 2;') + prnt(' return n;') + prnt('}') + else: + assert check_value is None + prnt('static void %s(char *o)' % funcname) + prnt('{') + prnt(' *(%s)o = %s;' % (tp.get_c_name('*'), name)) + prnt('}') + prnt() + + def _generate_cpy_constant_collecttype(self, tp, name): + is_int = tp.is_integer_type() + if not is_int or self.target_is_python: + self._do_collect_type(tp) + + def _generate_cpy_constant_decl(self, tp, name): + is_int = tp.is_integer_type() + self._generate_cpy_const(is_int, name, tp) + + def _generate_cpy_constant_ctx(self, tp, name): + if not self.target_is_python and tp.is_integer_type(): + type_op = CffiOp(OP_CONSTANT_INT, -1) + else: + if self.target_is_python: + const_kind = OP_DLOPEN_CONST + else: + const_kind = OP_CONSTANT + type_index = self._typesdict[tp] + type_op = CffiOp(const_kind, type_index) + self._lsts["global"].append( + GlobalExpr(name, '_cffi_const_%s' % name, type_op)) + + # ---------- + # enums + + def _generate_cpy_enum_collecttype(self, tp, name): + self._do_collect_type(tp) + + def _generate_cpy_enum_decl(self, tp, name=None): + for enumerator in tp.enumerators: + self._generate_cpy_const(True, enumerator) + + def _enum_ctx(self, tp, cname): + type_index = self._typesdict[tp] + type_op = CffiOp(OP_ENUM, -1) + if self.target_is_python: + tp.check_not_partial() + for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues): + self._lsts["global"].append( + GlobalExpr(enumerator, '_cffi_const_%s' % enumerator, type_op, + check_value=enumvalue)) + # + if cname is not None and '$' not in cname and not self.target_is_python: + size = "sizeof(%s)" % cname + signed = "((%s)-1) <= 0" % cname + else: + basetp = tp.build_baseinttype(self.ffi, []) + size = self.ffi.sizeof(basetp) + signed = int(int(self.ffi.cast(basetp, -1)) < 0) + allenums = ",".join(tp.enumerators) + self._lsts["enum"].append( + EnumExpr(tp.name, type_index, size, signed, allenums)) + + def _generate_cpy_enum_ctx(self, tp, name): + self._enum_ctx(tp, tp._get_c_name()) + + # ---------- + # macros: for now only for integers + + def _generate_cpy_macro_collecttype(self, tp, name): + pass + + def _generate_cpy_macro_decl(self, tp, name): + if tp == '...': + check_value = None + else: + check_value = tp # an integer + self._generate_cpy_const(True, name, check_value=check_value) + + def _generate_cpy_macro_ctx(self, tp, name): + if tp == '...': + if self.target_is_python: + raise VerificationError( + "cannot use the syntax '...' in '#define %s ...' when " + "using the ABI mode" % (name,)) + check_value = None + else: + check_value = tp # an integer + type_op = CffiOp(OP_CONSTANT_INT, -1) + self._lsts["global"].append( + GlobalExpr(name, '_cffi_const_%s' % name, type_op, + check_value=check_value)) + + # ---------- + # global variables + + def _global_type(self, tp, global_name): + if isinstance(tp, model.ArrayType): + actual_length = tp.length + if actual_length == '...': + actual_length = '_cffi_array_len(%s)' % (global_name,) + tp_item = self._global_type(tp.item, '%s[0]' % global_name) + tp = model.ArrayType(tp_item, actual_length) + return tp + + def _generate_cpy_variable_collecttype(self, tp, name): + self._do_collect_type(self._global_type(tp, name)) + + def _generate_cpy_variable_decl(self, tp, name): + prnt = self._prnt + tp = self._global_type(tp, name) + if isinstance(tp, model.ArrayType) and tp.length is None: + tp = tp.item + ampersand = '' + else: + ampersand = '&' + # This code assumes that casts from "tp *" to "void *" is a + # no-op, i.e. a function that returns a "tp *" can be called + # as if it returned a "void *". This should be generally true + # on any modern machine. The only exception to that rule (on + # uncommon architectures, and as far as I can tell) might be + # if 'tp' were a function type, but that is not possible here. + # (If 'tp' is a function _pointer_ type, then casts from "fn_t + # **" to "void *" are again no-ops, as far as I can tell.) + decl = '*_cffi_var_%s(void)' % (name,) + prnt('static ' + tp.get_c_name(decl, quals=self._current_quals)) + prnt('{') + prnt(' return %s(%s);' % (ampersand, name)) + prnt('}') + prnt() + + def _generate_cpy_variable_ctx(self, tp, name): + tp = self._global_type(tp, name) + type_index = self._typesdict[tp] + if self.target_is_python: + op = OP_GLOBAL_VAR + else: + op = OP_GLOBAL_VAR_F + self._lsts["global"].append( + GlobalExpr(name, '_cffi_var_%s' % name, CffiOp(op, type_index))) + + # ---------- + # extern "Python" + + def _generate_cpy_extern_python_collecttype(self, tp, name): + assert isinstance(tp, model.FunctionPtrType) + self._do_collect_type(tp) + _generate_cpy_dllexport_python_collecttype = \ + _generate_cpy_extern_python_plus_c_collecttype = \ + _generate_cpy_extern_python_collecttype + + def _extern_python_decl(self, tp, name, tag_and_space): + prnt = self._prnt + if isinstance(tp.result, model.VoidType): + size_of_result = '0' + else: + context = 'result of %s' % name + size_of_result = '(int)sizeof(%s)' % ( + tp.result.get_c_name('', context),) + prnt('static struct _cffi_externpy_s _cffi_externpy__%s =' % name) + prnt(' { "%s.%s", %s, 0, 0 };' % ( + self.module_name, name, size_of_result)) + prnt() + # + arguments = [] + context = 'argument of %s' % name + for i, type in enumerate(tp.args): + arg = type.get_c_name(' a%d' % i, context) + arguments.append(arg) + # + repr_arguments = ', '.join(arguments) + repr_arguments = repr_arguments or 'void' + name_and_arguments = '%s(%s)' % (name, repr_arguments) + if tp.abi == "__stdcall": + name_and_arguments = '_cffi_stdcall ' + name_and_arguments + # + def may_need_128_bits(tp): + return (isinstance(tp, model.PrimitiveType) and + tp.name == 'long double') + # + size_of_a = max(len(tp.args)*8, 8) + if may_need_128_bits(tp.result): + size_of_a = max(size_of_a, 16) + if isinstance(tp.result, model.StructOrUnion): + size_of_a = 'sizeof(%s) > %d ? sizeof(%s) : %d' % ( + tp.result.get_c_name(''), size_of_a, + tp.result.get_c_name(''), size_of_a) + prnt('%s%s' % (tag_and_space, tp.result.get_c_name(name_and_arguments))) + prnt('{') + prnt(' char a[%s];' % size_of_a) + prnt(' char *p = a;') + for i, type in enumerate(tp.args): + arg = 'a%d' % i + if (isinstance(type, model.StructOrUnion) or + may_need_128_bits(type)): + arg = '&' + arg + type = model.PointerType(type) + prnt(' *(%s)(p + %d) = %s;' % (type.get_c_name('*'), i*8, arg)) + prnt(' _cffi_call_python(&_cffi_externpy__%s, p);' % name) + if not isinstance(tp.result, model.VoidType): + prnt(' return *(%s)p;' % (tp.result.get_c_name('*'),)) + prnt('}') + prnt() + self._num_externpy += 1 + + def _generate_cpy_extern_python_decl(self, tp, name): + self._extern_python_decl(tp, name, 'static ') + + def _generate_cpy_dllexport_python_decl(self, tp, name): + self._extern_python_decl(tp, name, 'CFFI_DLLEXPORT ') + + def _generate_cpy_extern_python_plus_c_decl(self, tp, name): + self._extern_python_decl(tp, name, '') + + def _generate_cpy_extern_python_ctx(self, tp, name): + if self.target_is_python: + raise VerificationError( + "cannot use 'extern \"Python\"' in the ABI mode") + if tp.ellipsis: + raise NotImplementedError("a vararg function is extern \"Python\"") + type_index = self._typesdict[tp] + type_op = CffiOp(OP_EXTERN_PYTHON, type_index) + self._lsts["global"].append( + GlobalExpr(name, '&_cffi_externpy__%s' % name, type_op, name)) + + _generate_cpy_dllexport_python_ctx = \ + _generate_cpy_extern_python_plus_c_ctx = \ + _generate_cpy_extern_python_ctx + + def _print_string_literal_in_array(self, s): + prnt = self._prnt + prnt('// # NB. this is not a string because of a size limit in MSVC') + if not isinstance(s, bytes): # unicode + s = s.encode('utf-8') # -> bytes + else: + s.decode('utf-8') # got bytes, check for valid utf-8 + try: + s.decode('ascii') + except UnicodeDecodeError: + s = b'# -*- encoding: utf8 -*-\n' + s + for line in s.splitlines(True): + comment = line + if type('//') is bytes: # python2 + line = map(ord, line) # make a list of integers + else: # python3 + # type(line) is bytes, which enumerates like a list of integers + comment = ascii(comment)[1:-1] + prnt(('// ' + comment).rstrip()) + printed_line = '' + for c in line: + if len(printed_line) >= 76: + prnt(printed_line) + printed_line = '' + printed_line += '%d,' % (c,) + prnt(printed_line) + + # ---------- + # emitting the opcodes for individual types + + def _emit_bytecode_VoidType(self, tp, index): + self.cffi_types[index] = CffiOp(OP_PRIMITIVE, PRIM_VOID) + + def _emit_bytecode_PrimitiveType(self, tp, index): + prim_index = PRIMITIVE_TO_INDEX[tp.name] + self.cffi_types[index] = CffiOp(OP_PRIMITIVE, prim_index) + + def _emit_bytecode_UnknownIntegerType(self, tp, index): + s = ('_cffi_prim_int(sizeof(%s), (\n' + ' ((%s)-1) | 0 /* check that %s is an integer type */\n' + ' ) <= 0)' % (tp.name, tp.name, tp.name)) + self.cffi_types[index] = CffiOp(OP_PRIMITIVE, s) + + def _emit_bytecode_UnknownFloatType(self, tp, index): + s = ('_cffi_prim_float(sizeof(%s) *\n' + ' (((%s)1) / 2) * 2 /* integer => 0, float => 1 */\n' + ' )' % (tp.name, tp.name)) + self.cffi_types[index] = CffiOp(OP_PRIMITIVE, s) + + def _emit_bytecode_RawFunctionType(self, tp, index): + self.cffi_types[index] = CffiOp(OP_FUNCTION, self._typesdict[tp.result]) + index += 1 + for tp1 in tp.args: + realindex = self._typesdict[tp1] + if index != realindex: + if isinstance(tp1, model.PrimitiveType): + self._emit_bytecode_PrimitiveType(tp1, index) + else: + self.cffi_types[index] = CffiOp(OP_NOOP, realindex) + index += 1 + flags = int(tp.ellipsis) + if tp.abi is not None: + if tp.abi == '__stdcall': + flags |= 2 + else: + raise NotImplementedError("abi=%r" % (tp.abi,)) + self.cffi_types[index] = CffiOp(OP_FUNCTION_END, flags) + + def _emit_bytecode_PointerType(self, tp, index): + self.cffi_types[index] = CffiOp(OP_POINTER, self._typesdict[tp.totype]) + + _emit_bytecode_ConstPointerType = _emit_bytecode_PointerType + _emit_bytecode_NamedPointerType = _emit_bytecode_PointerType + + def _emit_bytecode_FunctionPtrType(self, tp, index): + raw = tp.as_raw_function() + self.cffi_types[index] = CffiOp(OP_POINTER, self._typesdict[raw]) + + def _emit_bytecode_ArrayType(self, tp, index): + item_index = self._typesdict[tp.item] + if tp.length is None: + self.cffi_types[index] = CffiOp(OP_OPEN_ARRAY, item_index) + elif tp.length == '...': + raise VerificationError( + "type %s badly placed: the '...' array length can only be " + "used on global arrays or on fields of structures" % ( + str(tp).replace('/*...*/', '...'),)) + else: + assert self.cffi_types[index + 1] == 'LEN' + self.cffi_types[index] = CffiOp(OP_ARRAY, item_index) + self.cffi_types[index + 1] = CffiOp(None, str(tp.length)) + + def _emit_bytecode_StructType(self, tp, index): + struct_index = self._struct_unions[tp] + self.cffi_types[index] = CffiOp(OP_STRUCT_UNION, struct_index) + _emit_bytecode_UnionType = _emit_bytecode_StructType + + def _emit_bytecode_EnumType(self, tp, index): + enum_index = self._enums[tp] + self.cffi_types[index] = CffiOp(OP_ENUM, enum_index) + + +if sys.version_info >= (3,): + NativeIO = io.StringIO +else: + class NativeIO(io.BytesIO): + def write(self, s): + if isinstance(s, unicode): + s = s.encode('ascii') + super(NativeIO, self).write(s) + +def _make_c_or_py_source(ffi, module_name, preamble, target_file, verbose): + if verbose: + print("generating %s" % (target_file,)) + recompiler = Recompiler(ffi, module_name, + target_is_python=(preamble is None)) + recompiler.collect_type_table() + recompiler.collect_step_tables() + f = NativeIO() + recompiler.write_source_to_f(f, preamble) + output = f.getvalue() + try: + with open(target_file, 'r') as f1: + if f1.read(len(output) + 1) != output: + raise IOError + if verbose: + print("(already up-to-date)") + return False # already up-to-date + except IOError: + tmp_file = '%s.~%d' % (target_file, os.getpid()) + with open(tmp_file, 'w') as f1: + f1.write(output) + try: + os.rename(tmp_file, target_file) + except OSError: + os.unlink(target_file) + os.rename(tmp_file, target_file) + return True + +def make_c_source(ffi, module_name, preamble, target_c_file, verbose=False): + assert preamble is not None + return _make_c_or_py_source(ffi, module_name, preamble, target_c_file, + verbose) + +def make_py_source(ffi, module_name, target_py_file, verbose=False): + return _make_c_or_py_source(ffi, module_name, None, target_py_file, + verbose) + +def _modname_to_file(outputdir, modname, extension): + parts = modname.split('.') + try: + os.makedirs(os.path.join(outputdir, *parts[:-1])) + except OSError: + pass + parts[-1] += extension + return os.path.join(outputdir, *parts), parts + + +# Aaargh. Distutils is not tested at all for the purpose of compiling +# DLLs that are not extension modules. Here are some hacks to work +# around that, in the _patch_for_*() functions... + +def _patch_meth(patchlist, cls, name, new_meth): + old = getattr(cls, name) + patchlist.append((cls, name, old)) + setattr(cls, name, new_meth) + return old + +def _unpatch_meths(patchlist): + for cls, name, old_meth in reversed(patchlist): + setattr(cls, name, old_meth) + +def _patch_for_embedding(patchlist): + if sys.platform == 'win32': + # we must not remove the manifest when building for embedding! + from cffi._shimmed_dist_utils import MSVCCompiler + _patch_meth(patchlist, MSVCCompiler, '_remove_visual_c_ref', + lambda self, manifest_file: manifest_file) + + if sys.platform == 'darwin': + # we must not make a '-bundle', but a '-dynamiclib' instead + from cffi._shimmed_dist_utils import CCompiler + def my_link_shared_object(self, *args, **kwds): + if '-bundle' in self.linker_so: + self.linker_so = list(self.linker_so) + i = self.linker_so.index('-bundle') + self.linker_so[i] = '-dynamiclib' + return old_link_shared_object(self, *args, **kwds) + old_link_shared_object = _patch_meth(patchlist, CCompiler, + 'link_shared_object', + my_link_shared_object) + +def _patch_for_target(patchlist, target): + from cffi._shimmed_dist_utils import build_ext + # if 'target' is different from '*', we need to patch some internal + # method to just return this 'target' value, instead of having it + # built from module_name + if target.endswith('.*'): + target = target[:-2] + if sys.platform == 'win32': + target += '.dll' + elif sys.platform == 'darwin': + target += '.dylib' + else: + target += '.so' + _patch_meth(patchlist, build_ext, 'get_ext_filename', + lambda self, ext_name: target) + + +def recompile(ffi, module_name, preamble, tmpdir='.', call_c_compiler=True, + c_file=None, source_extension='.c', extradir=None, + compiler_verbose=1, target=None, debug=None, **kwds): + if not isinstance(module_name, str): + module_name = module_name.encode('ascii') + if ffi._windows_unicode: + ffi._apply_windows_unicode(kwds) + if preamble is not None: + embedding = (ffi._embedding is not None) + if embedding: + ffi._apply_embedding_fix(kwds) + if c_file is None: + c_file, parts = _modname_to_file(tmpdir, module_name, + source_extension) + if extradir: + parts = [extradir] + parts + ext_c_file = os.path.join(*parts) + else: + ext_c_file = c_file + # + if target is None: + if embedding: + target = '%s.*' % module_name + else: + target = '*' + # + ext = ffiplatform.get_extension(ext_c_file, module_name, **kwds) + updated = make_c_source(ffi, module_name, preamble, c_file, + verbose=compiler_verbose) + if call_c_compiler: + patchlist = [] + cwd = os.getcwd() + try: + if embedding: + _patch_for_embedding(patchlist) + if target != '*': + _patch_for_target(patchlist, target) + if compiler_verbose: + if tmpdir == '.': + msg = 'the current directory is' + else: + msg = 'setting the current directory to' + print('%s %r' % (msg, os.path.abspath(tmpdir))) + os.chdir(tmpdir) + outputfilename = ffiplatform.compile('.', ext, + compiler_verbose, debug) + finally: + os.chdir(cwd) + _unpatch_meths(patchlist) + return outputfilename + else: + return ext, updated + else: + if c_file is None: + c_file, _ = _modname_to_file(tmpdir, module_name, '.py') + updated = make_py_source(ffi, module_name, c_file, + verbose=compiler_verbose) + if call_c_compiler: + return c_file + else: + return None, updated + diff --git a/.venv/Lib/site-packages/cffi/setuptools_ext.py b/.venv/Lib/site-packages/cffi/setuptools_ext.py new file mode 100644 index 00000000..681b49d7 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/setuptools_ext.py @@ -0,0 +1,216 @@ +import os +import sys + +try: + basestring +except NameError: + # Python 3.x + basestring = str + +def error(msg): + from cffi._shimmed_dist_utils import DistutilsSetupError + raise DistutilsSetupError(msg) + + +def execfile(filename, glob): + # We use execfile() (here rewritten for Python 3) instead of + # __import__() to load the build script. The problem with + # a normal import is that in some packages, the intermediate + # __init__.py files may already try to import the file that + # we are generating. + with open(filename) as f: + src = f.read() + src += '\n' # Python 2.6 compatibility + code = compile(src, filename, 'exec') + exec(code, glob, glob) + + +def add_cffi_module(dist, mod_spec): + from cffi.api import FFI + + if not isinstance(mod_spec, basestring): + error("argument to 'cffi_modules=...' must be a str or a list of str," + " not %r" % (type(mod_spec).__name__,)) + mod_spec = str(mod_spec) + try: + build_file_name, ffi_var_name = mod_spec.split(':') + except ValueError: + error("%r must be of the form 'path/build.py:ffi_variable'" % + (mod_spec,)) + if not os.path.exists(build_file_name): + ext = '' + rewritten = build_file_name.replace('.', '/') + '.py' + if os.path.exists(rewritten): + ext = ' (rewrite cffi_modules to [%r])' % ( + rewritten + ':' + ffi_var_name,) + error("%r does not name an existing file%s" % (build_file_name, ext)) + + mod_vars = {'__name__': '__cffi__', '__file__': build_file_name} + execfile(build_file_name, mod_vars) + + try: + ffi = mod_vars[ffi_var_name] + except KeyError: + error("%r: object %r not found in module" % (mod_spec, + ffi_var_name)) + if not isinstance(ffi, FFI): + ffi = ffi() # maybe it's a function instead of directly an ffi + if not isinstance(ffi, FFI): + error("%r is not an FFI instance (got %r)" % (mod_spec, + type(ffi).__name__)) + if not hasattr(ffi, '_assigned_source'): + error("%r: the set_source() method was not called" % (mod_spec,)) + module_name, source, source_extension, kwds = ffi._assigned_source + if ffi._windows_unicode: + kwds = kwds.copy() + ffi._apply_windows_unicode(kwds) + + if source is None: + _add_py_module(dist, ffi, module_name) + else: + _add_c_module(dist, ffi, module_name, source, source_extension, kwds) + +def _set_py_limited_api(Extension, kwds): + """ + Add py_limited_api to kwds if setuptools >= 26 is in use. + Do not alter the setting if it already exists. + Setuptools takes care of ignoring the flag on Python 2 and PyPy. + + CPython itself should ignore the flag in a debugging version + (by not listing .abi3.so in the extensions it supports), but + it doesn't so far, creating troubles. That's why we check + for "not hasattr(sys, 'gettotalrefcount')" (the 2.7 compatible equivalent + of 'd' not in sys.abiflags). (http://bugs.python.org/issue28401) + + On Windows, with CPython <= 3.4, it's better not to use py_limited_api + because virtualenv *still* doesn't copy PYTHON3.DLL on these versions. + Recently (2020) we started shipping only >= 3.5 wheels, though. So + we'll give it another try and set py_limited_api on Windows >= 3.5. + """ + from cffi import recompiler + + if ('py_limited_api' not in kwds and not hasattr(sys, 'gettotalrefcount') + and recompiler.USE_LIMITED_API): + import setuptools + try: + setuptools_major_version = int(setuptools.__version__.partition('.')[0]) + if setuptools_major_version >= 26: + kwds['py_limited_api'] = True + except ValueError: # certain development versions of setuptools + # If we don't know the version number of setuptools, we + # try to set 'py_limited_api' anyway. At worst, we get a + # warning. + kwds['py_limited_api'] = True + return kwds + +def _add_c_module(dist, ffi, module_name, source, source_extension, kwds): + # We are a setuptools extension. Need this build_ext for py_limited_api. + from setuptools.command.build_ext import build_ext + from cffi._shimmed_dist_utils import Extension, log, mkpath + from cffi import recompiler + + allsources = ['$PLACEHOLDER'] + allsources.extend(kwds.pop('sources', [])) + kwds = _set_py_limited_api(Extension, kwds) + ext = Extension(name=module_name, sources=allsources, **kwds) + + def make_mod(tmpdir, pre_run=None): + c_file = os.path.join(tmpdir, module_name + source_extension) + log.info("generating cffi module %r" % c_file) + mkpath(tmpdir) + # a setuptools-only, API-only hook: called with the "ext" and "ffi" + # arguments just before we turn the ffi into C code. To use it, + # subclass the 'distutils.command.build_ext.build_ext' class and + # add a method 'def pre_run(self, ext, ffi)'. + if pre_run is not None: + pre_run(ext, ffi) + updated = recompiler.make_c_source(ffi, module_name, source, c_file) + if not updated: + log.info("already up-to-date") + return c_file + + if dist.ext_modules is None: + dist.ext_modules = [] + dist.ext_modules.append(ext) + + base_class = dist.cmdclass.get('build_ext', build_ext) + class build_ext_make_mod(base_class): + def run(self): + if ext.sources[0] == '$PLACEHOLDER': + pre_run = getattr(self, 'pre_run', None) + ext.sources[0] = make_mod(self.build_temp, pre_run) + base_class.run(self) + dist.cmdclass['build_ext'] = build_ext_make_mod + # NB. multiple runs here will create multiple 'build_ext_make_mod' + # classes. Even in this case the 'build_ext' command should be + # run once; but just in case, the logic above does nothing if + # called again. + + +def _add_py_module(dist, ffi, module_name): + from setuptools.command.build_py import build_py + from setuptools.command.build_ext import build_ext + from cffi._shimmed_dist_utils import log, mkpath + from cffi import recompiler + + def generate_mod(py_file): + log.info("generating cffi module %r" % py_file) + mkpath(os.path.dirname(py_file)) + updated = recompiler.make_py_source(ffi, module_name, py_file) + if not updated: + log.info("already up-to-date") + + base_class = dist.cmdclass.get('build_py', build_py) + class build_py_make_mod(base_class): + def run(self): + base_class.run(self) + module_path = module_name.split('.') + module_path[-1] += '.py' + generate_mod(os.path.join(self.build_lib, *module_path)) + def get_source_files(self): + # This is called from 'setup.py sdist' only. Exclude + # the generate .py module in this case. + saved_py_modules = self.py_modules + try: + if saved_py_modules: + self.py_modules = [m for m in saved_py_modules + if m != module_name] + return base_class.get_source_files(self) + finally: + self.py_modules = saved_py_modules + dist.cmdclass['build_py'] = build_py_make_mod + + # distutils and setuptools have no notion I could find of a + # generated python module. If we don't add module_name to + # dist.py_modules, then things mostly work but there are some + # combination of options (--root and --record) that will miss + # the module. So we add it here, which gives a few apparently + # harmless warnings about not finding the file outside the + # build directory. + # Then we need to hack more in get_source_files(); see above. + if dist.py_modules is None: + dist.py_modules = [] + dist.py_modules.append(module_name) + + # the following is only for "build_ext -i" + base_class_2 = dist.cmdclass.get('build_ext', build_ext) + class build_ext_make_mod(base_class_2): + def run(self): + base_class_2.run(self) + if self.inplace: + # from get_ext_fullpath() in distutils/command/build_ext.py + module_path = module_name.split('.') + package = '.'.join(module_path[:-1]) + build_py = self.get_finalized_command('build_py') + package_dir = build_py.get_package_dir(package) + file_name = module_path[-1] + '.py' + generate_mod(os.path.join(package_dir, file_name)) + dist.cmdclass['build_ext'] = build_ext_make_mod + +def cffi_modules(dist, attr, value): + assert attr == 'cffi_modules' + if isinstance(value, basestring): + value = [value] + + for cffi_module in value: + add_cffi_module(dist, cffi_module) diff --git a/.venv/Lib/site-packages/cffi/vengine_cpy.py b/.venv/Lib/site-packages/cffi/vengine_cpy.py new file mode 100644 index 00000000..49727d36 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/vengine_cpy.py @@ -0,0 +1,1077 @@ +# +# DEPRECATED: implementation for ffi.verify() +# +import sys +from . import model +from .error import VerificationError +from . import _imp_emulation as imp + + +class VCPythonEngine(object): + _class_key = 'x' + _gen_python_module = True + + def __init__(self, verifier): + self.verifier = verifier + self.ffi = verifier.ffi + self._struct_pending_verification = {} + self._types_of_builtin_functions = {} + + def patch_extension_kwds(self, kwds): + pass + + def find_module(self, module_name, path, so_suffixes): + try: + f, filename, descr = imp.find_module(module_name, path) + except ImportError: + return None + if f is not None: + f.close() + # Note that after a setuptools installation, there are both .py + # and .so files with the same basename. The code here relies on + # imp.find_module() locating the .so in priority. + if descr[0] not in so_suffixes: + return None + return filename + + def collect_types(self): + self._typesdict = {} + self._generate("collecttype") + + def _prnt(self, what=''): + self._f.write(what + '\n') + + def _gettypenum(self, type): + # a KeyError here is a bug. please report it! :-) + return self._typesdict[type] + + def _do_collect_type(self, tp): + if ((not isinstance(tp, model.PrimitiveType) + or tp.name == 'long double') + and tp not in self._typesdict): + num = len(self._typesdict) + self._typesdict[tp] = num + + def write_source_to_f(self): + self.collect_types() + # + # The new module will have a _cffi_setup() function that receives + # objects from the ffi world, and that calls some setup code in + # the module. This setup code is split in several independent + # functions, e.g. one per constant. The functions are "chained" + # by ending in a tail call to each other. + # + # This is further split in two chained lists, depending on if we + # can do it at import-time or if we must wait for _cffi_setup() to + # provide us with the objects. This is needed because we + # need the values of the enum constants in order to build the + # that we may have to pass to _cffi_setup(). + # + # The following two 'chained_list_constants' items contains + # the head of these two chained lists, as a string that gives the + # call to do, if any. + self._chained_list_constants = ['((void)lib,0)', '((void)lib,0)'] + # + prnt = self._prnt + # first paste some standard set of lines that are mostly '#define' + prnt(cffimod_header) + prnt() + # then paste the C source given by the user, verbatim. + prnt(self.verifier.preamble) + prnt() + # + # call generate_cpy_xxx_decl(), for every xxx found from + # ffi._parser._declarations. This generates all the functions. + self._generate("decl") + # + # implement the function _cffi_setup_custom() as calling the + # head of the chained list. + self._generate_setup_custom() + prnt() + # + # produce the method table, including the entries for the + # generated Python->C function wrappers, which are done + # by generate_cpy_function_method(). + prnt('static PyMethodDef _cffi_methods[] = {') + self._generate("method") + prnt(' {"_cffi_setup", _cffi_setup, METH_VARARGS, NULL},') + prnt(' {NULL, NULL, 0, NULL} /* Sentinel */') + prnt('};') + prnt() + # + # standard init. + modname = self.verifier.get_module_name() + constants = self._chained_list_constants[False] + prnt('#if PY_MAJOR_VERSION >= 3') + prnt() + prnt('static struct PyModuleDef _cffi_module_def = {') + prnt(' PyModuleDef_HEAD_INIT,') + prnt(' "%s",' % modname) + prnt(' NULL,') + prnt(' -1,') + prnt(' _cffi_methods,') + prnt(' NULL, NULL, NULL, NULL') + prnt('};') + prnt() + prnt('PyMODINIT_FUNC') + prnt('PyInit_%s(void)' % modname) + prnt('{') + prnt(' PyObject *lib;') + prnt(' lib = PyModule_Create(&_cffi_module_def);') + prnt(' if (lib == NULL)') + prnt(' return NULL;') + prnt(' if (%s < 0 || _cffi_init() < 0) {' % (constants,)) + prnt(' Py_DECREF(lib);') + prnt(' return NULL;') + prnt(' }') + prnt(' return lib;') + prnt('}') + prnt() + prnt('#else') + prnt() + prnt('PyMODINIT_FUNC') + prnt('init%s(void)' % modname) + prnt('{') + prnt(' PyObject *lib;') + prnt(' lib = Py_InitModule("%s", _cffi_methods);' % modname) + prnt(' if (lib == NULL)') + prnt(' return;') + prnt(' if (%s < 0 || _cffi_init() < 0)' % (constants,)) + prnt(' return;') + prnt(' return;') + prnt('}') + prnt() + prnt('#endif') + + def load_library(self, flags=None): + # XXX review all usages of 'self' here! + # import it as a new extension module + imp.acquire_lock() + try: + if hasattr(sys, "getdlopenflags"): + previous_flags = sys.getdlopenflags() + try: + if hasattr(sys, "setdlopenflags") and flags is not None: + sys.setdlopenflags(flags) + module = imp.load_dynamic(self.verifier.get_module_name(), + self.verifier.modulefilename) + except ImportError as e: + error = "importing %r: %s" % (self.verifier.modulefilename, e) + raise VerificationError(error) + finally: + if hasattr(sys, "setdlopenflags"): + sys.setdlopenflags(previous_flags) + finally: + imp.release_lock() + # + # call loading_cpy_struct() to get the struct layout inferred by + # the C compiler + self._load(module, 'loading') + # + # the C code will need the objects. Collect them in + # order in a list. + revmapping = dict([(value, key) + for (key, value) in self._typesdict.items()]) + lst = [revmapping[i] for i in range(len(revmapping))] + lst = list(map(self.ffi._get_cached_btype, lst)) + # + # build the FFILibrary class and instance and call _cffi_setup(). + # this will set up some fields like '_cffi_types', and only then + # it will invoke the chained list of functions that will really + # build (notably) the constant objects, as if they are + # pointers, and store them as attributes on the 'library' object. + class FFILibrary(object): + _cffi_python_module = module + _cffi_ffi = self.ffi + _cffi_dir = [] + def __dir__(self): + return FFILibrary._cffi_dir + list(self.__dict__) + library = FFILibrary() + if module._cffi_setup(lst, VerificationError, library): + import warnings + warnings.warn("reimporting %r might overwrite older definitions" + % (self.verifier.get_module_name())) + # + # finally, call the loaded_cpy_xxx() functions. This will perform + # the final adjustments, like copying the Python->C wrapper + # functions from the module to the 'library' object, and setting + # up the FFILibrary class with properties for the global C variables. + self._load(module, 'loaded', library=library) + module._cffi_original_ffi = self.ffi + module._cffi_types_of_builtin_funcs = self._types_of_builtin_functions + return library + + def _get_declarations(self): + lst = [(key, tp) for (key, (tp, qual)) in + self.ffi._parser._declarations.items()] + lst.sort() + return lst + + def _generate(self, step_name): + for name, tp in self._get_declarations(): + kind, realname = name.split(' ', 1) + try: + method = getattr(self, '_generate_cpy_%s_%s' % (kind, + step_name)) + except AttributeError: + raise VerificationError( + "not implemented in verify(): %r" % name) + try: + method(tp, realname) + except Exception as e: + model.attach_exception_info(e, name) + raise + + def _load(self, module, step_name, **kwds): + for name, tp in self._get_declarations(): + kind, realname = name.split(' ', 1) + method = getattr(self, '_%s_cpy_%s' % (step_name, kind)) + try: + method(tp, realname, module, **kwds) + except Exception as e: + model.attach_exception_info(e, name) + raise + + def _generate_nothing(self, tp, name): + pass + + def _loaded_noop(self, tp, name, module, **kwds): + pass + + # ---------- + + def _convert_funcarg_to_c(self, tp, fromvar, tovar, errcode): + extraarg = '' + if isinstance(tp, model.PrimitiveType): + if tp.is_integer_type() and tp.name != '_Bool': + converter = '_cffi_to_c_int' + extraarg = ', %s' % tp.name + else: + converter = '(%s)_cffi_to_c_%s' % (tp.get_c_name(''), + tp.name.replace(' ', '_')) + errvalue = '-1' + # + elif isinstance(tp, model.PointerType): + self._convert_funcarg_to_c_ptr_or_array(tp, fromvar, + tovar, errcode) + return + # + elif isinstance(tp, (model.StructOrUnion, model.EnumType)): + # a struct (not a struct pointer) as a function argument + self._prnt(' if (_cffi_to_c((char *)&%s, _cffi_type(%d), %s) < 0)' + % (tovar, self._gettypenum(tp), fromvar)) + self._prnt(' %s;' % errcode) + return + # + elif isinstance(tp, model.FunctionPtrType): + converter = '(%s)_cffi_to_c_pointer' % tp.get_c_name('') + extraarg = ', _cffi_type(%d)' % self._gettypenum(tp) + errvalue = 'NULL' + # + else: + raise NotImplementedError(tp) + # + self._prnt(' %s = %s(%s%s);' % (tovar, converter, fromvar, extraarg)) + self._prnt(' if (%s == (%s)%s && PyErr_Occurred())' % ( + tovar, tp.get_c_name(''), errvalue)) + self._prnt(' %s;' % errcode) + + def _extra_local_variables(self, tp, localvars, freelines): + if isinstance(tp, model.PointerType): + localvars.add('Py_ssize_t datasize') + localvars.add('struct _cffi_freeme_s *large_args_free = NULL') + freelines.add('if (large_args_free != NULL)' + ' _cffi_free_array_arguments(large_args_free);') + + def _convert_funcarg_to_c_ptr_or_array(self, tp, fromvar, tovar, errcode): + self._prnt(' datasize = _cffi_prepare_pointer_call_argument(') + self._prnt(' _cffi_type(%d), %s, (char **)&%s);' % ( + self._gettypenum(tp), fromvar, tovar)) + self._prnt(' if (datasize != 0) {') + self._prnt(' %s = ((size_t)datasize) <= 640 ? ' + 'alloca((size_t)datasize) : NULL;' % (tovar,)) + self._prnt(' if (_cffi_convert_array_argument(_cffi_type(%d), %s, ' + '(char **)&%s,' % (self._gettypenum(tp), fromvar, tovar)) + self._prnt(' datasize, &large_args_free) < 0)') + self._prnt(' %s;' % errcode) + self._prnt(' }') + + def _convert_expr_from_c(self, tp, var, context): + if isinstance(tp, model.PrimitiveType): + if tp.is_integer_type() and tp.name != '_Bool': + return '_cffi_from_c_int(%s, %s)' % (var, tp.name) + elif tp.name != 'long double': + return '_cffi_from_c_%s(%s)' % (tp.name.replace(' ', '_'), var) + else: + return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, (model.PointerType, model.FunctionPtrType)): + return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, model.ArrayType): + return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % ( + var, self._gettypenum(model.PointerType(tp.item))) + elif isinstance(tp, model.StructOrUnion): + if tp.fldnames is None: + raise TypeError("'%s' is used as %s, but is opaque" % ( + tp._get_c_name(), context)) + return '_cffi_from_c_struct((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + elif isinstance(tp, model.EnumType): + return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % ( + var, self._gettypenum(tp)) + else: + raise NotImplementedError(tp) + + # ---------- + # typedefs: generates no code so far + + _generate_cpy_typedef_collecttype = _generate_nothing + _generate_cpy_typedef_decl = _generate_nothing + _generate_cpy_typedef_method = _generate_nothing + _loading_cpy_typedef = _loaded_noop + _loaded_cpy_typedef = _loaded_noop + + # ---------- + # function declarations + + def _generate_cpy_function_collecttype(self, tp, name): + assert isinstance(tp, model.FunctionPtrType) + if tp.ellipsis: + self._do_collect_type(tp) + else: + # don't call _do_collect_type(tp) in this common case, + # otherwise test_autofilled_struct_as_argument fails + for type in tp.args: + self._do_collect_type(type) + self._do_collect_type(tp.result) + + def _generate_cpy_function_decl(self, tp, name): + assert isinstance(tp, model.FunctionPtrType) + if tp.ellipsis: + # cannot support vararg functions better than this: check for its + # exact type (including the fixed arguments), and build it as a + # constant function pointer (no CPython wrapper) + self._generate_cpy_const(False, name, tp) + return + prnt = self._prnt + numargs = len(tp.args) + if numargs == 0: + argname = 'noarg' + elif numargs == 1: + argname = 'arg0' + else: + argname = 'args' + prnt('static PyObject *') + prnt('_cffi_f_%s(PyObject *self, PyObject *%s)' % (name, argname)) + prnt('{') + # + context = 'argument of %s' % name + for i, type in enumerate(tp.args): + prnt(' %s;' % type.get_c_name(' x%d' % i, context)) + # + localvars = set() + freelines = set() + for type in tp.args: + self._extra_local_variables(type, localvars, freelines) + for decl in sorted(localvars): + prnt(' %s;' % (decl,)) + # + if not isinstance(tp.result, model.VoidType): + result_code = 'result = ' + context = 'result of %s' % name + prnt(' %s;' % tp.result.get_c_name(' result', context)) + prnt(' PyObject *pyresult;') + else: + result_code = '' + # + if len(tp.args) > 1: + rng = range(len(tp.args)) + for i in rng: + prnt(' PyObject *arg%d;' % i) + prnt() + prnt(' if (!PyArg_ParseTuple(args, "%s:%s", %s))' % ( + 'O' * numargs, name, ', '.join(['&arg%d' % i for i in rng]))) + prnt(' return NULL;') + prnt() + # + for i, type in enumerate(tp.args): + self._convert_funcarg_to_c(type, 'arg%d' % i, 'x%d' % i, + 'return NULL') + prnt() + # + prnt(' Py_BEGIN_ALLOW_THREADS') + prnt(' _cffi_restore_errno();') + prnt(' { %s%s(%s); }' % ( + result_code, name, + ', '.join(['x%d' % i for i in range(len(tp.args))]))) + prnt(' _cffi_save_errno();') + prnt(' Py_END_ALLOW_THREADS') + prnt() + # + prnt(' (void)self; /* unused */') + if numargs == 0: + prnt(' (void)noarg; /* unused */') + if result_code: + prnt(' pyresult = %s;' % + self._convert_expr_from_c(tp.result, 'result', 'result type')) + for freeline in freelines: + prnt(' ' + freeline) + prnt(' return pyresult;') + else: + for freeline in freelines: + prnt(' ' + freeline) + prnt(' Py_INCREF(Py_None);') + prnt(' return Py_None;') + prnt('}') + prnt() + + def _generate_cpy_function_method(self, tp, name): + if tp.ellipsis: + return + numargs = len(tp.args) + if numargs == 0: + meth = 'METH_NOARGS' + elif numargs == 1: + meth = 'METH_O' + else: + meth = 'METH_VARARGS' + self._prnt(' {"%s", _cffi_f_%s, %s, NULL},' % (name, name, meth)) + + _loading_cpy_function = _loaded_noop + + def _loaded_cpy_function(self, tp, name, module, library): + if tp.ellipsis: + return + func = getattr(module, name) + setattr(library, name, func) + self._types_of_builtin_functions[func] = tp + + # ---------- + # named structs + + _generate_cpy_struct_collecttype = _generate_nothing + def _generate_cpy_struct_decl(self, tp, name): + assert name == tp.name + self._generate_struct_or_union_decl(tp, 'struct', name) + def _generate_cpy_struct_method(self, tp, name): + self._generate_struct_or_union_method(tp, 'struct', name) + def _loading_cpy_struct(self, tp, name, module): + self._loading_struct_or_union(tp, 'struct', name, module) + def _loaded_cpy_struct(self, tp, name, module, **kwds): + self._loaded_struct_or_union(tp) + + _generate_cpy_union_collecttype = _generate_nothing + def _generate_cpy_union_decl(self, tp, name): + assert name == tp.name + self._generate_struct_or_union_decl(tp, 'union', name) + def _generate_cpy_union_method(self, tp, name): + self._generate_struct_or_union_method(tp, 'union', name) + def _loading_cpy_union(self, tp, name, module): + self._loading_struct_or_union(tp, 'union', name, module) + def _loaded_cpy_union(self, tp, name, module, **kwds): + self._loaded_struct_or_union(tp) + + def _generate_struct_or_union_decl(self, tp, prefix, name): + if tp.fldnames is None: + return # nothing to do with opaque structs + checkfuncname = '_cffi_check_%s_%s' % (prefix, name) + layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name) + cname = ('%s %s' % (prefix, name)).strip() + # + prnt = self._prnt + prnt('static void %s(%s *p)' % (checkfuncname, cname)) + prnt('{') + prnt(' /* only to generate compile-time warnings or errors */') + prnt(' (void)p;') + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if (isinstance(ftype, model.PrimitiveType) + and ftype.is_integer_type()) or fbitsize >= 0: + # accept all integers, but complain on float or double + prnt(' (void)((p->%s) << 1);' % fname) + else: + # only accept exactly the type declared. + try: + prnt(' { %s = &p->%s; (void)tmp; }' % ( + ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual), + fname)) + except VerificationError as e: + prnt(' /* %s */' % str(e)) # cannot verify it, ignore + prnt('}') + prnt('static PyObject *') + prnt('%s(PyObject *self, PyObject *noarg)' % (layoutfuncname,)) + prnt('{') + prnt(' struct _cffi_aligncheck { char x; %s y; };' % cname) + prnt(' static Py_ssize_t nums[] = {') + prnt(' sizeof(%s),' % cname) + prnt(' offsetof(struct _cffi_aligncheck, y),') + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if fbitsize >= 0: + continue # xxx ignore fbitsize for now + prnt(' offsetof(%s, %s),' % (cname, fname)) + if isinstance(ftype, model.ArrayType) and ftype.length is None: + prnt(' 0, /* %s */' % ftype._get_c_name()) + else: + prnt(' sizeof(((%s *)0)->%s),' % (cname, fname)) + prnt(' -1') + prnt(' };') + prnt(' (void)self; /* unused */') + prnt(' (void)noarg; /* unused */') + prnt(' return _cffi_get_struct_layout(nums);') + prnt(' /* the next line is not executed, but compiled */') + prnt(' %s(0);' % (checkfuncname,)) + prnt('}') + prnt() + + def _generate_struct_or_union_method(self, tp, prefix, name): + if tp.fldnames is None: + return # nothing to do with opaque structs + layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name) + self._prnt(' {"%s", %s, METH_NOARGS, NULL},' % (layoutfuncname, + layoutfuncname)) + + def _loading_struct_or_union(self, tp, prefix, name, module): + if tp.fldnames is None: + return # nothing to do with opaque structs + layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name) + # + function = getattr(module, layoutfuncname) + layout = function() + if isinstance(tp, model.StructOrUnion) and tp.partial: + # use the function()'s sizes and offsets to guide the + # layout of the struct + totalsize = layout[0] + totalalignment = layout[1] + fieldofs = layout[2::2] + fieldsize = layout[3::2] + tp.force_flatten() + assert len(fieldofs) == len(fieldsize) == len(tp.fldnames) + tp.fixedlayout = fieldofs, fieldsize, totalsize, totalalignment + else: + cname = ('%s %s' % (prefix, name)).strip() + self._struct_pending_verification[tp] = layout, cname + + def _loaded_struct_or_union(self, tp): + if tp.fldnames is None: + return # nothing to do with opaque structs + self.ffi._get_cached_btype(tp) # force 'fixedlayout' to be considered + + if tp in self._struct_pending_verification: + # check that the layout sizes and offsets match the real ones + def check(realvalue, expectedvalue, msg): + if realvalue != expectedvalue: + raise VerificationError( + "%s (we have %d, but C compiler says %d)" + % (msg, expectedvalue, realvalue)) + ffi = self.ffi + BStruct = ffi._get_cached_btype(tp) + layout, cname = self._struct_pending_verification.pop(tp) + check(layout[0], ffi.sizeof(BStruct), "wrong total size") + check(layout[1], ffi.alignof(BStruct), "wrong total alignment") + i = 2 + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if fbitsize >= 0: + continue # xxx ignore fbitsize for now + check(layout[i], ffi.offsetof(BStruct, fname), + "wrong offset for field %r" % (fname,)) + if layout[i+1] != 0: + BField = ffi._get_cached_btype(ftype) + check(layout[i+1], ffi.sizeof(BField), + "wrong size for field %r" % (fname,)) + i += 2 + assert i == len(layout) + + # ---------- + # 'anonymous' declarations. These are produced for anonymous structs + # or unions; the 'name' is obtained by a typedef. + + _generate_cpy_anonymous_collecttype = _generate_nothing + + def _generate_cpy_anonymous_decl(self, tp, name): + if isinstance(tp, model.EnumType): + self._generate_cpy_enum_decl(tp, name, '') + else: + self._generate_struct_or_union_decl(tp, '', name) + + def _generate_cpy_anonymous_method(self, tp, name): + if not isinstance(tp, model.EnumType): + self._generate_struct_or_union_method(tp, '', name) + + def _loading_cpy_anonymous(self, tp, name, module): + if isinstance(tp, model.EnumType): + self._loading_cpy_enum(tp, name, module) + else: + self._loading_struct_or_union(tp, '', name, module) + + def _loaded_cpy_anonymous(self, tp, name, module, **kwds): + if isinstance(tp, model.EnumType): + self._loaded_cpy_enum(tp, name, module, **kwds) + else: + self._loaded_struct_or_union(tp) + + # ---------- + # constants, likely declared with '#define' + + def _generate_cpy_const(self, is_int, name, tp=None, category='const', + vartp=None, delayed=True, size_too=False, + check_value=None): + prnt = self._prnt + funcname = '_cffi_%s_%s' % (category, name) + prnt('static int %s(PyObject *lib)' % funcname) + prnt('{') + prnt(' PyObject *o;') + prnt(' int res;') + if not is_int: + prnt(' %s;' % (vartp or tp).get_c_name(' i', name)) + else: + assert category == 'const' + # + if check_value is not None: + self._check_int_constant_value(name, check_value) + # + if not is_int: + if category == 'var': + realexpr = '&' + name + else: + realexpr = name + prnt(' i = (%s);' % (realexpr,)) + prnt(' o = %s;' % (self._convert_expr_from_c(tp, 'i', + 'variable type'),)) + assert delayed + else: + prnt(' o = _cffi_from_c_int_const(%s);' % name) + prnt(' if (o == NULL)') + prnt(' return -1;') + if size_too: + prnt(' {') + prnt(' PyObject *o1 = o;') + prnt(' o = Py_BuildValue("On", o1, (Py_ssize_t)sizeof(%s));' + % (name,)) + prnt(' Py_DECREF(o1);') + prnt(' if (o == NULL)') + prnt(' return -1;') + prnt(' }') + prnt(' res = PyObject_SetAttrString(lib, "%s", o);' % name) + prnt(' Py_DECREF(o);') + prnt(' if (res < 0)') + prnt(' return -1;') + prnt(' return %s;' % self._chained_list_constants[delayed]) + self._chained_list_constants[delayed] = funcname + '(lib)' + prnt('}') + prnt() + + def _generate_cpy_constant_collecttype(self, tp, name): + is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type() + if not is_int: + self._do_collect_type(tp) + + def _generate_cpy_constant_decl(self, tp, name): + is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type() + self._generate_cpy_const(is_int, name, tp) + + _generate_cpy_constant_method = _generate_nothing + _loading_cpy_constant = _loaded_noop + _loaded_cpy_constant = _loaded_noop + + # ---------- + # enums + + def _check_int_constant_value(self, name, value, err_prefix=''): + prnt = self._prnt + if value <= 0: + prnt(' if ((%s) > 0 || (long)(%s) != %dL) {' % ( + name, name, value)) + else: + prnt(' if ((%s) <= 0 || (unsigned long)(%s) != %dUL) {' % ( + name, name, value)) + prnt(' char buf[64];') + prnt(' if ((%s) <= 0)' % name) + prnt(' snprintf(buf, 63, "%%ld", (long)(%s));' % name) + prnt(' else') + prnt(' snprintf(buf, 63, "%%lu", (unsigned long)(%s));' % + name) + prnt(' PyErr_Format(_cffi_VerificationError,') + prnt(' "%s%s has the real value %s, not %s",') + prnt(' "%s", "%s", buf, "%d");' % ( + err_prefix, name, value)) + prnt(' return -1;') + prnt(' }') + + def _enum_funcname(self, prefix, name): + # "$enum_$1" => "___D_enum____D_1" + name = name.replace('$', '___D_') + return '_cffi_e_%s_%s' % (prefix, name) + + def _generate_cpy_enum_decl(self, tp, name, prefix='enum'): + if tp.partial: + for enumerator in tp.enumerators: + self._generate_cpy_const(True, enumerator, delayed=False) + return + # + funcname = self._enum_funcname(prefix, name) + prnt = self._prnt + prnt('static int %s(PyObject *lib)' % funcname) + prnt('{') + for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues): + self._check_int_constant_value(enumerator, enumvalue, + "enum %s: " % name) + prnt(' return %s;' % self._chained_list_constants[True]) + self._chained_list_constants[True] = funcname + '(lib)' + prnt('}') + prnt() + + _generate_cpy_enum_collecttype = _generate_nothing + _generate_cpy_enum_method = _generate_nothing + + def _loading_cpy_enum(self, tp, name, module): + if tp.partial: + enumvalues = [getattr(module, enumerator) + for enumerator in tp.enumerators] + tp.enumvalues = tuple(enumvalues) + tp.partial_resolved = True + + def _loaded_cpy_enum(self, tp, name, module, library): + for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues): + setattr(library, enumerator, enumvalue) + + # ---------- + # macros: for now only for integers + + def _generate_cpy_macro_decl(self, tp, name): + if tp == '...': + check_value = None + else: + check_value = tp # an integer + self._generate_cpy_const(True, name, check_value=check_value) + + _generate_cpy_macro_collecttype = _generate_nothing + _generate_cpy_macro_method = _generate_nothing + _loading_cpy_macro = _loaded_noop + _loaded_cpy_macro = _loaded_noop + + # ---------- + # global variables + + def _generate_cpy_variable_collecttype(self, tp, name): + if isinstance(tp, model.ArrayType): + tp_ptr = model.PointerType(tp.item) + else: + tp_ptr = model.PointerType(tp) + self._do_collect_type(tp_ptr) + + def _generate_cpy_variable_decl(self, tp, name): + if isinstance(tp, model.ArrayType): + tp_ptr = model.PointerType(tp.item) + self._generate_cpy_const(False, name, tp, vartp=tp_ptr, + size_too = tp.length_is_unknown()) + else: + tp_ptr = model.PointerType(tp) + self._generate_cpy_const(False, name, tp_ptr, category='var') + + _generate_cpy_variable_method = _generate_nothing + _loading_cpy_variable = _loaded_noop + + def _loaded_cpy_variable(self, tp, name, module, library): + value = getattr(library, name) + if isinstance(tp, model.ArrayType): # int a[5] is "constant" in the + # sense that "a=..." is forbidden + if tp.length_is_unknown(): + assert isinstance(value, tuple) + (value, size) = value + BItemType = self.ffi._get_cached_btype(tp.item) + length, rest = divmod(size, self.ffi.sizeof(BItemType)) + if rest != 0: + raise VerificationError( + "bad size: %r does not seem to be an array of %s" % + (name, tp.item)) + tp = tp.resolve_length(length) + # 'value' is a which we have to replace with + # a if the N is actually known + if tp.length is not None: + BArray = self.ffi._get_cached_btype(tp) + value = self.ffi.cast(BArray, value) + setattr(library, name, value) + return + # remove ptr= from the library instance, and replace + # it by a property on the class, which reads/writes into ptr[0]. + ptr = value + delattr(library, name) + def getter(library): + return ptr[0] + def setter(library, value): + ptr[0] = value + setattr(type(library), name, property(getter, setter)) + type(library)._cffi_dir.append(name) + + # ---------- + + def _generate_setup_custom(self): + prnt = self._prnt + prnt('static int _cffi_setup_custom(PyObject *lib)') + prnt('{') + prnt(' return %s;' % self._chained_list_constants[True]) + prnt('}') + +cffimod_header = r''' +#include +#include + +/* this block of #ifs should be kept exactly identical between + c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py + and cffi/_cffi_include.h */ +#if defined(_MSC_VER) +# include /* for alloca() */ +# if _MSC_VER < 1600 /* MSVC < 2010 */ + typedef __int8 int8_t; + typedef __int16 int16_t; + typedef __int32 int32_t; + typedef __int64 int64_t; + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + typedef unsigned __int64 uint64_t; + typedef __int8 int_least8_t; + typedef __int16 int_least16_t; + typedef __int32 int_least32_t; + typedef __int64 int_least64_t; + typedef unsigned __int8 uint_least8_t; + typedef unsigned __int16 uint_least16_t; + typedef unsigned __int32 uint_least32_t; + typedef unsigned __int64 uint_least64_t; + typedef __int8 int_fast8_t; + typedef __int16 int_fast16_t; + typedef __int32 int_fast32_t; + typedef __int64 int_fast64_t; + typedef unsigned __int8 uint_fast8_t; + typedef unsigned __int16 uint_fast16_t; + typedef unsigned __int32 uint_fast32_t; + typedef unsigned __int64 uint_fast64_t; + typedef __int64 intmax_t; + typedef unsigned __int64 uintmax_t; +# else +# include +# endif +# if _MSC_VER < 1800 /* MSVC < 2013 */ +# ifndef __cplusplus + typedef unsigned char _Bool; +# endif +# endif +#else +# include +# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux) +# include +# endif +#endif + +#if PY_MAJOR_VERSION < 3 +# undef PyCapsule_CheckExact +# undef PyCapsule_GetPointer +# define PyCapsule_CheckExact(capsule) (PyCObject_Check(capsule)) +# define PyCapsule_GetPointer(capsule, name) \ + (PyCObject_AsVoidPtr(capsule)) +#endif + +#if PY_MAJOR_VERSION >= 3 +# define PyInt_FromLong PyLong_FromLong +#endif + +#define _cffi_from_c_double PyFloat_FromDouble +#define _cffi_from_c_float PyFloat_FromDouble +#define _cffi_from_c_long PyInt_FromLong +#define _cffi_from_c_ulong PyLong_FromUnsignedLong +#define _cffi_from_c_longlong PyLong_FromLongLong +#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong +#define _cffi_from_c__Bool PyBool_FromLong + +#define _cffi_to_c_double PyFloat_AsDouble +#define _cffi_to_c_float PyFloat_AsDouble + +#define _cffi_from_c_int_const(x) \ + (((x) > 0) ? \ + ((unsigned long long)(x) <= (unsigned long long)LONG_MAX) ? \ + PyInt_FromLong((long)(x)) : \ + PyLong_FromUnsignedLongLong((unsigned long long)(x)) : \ + ((long long)(x) >= (long long)LONG_MIN) ? \ + PyInt_FromLong((long)(x)) : \ + PyLong_FromLongLong((long long)(x))) + +#define _cffi_from_c_int(x, type) \ + (((type)-1) > 0 ? /* unsigned */ \ + (sizeof(type) < sizeof(long) ? \ + PyInt_FromLong((long)x) : \ + sizeof(type) == sizeof(long) ? \ + PyLong_FromUnsignedLong((unsigned long)x) : \ + PyLong_FromUnsignedLongLong((unsigned long long)x)) : \ + (sizeof(type) <= sizeof(long) ? \ + PyInt_FromLong((long)x) : \ + PyLong_FromLongLong((long long)x))) + +#define _cffi_to_c_int(o, type) \ + ((type)( \ + sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \ + : (type)_cffi_to_c_i8(o)) : \ + sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \ + : (type)_cffi_to_c_i16(o)) : \ + sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \ + : (type)_cffi_to_c_i32(o)) : \ + sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \ + : (type)_cffi_to_c_i64(o)) : \ + (Py_FatalError("unsupported size for type " #type), (type)0))) + +#define _cffi_to_c_i8 \ + ((int(*)(PyObject *))_cffi_exports[1]) +#define _cffi_to_c_u8 \ + ((int(*)(PyObject *))_cffi_exports[2]) +#define _cffi_to_c_i16 \ + ((int(*)(PyObject *))_cffi_exports[3]) +#define _cffi_to_c_u16 \ + ((int(*)(PyObject *))_cffi_exports[4]) +#define _cffi_to_c_i32 \ + ((int(*)(PyObject *))_cffi_exports[5]) +#define _cffi_to_c_u32 \ + ((unsigned int(*)(PyObject *))_cffi_exports[6]) +#define _cffi_to_c_i64 \ + ((long long(*)(PyObject *))_cffi_exports[7]) +#define _cffi_to_c_u64 \ + ((unsigned long long(*)(PyObject *))_cffi_exports[8]) +#define _cffi_to_c_char \ + ((int(*)(PyObject *))_cffi_exports[9]) +#define _cffi_from_c_pointer \ + ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[10]) +#define _cffi_to_c_pointer \ + ((char *(*)(PyObject *, CTypeDescrObject *))_cffi_exports[11]) +#define _cffi_get_struct_layout \ + ((PyObject *(*)(Py_ssize_t[]))_cffi_exports[12]) +#define _cffi_restore_errno \ + ((void(*)(void))_cffi_exports[13]) +#define _cffi_save_errno \ + ((void(*)(void))_cffi_exports[14]) +#define _cffi_from_c_char \ + ((PyObject *(*)(char))_cffi_exports[15]) +#define _cffi_from_c_deref \ + ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[16]) +#define _cffi_to_c \ + ((int(*)(char *, CTypeDescrObject *, PyObject *))_cffi_exports[17]) +#define _cffi_from_c_struct \ + ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[18]) +#define _cffi_to_c_wchar_t \ + ((wchar_t(*)(PyObject *))_cffi_exports[19]) +#define _cffi_from_c_wchar_t \ + ((PyObject *(*)(wchar_t))_cffi_exports[20]) +#define _cffi_to_c_long_double \ + ((long double(*)(PyObject *))_cffi_exports[21]) +#define _cffi_to_c__Bool \ + ((_Bool(*)(PyObject *))_cffi_exports[22]) +#define _cffi_prepare_pointer_call_argument \ + ((Py_ssize_t(*)(CTypeDescrObject *, PyObject *, char **))_cffi_exports[23]) +#define _cffi_convert_array_from_object \ + ((int(*)(char *, CTypeDescrObject *, PyObject *))_cffi_exports[24]) +#define _CFFI_NUM_EXPORTS 25 + +typedef struct _ctypedescr CTypeDescrObject; + +static void *_cffi_exports[_CFFI_NUM_EXPORTS]; +static PyObject *_cffi_types, *_cffi_VerificationError; + +static int _cffi_setup_custom(PyObject *lib); /* forward */ + +static PyObject *_cffi_setup(PyObject *self, PyObject *args) +{ + PyObject *library; + int was_alive = (_cffi_types != NULL); + (void)self; /* unused */ + if (!PyArg_ParseTuple(args, "OOO", &_cffi_types, &_cffi_VerificationError, + &library)) + return NULL; + Py_INCREF(_cffi_types); + Py_INCREF(_cffi_VerificationError); + if (_cffi_setup_custom(library) < 0) + return NULL; + return PyBool_FromLong(was_alive); +} + +union _cffi_union_alignment_u { + unsigned char m_char; + unsigned short m_short; + unsigned int m_int; + unsigned long m_long; + unsigned long long m_longlong; + float m_float; + double m_double; + long double m_longdouble; +}; + +struct _cffi_freeme_s { + struct _cffi_freeme_s *next; + union _cffi_union_alignment_u alignment; +}; + +#ifdef __GNUC__ + __attribute__((unused)) +#endif +static int _cffi_convert_array_argument(CTypeDescrObject *ctptr, PyObject *arg, + char **output_data, Py_ssize_t datasize, + struct _cffi_freeme_s **freeme) +{ + char *p; + if (datasize < 0) + return -1; + + p = *output_data; + if (p == NULL) { + struct _cffi_freeme_s *fp = (struct _cffi_freeme_s *)PyObject_Malloc( + offsetof(struct _cffi_freeme_s, alignment) + (size_t)datasize); + if (fp == NULL) + return -1; + fp->next = *freeme; + *freeme = fp; + p = *output_data = (char *)&fp->alignment; + } + memset((void *)p, 0, (size_t)datasize); + return _cffi_convert_array_from_object(p, ctptr, arg); +} + +#ifdef __GNUC__ + __attribute__((unused)) +#endif +static void _cffi_free_array_arguments(struct _cffi_freeme_s *freeme) +{ + do { + void *p = (void *)freeme; + freeme = freeme->next; + PyObject_Free(p); + } while (freeme != NULL); +} + +static int _cffi_init(void) +{ + PyObject *module, *c_api_object = NULL; + + module = PyImport_ImportModule("_cffi_backend"); + if (module == NULL) + goto failure; + + c_api_object = PyObject_GetAttrString(module, "_C_API"); + if (c_api_object == NULL) + goto failure; + if (!PyCapsule_CheckExact(c_api_object)) { + PyErr_SetNone(PyExc_ImportError); + goto failure; + } + memcpy(_cffi_exports, PyCapsule_GetPointer(c_api_object, "cffi"), + _CFFI_NUM_EXPORTS * sizeof(void *)); + + Py_DECREF(module); + Py_DECREF(c_api_object); + return 0; + + failure: + Py_XDECREF(module); + Py_XDECREF(c_api_object); + return -1; +} + +#define _cffi_type(num) ((CTypeDescrObject *)PyList_GET_ITEM(_cffi_types, num)) + +/**********/ +''' diff --git a/.venv/Lib/site-packages/cffi/vengine_gen.py b/.venv/Lib/site-packages/cffi/vengine_gen.py new file mode 100644 index 00000000..26421526 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/vengine_gen.py @@ -0,0 +1,675 @@ +# +# DEPRECATED: implementation for ffi.verify() +# +import sys, os +import types + +from . import model +from .error import VerificationError + + +class VGenericEngine(object): + _class_key = 'g' + _gen_python_module = False + + def __init__(self, verifier): + self.verifier = verifier + self.ffi = verifier.ffi + self.export_symbols = [] + self._struct_pending_verification = {} + + def patch_extension_kwds(self, kwds): + # add 'export_symbols' to the dictionary. Note that we add the + # list before filling it. When we fill it, it will thus also show + # up in kwds['export_symbols']. + kwds.setdefault('export_symbols', self.export_symbols) + + def find_module(self, module_name, path, so_suffixes): + for so_suffix in so_suffixes: + basename = module_name + so_suffix + if path is None: + path = sys.path + for dirname in path: + filename = os.path.join(dirname, basename) + if os.path.isfile(filename): + return filename + + def collect_types(self): + pass # not needed in the generic engine + + def _prnt(self, what=''): + self._f.write(what + '\n') + + def write_source_to_f(self): + prnt = self._prnt + # first paste some standard set of lines that are mostly '#include' + prnt(cffimod_header) + # then paste the C source given by the user, verbatim. + prnt(self.verifier.preamble) + # + # call generate_gen_xxx_decl(), for every xxx found from + # ffi._parser._declarations. This generates all the functions. + self._generate('decl') + # + # on Windows, distutils insists on putting init_cffi_xyz in + # 'export_symbols', so instead of fighting it, just give up and + # give it one + if sys.platform == 'win32': + if sys.version_info >= (3,): + prefix = 'PyInit_' + else: + prefix = 'init' + modname = self.verifier.get_module_name() + prnt("void %s%s(void) { }\n" % (prefix, modname)) + + def load_library(self, flags=0): + # import it with the CFFI backend + backend = self.ffi._backend + # needs to make a path that contains '/', on Posix + filename = os.path.join(os.curdir, self.verifier.modulefilename) + module = backend.load_library(filename, flags) + # + # call loading_gen_struct() to get the struct layout inferred by + # the C compiler + self._load(module, 'loading') + + # build the FFILibrary class and instance, this is a module subclass + # because modules are expected to have usually-constant-attributes and + # in PyPy this means the JIT is able to treat attributes as constant, + # which we want. + class FFILibrary(types.ModuleType): + _cffi_generic_module = module + _cffi_ffi = self.ffi + _cffi_dir = [] + def __dir__(self): + return FFILibrary._cffi_dir + library = FFILibrary("") + # + # finally, call the loaded_gen_xxx() functions. This will set + # up the 'library' object. + self._load(module, 'loaded', library=library) + return library + + def _get_declarations(self): + lst = [(key, tp) for (key, (tp, qual)) in + self.ffi._parser._declarations.items()] + lst.sort() + return lst + + def _generate(self, step_name): + for name, tp in self._get_declarations(): + kind, realname = name.split(' ', 1) + try: + method = getattr(self, '_generate_gen_%s_%s' % (kind, + step_name)) + except AttributeError: + raise VerificationError( + "not implemented in verify(): %r" % name) + try: + method(tp, realname) + except Exception as e: + model.attach_exception_info(e, name) + raise + + def _load(self, module, step_name, **kwds): + for name, tp in self._get_declarations(): + kind, realname = name.split(' ', 1) + method = getattr(self, '_%s_gen_%s' % (step_name, kind)) + try: + method(tp, realname, module, **kwds) + except Exception as e: + model.attach_exception_info(e, name) + raise + + def _generate_nothing(self, tp, name): + pass + + def _loaded_noop(self, tp, name, module, **kwds): + pass + + # ---------- + # typedefs: generates no code so far + + _generate_gen_typedef_decl = _generate_nothing + _loading_gen_typedef = _loaded_noop + _loaded_gen_typedef = _loaded_noop + + # ---------- + # function declarations + + def _generate_gen_function_decl(self, tp, name): + assert isinstance(tp, model.FunctionPtrType) + if tp.ellipsis: + # cannot support vararg functions better than this: check for its + # exact type (including the fixed arguments), and build it as a + # constant function pointer (no _cffi_f_%s wrapper) + self._generate_gen_const(False, name, tp) + return + prnt = self._prnt + numargs = len(tp.args) + argnames = [] + for i, type in enumerate(tp.args): + indirection = '' + if isinstance(type, model.StructOrUnion): + indirection = '*' + argnames.append('%sx%d' % (indirection, i)) + context = 'argument of %s' % name + arglist = [type.get_c_name(' %s' % arg, context) + for type, arg in zip(tp.args, argnames)] + tpresult = tp.result + if isinstance(tpresult, model.StructOrUnion): + arglist.insert(0, tpresult.get_c_name(' *r', context)) + tpresult = model.void_type + arglist = ', '.join(arglist) or 'void' + wrappername = '_cffi_f_%s' % name + self.export_symbols.append(wrappername) + if tp.abi: + abi = tp.abi + ' ' + else: + abi = '' + funcdecl = ' %s%s(%s)' % (abi, wrappername, arglist) + context = 'result of %s' % name + prnt(tpresult.get_c_name(funcdecl, context)) + prnt('{') + # + if isinstance(tp.result, model.StructOrUnion): + result_code = '*r = ' + elif not isinstance(tp.result, model.VoidType): + result_code = 'return ' + else: + result_code = '' + prnt(' %s%s(%s);' % (result_code, name, ', '.join(argnames))) + prnt('}') + prnt() + + _loading_gen_function = _loaded_noop + + def _loaded_gen_function(self, tp, name, module, library): + assert isinstance(tp, model.FunctionPtrType) + if tp.ellipsis: + newfunction = self._load_constant(False, tp, name, module) + else: + indirections = [] + base_tp = tp + if (any(isinstance(typ, model.StructOrUnion) for typ in tp.args) + or isinstance(tp.result, model.StructOrUnion)): + indirect_args = [] + for i, typ in enumerate(tp.args): + if isinstance(typ, model.StructOrUnion): + typ = model.PointerType(typ) + indirections.append((i, typ)) + indirect_args.append(typ) + indirect_result = tp.result + if isinstance(indirect_result, model.StructOrUnion): + if indirect_result.fldtypes is None: + raise TypeError("'%s' is used as result type, " + "but is opaque" % ( + indirect_result._get_c_name(),)) + indirect_result = model.PointerType(indirect_result) + indirect_args.insert(0, indirect_result) + indirections.insert(0, ("result", indirect_result)) + indirect_result = model.void_type + tp = model.FunctionPtrType(tuple(indirect_args), + indirect_result, tp.ellipsis) + BFunc = self.ffi._get_cached_btype(tp) + wrappername = '_cffi_f_%s' % name + newfunction = module.load_function(BFunc, wrappername) + for i, typ in indirections: + newfunction = self._make_struct_wrapper(newfunction, i, typ, + base_tp) + setattr(library, name, newfunction) + type(library)._cffi_dir.append(name) + + def _make_struct_wrapper(self, oldfunc, i, tp, base_tp): + backend = self.ffi._backend + BType = self.ffi._get_cached_btype(tp) + if i == "result": + ffi = self.ffi + def newfunc(*args): + res = ffi.new(BType) + oldfunc(res, *args) + return res[0] + else: + def newfunc(*args): + args = args[:i] + (backend.newp(BType, args[i]),) + args[i+1:] + return oldfunc(*args) + newfunc._cffi_base_type = base_tp + return newfunc + + # ---------- + # named structs + + def _generate_gen_struct_decl(self, tp, name): + assert name == tp.name + self._generate_struct_or_union_decl(tp, 'struct', name) + + def _loading_gen_struct(self, tp, name, module): + self._loading_struct_or_union(tp, 'struct', name, module) + + def _loaded_gen_struct(self, tp, name, module, **kwds): + self._loaded_struct_or_union(tp) + + def _generate_gen_union_decl(self, tp, name): + assert name == tp.name + self._generate_struct_or_union_decl(tp, 'union', name) + + def _loading_gen_union(self, tp, name, module): + self._loading_struct_or_union(tp, 'union', name, module) + + def _loaded_gen_union(self, tp, name, module, **kwds): + self._loaded_struct_or_union(tp) + + def _generate_struct_or_union_decl(self, tp, prefix, name): + if tp.fldnames is None: + return # nothing to do with opaque structs + checkfuncname = '_cffi_check_%s_%s' % (prefix, name) + layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name) + cname = ('%s %s' % (prefix, name)).strip() + # + prnt = self._prnt + prnt('static void %s(%s *p)' % (checkfuncname, cname)) + prnt('{') + prnt(' /* only to generate compile-time warnings or errors */') + prnt(' (void)p;') + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if (isinstance(ftype, model.PrimitiveType) + and ftype.is_integer_type()) or fbitsize >= 0: + # accept all integers, but complain on float or double + prnt(' (void)((p->%s) << 1);' % fname) + else: + # only accept exactly the type declared. + try: + prnt(' { %s = &p->%s; (void)tmp; }' % ( + ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual), + fname)) + except VerificationError as e: + prnt(' /* %s */' % str(e)) # cannot verify it, ignore + prnt('}') + self.export_symbols.append(layoutfuncname) + prnt('intptr_t %s(intptr_t i)' % (layoutfuncname,)) + prnt('{') + prnt(' struct _cffi_aligncheck { char x; %s y; };' % cname) + prnt(' static intptr_t nums[] = {') + prnt(' sizeof(%s),' % cname) + prnt(' offsetof(struct _cffi_aligncheck, y),') + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if fbitsize >= 0: + continue # xxx ignore fbitsize for now + prnt(' offsetof(%s, %s),' % (cname, fname)) + if isinstance(ftype, model.ArrayType) and ftype.length is None: + prnt(' 0, /* %s */' % ftype._get_c_name()) + else: + prnt(' sizeof(((%s *)0)->%s),' % (cname, fname)) + prnt(' -1') + prnt(' };') + prnt(' return nums[i];') + prnt(' /* the next line is not executed, but compiled */') + prnt(' %s(0);' % (checkfuncname,)) + prnt('}') + prnt() + + def _loading_struct_or_union(self, tp, prefix, name, module): + if tp.fldnames is None: + return # nothing to do with opaque structs + layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name) + # + BFunc = self.ffi._typeof_locked("intptr_t(*)(intptr_t)")[0] + function = module.load_function(BFunc, layoutfuncname) + layout = [] + num = 0 + while True: + x = function(num) + if x < 0: break + layout.append(x) + num += 1 + if isinstance(tp, model.StructOrUnion) and tp.partial: + # use the function()'s sizes and offsets to guide the + # layout of the struct + totalsize = layout[0] + totalalignment = layout[1] + fieldofs = layout[2::2] + fieldsize = layout[3::2] + tp.force_flatten() + assert len(fieldofs) == len(fieldsize) == len(tp.fldnames) + tp.fixedlayout = fieldofs, fieldsize, totalsize, totalalignment + else: + cname = ('%s %s' % (prefix, name)).strip() + self._struct_pending_verification[tp] = layout, cname + + def _loaded_struct_or_union(self, tp): + if tp.fldnames is None: + return # nothing to do with opaque structs + self.ffi._get_cached_btype(tp) # force 'fixedlayout' to be considered + + if tp in self._struct_pending_verification: + # check that the layout sizes and offsets match the real ones + def check(realvalue, expectedvalue, msg): + if realvalue != expectedvalue: + raise VerificationError( + "%s (we have %d, but C compiler says %d)" + % (msg, expectedvalue, realvalue)) + ffi = self.ffi + BStruct = ffi._get_cached_btype(tp) + layout, cname = self._struct_pending_verification.pop(tp) + check(layout[0], ffi.sizeof(BStruct), "wrong total size") + check(layout[1], ffi.alignof(BStruct), "wrong total alignment") + i = 2 + for fname, ftype, fbitsize, fqual in tp.enumfields(): + if fbitsize >= 0: + continue # xxx ignore fbitsize for now + check(layout[i], ffi.offsetof(BStruct, fname), + "wrong offset for field %r" % (fname,)) + if layout[i+1] != 0: + BField = ffi._get_cached_btype(ftype) + check(layout[i+1], ffi.sizeof(BField), + "wrong size for field %r" % (fname,)) + i += 2 + assert i == len(layout) + + # ---------- + # 'anonymous' declarations. These are produced for anonymous structs + # or unions; the 'name' is obtained by a typedef. + + def _generate_gen_anonymous_decl(self, tp, name): + if isinstance(tp, model.EnumType): + self._generate_gen_enum_decl(tp, name, '') + else: + self._generate_struct_or_union_decl(tp, '', name) + + def _loading_gen_anonymous(self, tp, name, module): + if isinstance(tp, model.EnumType): + self._loading_gen_enum(tp, name, module, '') + else: + self._loading_struct_or_union(tp, '', name, module) + + def _loaded_gen_anonymous(self, tp, name, module, **kwds): + if isinstance(tp, model.EnumType): + self._loaded_gen_enum(tp, name, module, **kwds) + else: + self._loaded_struct_or_union(tp) + + # ---------- + # constants, likely declared with '#define' + + def _generate_gen_const(self, is_int, name, tp=None, category='const', + check_value=None): + prnt = self._prnt + funcname = '_cffi_%s_%s' % (category, name) + self.export_symbols.append(funcname) + if check_value is not None: + assert is_int + assert category == 'const' + prnt('int %s(char *out_error)' % funcname) + prnt('{') + self._check_int_constant_value(name, check_value) + prnt(' return 0;') + prnt('}') + elif is_int: + assert category == 'const' + prnt('int %s(long long *out_value)' % funcname) + prnt('{') + prnt(' *out_value = (long long)(%s);' % (name,)) + prnt(' return (%s) <= 0;' % (name,)) + prnt('}') + else: + assert tp is not None + assert check_value is None + if category == 'var': + ampersand = '&' + else: + ampersand = '' + extra = '' + if category == 'const' and isinstance(tp, model.StructOrUnion): + extra = 'const *' + ampersand = '&' + prnt(tp.get_c_name(' %s%s(void)' % (extra, funcname), name)) + prnt('{') + prnt(' return (%s%s);' % (ampersand, name)) + prnt('}') + prnt() + + def _generate_gen_constant_decl(self, tp, name): + is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type() + self._generate_gen_const(is_int, name, tp) + + _loading_gen_constant = _loaded_noop + + def _load_constant(self, is_int, tp, name, module, check_value=None): + funcname = '_cffi_const_%s' % name + if check_value is not None: + assert is_int + self._load_known_int_constant(module, funcname) + value = check_value + elif is_int: + BType = self.ffi._typeof_locked("long long*")[0] + BFunc = self.ffi._typeof_locked("int(*)(long long*)")[0] + function = module.load_function(BFunc, funcname) + p = self.ffi.new(BType) + negative = function(p) + value = int(p[0]) + if value < 0 and not negative: + BLongLong = self.ffi._typeof_locked("long long")[0] + value += (1 << (8*self.ffi.sizeof(BLongLong))) + else: + assert check_value is None + fntypeextra = '(*)(void)' + if isinstance(tp, model.StructOrUnion): + fntypeextra = '*' + fntypeextra + BFunc = self.ffi._typeof_locked(tp.get_c_name(fntypeextra, name))[0] + function = module.load_function(BFunc, funcname) + value = function() + if isinstance(tp, model.StructOrUnion): + value = value[0] + return value + + def _loaded_gen_constant(self, tp, name, module, library): + is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type() + value = self._load_constant(is_int, tp, name, module) + setattr(library, name, value) + type(library)._cffi_dir.append(name) + + # ---------- + # enums + + def _check_int_constant_value(self, name, value): + prnt = self._prnt + if value <= 0: + prnt(' if ((%s) > 0 || (long)(%s) != %dL) {' % ( + name, name, value)) + else: + prnt(' if ((%s) <= 0 || (unsigned long)(%s) != %dUL) {' % ( + name, name, value)) + prnt(' char buf[64];') + prnt(' if ((%s) <= 0)' % name) + prnt(' sprintf(buf, "%%ld", (long)(%s));' % name) + prnt(' else') + prnt(' sprintf(buf, "%%lu", (unsigned long)(%s));' % + name) + prnt(' sprintf(out_error, "%s has the real value %s, not %s",') + prnt(' "%s", buf, "%d");' % (name[:100], value)) + prnt(' return -1;') + prnt(' }') + + def _load_known_int_constant(self, module, funcname): + BType = self.ffi._typeof_locked("char[]")[0] + BFunc = self.ffi._typeof_locked("int(*)(char*)")[0] + function = module.load_function(BFunc, funcname) + p = self.ffi.new(BType, 256) + if function(p) < 0: + error = self.ffi.string(p) + if sys.version_info >= (3,): + error = str(error, 'utf-8') + raise VerificationError(error) + + def _enum_funcname(self, prefix, name): + # "$enum_$1" => "___D_enum____D_1" + name = name.replace('$', '___D_') + return '_cffi_e_%s_%s' % (prefix, name) + + def _generate_gen_enum_decl(self, tp, name, prefix='enum'): + if tp.partial: + for enumerator in tp.enumerators: + self._generate_gen_const(True, enumerator) + return + # + funcname = self._enum_funcname(prefix, name) + self.export_symbols.append(funcname) + prnt = self._prnt + prnt('int %s(char *out_error)' % funcname) + prnt('{') + for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues): + self._check_int_constant_value(enumerator, enumvalue) + prnt(' return 0;') + prnt('}') + prnt() + + def _loading_gen_enum(self, tp, name, module, prefix='enum'): + if tp.partial: + enumvalues = [self._load_constant(True, tp, enumerator, module) + for enumerator in tp.enumerators] + tp.enumvalues = tuple(enumvalues) + tp.partial_resolved = True + else: + funcname = self._enum_funcname(prefix, name) + self._load_known_int_constant(module, funcname) + + def _loaded_gen_enum(self, tp, name, module, library): + for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues): + setattr(library, enumerator, enumvalue) + type(library)._cffi_dir.append(enumerator) + + # ---------- + # macros: for now only for integers + + def _generate_gen_macro_decl(self, tp, name): + if tp == '...': + check_value = None + else: + check_value = tp # an integer + self._generate_gen_const(True, name, check_value=check_value) + + _loading_gen_macro = _loaded_noop + + def _loaded_gen_macro(self, tp, name, module, library): + if tp == '...': + check_value = None + else: + check_value = tp # an integer + value = self._load_constant(True, tp, name, module, + check_value=check_value) + setattr(library, name, value) + type(library)._cffi_dir.append(name) + + # ---------- + # global variables + + def _generate_gen_variable_decl(self, tp, name): + if isinstance(tp, model.ArrayType): + if tp.length_is_unknown(): + prnt = self._prnt + funcname = '_cffi_sizeof_%s' % (name,) + self.export_symbols.append(funcname) + prnt("size_t %s(void)" % funcname) + prnt("{") + prnt(" return sizeof(%s);" % (name,)) + prnt("}") + tp_ptr = model.PointerType(tp.item) + self._generate_gen_const(False, name, tp_ptr) + else: + tp_ptr = model.PointerType(tp) + self._generate_gen_const(False, name, tp_ptr, category='var') + + _loading_gen_variable = _loaded_noop + + def _loaded_gen_variable(self, tp, name, module, library): + if isinstance(tp, model.ArrayType): # int a[5] is "constant" in the + # sense that "a=..." is forbidden + if tp.length_is_unknown(): + funcname = '_cffi_sizeof_%s' % (name,) + BFunc = self.ffi._typeof_locked('size_t(*)(void)')[0] + function = module.load_function(BFunc, funcname) + size = function() + BItemType = self.ffi._get_cached_btype(tp.item) + length, rest = divmod(size, self.ffi.sizeof(BItemType)) + if rest != 0: + raise VerificationError( + "bad size: %r does not seem to be an array of %s" % + (name, tp.item)) + tp = tp.resolve_length(length) + tp_ptr = model.PointerType(tp.item) + value = self._load_constant(False, tp_ptr, name, module) + # 'value' is a which we have to replace with + # a if the N is actually known + if tp.length is not None: + BArray = self.ffi._get_cached_btype(tp) + value = self.ffi.cast(BArray, value) + setattr(library, name, value) + type(library)._cffi_dir.append(name) + return + # remove ptr= from the library instance, and replace + # it by a property on the class, which reads/writes into ptr[0]. + funcname = '_cffi_var_%s' % name + BFunc = self.ffi._typeof_locked(tp.get_c_name('*(*)(void)', name))[0] + function = module.load_function(BFunc, funcname) + ptr = function() + def getter(library): + return ptr[0] + def setter(library, value): + ptr[0] = value + setattr(type(library), name, property(getter, setter)) + type(library)._cffi_dir.append(name) + +cffimod_header = r''' +#include +#include +#include +#include +#include /* XXX for ssize_t on some platforms */ + +/* this block of #ifs should be kept exactly identical between + c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py + and cffi/_cffi_include.h */ +#if defined(_MSC_VER) +# include /* for alloca() */ +# if _MSC_VER < 1600 /* MSVC < 2010 */ + typedef __int8 int8_t; + typedef __int16 int16_t; + typedef __int32 int32_t; + typedef __int64 int64_t; + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + typedef unsigned __int64 uint64_t; + typedef __int8 int_least8_t; + typedef __int16 int_least16_t; + typedef __int32 int_least32_t; + typedef __int64 int_least64_t; + typedef unsigned __int8 uint_least8_t; + typedef unsigned __int16 uint_least16_t; + typedef unsigned __int32 uint_least32_t; + typedef unsigned __int64 uint_least64_t; + typedef __int8 int_fast8_t; + typedef __int16 int_fast16_t; + typedef __int32 int_fast32_t; + typedef __int64 int_fast64_t; + typedef unsigned __int8 uint_fast8_t; + typedef unsigned __int16 uint_fast16_t; + typedef unsigned __int32 uint_fast32_t; + typedef unsigned __int64 uint_fast64_t; + typedef __int64 intmax_t; + typedef unsigned __int64 uintmax_t; +# else +# include +# endif +# if _MSC_VER < 1800 /* MSVC < 2013 */ +# ifndef __cplusplus + typedef unsigned char _Bool; +# endif +# endif +#else +# include +# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux) +# include +# endif +#endif +''' diff --git a/.venv/Lib/site-packages/cffi/verifier.py b/.venv/Lib/site-packages/cffi/verifier.py new file mode 100644 index 00000000..e392a2b7 --- /dev/null +++ b/.venv/Lib/site-packages/cffi/verifier.py @@ -0,0 +1,306 @@ +# +# DEPRECATED: implementation for ffi.verify() +# +import sys, os, binascii, shutil, io +from . import __version_verifier_modules__ +from . import ffiplatform +from .error import VerificationError + +if sys.version_info >= (3, 3): + import importlib.machinery + def _extension_suffixes(): + return importlib.machinery.EXTENSION_SUFFIXES[:] +else: + import imp + def _extension_suffixes(): + return [suffix for suffix, _, type in imp.get_suffixes() + if type == imp.C_EXTENSION] + + +if sys.version_info >= (3,): + NativeIO = io.StringIO +else: + class NativeIO(io.BytesIO): + def write(self, s): + if isinstance(s, unicode): + s = s.encode('ascii') + super(NativeIO, self).write(s) + + +class Verifier(object): + + def __init__(self, ffi, preamble, tmpdir=None, modulename=None, + ext_package=None, tag='', force_generic_engine=False, + source_extension='.c', flags=None, relative_to=None, **kwds): + if ffi._parser._uses_new_feature: + raise VerificationError( + "feature not supported with ffi.verify(), but only " + "with ffi.set_source(): %s" % (ffi._parser._uses_new_feature,)) + self.ffi = ffi + self.preamble = preamble + if not modulename: + flattened_kwds = ffiplatform.flatten(kwds) + vengine_class = _locate_engine_class(ffi, force_generic_engine) + self._vengine = vengine_class(self) + self._vengine.patch_extension_kwds(kwds) + self.flags = flags + self.kwds = self.make_relative_to(kwds, relative_to) + # + if modulename: + if tag: + raise TypeError("can't specify both 'modulename' and 'tag'") + else: + key = '\x00'.join(['%d.%d' % sys.version_info[:2], + __version_verifier_modules__, + preamble, flattened_kwds] + + ffi._cdefsources) + if sys.version_info >= (3,): + key = key.encode('utf-8') + k1 = hex(binascii.crc32(key[0::2]) & 0xffffffff) + k1 = k1.lstrip('0x').rstrip('L') + k2 = hex(binascii.crc32(key[1::2]) & 0xffffffff) + k2 = k2.lstrip('0').rstrip('L') + modulename = '_cffi_%s_%s%s%s' % (tag, self._vengine._class_key, + k1, k2) + suffix = _get_so_suffixes()[0] + self.tmpdir = tmpdir or _caller_dir_pycache() + self.sourcefilename = os.path.join(self.tmpdir, modulename + source_extension) + self.modulefilename = os.path.join(self.tmpdir, modulename + suffix) + self.ext_package = ext_package + self._has_source = False + self._has_module = False + + def write_source(self, file=None): + """Write the C source code. It is produced in 'self.sourcefilename', + which can be tweaked beforehand.""" + with self.ffi._lock: + if self._has_source and file is None: + raise VerificationError( + "source code already written") + self._write_source(file) + + def compile_module(self): + """Write the C source code (if not done already) and compile it. + This produces a dynamic link library in 'self.modulefilename'.""" + with self.ffi._lock: + if self._has_module: + raise VerificationError("module already compiled") + if not self._has_source: + self._write_source() + self._compile_module() + + def load_library(self): + """Get a C module from this Verifier instance. + Returns an instance of a FFILibrary class that behaves like the + objects returned by ffi.dlopen(), but that delegates all + operations to the C module. If necessary, the C code is written + and compiled first. + """ + with self.ffi._lock: + if not self._has_module: + self._locate_module() + if not self._has_module: + if not self._has_source: + self._write_source() + self._compile_module() + return self._load_library() + + def get_module_name(self): + basename = os.path.basename(self.modulefilename) + # kill both the .so extension and the other .'s, as introduced + # by Python 3: 'basename.cpython-33m.so' + basename = basename.split('.', 1)[0] + # and the _d added in Python 2 debug builds --- but try to be + # conservative and not kill a legitimate _d + if basename.endswith('_d') and hasattr(sys, 'gettotalrefcount'): + basename = basename[:-2] + return basename + + def get_extension(self): + if not self._has_source: + with self.ffi._lock: + if not self._has_source: + self._write_source() + sourcename = ffiplatform.maybe_relative_path(self.sourcefilename) + modname = self.get_module_name() + return ffiplatform.get_extension(sourcename, modname, **self.kwds) + + def generates_python_module(self): + return self._vengine._gen_python_module + + def make_relative_to(self, kwds, relative_to): + if relative_to and os.path.dirname(relative_to): + dirname = os.path.dirname(relative_to) + kwds = kwds.copy() + for key in ffiplatform.LIST_OF_FILE_NAMES: + if key in kwds: + lst = kwds[key] + if not isinstance(lst, (list, tuple)): + raise TypeError("keyword '%s' should be a list or tuple" + % (key,)) + lst = [os.path.join(dirname, fn) for fn in lst] + kwds[key] = lst + return kwds + + # ---------- + + def _locate_module(self): + if not os.path.isfile(self.modulefilename): + if self.ext_package: + try: + pkg = __import__(self.ext_package, None, None, ['__doc__']) + except ImportError: + return # cannot import the package itself, give up + # (e.g. it might be called differently before installation) + path = pkg.__path__ + else: + path = None + filename = self._vengine.find_module(self.get_module_name(), path, + _get_so_suffixes()) + if filename is None: + return + self.modulefilename = filename + self._vengine.collect_types() + self._has_module = True + + def _write_source_to(self, file): + self._vengine._f = file + try: + self._vengine.write_source_to_f() + finally: + del self._vengine._f + + def _write_source(self, file=None): + if file is not None: + self._write_source_to(file) + else: + # Write our source file to an in memory file. + f = NativeIO() + self._write_source_to(f) + source_data = f.getvalue() + + # Determine if this matches the current file + if os.path.exists(self.sourcefilename): + with open(self.sourcefilename, "r") as fp: + needs_written = not (fp.read() == source_data) + else: + needs_written = True + + # Actually write the file out if it doesn't match + if needs_written: + _ensure_dir(self.sourcefilename) + with open(self.sourcefilename, "w") as fp: + fp.write(source_data) + + # Set this flag + self._has_source = True + + def _compile_module(self): + # compile this C source + tmpdir = os.path.dirname(self.sourcefilename) + outputfilename = ffiplatform.compile(tmpdir, self.get_extension()) + try: + same = ffiplatform.samefile(outputfilename, self.modulefilename) + except OSError: + same = False + if not same: + _ensure_dir(self.modulefilename) + shutil.move(outputfilename, self.modulefilename) + self._has_module = True + + def _load_library(self): + assert self._has_module + if self.flags is not None: + return self._vengine.load_library(self.flags) + else: + return self._vengine.load_library() + +# ____________________________________________________________ + +_FORCE_GENERIC_ENGINE = False # for tests + +def _locate_engine_class(ffi, force_generic_engine): + if _FORCE_GENERIC_ENGINE: + force_generic_engine = True + if not force_generic_engine: + if '__pypy__' in sys.builtin_module_names: + force_generic_engine = True + else: + try: + import _cffi_backend + except ImportError: + _cffi_backend = '?' + if ffi._backend is not _cffi_backend: + force_generic_engine = True + if force_generic_engine: + from . import vengine_gen + return vengine_gen.VGenericEngine + else: + from . import vengine_cpy + return vengine_cpy.VCPythonEngine + +# ____________________________________________________________ + +_TMPDIR = None + +def _caller_dir_pycache(): + if _TMPDIR: + return _TMPDIR + result = os.environ.get('CFFI_TMPDIR') + if result: + return result + filename = sys._getframe(2).f_code.co_filename + return os.path.abspath(os.path.join(os.path.dirname(filename), + '__pycache__')) + +def set_tmpdir(dirname): + """Set the temporary directory to use instead of __pycache__.""" + global _TMPDIR + _TMPDIR = dirname + +def cleanup_tmpdir(tmpdir=None, keep_so=False): + """Clean up the temporary directory by removing all files in it + called `_cffi_*.{c,so}` as well as the `build` subdirectory.""" + tmpdir = tmpdir or _caller_dir_pycache() + try: + filelist = os.listdir(tmpdir) + except OSError: + return + if keep_so: + suffix = '.c' # only remove .c files + else: + suffix = _get_so_suffixes()[0].lower() + for fn in filelist: + if fn.lower().startswith('_cffi_') and ( + fn.lower().endswith(suffix) or fn.lower().endswith('.c')): + try: + os.unlink(os.path.join(tmpdir, fn)) + except OSError: + pass + clean_dir = [os.path.join(tmpdir, 'build')] + for dir in clean_dir: + try: + for fn in os.listdir(dir): + fn = os.path.join(dir, fn) + if os.path.isdir(fn): + clean_dir.append(fn) + else: + os.unlink(fn) + except OSError: + pass + +def _get_so_suffixes(): + suffixes = _extension_suffixes() + if not suffixes: + # bah, no C_EXTENSION available. Occurs on pypy without cpyext + if sys.platform == 'win32': + suffixes = [".pyd"] + else: + suffixes = [".so"] + + return suffixes + +def _ensure_dir(filename): + dirname = os.path.dirname(filename) + if dirname and not os.path.isdir(dirname): + os.makedirs(dirname) diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/INSTALLER b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE new file mode 100644 index 00000000..b11f379e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE @@ -0,0 +1,3 @@ +This software is made available under the terms of *either* of the licenses +found in LICENSE.APACHE or LICENSE.BSD. Contributions to cryptography are made +under the terms of *both* these licenses. diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.APACHE b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.APACHE new file mode 100644 index 00000000..62589edd --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.APACHE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.BSD b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.BSD new file mode 100644 index 00000000..ec1a29d3 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/LICENSE.BSD @@ -0,0 +1,27 @@ +Copyright (c) Individual contributors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of PyCA Cryptography nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/METADATA b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/METADATA new file mode 100644 index 00000000..2c043b13 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/METADATA @@ -0,0 +1,135 @@ +Metadata-Version: 2.1 +Name: cryptography +Version: 42.0.1 +Summary: cryptography is a package which provides cryptographic recipes and primitives to Python developers. +Author-email: The Python Cryptographic Authority and individual contributors +License: Apache-2.0 OR BSD-3-Clause +Project-URL: homepage, https://github.com/pyca/cryptography +Project-URL: documentation, https://cryptography.io/ +Project-URL: source, https://github.com/pyca/cryptography/ +Project-URL: issues, https://github.com/pyca/cryptography/issues +Project-URL: changelog, https://cryptography.io/en/latest/changelog/ +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: License :: OSI Approved :: BSD License +Classifier: Natural Language :: English +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: POSIX +Classifier: Operating System :: POSIX :: BSD +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Security :: Cryptography +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +License-File: LICENSE.APACHE +License-File: LICENSE.BSD +Requires-Dist: cffi >=1.12 ; platform_python_implementation != "PyPy" +Provides-Extra: docs +Requires-Dist: sphinx >=5.3.0 ; extra == 'docs' +Requires-Dist: sphinx-rtd-theme >=1.1.1 ; extra == 'docs' +Provides-Extra: docstest +Requires-Dist: pyenchant >=1.6.11 ; extra == 'docstest' +Requires-Dist: readme-renderer ; extra == 'docstest' +Requires-Dist: sphinxcontrib-spelling >=4.0.1 ; extra == 'docstest' +Provides-Extra: nox +Requires-Dist: nox ; extra == 'nox' +Provides-Extra: pep8test +Requires-Dist: ruff ; extra == 'pep8test' +Requires-Dist: mypy ; extra == 'pep8test' +Requires-Dist: check-sdist ; extra == 'pep8test' +Requires-Dist: click ; extra == 'pep8test' +Provides-Extra: sdist +Requires-Dist: build ; extra == 'sdist' +Provides-Extra: ssh +Requires-Dist: bcrypt >=3.1.5 ; extra == 'ssh' +Provides-Extra: test +Requires-Dist: pytest >=6.2.0 ; extra == 'test' +Requires-Dist: pytest-benchmark ; extra == 'test' +Requires-Dist: pytest-cov ; extra == 'test' +Requires-Dist: pytest-xdist ; extra == 'test' +Requires-Dist: pretend ; extra == 'test' +Requires-Dist: certifi ; extra == 'test' +Provides-Extra: test-randomorder +Requires-Dist: pytest-randomly ; extra == 'test-randomorder' + +pyca/cryptography +================= + +.. image:: https://img.shields.io/pypi/v/cryptography.svg + :target: https://pypi.org/project/cryptography/ + :alt: Latest Version + +.. image:: https://readthedocs.org/projects/cryptography/badge/?version=latest + :target: https://cryptography.io + :alt: Latest Docs + +.. image:: https://github.com/pyca/cryptography/workflows/CI/badge.svg?branch=main + :target: https://github.com/pyca/cryptography/actions?query=workflow%3ACI+branch%3Amain + + +``cryptography`` is a package which provides cryptographic recipes and +primitives to Python developers. Our goal is for it to be your "cryptographic +standard library". It supports Python 3.7+ and PyPy3 7.3.11+. + +``cryptography`` includes both high level recipes and low level interfaces to +common cryptographic algorithms such as symmetric ciphers, message digests, and +key derivation functions. For example, to encrypt something with +``cryptography``'s high level symmetric encryption recipe: + +.. code-block:: pycon + + >>> from cryptography.fernet import Fernet + >>> # Put this somewhere safe! + >>> key = Fernet.generate_key() + >>> f = Fernet(key) + >>> token = f.encrypt(b"A really secret message. Not for prying eyes.") + >>> token + b'...' + >>> f.decrypt(token) + b'A really secret message. Not for prying eyes.' + +You can find more information in the `documentation`_. + +You can install ``cryptography`` with: + +.. code-block:: console + + $ pip install cryptography + +For full details see `the installation documentation`_. + +Discussion +~~~~~~~~~~ + +If you run into bugs, you can file them in our `issue tracker`_. + +We maintain a `cryptography-dev`_ mailing list for development discussion. + +You can also join ``#pyca`` on ``irc.libera.chat`` to ask questions or get +involved. + +Security +~~~~~~~~ + +Need to report a security issue? Please consult our `security reporting`_ +documentation. + + +.. _`documentation`: https://cryptography.io/ +.. _`the installation documentation`: https://cryptography.io/en/latest/installation/ +.. _`issue tracker`: https://github.com/pyca/cryptography/issues +.. _`cryptography-dev`: https://mail.python.org/mailman/listinfo/cryptography-dev +.. _`security reporting`: https://cryptography.io/en/latest/security/ diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/RECORD b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/RECORD new file mode 100644 index 00000000..d09eaade --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/RECORD @@ -0,0 +1,171 @@ +cryptography-42.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +cryptography-42.0.1.dist-info/LICENSE,sha256=Pgx8CRqUi4JTO6mP18u0BDLW8amsv4X1ki0vmak65rs,197 +cryptography-42.0.1.dist-info/LICENSE.APACHE,sha256=qsc7MUj20dcRHbyjIJn2jSbGRMaBOuHk8F9leaomY_4,11360 +cryptography-42.0.1.dist-info/LICENSE.BSD,sha256=YCxMdILeZHndLpeTzaJ15eY9dz2s0eymiSMqtwCPtPs,1532 +cryptography-42.0.1.dist-info/METADATA,sha256=N9jI_s4pSMGbz0DaqjAr87UJCMGFRTqtPYcUoNitxJ8,5430 +cryptography-42.0.1.dist-info/RECORD,, +cryptography-42.0.1.dist-info/WHEEL,sha256=ZzJfItdlTwUbeh2SvWRPbrqgDfW_djikghnwfRmqFIQ,100 +cryptography-42.0.1.dist-info/top_level.txt,sha256=KNaT-Sn2K4uxNaEbe6mYdDn3qWDMlp4y-MtWfB73nJc,13 +cryptography/__about__.py,sha256=BWH-wXuR_WW0RqjRc0cp-C7I3e1ux0Aq-H2l7weXJb0,445 +cryptography/__init__.py,sha256=iVPlBlXWTJyiFeRedxcbMPhyHB34viOM10d72vGnWuE,364 +cryptography/__pycache__/__about__.cpython-311.pyc,, +cryptography/__pycache__/__init__.cpython-311.pyc,, +cryptography/__pycache__/exceptions.cpython-311.pyc,, +cryptography/__pycache__/fernet.cpython-311.pyc,, +cryptography/__pycache__/utils.cpython-311.pyc,, +cryptography/exceptions.py,sha256=835EWILc2fwxw-gyFMriciC2SqhViETB10LBSytnDIc,1087 +cryptography/fernet.py,sha256=aPj82w-Z_1GBXUtWRUsZdVbMwRo5Mbjj0wkA9wG4rkw,6696 +cryptography/hazmat/__init__.py,sha256=5IwrLWrVp0AjEr_4FdWG_V057NSJGY_W4egNNsuct0g,455 +cryptography/hazmat/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/__pycache__/_oid.cpython-311.pyc,, +cryptography/hazmat/_oid.py,sha256=0DhT6N-ziZzlQp05iPKOsy5wdPMayiKdrSg_yZfWLzc,14460 +cryptography/hazmat/backends/__init__.py,sha256=O5jvKFQdZnXhKeqJ-HtulaEL9Ni7mr1mDzZY5kHlYhI,361 +cryptography/hazmat/backends/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/__init__.py,sha256=p3jmJfnCag9iE5sdMrN6VvVEu55u46xaS_IjoI0SrmA,305 +cryptography/hazmat/backends/openssl/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/__pycache__/aead.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/__pycache__/backend.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/__pycache__/ciphers.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/__pycache__/decode_asn1.cpython-311.pyc,, +cryptography/hazmat/backends/openssl/aead.py,sha256=UBNLqkicUo2ve7q-q8R49IgVOYlDMmSPtbPUK2qdMbM,8176 +cryptography/hazmat/backends/openssl/backend.py,sha256=dqdL5le6MnRjSuWjxRnyzvi8gIa_5rsWdB_9lrpeltg,32606 +cryptography/hazmat/backends/openssl/ciphers.py,sha256=MwBbBauaUjNiaja25oZKt7vI9bRGXfF5lK1p-8AQ67U,10353 +cryptography/hazmat/backends/openssl/decode_asn1.py,sha256=kz6gys8wuJhrx4QyU6enYx7UatNHr0LB3TI1jH3oQ54,1148 +cryptography/hazmat/bindings/__init__.py,sha256=s9oKCQ2ycFdXoERdS1imafueSkBsL9kvbyfghaauZ9Y,180 +cryptography/hazmat/bindings/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/bindings/_rust.pyd,sha256=urLvC6p9z6aHShxTSNn-6VEEn5POzV0dFwFUqrhKYvI,7211520 +cryptography/hazmat/bindings/_rust/__init__.pyi,sha256=djseHBlzUqDJ7JUc2J51OT_7CLm_Lz0EyVQ55o3udUI,495 +cryptography/hazmat/bindings/_rust/_openssl.pyi,sha256=mpNJLuYLbCVrd5i33FBTmWwL_55Dw7JPkSLlSX9Q7oI,230 +cryptography/hazmat/bindings/_rust/asn1.pyi,sha256=8w-f89ls0pb7BAbt1E0Pvkd59NGtTFItLtFK8ZJGbkk,556 +cryptography/hazmat/bindings/_rust/exceptions.pyi,sha256=exXr2xw_0pB1kk93cYbM3MohbzoUkjOms1ZMUi0uQZE,640 +cryptography/hazmat/bindings/_rust/ocsp.pyi,sha256=qUA2x7lwbG_Z7wJ_wUxsBFJ71arjoX-nnkZAw4nVDeQ,860 +cryptography/hazmat/bindings/_rust/openssl/__init__.pyi,sha256=SwBmmK_wzQbHK_Y5Q3lQIIk3NPFciNv6IjXVSBLx89Q,1067 +cryptography/hazmat/bindings/_rust/openssl/aead.pyi,sha256=ZNsO1H8Q9ixQO9Db7qtkboWKM5fycWY_ZeyGXb3scHg,1737 +cryptography/hazmat/bindings/_rust/openssl/cmac.pyi,sha256=nPH0X57RYpsAkRowVpjQiHE566ThUTx7YXrsadmrmHk,564 +cryptography/hazmat/bindings/_rust/openssl/dh.pyi,sha256=Z3TC-G04-THtSdAOPLM1h2G7ml5bda1ElZUcn5wpuhk,1564 +cryptography/hazmat/bindings/_rust/openssl/dsa.pyi,sha256=qBtkgj2albt2qFcnZ9UDrhzoNhCVO7HTby5VSf1EXMI,1299 +cryptography/hazmat/bindings/_rust/openssl/ec.pyi,sha256=zJy0pRa5n-_p2dm45PxECB_-B6SVZyNKfjxFDpPqT38,1691 +cryptography/hazmat/bindings/_rust/openssl/ed25519.pyi,sha256=OJsrblS2nHptZctva-pAKFL5q8yPEAkhmjPZpJ6TA94,493 +cryptography/hazmat/bindings/_rust/openssl/ed448.pyi,sha256=SkPHK2HdbYN02TVQEUOgW3iTdiEY7HBE4DijpdkAzmk,475 +cryptography/hazmat/bindings/_rust/openssl/hashes.pyi,sha256=J8HoN0GdtPcjRAfNHr5Elva_nkmQfq63L75_z9dd8Uc,573 +cryptography/hazmat/bindings/_rust/openssl/hmac.pyi,sha256=ZmLJ73pmxcZFC1XosWEiXMRYtvJJor3ZLdCQOJu85Cw,662 +cryptography/hazmat/bindings/_rust/openssl/kdf.pyi,sha256=wPS5c7NLspM2632II0I4iH1RSxZvSRtBOVqmpyQATfk,544 +cryptography/hazmat/bindings/_rust/openssl/keys.pyi,sha256=9nFfZ0USUxHtPvqJmvWewz27so3qlQxxTEt2d904msI,980 +cryptography/hazmat/bindings/_rust/openssl/poly1305.pyi,sha256=9iogF7Q4i81IkOS-IMXp6HvxFF_3cNy_ucrAjVQnn14,540 +cryptography/hazmat/bindings/_rust/openssl/rsa.pyi,sha256=2OQCNSXkxgc-3uw1xiCCloIQTV6p9_kK79Yu0rhZgPc,1364 +cryptography/hazmat/bindings/_rust/openssl/x25519.pyi,sha256=2BKdbrddM_9SMUpdvHKGhb9MNjURCarPxccbUDzHeoA,484 +cryptography/hazmat/bindings/_rust/openssl/x448.pyi,sha256=AoRMWNvCJTiH5L-lkIkCdPlrPLUdJvvfXpIvf1GmxpM,466 +cryptography/hazmat/bindings/_rust/pkcs7.pyi,sha256=WfJXBDgmsOg1ui1U3wclgL-xpmbcFNq6lt6fY6yxy8w,619 +cryptography/hazmat/bindings/_rust/x509.pyi,sha256=KqsM2W3tg4MpzxjI4eL9Jbsm7pQwvJ4_-xDE7wA1x3w,3001 +cryptography/hazmat/bindings/openssl/__init__.py,sha256=s9oKCQ2ycFdXoERdS1imafueSkBsL9kvbyfghaauZ9Y,180 +cryptography/hazmat/bindings/openssl/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/bindings/openssl/__pycache__/_conditional.cpython-311.pyc,, +cryptography/hazmat/bindings/openssl/__pycache__/binding.cpython-311.pyc,, +cryptography/hazmat/bindings/openssl/_conditional.py,sha256=rqgTeJjw9y83ICW5hd3bowvFWVO49-gRC9QF-636Vhg,6481 +cryptography/hazmat/bindings/openssl/binding.py,sha256=G4Nh4jXcIYiFyPJhwnJT4TGTyx8m8gY2REG7xgU1eaA,6531 +cryptography/hazmat/primitives/__init__.py,sha256=s9oKCQ2ycFdXoERdS1imafueSkBsL9kvbyfghaauZ9Y,180 +cryptography/hazmat/primitives/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/_asymmetric.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/_cipheralgorithm.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/_serialization.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/cmac.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/constant_time.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/hashes.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/hmac.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/keywrap.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/padding.cpython-311.pyc,, +cryptography/hazmat/primitives/__pycache__/poly1305.cpython-311.pyc,, +cryptography/hazmat/primitives/_asymmetric.py,sha256=RhgcouUB6HTiFDBrR1LxqkMjpUxIiNvQ1r_zJjRG6qQ,532 +cryptography/hazmat/primitives/_cipheralgorithm.py,sha256=u7ryLG_HivCXn-ulKM-h_eVWMzlobeg0K45Udflk7Gg,1072 +cryptography/hazmat/primitives/_serialization.py,sha256=qrozc8fw2WZSbjk3DAlSl3ResxpauwJ74ZgGoUL-mj0,5142 +cryptography/hazmat/primitives/asymmetric/__init__.py,sha256=s9oKCQ2ycFdXoERdS1imafueSkBsL9kvbyfghaauZ9Y,180 +cryptography/hazmat/primitives/asymmetric/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/dh.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/dsa.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/ec.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/ed25519.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/ed448.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/padding.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/rsa.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/types.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/utils.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/x25519.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/__pycache__/x448.cpython-311.pyc,, +cryptography/hazmat/primitives/asymmetric/dh.py,sha256=OOCjMClH1Bf14Sy7jAdwzEeCxFPb8XUe2qePbExvXwc,3420 +cryptography/hazmat/primitives/asymmetric/dsa.py,sha256=xBwdf0pZOgvqjUKcO7Q0L3NxwalYj0SJDUqThemhSmI,3945 +cryptography/hazmat/primitives/asymmetric/ec.py,sha256=W6nLb4Oho3BI3OsTR_nUI4WRHCbikTrqVOjQQYjV5vs,9704 +cryptography/hazmat/primitives/asymmetric/ed25519.py,sha256=kl63fg7myuMjNTmMoVFeH6iVr0x5FkjNmggxIRTloJk,3423 +cryptography/hazmat/primitives/asymmetric/ed448.py,sha256=2UzEDzzfkPn83UFVFlMZfIMbAixxY09WmQyrwinWTn8,3456 +cryptography/hazmat/primitives/asymmetric/padding.py,sha256=eZcvUqVLbe3u48SunLdeniaPlV4-k6pwBl67OW4jSy8,2885 +cryptography/hazmat/primitives/asymmetric/rsa.py,sha256=HToE4M5VJbGZS_2SbJ11kIGhtQ8D3GozW59sWEzrfZ4,6799 +cryptography/hazmat/primitives/asymmetric/types.py,sha256=LnsOJym-wmPUJ7Knu_7bCNU3kIiELCd6krOaW_JU08I,2996 +cryptography/hazmat/primitives/asymmetric/utils.py,sha256=DPTs6T4F-UhwzFQTh-1fSEpQzazH2jf2xpIro3ItF4o,790 +cryptography/hazmat/primitives/asymmetric/x25519.py,sha256=VGYuRdIYuVBtizpFdNWd2bTrT10JRa1admQdBr08xz8,3341 +cryptography/hazmat/primitives/asymmetric/x448.py,sha256=GKKJBqYLr03VewMF18bXIM941aaWcZIQ4rC02GLLEmw,3374 +cryptography/hazmat/primitives/ciphers/__init__.py,sha256=kAyb9NSczqTrCWj0HEoVp3Cxo7AHW8ibPFQz-ZHsOtA,680 +cryptography/hazmat/primitives/ciphers/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/ciphers/__pycache__/aead.cpython-311.pyc,, +cryptography/hazmat/primitives/ciphers/__pycache__/algorithms.cpython-311.pyc,, +cryptography/hazmat/primitives/ciphers/__pycache__/base.cpython-311.pyc,, +cryptography/hazmat/primitives/ciphers/__pycache__/modes.cpython-311.pyc,, +cryptography/hazmat/primitives/ciphers/aead.py,sha256=V6UKsIPNZQh0cfd8hpXx3ZzztQ-JQ9ChBMMN1ZTZXJ0,5540 +cryptography/hazmat/primitives/ciphers/algorithms.py,sha256=rNsvAJZIft8o0yan5Z62hJ-xoEM_Y6BYBkFs4jnnR2s,5120 +cryptography/hazmat/primitives/ciphers/base.py,sha256=4VktSqxhRjigjNQ3m2BiQQDo-1bYqCxXpddphJukoMI,8445 +cryptography/hazmat/primitives/ciphers/modes.py,sha256=Kw1419ZCUBNbbxd7BctwPp6i8rwnOvvifdXokrx_bYM,8317 +cryptography/hazmat/primitives/cmac.py,sha256=sz_s6H_cYnOvx-VNWdIKhRhe3Ymp8z8J0D3CBqOX3gg,338 +cryptography/hazmat/primitives/constant_time.py,sha256=xdunWT0nf8OvKdcqUhhlFKayGp4_PgVJRU2W1wLSr_A,422 +cryptography/hazmat/primitives/hashes.py,sha256=HCFCsR8p7OEWt1YA7oRbqgKHXOuZnrspkVrniU_B2uU,5091 +cryptography/hazmat/primitives/hmac.py,sha256=RpB3z9z5skirCQrm7zQbtnp9pLMnAjrlTUvKqF5aDDc,423 +cryptography/hazmat/primitives/kdf/__init__.py,sha256=4XibZnrYq4hh5xBjWiIXzaYW6FKx8hPbVaa_cB9zS64,750 +cryptography/hazmat/primitives/kdf/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/concatkdf.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/hkdf.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/kbkdf.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/pbkdf2.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/scrypt.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/__pycache__/x963kdf.cpython-311.pyc,, +cryptography/hazmat/primitives/kdf/concatkdf.py,sha256=bcn4NGXse-EsFl7nlU83e5ilop7TSHcX-CJJS107W80,3686 +cryptography/hazmat/primitives/kdf/hkdf.py,sha256=uhN5L87w4JvtAqQcPh_Ji2TPSc18IDThpaYJiHOWy3A,3015 +cryptography/hazmat/primitives/kdf/kbkdf.py,sha256=C3koAdtF_fwyvbhQA88AYbi3YOrUZ_7eaIM4DkWrfyM,9072 +cryptography/hazmat/primitives/kdf/pbkdf2.py,sha256=1CCH9Q5gXUpnZd3c8d8bCXgpJ3s2hZZGBnuG7FH1waM,2012 +cryptography/hazmat/primitives/kdf/scrypt.py,sha256=4QONhjxA_ZtuQtQ7QV3FnbB8ftrFnM52B4HPfV7hFys,2354 +cryptography/hazmat/primitives/kdf/x963kdf.py,sha256=wCpWmwQjZ2vAu2rlk3R_PX0nINl8WGXYBmlyMOC5iPw,1992 +cryptography/hazmat/primitives/keywrap.py,sha256=kHqtc56YvpTNEi6q1ifoHKXmY4SWqllBv-eBfqMpvuE,5650 +cryptography/hazmat/primitives/padding.py,sha256=g4qonAgYADkMArKt2MXD1XlnGd4ET_Rf5YDADwb_v8Q,6148 +cryptography/hazmat/primitives/poly1305.py,sha256=P5EPQV-RB_FJPahpg01u0Ts4S_PnAmsroxIGXbGeRRo,355 +cryptography/hazmat/primitives/serialization/__init__.py,sha256=6ZlL3EicEzoGdMOat86w8y_XICCnlHdCjFI97rMxRDg,1653 +cryptography/hazmat/primitives/serialization/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/serialization/__pycache__/base.cpython-311.pyc,, +cryptography/hazmat/primitives/serialization/__pycache__/pkcs12.cpython-311.pyc,, +cryptography/hazmat/primitives/serialization/__pycache__/pkcs7.cpython-311.pyc,, +cryptography/hazmat/primitives/serialization/__pycache__/ssh.cpython-311.pyc,, +cryptography/hazmat/primitives/serialization/base.py,sha256=ikq5MJIwp_oUnjiaBco_PmQwOTYuGi-XkYUYHKy8Vo0,615 +cryptography/hazmat/primitives/serialization/pkcs12.py,sha256=jtMcM-At_GZFRD5oSlOGHOE1OcosroWIvmkzrEsv75Q,6599 +cryptography/hazmat/primitives/serialization/pkcs7.py,sha256=uaWAdWggcM087zL1ltQc5fFhpXFFbBNn_2cyQK8toZ4,7488 +cryptography/hazmat/primitives/serialization/ssh.py,sha256=7JjL4ZWcOliyAOJdnlnWi_0nNlLtOrAoj6AqWHdrLNg,50051 +cryptography/hazmat/primitives/twofactor/__init__.py,sha256=tmMZGB-g4IU1r7lIFqASU019zr0uPp_wEBYcwdDCKCA,258 +cryptography/hazmat/primitives/twofactor/__pycache__/__init__.cpython-311.pyc,, +cryptography/hazmat/primitives/twofactor/__pycache__/hotp.cpython-311.pyc,, +cryptography/hazmat/primitives/twofactor/__pycache__/totp.cpython-311.pyc,, +cryptography/hazmat/primitives/twofactor/hotp.py,sha256=l1YdRMIhfPIuHKkA66keBDHhNbnBAlh6-O44P-OHIK8,2976 +cryptography/hazmat/primitives/twofactor/totp.py,sha256=v0y0xKwtYrP83ypOo5Ofd441RJLOkaFfjmp554jo5F0,1450 +cryptography/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +cryptography/utils.py,sha256=8fNXSfKvDgaji9M_m4lVXHFTVdIDP32GhlXzUBYDBHE,4033 +cryptography/x509/__init__.py,sha256=zaKuAaluw0p-lQm4RGK3_NBAG9V_UW6nhv_1m_ppugI,7924 +cryptography/x509/__pycache__/__init__.cpython-311.pyc,, +cryptography/x509/__pycache__/base.cpython-311.pyc,, +cryptography/x509/__pycache__/certificate_transparency.cpython-311.pyc,, +cryptography/x509/__pycache__/extensions.cpython-311.pyc,, +cryptography/x509/__pycache__/general_name.cpython-311.pyc,, +cryptography/x509/__pycache__/name.cpython-311.pyc,, +cryptography/x509/__pycache__/ocsp.cpython-311.pyc,, +cryptography/x509/__pycache__/oid.cpython-311.pyc,, +cryptography/x509/__pycache__/verification.cpython-311.pyc,, +cryptography/x509/base.py,sha256=U2ZTy4BMQKiQ7YwncAnfKffRv7KSzWaMvbbgMlO8blk,36933 +cryptography/x509/certificate_transparency.py,sha256=6HvzAD0dlSQVxy6tnDhGj0-pisp1MaJ9bxQNRr92inI,2261 +cryptography/x509/extensions.py,sha256=YU9R9IGt2tFl3zM7T2LI3dzQvKyvMhZxT2JgqCrZ3SE,66345 +cryptography/x509/general_name.py,sha256=sP_rV11Qlpsk4x3XXGJY_Mv0Q_s9dtjeLckHsjpLQoQ,7836 +cryptography/x509/name.py,sha256=85k7lJRtXnWTsVfsJXHNiWnDrsbW0OJ54np2opaBV28,14609 +cryptography/x509/ocsp.py,sha256=7Na0PAyA6nSyApTGd-QZ9Nfw2uyUS_PDVQx5XUw1xmU,18126 +cryptography/x509/oid.py,sha256=fFosjGsnIB_w_0YrzZv1ggkSVwZl7xmY0zofKZNZkDA,829 +cryptography/x509/verification.py,sha256=mPg6AUQDxK5wgGerP_hkFWD1Wj6l7lAt2IxpizZzekA,668 diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/WHEEL b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/WHEEL new file mode 100644 index 00000000..96dd4533 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: false +Tag: cp39-abi3-win_amd64 + diff --git a/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/top_level.txt b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/top_level.txt new file mode 100644 index 00000000..0d38bc5e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography-42.0.1.dist-info/top_level.txt @@ -0,0 +1 @@ +cryptography diff --git a/.venv/Lib/site-packages/cryptography/__about__.py b/.venv/Lib/site-packages/cryptography/__about__.py new file mode 100644 index 00000000..35d8510f --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/__about__.py @@ -0,0 +1,17 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +__all__ = [ + "__version__", + "__author__", + "__copyright__", +] + +__version__ = "42.0.1" + + +__author__ = "The Python Cryptographic Authority and individual contributors" +__copyright__ = f"Copyright 2013-2024 {__author__}" diff --git a/.venv/Lib/site-packages/cryptography/__init__.py b/.venv/Lib/site-packages/cryptography/__init__.py new file mode 100644 index 00000000..86b9a257 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/__init__.py @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.__about__ import __author__, __copyright__, __version__ + +__all__ = [ + "__version__", + "__author__", + "__copyright__", +] diff --git a/.venv/Lib/site-packages/cryptography/__pycache__/__about__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/__pycache__/__about__.cpython-311.pyc new file mode 100644 index 00000000..f2d832b6 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/__pycache__/__about__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6a6fc420 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/__pycache__/exceptions.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 00000000..82b5aa8c Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/__pycache__/fernet.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/__pycache__/fernet.cpython-311.pyc new file mode 100644 index 00000000..fdf75341 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/__pycache__/fernet.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..039c026f Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/exceptions.py b/.venv/Lib/site-packages/cryptography/exceptions.py new file mode 100644 index 00000000..fe125ea9 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/exceptions.py @@ -0,0 +1,52 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.hazmat.bindings._rust import exceptions as rust_exceptions + +if typing.TYPE_CHECKING: + from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +_Reasons = rust_exceptions._Reasons + + +class UnsupportedAlgorithm(Exception): + def __init__(self, message: str, reason: _Reasons | None = None) -> None: + super().__init__(message) + self._reason = reason + + +class AlreadyFinalized(Exception): + pass + + +class AlreadyUpdated(Exception): + pass + + +class NotYetFinalized(Exception): + pass + + +class InvalidTag(Exception): + pass + + +class InvalidSignature(Exception): + pass + + +class InternalError(Exception): + def __init__( + self, msg: str, err_code: list[rust_openssl.OpenSSLError] + ) -> None: + super().__init__(msg) + self.err_code = err_code + + +class InvalidKey(Exception): + pass diff --git a/.venv/Lib/site-packages/cryptography/fernet.py b/.venv/Lib/site-packages/cryptography/fernet.py new file mode 100644 index 00000000..35ce1131 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/fernet.py @@ -0,0 +1,215 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import base64 +import binascii +import os +import time +import typing + +from cryptography import utils +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.hmac import HMAC + + +class InvalidToken(Exception): + pass + + +_MAX_CLOCK_SKEW = 60 + + +class Fernet: + def __init__( + self, + key: bytes | str, + backend: typing.Any = None, + ) -> None: + try: + key = base64.urlsafe_b64decode(key) + except binascii.Error as exc: + raise ValueError( + "Fernet key must be 32 url-safe base64-encoded bytes." + ) from exc + if len(key) != 32: + raise ValueError( + "Fernet key must be 32 url-safe base64-encoded bytes." + ) + + self._signing_key = key[:16] + self._encryption_key = key[16:] + + @classmethod + def generate_key(cls) -> bytes: + return base64.urlsafe_b64encode(os.urandom(32)) + + def encrypt(self, data: bytes) -> bytes: + return self.encrypt_at_time(data, int(time.time())) + + def encrypt_at_time(self, data: bytes, current_time: int) -> bytes: + iv = os.urandom(16) + return self._encrypt_from_parts(data, current_time, iv) + + def _encrypt_from_parts( + self, data: bytes, current_time: int, iv: bytes + ) -> bytes: + utils._check_bytes("data", data) + + padder = padding.PKCS7(algorithms.AES.block_size).padder() + padded_data = padder.update(data) + padder.finalize() + encryptor = Cipher( + algorithms.AES(self._encryption_key), + modes.CBC(iv), + ).encryptor() + ciphertext = encryptor.update(padded_data) + encryptor.finalize() + + basic_parts = ( + b"\x80" + + current_time.to_bytes(length=8, byteorder="big") + + iv + + ciphertext + ) + + h = HMAC(self._signing_key, hashes.SHA256()) + h.update(basic_parts) + hmac = h.finalize() + return base64.urlsafe_b64encode(basic_parts + hmac) + + def decrypt(self, token: bytes | str, ttl: int | None = None) -> bytes: + timestamp, data = Fernet._get_unverified_token_data(token) + if ttl is None: + time_info = None + else: + time_info = (ttl, int(time.time())) + return self._decrypt_data(data, timestamp, time_info) + + def decrypt_at_time( + self, token: bytes | str, ttl: int, current_time: int + ) -> bytes: + if ttl is None: + raise ValueError( + "decrypt_at_time() can only be used with a non-None ttl" + ) + timestamp, data = Fernet._get_unverified_token_data(token) + return self._decrypt_data(data, timestamp, (ttl, current_time)) + + def extract_timestamp(self, token: bytes | str) -> int: + timestamp, data = Fernet._get_unverified_token_data(token) + # Verify the token was not tampered with. + self._verify_signature(data) + return timestamp + + @staticmethod + def _get_unverified_token_data(token: bytes | str) -> tuple[int, bytes]: + if not isinstance(token, (str, bytes)): + raise TypeError("token must be bytes or str") + + try: + data = base64.urlsafe_b64decode(token) + except (TypeError, binascii.Error): + raise InvalidToken + + if not data or data[0] != 0x80: + raise InvalidToken + + if len(data) < 9: + raise InvalidToken + + timestamp = int.from_bytes(data[1:9], byteorder="big") + return timestamp, data + + def _verify_signature(self, data: bytes) -> None: + h = HMAC(self._signing_key, hashes.SHA256()) + h.update(data[:-32]) + try: + h.verify(data[-32:]) + except InvalidSignature: + raise InvalidToken + + def _decrypt_data( + self, + data: bytes, + timestamp: int, + time_info: tuple[int, int] | None, + ) -> bytes: + if time_info is not None: + ttl, current_time = time_info + if timestamp + ttl < current_time: + raise InvalidToken + + if current_time + _MAX_CLOCK_SKEW < timestamp: + raise InvalidToken + + self._verify_signature(data) + + iv = data[9:25] + ciphertext = data[25:-32] + decryptor = Cipher( + algorithms.AES(self._encryption_key), modes.CBC(iv) + ).decryptor() + plaintext_padded = decryptor.update(ciphertext) + try: + plaintext_padded += decryptor.finalize() + except ValueError: + raise InvalidToken + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + + unpadded = unpadder.update(plaintext_padded) + try: + unpadded += unpadder.finalize() + except ValueError: + raise InvalidToken + return unpadded + + +class MultiFernet: + def __init__(self, fernets: typing.Iterable[Fernet]): + fernets = list(fernets) + if not fernets: + raise ValueError( + "MultiFernet requires at least one Fernet instance" + ) + self._fernets = fernets + + def encrypt(self, msg: bytes) -> bytes: + return self.encrypt_at_time(msg, int(time.time())) + + def encrypt_at_time(self, msg: bytes, current_time: int) -> bytes: + return self._fernets[0].encrypt_at_time(msg, current_time) + + def rotate(self, msg: bytes | str) -> bytes: + timestamp, data = Fernet._get_unverified_token_data(msg) + for f in self._fernets: + try: + p = f._decrypt_data(data, timestamp, None) + break + except InvalidToken: + pass + else: + raise InvalidToken + + iv = os.urandom(16) + return self._fernets[0]._encrypt_from_parts(p, timestamp, iv) + + def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes: + for f in self._fernets: + try: + return f.decrypt(msg, ttl) + except InvalidToken: + pass + raise InvalidToken + + def decrypt_at_time( + self, msg: bytes | str, ttl: int, current_time: int + ) -> bytes: + for f in self._fernets: + try: + return f.decrypt_at_time(msg, ttl, current_time) + except InvalidToken: + pass + raise InvalidToken diff --git a/.venv/Lib/site-packages/cryptography/hazmat/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/__init__.py new file mode 100644 index 00000000..b9f11870 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/__init__.py @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +""" +Hazardous Materials + +This is a "Hazardous Materials" module. You should ONLY use it if you're +100% absolutely sure that you know what you're doing because this module +is full of land mines, dragons, and dinosaurs with laser guns. +""" diff --git a/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b0f7424e Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/_oid.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/_oid.cpython-311.pyc new file mode 100644 index 00000000..e3962eb2 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/__pycache__/_oid.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/_oid.py b/.venv/Lib/site-packages/cryptography/hazmat/_oid.py new file mode 100644 index 00000000..c5d062c1 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/_oid.py @@ -0,0 +1,296 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.bindings._rust import ( + ObjectIdentifier as ObjectIdentifier, +) +from cryptography.hazmat.primitives import hashes + + +class ExtensionOID: + SUBJECT_DIRECTORY_ATTRIBUTES = ObjectIdentifier("2.5.29.9") + SUBJECT_KEY_IDENTIFIER = ObjectIdentifier("2.5.29.14") + KEY_USAGE = ObjectIdentifier("2.5.29.15") + SUBJECT_ALTERNATIVE_NAME = ObjectIdentifier("2.5.29.17") + ISSUER_ALTERNATIVE_NAME = ObjectIdentifier("2.5.29.18") + BASIC_CONSTRAINTS = ObjectIdentifier("2.5.29.19") + NAME_CONSTRAINTS = ObjectIdentifier("2.5.29.30") + CRL_DISTRIBUTION_POINTS = ObjectIdentifier("2.5.29.31") + CERTIFICATE_POLICIES = ObjectIdentifier("2.5.29.32") + POLICY_MAPPINGS = ObjectIdentifier("2.5.29.33") + AUTHORITY_KEY_IDENTIFIER = ObjectIdentifier("2.5.29.35") + POLICY_CONSTRAINTS = ObjectIdentifier("2.5.29.36") + EXTENDED_KEY_USAGE = ObjectIdentifier("2.5.29.37") + FRESHEST_CRL = ObjectIdentifier("2.5.29.46") + INHIBIT_ANY_POLICY = ObjectIdentifier("2.5.29.54") + ISSUING_DISTRIBUTION_POINT = ObjectIdentifier("2.5.29.28") + AUTHORITY_INFORMATION_ACCESS = ObjectIdentifier("1.3.6.1.5.5.7.1.1") + SUBJECT_INFORMATION_ACCESS = ObjectIdentifier("1.3.6.1.5.5.7.1.11") + OCSP_NO_CHECK = ObjectIdentifier("1.3.6.1.5.5.7.48.1.5") + TLS_FEATURE = ObjectIdentifier("1.3.6.1.5.5.7.1.24") + CRL_NUMBER = ObjectIdentifier("2.5.29.20") + DELTA_CRL_INDICATOR = ObjectIdentifier("2.5.29.27") + PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS = ObjectIdentifier( + "1.3.6.1.4.1.11129.2.4.2" + ) + PRECERT_POISON = ObjectIdentifier("1.3.6.1.4.1.11129.2.4.3") + SIGNED_CERTIFICATE_TIMESTAMPS = ObjectIdentifier("1.3.6.1.4.1.11129.2.4.5") + MS_CERTIFICATE_TEMPLATE = ObjectIdentifier("1.3.6.1.4.1.311.21.7") + + +class OCSPExtensionOID: + NONCE = ObjectIdentifier("1.3.6.1.5.5.7.48.1.2") + ACCEPTABLE_RESPONSES = ObjectIdentifier("1.3.6.1.5.5.7.48.1.4") + + +class CRLEntryExtensionOID: + CERTIFICATE_ISSUER = ObjectIdentifier("2.5.29.29") + CRL_REASON = ObjectIdentifier("2.5.29.21") + INVALIDITY_DATE = ObjectIdentifier("2.5.29.24") + + +class NameOID: + COMMON_NAME = ObjectIdentifier("2.5.4.3") + COUNTRY_NAME = ObjectIdentifier("2.5.4.6") + LOCALITY_NAME = ObjectIdentifier("2.5.4.7") + STATE_OR_PROVINCE_NAME = ObjectIdentifier("2.5.4.8") + STREET_ADDRESS = ObjectIdentifier("2.5.4.9") + ORGANIZATION_IDENTIFIER = ObjectIdentifier("2.5.4.97") + ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10") + ORGANIZATIONAL_UNIT_NAME = ObjectIdentifier("2.5.4.11") + SERIAL_NUMBER = ObjectIdentifier("2.5.4.5") + SURNAME = ObjectIdentifier("2.5.4.4") + GIVEN_NAME = ObjectIdentifier("2.5.4.42") + TITLE = ObjectIdentifier("2.5.4.12") + INITIALS = ObjectIdentifier("2.5.4.43") + GENERATION_QUALIFIER = ObjectIdentifier("2.5.4.44") + X500_UNIQUE_IDENTIFIER = ObjectIdentifier("2.5.4.45") + DN_QUALIFIER = ObjectIdentifier("2.5.4.46") + PSEUDONYM = ObjectIdentifier("2.5.4.65") + USER_ID = ObjectIdentifier("0.9.2342.19200300.100.1.1") + DOMAIN_COMPONENT = ObjectIdentifier("0.9.2342.19200300.100.1.25") + EMAIL_ADDRESS = ObjectIdentifier("1.2.840.113549.1.9.1") + JURISDICTION_COUNTRY_NAME = ObjectIdentifier("1.3.6.1.4.1.311.60.2.1.3") + JURISDICTION_LOCALITY_NAME = ObjectIdentifier("1.3.6.1.4.1.311.60.2.1.1") + JURISDICTION_STATE_OR_PROVINCE_NAME = ObjectIdentifier( + "1.3.6.1.4.1.311.60.2.1.2" + ) + BUSINESS_CATEGORY = ObjectIdentifier("2.5.4.15") + POSTAL_ADDRESS = ObjectIdentifier("2.5.4.16") + POSTAL_CODE = ObjectIdentifier("2.5.4.17") + INN = ObjectIdentifier("1.2.643.3.131.1.1") + OGRN = ObjectIdentifier("1.2.643.100.1") + SNILS = ObjectIdentifier("1.2.643.100.3") + UNSTRUCTURED_NAME = ObjectIdentifier("1.2.840.113549.1.9.2") + + +class SignatureAlgorithmOID: + RSA_WITH_MD5 = ObjectIdentifier("1.2.840.113549.1.1.4") + RSA_WITH_SHA1 = ObjectIdentifier("1.2.840.113549.1.1.5") + # This is an alternate OID for RSA with SHA1 that is occasionally seen + _RSA_WITH_SHA1 = ObjectIdentifier("1.3.14.3.2.29") + RSA_WITH_SHA224 = ObjectIdentifier("1.2.840.113549.1.1.14") + RSA_WITH_SHA256 = ObjectIdentifier("1.2.840.113549.1.1.11") + RSA_WITH_SHA384 = ObjectIdentifier("1.2.840.113549.1.1.12") + RSA_WITH_SHA512 = ObjectIdentifier("1.2.840.113549.1.1.13") + RSA_WITH_SHA3_224 = ObjectIdentifier("2.16.840.1.101.3.4.3.13") + RSA_WITH_SHA3_256 = ObjectIdentifier("2.16.840.1.101.3.4.3.14") + RSA_WITH_SHA3_384 = ObjectIdentifier("2.16.840.1.101.3.4.3.15") + RSA_WITH_SHA3_512 = ObjectIdentifier("2.16.840.1.101.3.4.3.16") + RSASSA_PSS = ObjectIdentifier("1.2.840.113549.1.1.10") + ECDSA_WITH_SHA1 = ObjectIdentifier("1.2.840.10045.4.1") + ECDSA_WITH_SHA224 = ObjectIdentifier("1.2.840.10045.4.3.1") + ECDSA_WITH_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2") + ECDSA_WITH_SHA384 = ObjectIdentifier("1.2.840.10045.4.3.3") + ECDSA_WITH_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4") + ECDSA_WITH_SHA3_224 = ObjectIdentifier("2.16.840.1.101.3.4.3.9") + ECDSA_WITH_SHA3_256 = ObjectIdentifier("2.16.840.1.101.3.4.3.10") + ECDSA_WITH_SHA3_384 = ObjectIdentifier("2.16.840.1.101.3.4.3.11") + ECDSA_WITH_SHA3_512 = ObjectIdentifier("2.16.840.1.101.3.4.3.12") + DSA_WITH_SHA1 = ObjectIdentifier("1.2.840.10040.4.3") + DSA_WITH_SHA224 = ObjectIdentifier("2.16.840.1.101.3.4.3.1") + DSA_WITH_SHA256 = ObjectIdentifier("2.16.840.1.101.3.4.3.2") + DSA_WITH_SHA384 = ObjectIdentifier("2.16.840.1.101.3.4.3.3") + DSA_WITH_SHA512 = ObjectIdentifier("2.16.840.1.101.3.4.3.4") + ED25519 = ObjectIdentifier("1.3.101.112") + ED448 = ObjectIdentifier("1.3.101.113") + GOSTR3411_94_WITH_3410_2001 = ObjectIdentifier("1.2.643.2.2.3") + GOSTR3410_2012_WITH_3411_2012_256 = ObjectIdentifier("1.2.643.7.1.1.3.2") + GOSTR3410_2012_WITH_3411_2012_512 = ObjectIdentifier("1.2.643.7.1.1.3.3") + + +_SIG_OIDS_TO_HASH: dict[ObjectIdentifier, hashes.HashAlgorithm | None] = { + SignatureAlgorithmOID.RSA_WITH_MD5: hashes.MD5(), + SignatureAlgorithmOID.RSA_WITH_SHA1: hashes.SHA1(), + SignatureAlgorithmOID._RSA_WITH_SHA1: hashes.SHA1(), + SignatureAlgorithmOID.RSA_WITH_SHA224: hashes.SHA224(), + SignatureAlgorithmOID.RSA_WITH_SHA256: hashes.SHA256(), + SignatureAlgorithmOID.RSA_WITH_SHA384: hashes.SHA384(), + SignatureAlgorithmOID.RSA_WITH_SHA512: hashes.SHA512(), + SignatureAlgorithmOID.RSA_WITH_SHA3_224: hashes.SHA3_224(), + SignatureAlgorithmOID.RSA_WITH_SHA3_256: hashes.SHA3_256(), + SignatureAlgorithmOID.RSA_WITH_SHA3_384: hashes.SHA3_384(), + SignatureAlgorithmOID.RSA_WITH_SHA3_512: hashes.SHA3_512(), + SignatureAlgorithmOID.ECDSA_WITH_SHA1: hashes.SHA1(), + SignatureAlgorithmOID.ECDSA_WITH_SHA224: hashes.SHA224(), + SignatureAlgorithmOID.ECDSA_WITH_SHA256: hashes.SHA256(), + SignatureAlgorithmOID.ECDSA_WITH_SHA384: hashes.SHA384(), + SignatureAlgorithmOID.ECDSA_WITH_SHA512: hashes.SHA512(), + SignatureAlgorithmOID.ECDSA_WITH_SHA3_224: hashes.SHA3_224(), + SignatureAlgorithmOID.ECDSA_WITH_SHA3_256: hashes.SHA3_256(), + SignatureAlgorithmOID.ECDSA_WITH_SHA3_384: hashes.SHA3_384(), + SignatureAlgorithmOID.ECDSA_WITH_SHA3_512: hashes.SHA3_512(), + SignatureAlgorithmOID.DSA_WITH_SHA1: hashes.SHA1(), + SignatureAlgorithmOID.DSA_WITH_SHA224: hashes.SHA224(), + SignatureAlgorithmOID.DSA_WITH_SHA256: hashes.SHA256(), + SignatureAlgorithmOID.ED25519: None, + SignatureAlgorithmOID.ED448: None, + SignatureAlgorithmOID.GOSTR3411_94_WITH_3410_2001: None, + SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_256: None, + SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_512: None, +} + + +class ExtendedKeyUsageOID: + SERVER_AUTH = ObjectIdentifier("1.3.6.1.5.5.7.3.1") + CLIENT_AUTH = ObjectIdentifier("1.3.6.1.5.5.7.3.2") + CODE_SIGNING = ObjectIdentifier("1.3.6.1.5.5.7.3.3") + EMAIL_PROTECTION = ObjectIdentifier("1.3.6.1.5.5.7.3.4") + TIME_STAMPING = ObjectIdentifier("1.3.6.1.5.5.7.3.8") + OCSP_SIGNING = ObjectIdentifier("1.3.6.1.5.5.7.3.9") + ANY_EXTENDED_KEY_USAGE = ObjectIdentifier("2.5.29.37.0") + SMARTCARD_LOGON = ObjectIdentifier("1.3.6.1.4.1.311.20.2.2") + KERBEROS_PKINIT_KDC = ObjectIdentifier("1.3.6.1.5.2.3.5") + IPSEC_IKE = ObjectIdentifier("1.3.6.1.5.5.7.3.17") + CERTIFICATE_TRANSPARENCY = ObjectIdentifier("1.3.6.1.4.1.11129.2.4.4") + + +class AuthorityInformationAccessOID: + CA_ISSUERS = ObjectIdentifier("1.3.6.1.5.5.7.48.2") + OCSP = ObjectIdentifier("1.3.6.1.5.5.7.48.1") + + +class SubjectInformationAccessOID: + CA_REPOSITORY = ObjectIdentifier("1.3.6.1.5.5.7.48.5") + + +class CertificatePoliciesOID: + CPS_QUALIFIER = ObjectIdentifier("1.3.6.1.5.5.7.2.1") + CPS_USER_NOTICE = ObjectIdentifier("1.3.6.1.5.5.7.2.2") + ANY_POLICY = ObjectIdentifier("2.5.29.32.0") + + +class AttributeOID: + CHALLENGE_PASSWORD = ObjectIdentifier("1.2.840.113549.1.9.7") + UNSTRUCTURED_NAME = ObjectIdentifier("1.2.840.113549.1.9.2") + + +_OID_NAMES = { + NameOID.COMMON_NAME: "commonName", + NameOID.COUNTRY_NAME: "countryName", + NameOID.LOCALITY_NAME: "localityName", + NameOID.STATE_OR_PROVINCE_NAME: "stateOrProvinceName", + NameOID.STREET_ADDRESS: "streetAddress", + NameOID.ORGANIZATION_NAME: "organizationName", + NameOID.ORGANIZATIONAL_UNIT_NAME: "organizationalUnitName", + NameOID.SERIAL_NUMBER: "serialNumber", + NameOID.SURNAME: "surname", + NameOID.GIVEN_NAME: "givenName", + NameOID.TITLE: "title", + NameOID.GENERATION_QUALIFIER: "generationQualifier", + NameOID.X500_UNIQUE_IDENTIFIER: "x500UniqueIdentifier", + NameOID.DN_QUALIFIER: "dnQualifier", + NameOID.PSEUDONYM: "pseudonym", + NameOID.USER_ID: "userID", + NameOID.DOMAIN_COMPONENT: "domainComponent", + NameOID.EMAIL_ADDRESS: "emailAddress", + NameOID.JURISDICTION_COUNTRY_NAME: "jurisdictionCountryName", + NameOID.JURISDICTION_LOCALITY_NAME: "jurisdictionLocalityName", + NameOID.JURISDICTION_STATE_OR_PROVINCE_NAME: ( + "jurisdictionStateOrProvinceName" + ), + NameOID.BUSINESS_CATEGORY: "businessCategory", + NameOID.POSTAL_ADDRESS: "postalAddress", + NameOID.POSTAL_CODE: "postalCode", + NameOID.INN: "INN", + NameOID.OGRN: "OGRN", + NameOID.SNILS: "SNILS", + NameOID.UNSTRUCTURED_NAME: "unstructuredName", + SignatureAlgorithmOID.RSA_WITH_MD5: "md5WithRSAEncryption", + SignatureAlgorithmOID.RSA_WITH_SHA1: "sha1WithRSAEncryption", + SignatureAlgorithmOID.RSA_WITH_SHA224: "sha224WithRSAEncryption", + SignatureAlgorithmOID.RSA_WITH_SHA256: "sha256WithRSAEncryption", + SignatureAlgorithmOID.RSA_WITH_SHA384: "sha384WithRSAEncryption", + SignatureAlgorithmOID.RSA_WITH_SHA512: "sha512WithRSAEncryption", + SignatureAlgorithmOID.RSASSA_PSS: "RSASSA-PSS", + SignatureAlgorithmOID.ECDSA_WITH_SHA1: "ecdsa-with-SHA1", + SignatureAlgorithmOID.ECDSA_WITH_SHA224: "ecdsa-with-SHA224", + SignatureAlgorithmOID.ECDSA_WITH_SHA256: "ecdsa-with-SHA256", + SignatureAlgorithmOID.ECDSA_WITH_SHA384: "ecdsa-with-SHA384", + SignatureAlgorithmOID.ECDSA_WITH_SHA512: "ecdsa-with-SHA512", + SignatureAlgorithmOID.DSA_WITH_SHA1: "dsa-with-sha1", + SignatureAlgorithmOID.DSA_WITH_SHA224: "dsa-with-sha224", + SignatureAlgorithmOID.DSA_WITH_SHA256: "dsa-with-sha256", + SignatureAlgorithmOID.ED25519: "ed25519", + SignatureAlgorithmOID.ED448: "ed448", + SignatureAlgorithmOID.GOSTR3411_94_WITH_3410_2001: ( + "GOST R 34.11-94 with GOST R 34.10-2001" + ), + SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_256: ( + "GOST R 34.10-2012 with GOST R 34.11-2012 (256 bit)" + ), + SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_512: ( + "GOST R 34.10-2012 with GOST R 34.11-2012 (512 bit)" + ), + ExtendedKeyUsageOID.SERVER_AUTH: "serverAuth", + ExtendedKeyUsageOID.CLIENT_AUTH: "clientAuth", + ExtendedKeyUsageOID.CODE_SIGNING: "codeSigning", + ExtendedKeyUsageOID.EMAIL_PROTECTION: "emailProtection", + ExtendedKeyUsageOID.TIME_STAMPING: "timeStamping", + ExtendedKeyUsageOID.OCSP_SIGNING: "OCSPSigning", + ExtendedKeyUsageOID.SMARTCARD_LOGON: "msSmartcardLogin", + ExtendedKeyUsageOID.KERBEROS_PKINIT_KDC: "pkInitKDC", + ExtensionOID.SUBJECT_DIRECTORY_ATTRIBUTES: "subjectDirectoryAttributes", + ExtensionOID.SUBJECT_KEY_IDENTIFIER: "subjectKeyIdentifier", + ExtensionOID.KEY_USAGE: "keyUsage", + ExtensionOID.SUBJECT_ALTERNATIVE_NAME: "subjectAltName", + ExtensionOID.ISSUER_ALTERNATIVE_NAME: "issuerAltName", + ExtensionOID.BASIC_CONSTRAINTS: "basicConstraints", + ExtensionOID.PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS: ( + "signedCertificateTimestampList" + ), + ExtensionOID.SIGNED_CERTIFICATE_TIMESTAMPS: ( + "signedCertificateTimestampList" + ), + ExtensionOID.PRECERT_POISON: "ctPoison", + ExtensionOID.MS_CERTIFICATE_TEMPLATE: "msCertificateTemplate", + CRLEntryExtensionOID.CRL_REASON: "cRLReason", + CRLEntryExtensionOID.INVALIDITY_DATE: "invalidityDate", + CRLEntryExtensionOID.CERTIFICATE_ISSUER: "certificateIssuer", + ExtensionOID.NAME_CONSTRAINTS: "nameConstraints", + ExtensionOID.CRL_DISTRIBUTION_POINTS: "cRLDistributionPoints", + ExtensionOID.CERTIFICATE_POLICIES: "certificatePolicies", + ExtensionOID.POLICY_MAPPINGS: "policyMappings", + ExtensionOID.AUTHORITY_KEY_IDENTIFIER: "authorityKeyIdentifier", + ExtensionOID.POLICY_CONSTRAINTS: "policyConstraints", + ExtensionOID.EXTENDED_KEY_USAGE: "extendedKeyUsage", + ExtensionOID.FRESHEST_CRL: "freshestCRL", + ExtensionOID.INHIBIT_ANY_POLICY: "inhibitAnyPolicy", + ExtensionOID.ISSUING_DISTRIBUTION_POINT: "issuingDistributionPoint", + ExtensionOID.AUTHORITY_INFORMATION_ACCESS: "authorityInfoAccess", + ExtensionOID.SUBJECT_INFORMATION_ACCESS: "subjectInfoAccess", + ExtensionOID.OCSP_NO_CHECK: "OCSPNoCheck", + ExtensionOID.CRL_NUMBER: "cRLNumber", + ExtensionOID.DELTA_CRL_INDICATOR: "deltaCRLIndicator", + ExtensionOID.TLS_FEATURE: "TLSFeature", + AuthorityInformationAccessOID.OCSP: "OCSP", + AuthorityInformationAccessOID.CA_ISSUERS: "caIssuers", + SubjectInformationAccessOID.CA_REPOSITORY: "caRepository", + CertificatePoliciesOID.CPS_QUALIFIER: "id-qt-cps", + CertificatePoliciesOID.CPS_USER_NOTICE: "id-qt-unotice", + OCSPExtensionOID.NONCE: "OCSPNonce", + AttributeOID.CHALLENGE_PASSWORD: "challengePassword", +} diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/__init__.py new file mode 100644 index 00000000..b4400aa0 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/__init__.py @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from typing import Any + + +def default_backend() -> Any: + from cryptography.hazmat.backends.openssl.backend import backend + + return backend diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..1233a821 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__init__.py new file mode 100644 index 00000000..51b04476 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__init__.py @@ -0,0 +1,9 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.backends.openssl.backend import backend + +__all__ = ["backend"] diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..21309476 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/aead.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/aead.cpython-311.pyc new file mode 100644 index 00000000..5c9847a0 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/aead.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/backend.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/backend.cpython-311.pyc new file mode 100644 index 00000000..eff54f3d Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/backend.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/ciphers.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/ciphers.cpython-311.pyc new file mode 100644 index 00000000..887c2e3a Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/ciphers.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/decode_asn1.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/decode_asn1.cpython-311.pyc new file mode 100644 index 00000000..2c43b3a5 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/__pycache__/decode_asn1.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/aead.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/aead.py new file mode 100644 index 00000000..f1d99010 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/aead.py @@ -0,0 +1,272 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.exceptions import InvalidTag + +if typing.TYPE_CHECKING: + from cryptography.hazmat.backends.openssl.backend import Backend + from cryptography.hazmat.primitives.ciphers.aead import ( + AESCCM, + AESGCM, + ) + + _AEADTypes = typing.Union[AESCCM, AESGCM] + + +def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: + cipher_name = _evp_cipher_cipher_name(cipher) + + return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL + + +def _encrypt( + backend: Backend, + cipher: _AEADTypes, + nonce: bytes, + data: bytes, + associated_data: list[bytes], + tag_length: int, +) -> bytes: + return _evp_cipher_encrypt( + backend, cipher, nonce, data, associated_data, tag_length + ) + + +def _decrypt( + backend: Backend, + cipher: _AEADTypes, + nonce: bytes, + data: bytes, + associated_data: list[bytes], + tag_length: int, +) -> bytes: + return _evp_cipher_decrypt( + backend, cipher, nonce, data, associated_data, tag_length + ) + + +_ENCRYPT = 1 +_DECRYPT = 0 + + +def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: + from cryptography.hazmat.primitives.ciphers.aead import ( + AESCCM, + AESGCM, + ) + + if isinstance(cipher, AESCCM): + return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") + else: + assert isinstance(cipher, AESGCM) + return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii") + + +def _evp_cipher(cipher_name: bytes, backend: Backend): + evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name) + backend.openssl_assert(evp_cipher != backend._ffi.NULL) + return evp_cipher + + +def _evp_cipher_aead_setup( + backend: Backend, + cipher_name: bytes, + key: bytes, + nonce: bytes, + tag: bytes | None, + tag_len: int, + operation: int, +): + evp_cipher = _evp_cipher(cipher_name, backend) + ctx = backend._lib.EVP_CIPHER_CTX_new() + ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free) + res = backend._lib.EVP_CipherInit_ex( + ctx, + evp_cipher, + backend._ffi.NULL, + backend._ffi.NULL, + backend._ffi.NULL, + int(operation == _ENCRYPT), + ) + backend.openssl_assert(res != 0) + # CCM requires the IVLEN to be set before calling SET_TAG on decrypt + res = backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, + backend._lib.EVP_CTRL_AEAD_SET_IVLEN, + len(nonce), + backend._ffi.NULL, + ) + backend.openssl_assert(res != 0) + if operation == _DECRYPT: + assert tag is not None + _evp_cipher_set_tag(backend, ctx, tag) + elif cipher_name.endswith(b"-ccm"): + res = backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, + backend._lib.EVP_CTRL_AEAD_SET_TAG, + tag_len, + backend._ffi.NULL, + ) + backend.openssl_assert(res != 0) + + nonce_ptr = backend._ffi.from_buffer(nonce) + key_ptr = backend._ffi.from_buffer(key) + res = backend._lib.EVP_CipherInit_ex( + ctx, + backend._ffi.NULL, + backend._ffi.NULL, + key_ptr, + nonce_ptr, + int(operation == _ENCRYPT), + ) + backend.openssl_assert(res != 0) + return ctx + + +def _evp_cipher_set_tag(backend, ctx, tag: bytes) -> None: + tag_ptr = backend._ffi.from_buffer(tag) + res = backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag_ptr + ) + backend.openssl_assert(res != 0) + + +def _evp_cipher_set_length(backend: Backend, ctx, data_len: int) -> None: + intptr = backend._ffi.new("int *") + res = backend._lib.EVP_CipherUpdate( + ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len + ) + backend.openssl_assert(res != 0) + + +def _evp_cipher_process_aad( + backend: Backend, ctx, associated_data: bytes +) -> None: + outlen = backend._ffi.new("int *") + a_data_ptr = backend._ffi.from_buffer(associated_data) + res = backend._lib.EVP_CipherUpdate( + ctx, backend._ffi.NULL, outlen, a_data_ptr, len(associated_data) + ) + backend.openssl_assert(res != 0) + + +def _evp_cipher_process_data(backend: Backend, ctx, data: bytes) -> bytes: + outlen = backend._ffi.new("int *") + buf = backend._ffi.new("unsigned char[]", len(data)) + data_ptr = backend._ffi.from_buffer(data) + res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data_ptr, len(data)) + backend.openssl_assert(res != 0) + return backend._ffi.buffer(buf, outlen[0])[:] + + +def _evp_cipher_encrypt( + backend: Backend, + cipher: _AEADTypes, + nonce: bytes, + data: bytes, + associated_data: list[bytes], + tag_length: int, +) -> bytes: + from cryptography.hazmat.primitives.ciphers.aead import AESCCM + + cipher_name = _evp_cipher_cipher_name(cipher) + ctx = _evp_cipher_aead_setup( + backend, + cipher_name, + cipher._key, + nonce, + None, + tag_length, + _ENCRYPT, + ) + + # CCM requires us to pass the length of the data before processing + # anything. + # However calling this with any other AEAD results in an error + if isinstance(cipher, AESCCM): + _evp_cipher_set_length(backend, ctx, len(data)) + + for ad in associated_data: + _evp_cipher_process_aad(backend, ctx, ad) + processed_data = _evp_cipher_process_data(backend, ctx, data) + outlen = backend._ffi.new("int *") + # All AEADs we support besides OCB are streaming so they return nothing + # in finalization. OCB can return up to (16 byte block - 1) bytes so + # we need a buffer here too. + buf = backend._ffi.new("unsigned char[]", 16) + res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen) + backend.openssl_assert(res != 0) + processed_data += backend._ffi.buffer(buf, outlen[0])[:] + tag_buf = backend._ffi.new("unsigned char[]", tag_length) + res = backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf + ) + backend.openssl_assert(res != 0) + tag = backend._ffi.buffer(tag_buf)[:] + + return processed_data + tag + + +def _evp_cipher_decrypt( + backend: Backend, + cipher: _AEADTypes, + nonce: bytes, + data: bytes, + associated_data: list[bytes], + tag_length: int, +) -> bytes: + from cryptography.hazmat.primitives.ciphers.aead import AESCCM + + if len(data) < tag_length: + raise InvalidTag + + tag = data[-tag_length:] + data = data[:-tag_length] + cipher_name = _evp_cipher_cipher_name(cipher) + ctx = _evp_cipher_aead_setup( + backend, + cipher_name, + cipher._key, + nonce, + tag, + tag_length, + _DECRYPT, + ) + + # CCM requires us to pass the length of the data before processing + # anything. + # However calling this with any other AEAD results in an error + if isinstance(cipher, AESCCM): + _evp_cipher_set_length(backend, ctx, len(data)) + + for ad in associated_data: + _evp_cipher_process_aad(backend, ctx, ad) + # CCM has a different error path if the tag doesn't match. Errors are + # raised in Update and Final is irrelevant. + if isinstance(cipher, AESCCM): + outlen = backend._ffi.new("int *") + buf = backend._ffi.new("unsigned char[]", len(data)) + d_ptr = backend._ffi.from_buffer(data) + res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data)) + if res != 1: + backend._consume_errors() + raise InvalidTag + + processed_data = backend._ffi.buffer(buf, outlen[0])[:] + else: + processed_data = _evp_cipher_process_data(backend, ctx, data) + outlen = backend._ffi.new("int *") + # OCB can return up to 15 bytes (16 byte block - 1) in finalization + buf = backend._ffi.new("unsigned char[]", 16) + res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen) + processed_data += backend._ffi.buffer(buf, outlen[0])[:] + if res == 0: + backend._consume_errors() + raise InvalidTag + + return processed_data diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/backend.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/backend.py new file mode 100644 index 00000000..5d9eb276 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/backend.py @@ -0,0 +1,897 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import collections +import contextlib +import itertools +import typing + +from cryptography import utils, x509 +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.backends.openssl import aead +from cryptography.hazmat.backends.openssl.ciphers import _CipherContext +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.bindings.openssl import binding +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import utils as asym_utils +from cryptography.hazmat.primitives.asymmetric.padding import ( + MGF1, + OAEP, + PSS, + PKCS1v15, +) +from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, +) +from cryptography.hazmat.primitives.ciphers import ( + CipherAlgorithm, +) +from cryptography.hazmat.primitives.ciphers.algorithms import ( + AES, + AES128, + AES256, + ARC4, + SM4, + Camellia, + ChaCha20, + TripleDES, + _BlowfishInternal, + _CAST5Internal, + _IDEAInternal, + _SEEDInternal, +) +from cryptography.hazmat.primitives.ciphers.modes import ( + CBC, + CFB, + CFB8, + CTR, + ECB, + GCM, + OFB, + XTS, + Mode, +) +from cryptography.hazmat.primitives.serialization.pkcs12 import ( + PBES, + PKCS12Certificate, + PKCS12KeyAndCertificates, + PKCS12PrivateKeyTypes, + _PKCS12CATypes, +) + +_MemoryBIO = collections.namedtuple("_MemoryBIO", ["bio", "char_ptr"]) + + +# Not actually supported, just used as a marker for some serialization tests. +class _RC2: + pass + + +class Backend: + """ + OpenSSL API binding interfaces. + """ + + name = "openssl" + + # FIPS has opinions about acceptable algorithms and key sizes, but the + # disallowed algorithms are still present in OpenSSL. They just error if + # you try to use them. To avoid that we allowlist the algorithms in + # FIPS 140-3. This isn't ideal, but FIPS 140-3 is trash so here we are. + _fips_aead: typing.ClassVar[set[bytes]] = { + b"aes-128-ccm", + b"aes-192-ccm", + b"aes-256-ccm", + b"aes-128-gcm", + b"aes-192-gcm", + b"aes-256-gcm", + } + # TripleDES encryption is disallowed/deprecated throughout 2023 in + # FIPS 140-3. To keep it simple we denylist any use of TripleDES (TDEA). + _fips_ciphers = (AES,) + # Sometimes SHA1 is still permissible. That logic is contained + # within the various *_supported methods. + _fips_hashes = ( + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, + hashes.SHA512_224, + hashes.SHA512_256, + hashes.SHA3_224, + hashes.SHA3_256, + hashes.SHA3_384, + hashes.SHA3_512, + hashes.SHAKE128, + hashes.SHAKE256, + ) + _fips_ecdh_curves = ( + ec.SECP224R1, + ec.SECP256R1, + ec.SECP384R1, + ec.SECP521R1, + ) + _fips_rsa_min_key_size = 2048 + _fips_rsa_min_public_exponent = 65537 + _fips_dsa_min_modulus = 1 << 2048 + _fips_dh_min_key_size = 2048 + _fips_dh_min_modulus = 1 << _fips_dh_min_key_size + + def __init__(self) -> None: + self._binding = binding.Binding() + self._ffi = self._binding.ffi + self._lib = self._binding.lib + self._fips_enabled = rust_openssl.is_fips_enabled() + + self._cipher_registry: dict[ + tuple[type[CipherAlgorithm], type[Mode]], + typing.Callable, + ] = {} + self._register_default_ciphers() + + def __repr__(self) -> str: + return "".format( + self.openssl_version_text(), + self._fips_enabled, + self._binding._legacy_provider_loaded, + ) + + def openssl_assert( + self, + ok: bool, + errors: list[rust_openssl.OpenSSLError] | None = None, + ) -> None: + return binding._openssl_assert(ok, errors=errors) + + def _enable_fips(self) -> None: + # This function enables FIPS mode for OpenSSL 3.0.0 on installs that + # have the FIPS provider installed properly. + self._binding._enable_fips() + assert rust_openssl.is_fips_enabled() + self._fips_enabled = rust_openssl.is_fips_enabled() + + def openssl_version_text(self) -> str: + """ + Friendly string name of the loaded OpenSSL library. This is not + necessarily the same version as it was compiled against. + + Example: OpenSSL 1.1.1d 10 Sep 2019 + """ + return self._ffi.string( + self._lib.OpenSSL_version(self._lib.OPENSSL_VERSION) + ).decode("ascii") + + def openssl_version_number(self) -> int: + return self._lib.OpenSSL_version_num() + + def _evp_md_from_algorithm(self, algorithm: hashes.HashAlgorithm): + if algorithm.name in ("blake2b", "blake2s"): + alg = f"{algorithm.name}{algorithm.digest_size * 8}".encode( + "ascii" + ) + else: + alg = algorithm.name.encode("ascii") + + evp_md = self._lib.EVP_get_digestbyname(alg) + return evp_md + + def _evp_md_non_null_from_algorithm(self, algorithm: hashes.HashAlgorithm): + evp_md = self._evp_md_from_algorithm(algorithm) + self.openssl_assert(evp_md != self._ffi.NULL) + return evp_md + + def hash_supported(self, algorithm: hashes.HashAlgorithm) -> bool: + if self._fips_enabled and not isinstance(algorithm, self._fips_hashes): + return False + + evp_md = self._evp_md_from_algorithm(algorithm) + return evp_md != self._ffi.NULL + + def signature_hash_supported( + self, algorithm: hashes.HashAlgorithm + ) -> bool: + # Dedicated check for hashing algorithm use in message digest for + # signatures, e.g. RSA PKCS#1 v1.5 SHA1 (sha1WithRSAEncryption). + if self._fips_enabled and isinstance(algorithm, hashes.SHA1): + return False + return self.hash_supported(algorithm) + + def scrypt_supported(self) -> bool: + if self._fips_enabled: + return False + else: + return self._lib.Cryptography_HAS_SCRYPT == 1 + + def hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool: + # FIPS mode still allows SHA1 for HMAC + if self._fips_enabled and isinstance(algorithm, hashes.SHA1): + return True + + return self.hash_supported(algorithm) + + def cipher_supported(self, cipher: CipherAlgorithm, mode: Mode) -> bool: + if self._fips_enabled: + # FIPS mode requires AES. TripleDES is disallowed/deprecated in + # FIPS 140-3. + if not isinstance(cipher, self._fips_ciphers): + return False + + try: + adapter = self._cipher_registry[type(cipher), type(mode)] + except KeyError: + return False + evp_cipher = adapter(self, cipher, mode) + return self._ffi.NULL != evp_cipher + + def register_cipher_adapter(self, cipher_cls, mode_cls, adapter) -> None: + if (cipher_cls, mode_cls) in self._cipher_registry: + raise ValueError( + f"Duplicate registration for: {cipher_cls} {mode_cls}." + ) + self._cipher_registry[cipher_cls, mode_cls] = adapter + + def _register_default_ciphers(self) -> None: + for cipher_cls in [AES, AES128, AES256]: + for mode_cls in [CBC, CTR, ECB, OFB, CFB, CFB8, GCM]: + self.register_cipher_adapter( + cipher_cls, + mode_cls, + GetCipherByName( + "{cipher.name}-{cipher.key_size}-{mode.name}" + ), + ) + for mode_cls in [CBC, CTR, ECB, OFB, CFB]: + self.register_cipher_adapter( + Camellia, + mode_cls, + GetCipherByName("{cipher.name}-{cipher.key_size}-{mode.name}"), + ) + for mode_cls in [CBC, CFB, CFB8, OFB]: + self.register_cipher_adapter( + TripleDES, mode_cls, GetCipherByName("des-ede3-{mode.name}") + ) + self.register_cipher_adapter( + TripleDES, ECB, GetCipherByName("des-ede3") + ) + # ChaCha20 uses the Long Name "chacha20" in OpenSSL, but in LibreSSL + # it uses "chacha" + self.register_cipher_adapter( + ChaCha20, + type(None), + GetCipherByName( + "chacha" if self._lib.CRYPTOGRAPHY_IS_LIBRESSL else "chacha20" + ), + ) + self.register_cipher_adapter(AES, XTS, _get_xts_cipher) + for mode_cls in [ECB, CBC, OFB, CFB, CTR, GCM]: + self.register_cipher_adapter( + SM4, mode_cls, GetCipherByName("sm4-{mode.name}") + ) + # Don't register legacy ciphers if they're unavailable. Hypothetically + # this wouldn't be necessary because we test availability by seeing if + # we get an EVP_CIPHER * in the _CipherContext __init__, but OpenSSL 3 + # will return a valid pointer even though the cipher is unavailable. + if ( + self._binding._legacy_provider_loaded + or not self._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER + ): + for mode_cls in [CBC, CFB, OFB, ECB]: + self.register_cipher_adapter( + _BlowfishInternal, + mode_cls, + GetCipherByName("bf-{mode.name}"), + ) + for mode_cls in [CBC, CFB, OFB, ECB]: + self.register_cipher_adapter( + _SEEDInternal, + mode_cls, + GetCipherByName("seed-{mode.name}"), + ) + for cipher_cls, mode_cls in itertools.product( + [_CAST5Internal, _IDEAInternal], + [CBC, OFB, CFB, ECB], + ): + self.register_cipher_adapter( + cipher_cls, + mode_cls, + GetCipherByName("{cipher.name}-{mode.name}"), + ) + self.register_cipher_adapter( + ARC4, type(None), GetCipherByName("rc4") + ) + # We don't actually support RC2, this is just used by some tests. + self.register_cipher_adapter( + _RC2, type(None), GetCipherByName("rc2") + ) + + def create_symmetric_encryption_ctx( + self, cipher: CipherAlgorithm, mode: Mode + ) -> _CipherContext: + return _CipherContext(self, cipher, mode, _CipherContext._ENCRYPT) + + def create_symmetric_decryption_ctx( + self, cipher: CipherAlgorithm, mode: Mode + ) -> _CipherContext: + return _CipherContext(self, cipher, mode, _CipherContext._DECRYPT) + + def pbkdf2_hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool: + return self.hmac_supported(algorithm) + + def _consume_errors(self) -> list[rust_openssl.OpenSSLError]: + return rust_openssl.capture_error_stack() + + def generate_rsa_parameters_supported( + self, public_exponent: int, key_size: int + ) -> bool: + return ( + public_exponent >= 3 + and public_exponent & 1 != 0 + and key_size >= 512 + ) + + def _bytes_to_bio(self, data: bytes) -> _MemoryBIO: + """ + Return a _MemoryBIO namedtuple of (BIO, char*). + + The char* is the storage for the BIO and it must stay alive until the + BIO is finished with. + """ + data_ptr = self._ffi.from_buffer(data) + bio = self._lib.BIO_new_mem_buf(data_ptr, len(data)) + self.openssl_assert(bio != self._ffi.NULL) + + return _MemoryBIO(self._ffi.gc(bio, self._lib.BIO_free), data_ptr) + + def _create_mem_bio_gc(self): + """ + Creates an empty memory BIO. + """ + bio_method = self._lib.BIO_s_mem() + self.openssl_assert(bio_method != self._ffi.NULL) + bio = self._lib.BIO_new(bio_method) + self.openssl_assert(bio != self._ffi.NULL) + bio = self._ffi.gc(bio, self._lib.BIO_free) + return bio + + def _read_mem_bio(self, bio) -> bytes: + """ + Reads a memory BIO. This only works on memory BIOs. + """ + buf = self._ffi.new("char **") + buf_len = self._lib.BIO_get_mem_data(bio, buf) + self.openssl_assert(buf_len > 0) + self.openssl_assert(buf[0] != self._ffi.NULL) + bio_data = self._ffi.buffer(buf[0], buf_len)[:] + return bio_data + + def _oaep_hash_supported(self, algorithm: hashes.HashAlgorithm) -> bool: + if self._fips_enabled and isinstance(algorithm, hashes.SHA1): + return False + + return isinstance( + algorithm, + ( + hashes.SHA1, + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, + ), + ) + + def rsa_padding_supported(self, padding: AsymmetricPadding) -> bool: + if isinstance(padding, PKCS1v15): + return True + elif isinstance(padding, PSS) and isinstance(padding._mgf, MGF1): + # SHA1 is permissible in MGF1 in FIPS even when SHA1 is blocked + # as signature algorithm. + if self._fips_enabled and isinstance( + padding._mgf._algorithm, hashes.SHA1 + ): + return True + else: + return self.hash_supported(padding._mgf._algorithm) + elif isinstance(padding, OAEP) and isinstance(padding._mgf, MGF1): + return self._oaep_hash_supported( + padding._mgf._algorithm + ) and self._oaep_hash_supported(padding._algorithm) + else: + return False + + def rsa_encryption_supported(self, padding: AsymmetricPadding) -> bool: + if self._fips_enabled and isinstance(padding, PKCS1v15): + return False + else: + return self.rsa_padding_supported(padding) + + def dsa_supported(self) -> bool: + return ( + not self._lib.CRYPTOGRAPHY_IS_BORINGSSL and not self._fips_enabled + ) + + def dsa_hash_supported(self, algorithm: hashes.HashAlgorithm) -> bool: + if not self.dsa_supported(): + return False + return self.signature_hash_supported(algorithm) + + def cmac_algorithm_supported(self, algorithm) -> bool: + return self.cipher_supported( + algorithm, CBC(b"\x00" * algorithm.block_size) + ) + + def _cert2ossl(self, cert: x509.Certificate) -> typing.Any: + data = cert.public_bytes(serialization.Encoding.DER) + mem_bio = self._bytes_to_bio(data) + x509 = self._lib.d2i_X509_bio(mem_bio.bio, self._ffi.NULL) + self.openssl_assert(x509 != self._ffi.NULL) + x509 = self._ffi.gc(x509, self._lib.X509_free) + return x509 + + def _ossl2cert(self, x509_ptr: typing.Any) -> x509.Certificate: + bio = self._create_mem_bio_gc() + res = self._lib.i2d_X509_bio(bio, x509_ptr) + self.openssl_assert(res == 1) + return x509.load_der_x509_certificate(self._read_mem_bio(bio)) + + def _key2ossl(self, key: PKCS12PrivateKeyTypes) -> typing.Any: + data = key.private_bytes( + serialization.Encoding.DER, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + mem_bio = self._bytes_to_bio(data) + + evp_pkey = self._lib.d2i_PrivateKey_bio( + mem_bio.bio, + self._ffi.NULL, + ) + self.openssl_assert(evp_pkey != self._ffi.NULL) + return self._ffi.gc(evp_pkey, self._lib.EVP_PKEY_free) + + def _handle_key_loading_error( + self, errors: list[rust_openssl.OpenSSLError] + ) -> typing.NoReturn: + if not errors: + raise ValueError( + "Could not deserialize key data. The data may be in an " + "incorrect format or it may be encrypted with an unsupported " + "algorithm." + ) + + elif ( + errors[0]._lib_reason_match( + self._lib.ERR_LIB_EVP, self._lib.EVP_R_BAD_DECRYPT + ) + or errors[0]._lib_reason_match( + self._lib.ERR_LIB_PKCS12, + self._lib.PKCS12_R_PKCS12_CIPHERFINAL_ERROR, + ) + or ( + self._lib.Cryptography_HAS_PROVIDERS + and errors[0]._lib_reason_match( + self._lib.ERR_LIB_PROV, + self._lib.PROV_R_BAD_DECRYPT, + ) + ) + ): + raise ValueError("Bad decrypt. Incorrect password?") + + elif any( + error._lib_reason_match( + self._lib.ERR_LIB_EVP, + self._lib.EVP_R_UNSUPPORTED_PRIVATE_KEY_ALGORITHM, + ) + for error in errors + ): + raise ValueError("Unsupported public key algorithm.") + + else: + raise ValueError( + "Could not deserialize key data. The data may be in an " + "incorrect format, it may be encrypted with an unsupported " + "algorithm, or it may be an unsupported key type (e.g. EC " + "curves with explicit parameters).", + errors, + ) + + def elliptic_curve_supported(self, curve: ec.EllipticCurve) -> bool: + if self._fips_enabled and not isinstance( + curve, self._fips_ecdh_curves + ): + return False + + return rust_openssl.ec.curve_supported(curve) + + def elliptic_curve_signature_algorithm_supported( + self, + signature_algorithm: ec.EllipticCurveSignatureAlgorithm, + curve: ec.EllipticCurve, + ) -> bool: + # We only support ECDSA right now. + if not isinstance(signature_algorithm, ec.ECDSA): + return False + + return self.elliptic_curve_supported(curve) and ( + isinstance(signature_algorithm.algorithm, asym_utils.Prehashed) + or self.hash_supported(signature_algorithm.algorithm) + ) + + def elliptic_curve_exchange_algorithm_supported( + self, algorithm: ec.ECDH, curve: ec.EllipticCurve + ) -> bool: + return self.elliptic_curve_supported(curve) and isinstance( + algorithm, ec.ECDH + ) + + def dh_supported(self) -> bool: + return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL + + def dh_x942_serialization_supported(self) -> bool: + return self._lib.Cryptography_HAS_EVP_PKEY_DHX == 1 + + def x25519_supported(self) -> bool: + # Beginning with OpenSSL 3.2.0, X25519 is considered FIPS. + if ( + self._fips_enabled + and not self._lib.CRYPTOGRAPHY_OPENSSL_320_OR_GREATER + ): + return False + return True + + def x448_supported(self) -> bool: + # Beginning with OpenSSL 3.2.0, X448 is considered FIPS. + if ( + self._fips_enabled + and not self._lib.CRYPTOGRAPHY_OPENSSL_320_OR_GREATER + ): + return False + return ( + not self._lib.CRYPTOGRAPHY_IS_LIBRESSL + and not self._lib.CRYPTOGRAPHY_IS_BORINGSSL + ) + + def ed25519_supported(self) -> bool: + if self._fips_enabled: + return False + return True + + def ed448_supported(self) -> bool: + if self._fips_enabled: + return False + return ( + not self._lib.CRYPTOGRAPHY_IS_LIBRESSL + and not self._lib.CRYPTOGRAPHY_IS_BORINGSSL + ) + + def aead_cipher_supported(self, cipher) -> bool: + return aead._aead_cipher_supported(self, cipher) + + def _zero_data(self, data, length: int) -> None: + # We clear things this way because at the moment we're not + # sure of a better way that can guarantee it overwrites the + # memory of a bytearray and doesn't just replace the underlying char *. + for i in range(length): + data[i] = 0 + + @contextlib.contextmanager + def _zeroed_null_terminated_buf(self, data): + """ + This method takes bytes, which can be a bytestring or a mutable + buffer like a bytearray, and yields a null-terminated version of that + data. This is required because PKCS12_parse doesn't take a length with + its password char * and ffi.from_buffer doesn't provide null + termination. So, to support zeroing the data via bytearray we + need to build this ridiculous construct that copies the memory, but + zeroes it after use. + """ + if data is None: + yield self._ffi.NULL + else: + data_len = len(data) + buf = self._ffi.new("char[]", data_len + 1) + self._ffi.memmove(buf, data, data_len) + try: + yield buf + finally: + # Cast to a uint8_t * so we can assign by integer + self._zero_data(self._ffi.cast("uint8_t *", buf), data_len) + + def load_key_and_certificates_from_pkcs12( + self, data: bytes, password: bytes | None + ) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], + ]: + pkcs12 = self.load_pkcs12(data, password) + return ( + pkcs12.key, + pkcs12.cert.certificate if pkcs12.cert else None, + [cert.certificate for cert in pkcs12.additional_certs], + ) + + def load_pkcs12( + self, data: bytes, password: bytes | None + ) -> PKCS12KeyAndCertificates: + if password is not None: + utils._check_byteslike("password", password) + + bio = self._bytes_to_bio(data) + p12 = self._lib.d2i_PKCS12_bio(bio.bio, self._ffi.NULL) + if p12 == self._ffi.NULL: + self._consume_errors() + raise ValueError("Could not deserialize PKCS12 data") + + p12 = self._ffi.gc(p12, self._lib.PKCS12_free) + evp_pkey_ptr = self._ffi.new("EVP_PKEY **") + x509_ptr = self._ffi.new("X509 **") + sk_x509_ptr = self._ffi.new("Cryptography_STACK_OF_X509 **") + with self._zeroed_null_terminated_buf(password) as password_buf: + res = self._lib.PKCS12_parse( + p12, password_buf, evp_pkey_ptr, x509_ptr, sk_x509_ptr + ) + if res == 0: + self._consume_errors() + raise ValueError("Invalid password or PKCS12 data") + + cert = None + key = None + additional_certificates = [] + + if evp_pkey_ptr[0] != self._ffi.NULL: + evp_pkey = self._ffi.gc(evp_pkey_ptr[0], self._lib.EVP_PKEY_free) + # We don't support turning off RSA key validation when loading + # PKCS12 keys + key = rust_openssl.keys.private_key_from_ptr( + int(self._ffi.cast("uintptr_t", evp_pkey)), + unsafe_skip_rsa_key_validation=False, + ) + + if x509_ptr[0] != self._ffi.NULL: + x509 = self._ffi.gc(x509_ptr[0], self._lib.X509_free) + cert_obj = self._ossl2cert(x509) + name = None + maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL) + if maybe_name != self._ffi.NULL: + name = self._ffi.string(maybe_name) + cert = PKCS12Certificate(cert_obj, name) + + if sk_x509_ptr[0] != self._ffi.NULL: + sk_x509 = self._ffi.gc(sk_x509_ptr[0], self._lib.sk_X509_free) + num = self._lib.sk_X509_num(sk_x509_ptr[0]) + + # In OpenSSL < 3.0.0 PKCS12 parsing reverses the order of the + # certificates. + indices: typing.Iterable[int] + if ( + self._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER + or self._lib.CRYPTOGRAPHY_IS_BORINGSSL + ): + indices = range(num) + else: + indices = reversed(range(num)) + + for i in indices: + x509 = self._lib.sk_X509_value(sk_x509, i) + self.openssl_assert(x509 != self._ffi.NULL) + x509 = self._ffi.gc(x509, self._lib.X509_free) + addl_cert = self._ossl2cert(x509) + addl_name = None + maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL) + if maybe_name != self._ffi.NULL: + addl_name = self._ffi.string(maybe_name) + additional_certificates.append( + PKCS12Certificate(addl_cert, addl_name) + ) + + return PKCS12KeyAndCertificates(key, cert, additional_certificates) + + def serialize_key_and_certificates_to_pkcs12( + self, + name: bytes | None, + key: PKCS12PrivateKeyTypes | None, + cert: x509.Certificate | None, + cas: list[_PKCS12CATypes] | None, + encryption_algorithm: serialization.KeySerializationEncryption, + ) -> bytes: + password = None + if name is not None: + utils._check_bytes("name", name) + + if isinstance(encryption_algorithm, serialization.NoEncryption): + nid_cert = -1 + nid_key = -1 + pkcs12_iter = 0 + mac_iter = 0 + mac_alg = self._ffi.NULL + elif isinstance( + encryption_algorithm, serialization.BestAvailableEncryption + ): + # PKCS12 encryption is hopeless trash and can never be fixed. + # OpenSSL 3 supports PBESv2, but Libre and Boring do not, so + # we use PBESv1 with 3DES on the older paths. + if self._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER: + nid_cert = self._lib.NID_aes_256_cbc + nid_key = self._lib.NID_aes_256_cbc + else: + nid_cert = self._lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC + nid_key = self._lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC + # At least we can set this higher than OpenSSL's default + pkcs12_iter = 20000 + # mac_iter chosen for compatibility reasons, see: + # https://www.openssl.org/docs/man1.1.1/man3/PKCS12_create.html + # Did we mention how lousy PKCS12 encryption is? + mac_iter = 1 + # MAC algorithm can only be set on OpenSSL 3.0.0+ + mac_alg = self._ffi.NULL + password = encryption_algorithm.password + elif ( + isinstance( + encryption_algorithm, serialization._KeySerializationEncryption + ) + and encryption_algorithm._format + is serialization.PrivateFormat.PKCS12 + ): + # Default to OpenSSL's defaults. Behavior will vary based on the + # version of OpenSSL cryptography is compiled against. + nid_cert = 0 + nid_key = 0 + # Use the default iters we use in best available + pkcs12_iter = 20000 + # See the Best Available comment for why this is 1 + mac_iter = 1 + password = encryption_algorithm.password + keycertalg = encryption_algorithm._key_cert_algorithm + if keycertalg is PBES.PBESv1SHA1And3KeyTripleDESCBC: + nid_cert = self._lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC + nid_key = self._lib.NID_pbe_WithSHA1And3_Key_TripleDES_CBC + elif keycertalg is PBES.PBESv2SHA256AndAES256CBC: + if not self._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER: + raise UnsupportedAlgorithm( + "PBESv2 is not supported by this version of OpenSSL" + ) + nid_cert = self._lib.NID_aes_256_cbc + nid_key = self._lib.NID_aes_256_cbc + else: + assert keycertalg is None + # We use OpenSSL's defaults + + if encryption_algorithm._hmac_hash is not None: + if not self._lib.Cryptography_HAS_PKCS12_SET_MAC: + raise UnsupportedAlgorithm( + "Setting MAC algorithm is not supported by this " + "version of OpenSSL." + ) + mac_alg = self._evp_md_non_null_from_algorithm( + encryption_algorithm._hmac_hash + ) + self.openssl_assert(mac_alg != self._ffi.NULL) + else: + mac_alg = self._ffi.NULL + + if encryption_algorithm._kdf_rounds is not None: + pkcs12_iter = encryption_algorithm._kdf_rounds + + else: + raise ValueError("Unsupported key encryption type") + + if cas is None or len(cas) == 0: + sk_x509 = self._ffi.NULL + else: + sk_x509 = self._lib.sk_X509_new_null() + sk_x509 = self._ffi.gc(sk_x509, self._lib.sk_X509_free) + + # This list is to keep the x509 values alive until end of function + ossl_cas = [] + for ca in cas: + if isinstance(ca, PKCS12Certificate): + ca_alias = ca.friendly_name + ossl_ca = self._cert2ossl(ca.certificate) + if ca_alias is None: + res = self._lib.X509_alias_set1( + ossl_ca, self._ffi.NULL, -1 + ) + else: + res = self._lib.X509_alias_set1( + ossl_ca, ca_alias, len(ca_alias) + ) + self.openssl_assert(res == 1) + else: + ossl_ca = self._cert2ossl(ca) + ossl_cas.append(ossl_ca) + res = self._lib.sk_X509_push(sk_x509, ossl_ca) + backend.openssl_assert(res >= 1) + + with self._zeroed_null_terminated_buf(password) as password_buf: + with self._zeroed_null_terminated_buf(name) as name_buf: + ossl_cert = self._cert2ossl(cert) if cert else self._ffi.NULL + ossl_pkey = ( + self._key2ossl(key) if key is not None else self._ffi.NULL + ) + + p12 = self._lib.PKCS12_create( + password_buf, + name_buf, + ossl_pkey, + ossl_cert, + sk_x509, + nid_key, + nid_cert, + pkcs12_iter, + mac_iter, + 0, + ) + + if ( + self._lib.Cryptography_HAS_PKCS12_SET_MAC + and mac_alg != self._ffi.NULL + ): + self._lib.PKCS12_set_mac( + p12, + password_buf, + -1, + self._ffi.NULL, + 0, + mac_iter, + mac_alg, + ) + + self.openssl_assert(p12 != self._ffi.NULL) + p12 = self._ffi.gc(p12, self._lib.PKCS12_free) + + bio = self._create_mem_bio_gc() + res = self._lib.i2d_PKCS12_bio(bio, p12) + self.openssl_assert(res > 0) + return self._read_mem_bio(bio) + + def poly1305_supported(self) -> bool: + if self._fips_enabled: + return False + elif ( + self._lib.CRYPTOGRAPHY_IS_BORINGSSL + or self._lib.CRYPTOGRAPHY_IS_LIBRESSL + ): + return True + else: + return self._lib.Cryptography_HAS_POLY1305 == 1 + + def pkcs7_supported(self) -> bool: + return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL + + +class GetCipherByName: + def __init__(self, fmt: str): + self._fmt = fmt + + def __call__(self, backend: Backend, cipher: CipherAlgorithm, mode: Mode): + cipher_name = self._fmt.format(cipher=cipher, mode=mode).lower() + evp_cipher = backend._lib.EVP_get_cipherbyname( + cipher_name.encode("ascii") + ) + + # try EVP_CIPHER_fetch if present + if ( + evp_cipher == backend._ffi.NULL + and backend._lib.Cryptography_HAS_300_EVP_CIPHER + ): + evp_cipher = backend._lib.EVP_CIPHER_fetch( + backend._ffi.NULL, + cipher_name.encode("ascii"), + backend._ffi.NULL, + ) + + backend._consume_errors() + return evp_cipher + + +def _get_xts_cipher(backend: Backend, cipher: AES, mode): + cipher_name = f"aes-{cipher.key_size // 2}-xts" + return backend._lib.EVP_get_cipherbyname(cipher_name.encode("ascii")) + + +backend = Backend() diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/ciphers.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/ciphers.py new file mode 100644 index 00000000..3916b1a5 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/ciphers.py @@ -0,0 +1,282 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.primitives import ciphers +from cryptography.hazmat.primitives.ciphers import algorithms, modes + +if typing.TYPE_CHECKING: + from cryptography.hazmat.backends.openssl.backend import Backend + + +class _CipherContext: + _ENCRYPT = 1 + _DECRYPT = 0 + _MAX_CHUNK_SIZE = 2**29 + + def __init__(self, backend: Backend, cipher, mode, operation: int) -> None: + self._backend = backend + self._cipher = cipher + self._mode = mode + self._operation = operation + self._tag: bytes | None = None + + if isinstance(self._cipher, ciphers.BlockCipherAlgorithm): + self._block_size_bytes = self._cipher.block_size // 8 + else: + self._block_size_bytes = 1 + + ctx = self._backend._lib.EVP_CIPHER_CTX_new() + ctx = self._backend._ffi.gc( + ctx, self._backend._lib.EVP_CIPHER_CTX_free + ) + + registry = self._backend._cipher_registry + try: + adapter = registry[type(cipher), type(mode)] + except KeyError: + raise UnsupportedAlgorithm( + "cipher {} in {} mode is not supported " + "by this backend.".format( + cipher.name, mode.name if mode else mode + ), + _Reasons.UNSUPPORTED_CIPHER, + ) + + evp_cipher = adapter(self._backend, cipher, mode) + if evp_cipher == self._backend._ffi.NULL: + msg = f"cipher {cipher.name} " + if mode is not None: + msg += f"in {mode.name} mode " + msg += ( + "is not supported by this backend (Your version of OpenSSL " + "may be too old. Current version: {}.)" + ).format(self._backend.openssl_version_text()) + raise UnsupportedAlgorithm(msg, _Reasons.UNSUPPORTED_CIPHER) + + if isinstance(mode, modes.ModeWithInitializationVector): + iv_nonce = self._backend._ffi.from_buffer( + mode.initialization_vector + ) + elif isinstance(mode, modes.ModeWithTweak): + iv_nonce = self._backend._ffi.from_buffer(mode.tweak) + elif isinstance(mode, modes.ModeWithNonce): + iv_nonce = self._backend._ffi.from_buffer(mode.nonce) + elif isinstance(cipher, algorithms.ChaCha20): + iv_nonce = self._backend._ffi.from_buffer(cipher.nonce) + else: + iv_nonce = self._backend._ffi.NULL + # begin init with cipher and operation type + res = self._backend._lib.EVP_CipherInit_ex( + ctx, + evp_cipher, + self._backend._ffi.NULL, + self._backend._ffi.NULL, + self._backend._ffi.NULL, + operation, + ) + self._backend.openssl_assert(res != 0) + # set the key length to handle variable key ciphers + res = self._backend._lib.EVP_CIPHER_CTX_set_key_length( + ctx, len(cipher.key) + ) + self._backend.openssl_assert(res != 0) + if isinstance(mode, modes.GCM): + res = self._backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, + self._backend._lib.EVP_CTRL_AEAD_SET_IVLEN, + len(iv_nonce), + self._backend._ffi.NULL, + ) + self._backend.openssl_assert(res != 0) + if mode.tag is not None: + res = self._backend._lib.EVP_CIPHER_CTX_ctrl( + ctx, + self._backend._lib.EVP_CTRL_AEAD_SET_TAG, + len(mode.tag), + mode.tag, + ) + self._backend.openssl_assert(res != 0) + self._tag = mode.tag + + # pass key/iv + res = self._backend._lib.EVP_CipherInit_ex( + ctx, + self._backend._ffi.NULL, + self._backend._ffi.NULL, + self._backend._ffi.from_buffer(cipher.key), + iv_nonce, + operation, + ) + + # Check for XTS mode duplicate keys error + errors = self._backend._consume_errors() + lib = self._backend._lib + if res == 0 and ( + ( + not lib.CRYPTOGRAPHY_IS_LIBRESSL + and errors[0]._lib_reason_match( + lib.ERR_LIB_EVP, lib.EVP_R_XTS_DUPLICATED_KEYS + ) + ) + or ( + lib.Cryptography_HAS_PROVIDERS + and errors[0]._lib_reason_match( + lib.ERR_LIB_PROV, lib.PROV_R_XTS_DUPLICATED_KEYS + ) + ) + ): + raise ValueError("In XTS mode duplicated keys are not allowed") + + self._backend.openssl_assert(res != 0, errors=errors) + + # We purposely disable padding here as it's handled higher up in the + # API. + self._backend._lib.EVP_CIPHER_CTX_set_padding(ctx, 0) + self._ctx = ctx + + def update(self, data: bytes) -> bytes: + buf = bytearray(len(data) + self._block_size_bytes - 1) + n = self.update_into(data, buf) + return bytes(buf[:n]) + + def update_into(self, data: bytes, buf: bytes) -> int: + total_data_len = len(data) + if len(buf) < (total_data_len + self._block_size_bytes - 1): + raise ValueError( + "buffer must be at least {} bytes for this payload".format( + len(data) + self._block_size_bytes - 1 + ) + ) + + data_processed = 0 + total_out = 0 + outlen = self._backend._ffi.new("int *") + baseoutbuf = self._backend._ffi.from_buffer(buf, require_writable=True) + baseinbuf = self._backend._ffi.from_buffer(data) + + while data_processed != total_data_len: + outbuf = baseoutbuf + total_out + inbuf = baseinbuf + data_processed + inlen = min(self._MAX_CHUNK_SIZE, total_data_len - data_processed) + + res = self._backend._lib.EVP_CipherUpdate( + self._ctx, outbuf, outlen, inbuf, inlen + ) + if res == 0 and isinstance(self._mode, modes.XTS): + self._backend._consume_errors() + raise ValueError( + "In XTS mode you must supply at least a full block in the " + "first update call. For AES this is 16 bytes." + ) + else: + self._backend.openssl_assert(res != 0) + data_processed += inlen + total_out += outlen[0] + + return total_out + + def finalize(self) -> bytes: + if ( + self._operation == self._DECRYPT + and isinstance(self._mode, modes.ModeWithAuthenticationTag) + and self.tag is None + ): + raise ValueError( + "Authentication tag must be provided when decrypting." + ) + + buf = self._backend._ffi.new("unsigned char[]", self._block_size_bytes) + outlen = self._backend._ffi.new("int *") + res = self._backend._lib.EVP_CipherFinal_ex(self._ctx, buf, outlen) + if res == 0: + errors = self._backend._consume_errors() + + if not errors and isinstance(self._mode, modes.GCM): + raise InvalidTag + + lib = self._backend._lib + self._backend.openssl_assert( + errors[0]._lib_reason_match( + lib.ERR_LIB_EVP, + lib.EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH, + ) + or ( + lib.Cryptography_HAS_PROVIDERS + and errors[0]._lib_reason_match( + lib.ERR_LIB_PROV, + lib.PROV_R_WRONG_FINAL_BLOCK_LENGTH, + ) + ) + or ( + lib.CRYPTOGRAPHY_IS_BORINGSSL + and errors[0].reason + == lib.CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH + ), + errors=errors, + ) + raise ValueError( + "The length of the provided data is not a multiple of " + "the block length." + ) + + if ( + isinstance(self._mode, modes.GCM) + and self._operation == self._ENCRYPT + ): + tag_buf = self._backend._ffi.new( + "unsigned char[]", self._block_size_bytes + ) + res = self._backend._lib.EVP_CIPHER_CTX_ctrl( + self._ctx, + self._backend._lib.EVP_CTRL_AEAD_GET_TAG, + self._block_size_bytes, + tag_buf, + ) + self._backend.openssl_assert(res != 0) + self._tag = self._backend._ffi.buffer(tag_buf)[:] + + res = self._backend._lib.EVP_CIPHER_CTX_reset(self._ctx) + self._backend.openssl_assert(res == 1) + return self._backend._ffi.buffer(buf)[: outlen[0]] + + def finalize_with_tag(self, tag: bytes) -> bytes: + tag_len = len(tag) + if tag_len < self._mode._min_tag_length: + raise ValueError( + "Authentication tag must be {} bytes or longer.".format( + self._mode._min_tag_length + ) + ) + elif tag_len > self._block_size_bytes: + raise ValueError( + "Authentication tag cannot be more than {} bytes.".format( + self._block_size_bytes + ) + ) + res = self._backend._lib.EVP_CIPHER_CTX_ctrl( + self._ctx, self._backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag + ) + self._backend.openssl_assert(res != 0) + self._tag = tag + return self.finalize() + + def authenticate_additional_data(self, data: bytes) -> None: + outlen = self._backend._ffi.new("int *") + res = self._backend._lib.EVP_CipherUpdate( + self._ctx, + self._backend._ffi.NULL, + outlen, + self._backend._ffi.from_buffer(data), + len(data), + ) + self._backend.openssl_assert(res != 0) + + @property + def tag(self) -> bytes | None: + return self._tag diff --git a/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/decode_asn1.py b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/decode_asn1.py new file mode 100644 index 00000000..bf123b62 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/backends/openssl/decode_asn1.py @@ -0,0 +1,32 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography import x509 + +# CRLReason ::= ENUMERATED { +# unspecified (0), +# keyCompromise (1), +# cACompromise (2), +# affiliationChanged (3), +# superseded (4), +# cessationOfOperation (5), +# certificateHold (6), +# -- value 7 is not used +# removeFromCRL (8), +# privilegeWithdrawn (9), +# aACompromise (10) } +_CRL_ENTRY_REASON_ENUM_TO_CODE = { + x509.ReasonFlags.unspecified: 0, + x509.ReasonFlags.key_compromise: 1, + x509.ReasonFlags.ca_compromise: 2, + x509.ReasonFlags.affiliation_changed: 3, + x509.ReasonFlags.superseded: 4, + x509.ReasonFlags.cessation_of_operation: 5, + x509.ReasonFlags.certificate_hold: 6, + x509.ReasonFlags.remove_from_crl: 8, + x509.ReasonFlags.privilege_withdrawn: 9, + x509.ReasonFlags.aa_compromise: 10, +} diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/bindings/__init__.py new file mode 100644 index 00000000..b5093362 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/__init__.py @@ -0,0 +1,3 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/bindings/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..be334345 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/bindings/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust.pyd b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust.pyd new file mode 100644 index 00000000..97e43779 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust.pyd differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/__init__.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/__init__.pyi new file mode 100644 index 00000000..18a6fb87 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/__init__.pyi @@ -0,0 +1,17 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +def check_pkcs7_padding(data: bytes) -> bool: ... +def check_ansix923_padding(data: bytes) -> bool: ... + +class ObjectIdentifier: + def __init__(self, val: str) -> None: ... + @property + def dotted_string(self) -> str: ... + @property + def _name(self) -> str: ... + +T = typing.TypeVar("T") diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/_openssl.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/_openssl.pyi new file mode 100644 index 00000000..80100082 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/_openssl.pyi @@ -0,0 +1,8 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +lib = typing.Any +ffi = typing.Any diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/asn1.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/asn1.pyi new file mode 100644 index 00000000..35652c6a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/asn1.pyi @@ -0,0 +1,14 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +class TestCertificate: + not_after_tag: int + not_before_tag: int + issuer_value_tags: list[int] + subject_value_tags: list[int] + +def decode_dss_signature(signature: bytes) -> tuple[int, int]: ... +def encode_dss_signature(r: int, s: int) -> bytes: ... +def parse_spki_for_data(data: bytes) -> bytes: ... +def test_parse_certificate(data: bytes) -> TestCertificate: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/exceptions.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/exceptions.pyi new file mode 100644 index 00000000..09f46b1e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/exceptions.pyi @@ -0,0 +1,17 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +class _Reasons: + BACKEND_MISSING_INTERFACE: _Reasons + UNSUPPORTED_HASH: _Reasons + UNSUPPORTED_CIPHER: _Reasons + UNSUPPORTED_PADDING: _Reasons + UNSUPPORTED_MGF: _Reasons + UNSUPPORTED_PUBLIC_KEY_ALGORITHM: _Reasons + UNSUPPORTED_ELLIPTIC_CURVE: _Reasons + UNSUPPORTED_SERIALIZATION: _Reasons + UNSUPPORTED_X509: _Reasons + UNSUPPORTED_EXCHANGE_ALGORITHM: _Reasons + UNSUPPORTED_DIFFIE_HELLMAN: _Reasons + UNSUPPORTED_MAC: _Reasons diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/ocsp.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/ocsp.pyi new file mode 100644 index 00000000..b15628f8 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/ocsp.pyi @@ -0,0 +1,23 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +from cryptography.x509.ocsp import ( + OCSPRequest, + OCSPRequestBuilder, + OCSPResponse, + OCSPResponseBuilder, + OCSPResponseStatus, +) + +def load_der_ocsp_request(data: bytes) -> OCSPRequest: ... +def load_der_ocsp_response(data: bytes) -> OCSPResponse: ... +def create_ocsp_request(builder: OCSPRequestBuilder) -> OCSPRequest: ... +def create_ocsp_response( + status: OCSPResponseStatus, + builder: OCSPResponseBuilder | None, + private_key: PrivateKeyTypes | None, + hash_algorithm: hashes.HashAlgorithm | None, +) -> OCSPResponse: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi new file mode 100644 index 00000000..9cdb4d6a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -0,0 +1,57 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.bindings._rust.openssl import ( + aead, + cmac, + dh, + dsa, + ec, + ed448, + ed25519, + hashes, + hmac, + kdf, + keys, + poly1305, + rsa, + x448, + x25519, +) + +__all__ = [ + "openssl_version", + "raise_openssl_error", + "aead", + "cmac", + "dh", + "dsa", + "ec", + "hashes", + "hmac", + "kdf", + "keys", + "ed448", + "ed25519", + "rsa", + "poly1305", + "x448", + "x25519", +] + +def openssl_version() -> int: ... +def raise_openssl_error() -> typing.NoReturn: ... +def capture_error_stack() -> list[OpenSSLError]: ... +def is_fips_enabled() -> bool: ... + +class OpenSSLError: + @property + def lib(self) -> int: ... + @property + def reason(self) -> int: ... + @property + def reason_text(self) -> bytes: ... + def _lib_reason_match(self, lib: int, reason: int) -> bool: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/aead.pyi new file mode 100644 index 00000000..81e801e3 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -0,0 +1,69 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +class ChaCha20Poly1305: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key() -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + +class AESSIV: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + data: bytes, + associated_data: list[bytes] | None, + ) -> bytes: ... + def decrypt( + self, + data: bytes, + associated_data: list[bytes] | None, + ) -> bytes: ... + +class AESOCB3: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + +class AESGCMSIV: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/cmac.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/cmac.pyi new file mode 100644 index 00000000..9c03508b --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/cmac.pyi @@ -0,0 +1,18 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives import ciphers + +class CMAC: + def __init__( + self, + algorithm: ciphers.BlockCipherAlgorithm, + backend: typing.Any = None, + ) -> None: ... + def update(self, data: bytes) -> None: ... + def finalize(self) -> bytes: ... + def verify(self, signature: bytes) -> None: ... + def copy(self) -> CMAC: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dh.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dh.pyi new file mode 100644 index 00000000..08733d74 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dh.pyi @@ -0,0 +1,51 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives.asymmetric import dh + +MIN_MODULUS_SIZE: int + +class DHPrivateKey: ... +class DHPublicKey: ... +class DHParameters: ... + +class DHPrivateNumbers: + def __init__(self, x: int, public_numbers: DHPublicNumbers) -> None: ... + def private_key(self, backend: typing.Any = None) -> dh.DHPrivateKey: ... + @property + def x(self) -> int: ... + @property + def public_numbers(self) -> DHPublicNumbers: ... + +class DHPublicNumbers: + def __init__( + self, y: int, parameter_numbers: DHParameterNumbers + ) -> None: ... + def public_key(self, backend: typing.Any = None) -> dh.DHPublicKey: ... + @property + def y(self) -> int: ... + @property + def parameter_numbers(self) -> DHParameterNumbers: ... + +class DHParameterNumbers: + def __init__(self, p: int, g: int, q: int | None = None) -> None: ... + def parameters(self, backend: typing.Any = None) -> dh.DHParameters: ... + @property + def p(self) -> int: ... + @property + def g(self) -> int: ... + @property + def q(self) -> int | None: ... + +def generate_parameters( + generator: int, key_size: int, backend: typing.Any = None +) -> dh.DHParameters: ... +def from_pem_parameters( + data: bytes, backend: typing.Any = None +) -> dh.DHParameters: ... +def from_der_parameters( + data: bytes, backend: typing.Any = None +) -> dh.DHParameters: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dsa.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dsa.pyi new file mode 100644 index 00000000..0922a4c4 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/dsa.pyi @@ -0,0 +1,41 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives.asymmetric import dsa + +class DSAPrivateKey: ... +class DSAPublicKey: ... +class DSAParameters: ... + +class DSAPrivateNumbers: + def __init__(self, x: int, public_numbers: DSAPublicNumbers) -> None: ... + @property + def x(self) -> int: ... + @property + def public_numbers(self) -> DSAPublicNumbers: ... + def private_key(self, backend: typing.Any = None) -> dsa.DSAPrivateKey: ... + +class DSAPublicNumbers: + def __init__( + self, y: int, parameter_numbers: DSAParameterNumbers + ) -> None: ... + @property + def y(self) -> int: ... + @property + def parameter_numbers(self) -> DSAParameterNumbers: ... + def public_key(self, backend: typing.Any = None) -> dsa.DSAPublicKey: ... + +class DSAParameterNumbers: + def __init__(self, p: int, q: int, g: int) -> None: ... + @property + def p(self) -> int: ... + @property + def q(self) -> int: ... + @property + def g(self) -> int: ... + def parameters(self, backend: typing.Any = None) -> dsa.DSAParameters: ... + +def generate_parameters(key_size: int) -> dsa.DSAParameters: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ec.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ec.pyi new file mode 100644 index 00000000..5c3b7bf6 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ec.pyi @@ -0,0 +1,52 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives.asymmetric import ec + +class ECPrivateKey: ... +class ECPublicKey: ... + +class EllipticCurvePrivateNumbers: + def __init__( + self, private_value: int, public_numbers: EllipticCurvePublicNumbers + ) -> None: ... + def private_key( + self, backend: typing.Any = None + ) -> ec.EllipticCurvePrivateKey: ... + @property + def private_value(self) -> int: ... + @property + def public_numbers(self) -> EllipticCurvePublicNumbers: ... + +class EllipticCurvePublicNumbers: + def __init__(self, x: int, y: int, curve: ec.EllipticCurve) -> None: ... + def public_key( + self, backend: typing.Any = None + ) -> ec.EllipticCurvePublicKey: ... + @property + def x(self) -> int: ... + @property + def y(self) -> int: ... + @property + def curve(self) -> ec.EllipticCurve: ... + def __eq__(self, other: object) -> bool: ... + +def curve_supported(curve: ec.EllipticCurve) -> bool: ... +def generate_private_key( + curve: ec.EllipticCurve, backend: typing.Any = None +) -> ec.EllipticCurvePrivateKey: ... +def from_private_numbers( + numbers: ec.EllipticCurvePrivateNumbers, +) -> ec.EllipticCurvePrivateKey: ... +def from_public_numbers( + numbers: ec.EllipticCurvePublicNumbers, +) -> ec.EllipticCurvePublicKey: ... +def from_public_bytes( + curve: ec.EllipticCurve, data: bytes +) -> ec.EllipticCurvePublicKey: ... +def derive_private_key( + private_value: int, curve: ec.EllipticCurve +) -> ec.EllipticCurvePrivateKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed25519.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed25519.pyi new file mode 100644 index 00000000..5233f9a1 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed25519.pyi @@ -0,0 +1,12 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import ed25519 + +class Ed25519PrivateKey: ... +class Ed25519PublicKey: ... + +def generate_key() -> ed25519.Ed25519PrivateKey: ... +def from_private_bytes(data: bytes) -> ed25519.Ed25519PrivateKey: ... +def from_public_bytes(data: bytes) -> ed25519.Ed25519PublicKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed448.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed448.pyi new file mode 100644 index 00000000..7a065203 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/ed448.pyi @@ -0,0 +1,12 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import ed448 + +class Ed448PrivateKey: ... +class Ed448PublicKey: ... + +def generate_key() -> ed448.Ed448PrivateKey: ... +def from_private_bytes(data: bytes) -> ed448.Ed448PrivateKey: ... +def from_public_bytes(data: bytes) -> ed448.Ed448PublicKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hashes.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hashes.pyi new file mode 100644 index 00000000..ca5f42a0 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hashes.pyi @@ -0,0 +1,17 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives import hashes + +class Hash(hashes.HashContext): + def __init__( + self, algorithm: hashes.HashAlgorithm, backend: typing.Any = None + ) -> None: ... + @property + def algorithm(self) -> hashes.HashAlgorithm: ... + def update(self, data: bytes) -> None: ... + def finalize(self) -> bytes: ... + def copy(self) -> Hash: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hmac.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hmac.pyi new file mode 100644 index 00000000..e38d9b54 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/hmac.pyi @@ -0,0 +1,21 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives import hashes + +class HMAC(hashes.HashContext): + def __init__( + self, + key: bytes, + algorithm: hashes.HashAlgorithm, + backend: typing.Any = None, + ) -> None: ... + @property + def algorithm(self) -> hashes.HashAlgorithm: ... + def update(self, data: bytes) -> None: ... + def finalize(self) -> bytes: ... + def verify(self, signature: bytes) -> None: ... + def copy(self) -> HMAC: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi new file mode 100644 index 00000000..034a8fed --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/kdf.pyi @@ -0,0 +1,22 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.hashes import HashAlgorithm + +def derive_pbkdf2_hmac( + key_material: bytes, + algorithm: HashAlgorithm, + salt: bytes, + iterations: int, + length: int, +) -> bytes: ... +def derive_scrypt( + key_material: bytes, + salt: bytes, + n: int, + r: int, + p: int, + max_mem: int, + length: int, +) -> bytes: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/keys.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/keys.pyi new file mode 100644 index 00000000..e312d51d --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/keys.pyi @@ -0,0 +1,37 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, + PublicKeyTypes, +) + +def private_key_from_ptr( + ptr: int, + unsafe_skip_rsa_key_validation: bool, +) -> PrivateKeyTypes: ... +def load_der_private_key( + data: bytes, + password: bytes | None, + backend: typing.Any = None, + *, + unsafe_skip_rsa_key_validation: bool = False, +) -> PrivateKeyTypes: ... +def load_pem_private_key( + data: bytes, + password: bytes | None, + backend: typing.Any = None, + *, + unsafe_skip_rsa_key_validation: bool = False, +) -> PrivateKeyTypes: ... +def load_der_public_key( + data: bytes, + backend: typing.Any = None, +) -> PublicKeyTypes: ... +def load_pem_public_key( + data: bytes, + backend: typing.Any = None, +) -> PublicKeyTypes: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/poly1305.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/poly1305.pyi new file mode 100644 index 00000000..2e9b0a9e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/poly1305.pyi @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +class Poly1305: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_tag(key: bytes, data: bytes) -> bytes: ... + @staticmethod + def verify_tag(key: bytes, data: bytes, tag: bytes) -> None: ... + def update(self, data: bytes) -> None: ... + def finalize(self) -> bytes: ... + def verify(self, tag: bytes) -> None: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi new file mode 100644 index 00000000..ef7752dd --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi @@ -0,0 +1,55 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives.asymmetric import rsa + +class RSAPrivateKey: ... +class RSAPublicKey: ... + +class RSAPrivateNumbers: + def __init__( + self, + p: int, + q: int, + d: int, + dmp1: int, + dmq1: int, + iqmp: int, + public_numbers: RSAPublicNumbers, + ) -> None: ... + @property + def p(self) -> int: ... + @property + def q(self) -> int: ... + @property + def d(self) -> int: ... + @property + def dmp1(self) -> int: ... + @property + def dmq1(self) -> int: ... + @property + def iqmp(self) -> int: ... + @property + def public_numbers(self) -> RSAPublicNumbers: ... + def private_key( + self, + backend: typing.Any = None, + *, + unsafe_skip_rsa_key_validation: bool = False, + ) -> rsa.RSAPrivateKey: ... + +class RSAPublicNumbers: + def __init__(self, e: int, n: int) -> None: ... + @property + def n(self) -> int: ... + @property + def e(self) -> int: ... + def public_key(self, backend: typing.Any = None) -> rsa.RSAPublicKey: ... + +def generate_private_key( + public_exponent: int, + key_size: int, +) -> rsa.RSAPrivateKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x25519.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x25519.pyi new file mode 100644 index 00000000..da0f3ec5 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x25519.pyi @@ -0,0 +1,12 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import x25519 + +class X25519PrivateKey: ... +class X25519PublicKey: ... + +def generate_key() -> x25519.X25519PrivateKey: ... +def from_private_bytes(data: bytes) -> x25519.X25519PrivateKey: ... +def from_public_bytes(data: bytes) -> x25519.X25519PublicKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x448.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x448.pyi new file mode 100644 index 00000000..e51cfebe --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/openssl/x448.pyi @@ -0,0 +1,12 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import x448 + +class X448PrivateKey: ... +class X448PublicKey: ... + +def generate_key() -> x448.X448PrivateKey: ... +def from_private_bytes(data: bytes) -> x448.X448PrivateKey: ... +def from_public_bytes(data: bytes) -> x448.X448PublicKey: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/pkcs7.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/pkcs7.pyi new file mode 100644 index 00000000..a8497824 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/pkcs7.pyi @@ -0,0 +1,21 @@ +import typing + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.serialization import pkcs7 + +def serialize_certificates( + certs: list[x509.Certificate], + encoding: serialization.Encoding, +) -> bytes: ... +def sign_and_serialize( + builder: pkcs7.PKCS7SignatureBuilder, + encoding: serialization.Encoding, + options: typing.Iterable[pkcs7.PKCS7Options], +) -> bytes: ... +def load_pem_pkcs7_certificates( + data: bytes, +) -> list[x509.Certificate]: ... +def load_der_pkcs7_certificates( + data: bytes, +) -> list[x509.Certificate]: ... diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/x509.pyi b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/x509.pyi new file mode 100644 index 00000000..418184f8 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/_rust/x509.pyi @@ -0,0 +1,88 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import datetime +import typing + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric.padding import PSS, PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + +def load_pem_x509_certificate( + data: bytes, backend: typing.Any = None +) -> x509.Certificate: ... +def load_der_x509_certificate( + data: bytes, backend: typing.Any = None +) -> x509.Certificate: ... +def load_pem_x509_certificates( + data: bytes, +) -> list[x509.Certificate]: ... +def load_pem_x509_crl( + data: bytes, backend: typing.Any = None +) -> x509.CertificateRevocationList: ... +def load_der_x509_crl( + data: bytes, backend: typing.Any = None +) -> x509.CertificateRevocationList: ... +def load_pem_x509_csr( + data: bytes, backend: typing.Any = None +) -> x509.CertificateSigningRequest: ... +def load_der_x509_csr( + data: bytes, backend: typing.Any = None +) -> x509.CertificateSigningRequest: ... +def encode_name_bytes(name: x509.Name) -> bytes: ... +def encode_extension_value(extension: x509.ExtensionType) -> bytes: ... +def create_x509_certificate( + builder: x509.CertificateBuilder, + private_key: PrivateKeyTypes, + hash_algorithm: hashes.HashAlgorithm | None, + rsa_padding: PKCS1v15 | PSS | None, +) -> x509.Certificate: ... +def create_x509_csr( + builder: x509.CertificateSigningRequestBuilder, + private_key: PrivateKeyTypes, + hash_algorithm: hashes.HashAlgorithm | None, + rsa_padding: PKCS1v15 | PSS | None, +) -> x509.CertificateSigningRequest: ... +def create_x509_crl( + builder: x509.CertificateRevocationListBuilder, + private_key: PrivateKeyTypes, + hash_algorithm: hashes.HashAlgorithm | None, + rsa_padding: PKCS1v15 | PSS | None, +) -> x509.CertificateRevocationList: ... + +class Sct: ... +class Certificate: ... +class RevokedCertificate: ... +class CertificateRevocationList: ... +class CertificateSigningRequest: ... + +class PolicyBuilder: + def time(self, new_time: datetime.datetime) -> PolicyBuilder: ... + def store(self, new_store: Store) -> PolicyBuilder: ... + def max_chain_depth(self, new_max_chain_depth: int) -> PolicyBuilder: ... + def build_server_verifier( + self, subject: x509.verification.Subject + ) -> ServerVerifier: ... + +class ServerVerifier: + @property + def subject(self) -> x509.verification.Subject: ... + @property + def validation_time(self) -> datetime.datetime: ... + @property + def store(self) -> Store: ... + @property + def max_chain_depth(self) -> int: ... + def verify( + self, + leaf: x509.Certificate, + intermediates: list[x509.Certificate], + ) -> list[x509.Certificate]: ... + +class Store: + def __init__(self, certs: list[x509.Certificate]) -> None: ... + +class VerificationError(Exception): + pass diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__init__.py new file mode 100644 index 00000000..b5093362 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__init__.py @@ -0,0 +1,3 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..d24749ae Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/_conditional.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/_conditional.cpython-311.pyc new file mode 100644 index 00000000..a60e57b9 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/_conditional.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/binding.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/binding.cpython-311.pyc new file mode 100644 index 00000000..0c4620c8 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/__pycache__/binding.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/_conditional.py b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/_conditional.py new file mode 100644 index 00000000..30cc3bfa --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/_conditional.py @@ -0,0 +1,233 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + + +def cryptography_has_set_cert_cb() -> list[str]: + return [ + "SSL_CTX_set_cert_cb", + "SSL_set_cert_cb", + ] + + +def cryptography_has_ssl_st() -> list[str]: + return [ + "SSL_ST_BEFORE", + "SSL_ST_OK", + "SSL_ST_INIT", + "SSL_ST_RENEGOTIATE", + ] + + +def cryptography_has_tls_st() -> list[str]: + return [ + "TLS_ST_BEFORE", + "TLS_ST_OK", + ] + + +def cryptography_has_mem_functions() -> list[str]: + return [ + "Cryptography_CRYPTO_set_mem_functions", + ] + + +def cryptography_has_ed448() -> list[str]: + return [ + "EVP_PKEY_ED448", + ] + + +def cryptography_has_ssl_sigalgs() -> list[str]: + return [ + "SSL_CTX_set1_sigalgs_list", + ] + + +def cryptography_has_psk() -> list[str]: + return [ + "SSL_CTX_use_psk_identity_hint", + "SSL_CTX_set_psk_server_callback", + "SSL_CTX_set_psk_client_callback", + ] + + +def cryptography_has_psk_tlsv13() -> list[str]: + return [ + "SSL_CTX_set_psk_find_session_callback", + "SSL_CTX_set_psk_use_session_callback", + "Cryptography_SSL_SESSION_new", + "SSL_CIPHER_find", + "SSL_SESSION_set1_master_key", + "SSL_SESSION_set_cipher", + "SSL_SESSION_set_protocol_version", + ] + + +def cryptography_has_custom_ext() -> list[str]: + return [ + "SSL_CTX_add_client_custom_ext", + "SSL_CTX_add_server_custom_ext", + "SSL_extension_supported", + ] + + +def cryptography_has_tlsv13_functions() -> list[str]: + return [ + "SSL_VERIFY_POST_HANDSHAKE", + "SSL_CTX_set_ciphersuites", + "SSL_verify_client_post_handshake", + "SSL_CTX_set_post_handshake_auth", + "SSL_set_post_handshake_auth", + "SSL_SESSION_get_max_early_data", + "SSL_write_early_data", + "SSL_read_early_data", + "SSL_CTX_set_max_early_data", + ] + + +def cryptography_has_engine() -> list[str]: + return [ + "ENGINE_by_id", + "ENGINE_init", + "ENGINE_finish", + "ENGINE_get_default_RAND", + "ENGINE_set_default_RAND", + "ENGINE_unregister_RAND", + "ENGINE_ctrl_cmd", + "ENGINE_free", + "ENGINE_get_name", + "ENGINE_ctrl_cmd_string", + "ENGINE_load_builtin_engines", + "ENGINE_load_private_key", + "ENGINE_load_public_key", + "SSL_CTX_set_client_cert_engine", + ] + + +def cryptography_has_verified_chain() -> list[str]: + return [ + "SSL_get0_verified_chain", + ] + + +def cryptography_has_srtp() -> list[str]: + return [ + "SSL_CTX_set_tlsext_use_srtp", + "SSL_set_tlsext_use_srtp", + "SSL_get_selected_srtp_profile", + ] + + +def cryptography_has_providers() -> list[str]: + return [ + "OSSL_PROVIDER_load", + "OSSL_PROVIDER_unload", + "ERR_LIB_PROV", + "PROV_R_WRONG_FINAL_BLOCK_LENGTH", + "PROV_R_BAD_DECRYPT", + ] + + +def cryptography_has_op_no_renegotiation() -> list[str]: + return [ + "SSL_OP_NO_RENEGOTIATION", + ] + + +def cryptography_has_dtls_get_data_mtu() -> list[str]: + return [ + "DTLS_get_data_mtu", + ] + + +def cryptography_has_300_fips() -> list[str]: + return [ + "EVP_default_properties_enable_fips", + ] + + +def cryptography_has_ssl_cookie() -> list[str]: + return [ + "SSL_OP_COOKIE_EXCHANGE", + "DTLSv1_listen", + "SSL_CTX_set_cookie_generate_cb", + "SSL_CTX_set_cookie_verify_cb", + ] + + +def cryptography_has_pkcs7_funcs() -> list[str]: + return [ + "PKCS7_verify", + "SMIME_read_PKCS7", + ] + + +def cryptography_has_prime_checks() -> list[str]: + return [ + "BN_prime_checks_for_size", + ] + + +def cryptography_has_300_evp_cipher() -> list[str]: + return ["EVP_CIPHER_fetch", "EVP_CIPHER_free"] + + +def cryptography_has_unexpected_eof_while_reading() -> list[str]: + return ["SSL_R_UNEXPECTED_EOF_WHILE_READING"] + + +def cryptography_has_pkcs12_set_mac() -> list[str]: + return ["PKCS12_set_mac"] + + +def cryptography_has_ssl_op_ignore_unexpected_eof() -> list[str]: + return [ + "SSL_OP_IGNORE_UNEXPECTED_EOF", + ] + + +def cryptography_has_get_extms_support() -> list[str]: + return ["SSL_get_extms_support"] + + +# This is a mapping of +# {condition: function-returning-names-dependent-on-that-condition} so we can +# loop over them and delete unsupported names at runtime. It will be removed +# when cffi supports #if in cdef. We use functions instead of just a dict of +# lists so we can use coverage to measure which are used. +CONDITIONAL_NAMES = { + "Cryptography_HAS_SET_CERT_CB": cryptography_has_set_cert_cb, + "Cryptography_HAS_SSL_ST": cryptography_has_ssl_st, + "Cryptography_HAS_TLS_ST": cryptography_has_tls_st, + "Cryptography_HAS_MEM_FUNCTIONS": cryptography_has_mem_functions, + "Cryptography_HAS_ED448": cryptography_has_ed448, + "Cryptography_HAS_SIGALGS": cryptography_has_ssl_sigalgs, + "Cryptography_HAS_PSK": cryptography_has_psk, + "Cryptography_HAS_PSK_TLSv1_3": cryptography_has_psk_tlsv13, + "Cryptography_HAS_CUSTOM_EXT": cryptography_has_custom_ext, + "Cryptography_HAS_TLSv1_3_FUNCTIONS": cryptography_has_tlsv13_functions, + "Cryptography_HAS_ENGINE": cryptography_has_engine, + "Cryptography_HAS_VERIFIED_CHAIN": cryptography_has_verified_chain, + "Cryptography_HAS_SRTP": cryptography_has_srtp, + "Cryptography_HAS_PROVIDERS": cryptography_has_providers, + "Cryptography_HAS_OP_NO_RENEGOTIATION": ( + cryptography_has_op_no_renegotiation + ), + "Cryptography_HAS_DTLS_GET_DATA_MTU": cryptography_has_dtls_get_data_mtu, + "Cryptography_HAS_300_FIPS": cryptography_has_300_fips, + "Cryptography_HAS_SSL_COOKIE": cryptography_has_ssl_cookie, + "Cryptography_HAS_PKCS7_FUNCS": cryptography_has_pkcs7_funcs, + "Cryptography_HAS_PRIME_CHECKS": cryptography_has_prime_checks, + "Cryptography_HAS_300_EVP_CIPHER": cryptography_has_300_evp_cipher, + "Cryptography_HAS_UNEXPECTED_EOF_WHILE_READING": ( + cryptography_has_unexpected_eof_while_reading + ), + "Cryptography_HAS_PKCS12_SET_MAC": cryptography_has_pkcs12_set_mac, + "Cryptography_HAS_SSL_OP_IGNORE_UNEXPECTED_EOF": ( + cryptography_has_ssl_op_ignore_unexpected_eof + ), + "Cryptography_HAS_GET_EXTMS_SUPPORT": cryptography_has_get_extms_support, +} diff --git a/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/binding.py b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/binding.py new file mode 100644 index 00000000..40814f2a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/bindings/openssl/binding.py @@ -0,0 +1,175 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import os +import sys +import threading +import types +import typing +import warnings + +import cryptography +from cryptography.exceptions import InternalError +from cryptography.hazmat.bindings._rust import _openssl, openssl +from cryptography.hazmat.bindings.openssl._conditional import CONDITIONAL_NAMES + + +def _openssl_assert( + ok: bool, + errors: list[openssl.OpenSSLError] | None = None, +) -> None: + if not ok: + if errors is None: + errors = openssl.capture_error_stack() + + raise InternalError( + "Unknown OpenSSL error. This error is commonly encountered when " + "another library is not cleaning up the OpenSSL error stack. If " + "you are using cryptography with another library that uses " + "OpenSSL try disabling it before reporting a bug. Otherwise " + "please file an issue at https://github.com/pyca/cryptography/" + "issues with information on how to reproduce " + f"this. ({errors!r})", + errors, + ) + + +def _legacy_provider_error(loaded: bool) -> None: + if not loaded: + raise RuntimeError( + "OpenSSL 3.0's legacy provider failed to load. This is a fatal " + "error by default, but cryptography supports running without " + "legacy algorithms by setting the environment variable " + "CRYPTOGRAPHY_OPENSSL_NO_LEGACY. If you did not expect this error," + " you have likely made a mistake with your OpenSSL configuration." + ) + + +def build_conditional_library( + lib: typing.Any, + conditional_names: dict[str, typing.Callable[[], list[str]]], +) -> typing.Any: + conditional_lib = types.ModuleType("lib") + conditional_lib._original_lib = lib # type: ignore[attr-defined] + excluded_names = set() + for condition, names_cb in conditional_names.items(): + if not getattr(lib, condition): + excluded_names.update(names_cb()) + + for attr in dir(lib): + if attr not in excluded_names: + setattr(conditional_lib, attr, getattr(lib, attr)) + + return conditional_lib + + +class Binding: + """ + OpenSSL API wrapper. + """ + + lib: typing.ClassVar = None + ffi = _openssl.ffi + _lib_loaded = False + _init_lock = threading.Lock() + _legacy_provider: typing.Any = ffi.NULL + _legacy_provider_loaded = False + _default_provider: typing.Any = ffi.NULL + + def __init__(self) -> None: + self._ensure_ffi_initialized() + + def _enable_fips(self) -> None: + # This function enables FIPS mode for OpenSSL 3.0.0 on installs that + # have the FIPS provider installed properly. + _openssl_assert(self.lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER) + self._base_provider = self.lib.OSSL_PROVIDER_load( + self.ffi.NULL, b"base" + ) + _openssl_assert(self._base_provider != self.ffi.NULL) + self.lib._fips_provider = self.lib.OSSL_PROVIDER_load( + self.ffi.NULL, b"fips" + ) + _openssl_assert(self.lib._fips_provider != self.ffi.NULL) + + res = self.lib.EVP_default_properties_enable_fips(self.ffi.NULL, 1) + _openssl_assert(res == 1) + + @classmethod + def _ensure_ffi_initialized(cls) -> None: + with cls._init_lock: + if not cls._lib_loaded: + cls.lib = build_conditional_library( + _openssl.lib, CONDITIONAL_NAMES + ) + cls._lib_loaded = True + # As of OpenSSL 3.0.0 we must register a legacy cipher provider + # to get RC2 (needed for junk asymmetric private key + # serialization), RC4, Blowfish, IDEA, SEED, etc. These things + # are ugly legacy, but we aren't going to get rid of them + # any time soon. + if cls.lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER: + if not os.environ.get("CRYPTOGRAPHY_OPENSSL_NO_LEGACY"): + cls._legacy_provider = cls.lib.OSSL_PROVIDER_load( + cls.ffi.NULL, b"legacy" + ) + cls._legacy_provider_loaded = ( + cls._legacy_provider != cls.ffi.NULL + ) + _legacy_provider_error(cls._legacy_provider_loaded) + + cls._default_provider = cls.lib.OSSL_PROVIDER_load( + cls.ffi.NULL, b"default" + ) + _openssl_assert(cls._default_provider != cls.ffi.NULL) + + @classmethod + def init_static_locks(cls) -> None: + cls._ensure_ffi_initialized() + + +def _verify_package_version(version: str) -> None: + # Occasionally we run into situations where the version of the Python + # package does not match the version of the shared object that is loaded. + # This may occur in environments where multiple versions of cryptography + # are installed and available in the python path. To avoid errors cropping + # up later this code checks that the currently imported package and the + # shared object that were loaded have the same version and raise an + # ImportError if they do not + so_package_version = _openssl.ffi.string( + _openssl.lib.CRYPTOGRAPHY_PACKAGE_VERSION + ) + if version.encode("ascii") != so_package_version: + raise ImportError( + "The version of cryptography does not match the loaded " + "shared object. This can happen if you have multiple copies of " + "cryptography installed in your Python path. Please try creating " + "a new virtual environment to resolve this issue. " + "Loaded python version: {}, shared object version: {}".format( + version, so_package_version + ) + ) + + _openssl_assert( + _openssl.lib.OpenSSL_version_num() == openssl.openssl_version(), + ) + + +_verify_package_version(cryptography.__version__) + +Binding.init_static_locks() + +if ( + sys.platform == "win32" + and os.environ.get("PROCESSOR_ARCHITEW6432") is not None +): + warnings.warn( + "You are using cryptography on a 32-bit Python on a 64-bit Windows " + "Operating System. Cryptography will be significantly faster if you " + "switch to using a 64-bit Python.", + UserWarning, + stacklevel=2, + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__init__.py new file mode 100644 index 00000000..b5093362 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__init__.py @@ -0,0 +1,3 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..4b78eb95 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_asymmetric.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_asymmetric.cpython-311.pyc new file mode 100644 index 00000000..66ba9242 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_asymmetric.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_cipheralgorithm.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_cipheralgorithm.cpython-311.pyc new file mode 100644 index 00000000..9d8e820c Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_cipheralgorithm.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_serialization.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_serialization.cpython-311.pyc new file mode 100644 index 00000000..3724b63d Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/_serialization.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/cmac.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/cmac.cpython-311.pyc new file mode 100644 index 00000000..7cc769e9 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/cmac.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/constant_time.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/constant_time.cpython-311.pyc new file mode 100644 index 00000000..ca56697e Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/constant_time.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hashes.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hashes.cpython-311.pyc new file mode 100644 index 00000000..6ecc78df Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hashes.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hmac.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hmac.cpython-311.pyc new file mode 100644 index 00000000..bf9eb42e Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/hmac.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/keywrap.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/keywrap.cpython-311.pyc new file mode 100644 index 00000000..b8cc21f6 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/keywrap.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/padding.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/padding.cpython-311.pyc new file mode 100644 index 00000000..772cfbeb Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/padding.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/poly1305.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/poly1305.cpython-311.pyc new file mode 100644 index 00000000..378ef337 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/__pycache__/poly1305.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/_asymmetric.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_asymmetric.py new file mode 100644 index 00000000..ea55ffdf --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_asymmetric.py @@ -0,0 +1,19 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +# This exists to break an import cycle. It is normally accessible from the +# asymmetric padding module. + + +class AsymmetricPadding(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def name(self) -> str: + """ + A string naming this padding (e.g. "PSS", "PKCS1"). + """ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/_cipheralgorithm.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_cipheralgorithm.py new file mode 100644 index 00000000..9d7f5bc7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_cipheralgorithm.py @@ -0,0 +1,44 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +# This exists to break an import cycle. It is normally accessible from the +# ciphers module. + + +class CipherAlgorithm(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def name(self) -> str: + """ + A string naming this mode (e.g. "AES", "Camellia"). + """ + + @property + @abc.abstractmethod + def key_sizes(self) -> frozenset[int]: + """ + Valid key sizes for this algorithm in bits + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The size of the key being used as an integer in bits (e.g. 128, 256). + """ + + +class BlockCipherAlgorithm(CipherAlgorithm): + key: bytes + + @property + @abc.abstractmethod + def block_size(self) -> int: + """ + The size of a block as an integer in bits (e.g. 64, 128). + """ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/_serialization.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_serialization.py new file mode 100644 index 00000000..46157721 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/_serialization.py @@ -0,0 +1,169 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography import utils +from cryptography.hazmat.primitives.hashes import HashAlgorithm + +# This exists to break an import cycle. These classes are normally accessible +# from the serialization module. + + +class PBES(utils.Enum): + PBESv1SHA1And3KeyTripleDESCBC = "PBESv1 using SHA1 and 3-Key TripleDES" + PBESv2SHA256AndAES256CBC = "PBESv2 using SHA256 PBKDF2 and AES256 CBC" + + +class Encoding(utils.Enum): + PEM = "PEM" + DER = "DER" + OpenSSH = "OpenSSH" + Raw = "Raw" + X962 = "ANSI X9.62" + SMIME = "S/MIME" + + +class PrivateFormat(utils.Enum): + PKCS8 = "PKCS8" + TraditionalOpenSSL = "TraditionalOpenSSL" + Raw = "Raw" + OpenSSH = "OpenSSH" + PKCS12 = "PKCS12" + + def encryption_builder(self) -> KeySerializationEncryptionBuilder: + if self not in (PrivateFormat.OpenSSH, PrivateFormat.PKCS12): + raise ValueError( + "encryption_builder only supported with PrivateFormat.OpenSSH" + " and PrivateFormat.PKCS12" + ) + return KeySerializationEncryptionBuilder(self) + + +class PublicFormat(utils.Enum): + SubjectPublicKeyInfo = "X.509 subjectPublicKeyInfo with PKCS#1" + PKCS1 = "Raw PKCS#1" + OpenSSH = "OpenSSH" + Raw = "Raw" + CompressedPoint = "X9.62 Compressed Point" + UncompressedPoint = "X9.62 Uncompressed Point" + + +class ParameterFormat(utils.Enum): + PKCS3 = "PKCS3" + + +class KeySerializationEncryption(metaclass=abc.ABCMeta): + pass + + +class BestAvailableEncryption(KeySerializationEncryption): + def __init__(self, password: bytes): + if not isinstance(password, bytes) or len(password) == 0: + raise ValueError("Password must be 1 or more bytes.") + + self.password = password + + +class NoEncryption(KeySerializationEncryption): + pass + + +class KeySerializationEncryptionBuilder: + def __init__( + self, + format: PrivateFormat, + *, + _kdf_rounds: int | None = None, + _hmac_hash: HashAlgorithm | None = None, + _key_cert_algorithm: PBES | None = None, + ) -> None: + self._format = format + + self._kdf_rounds = _kdf_rounds + self._hmac_hash = _hmac_hash + self._key_cert_algorithm = _key_cert_algorithm + + def kdf_rounds(self, rounds: int) -> KeySerializationEncryptionBuilder: + if self._kdf_rounds is not None: + raise ValueError("kdf_rounds already set") + + if not isinstance(rounds, int): + raise TypeError("kdf_rounds must be an integer") + + if rounds < 1: + raise ValueError("kdf_rounds must be a positive integer") + + return KeySerializationEncryptionBuilder( + self._format, + _kdf_rounds=rounds, + _hmac_hash=self._hmac_hash, + _key_cert_algorithm=self._key_cert_algorithm, + ) + + def hmac_hash( + self, algorithm: HashAlgorithm + ) -> KeySerializationEncryptionBuilder: + if self._format is not PrivateFormat.PKCS12: + raise TypeError( + "hmac_hash only supported with PrivateFormat.PKCS12" + ) + + if self._hmac_hash is not None: + raise ValueError("hmac_hash already set") + return KeySerializationEncryptionBuilder( + self._format, + _kdf_rounds=self._kdf_rounds, + _hmac_hash=algorithm, + _key_cert_algorithm=self._key_cert_algorithm, + ) + + def key_cert_algorithm( + self, algorithm: PBES + ) -> KeySerializationEncryptionBuilder: + if self._format is not PrivateFormat.PKCS12: + raise TypeError( + "key_cert_algorithm only supported with " + "PrivateFormat.PKCS12" + ) + if self._key_cert_algorithm is not None: + raise ValueError("key_cert_algorithm already set") + return KeySerializationEncryptionBuilder( + self._format, + _kdf_rounds=self._kdf_rounds, + _hmac_hash=self._hmac_hash, + _key_cert_algorithm=algorithm, + ) + + def build(self, password: bytes) -> KeySerializationEncryption: + if not isinstance(password, bytes) or len(password) == 0: + raise ValueError("Password must be 1 or more bytes.") + + return _KeySerializationEncryption( + self._format, + password, + kdf_rounds=self._kdf_rounds, + hmac_hash=self._hmac_hash, + key_cert_algorithm=self._key_cert_algorithm, + ) + + +class _KeySerializationEncryption(KeySerializationEncryption): + def __init__( + self, + format: PrivateFormat, + password: bytes, + *, + kdf_rounds: int | None, + hmac_hash: HashAlgorithm | None, + key_cert_algorithm: PBES | None, + ): + self._format = format + self.password = password + + self._kdf_rounds = kdf_rounds + self._hmac_hash = hmac_hash + self._key_cert_algorithm = key_cert_algorithm diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__init__.py new file mode 100644 index 00000000..b5093362 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__init__.py @@ -0,0 +1,3 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..da3ec667 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dh.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dh.cpython-311.pyc new file mode 100644 index 00000000..5693c802 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dh.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dsa.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dsa.cpython-311.pyc new file mode 100644 index 00000000..cb30a671 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/dsa.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ec.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ec.cpython-311.pyc new file mode 100644 index 00000000..e1113514 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ec.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed25519.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed25519.cpython-311.pyc new file mode 100644 index 00000000..d0d3e0a3 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed25519.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed448.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed448.cpython-311.pyc new file mode 100644 index 00000000..71975c84 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/ed448.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/padding.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/padding.cpython-311.pyc new file mode 100644 index 00000000..682e0d94 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/padding.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/rsa.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/rsa.cpython-311.pyc new file mode 100644 index 00000000..cbb6e214 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/rsa.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/types.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/types.cpython-311.pyc new file mode 100644 index 00000000..a941d5bc Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/types.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..595f90b4 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x25519.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x25519.cpython-311.pyc new file mode 100644 index 00000000..db83a638 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x25519.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x448.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x448.cpython-311.pyc new file mode 100644 index 00000000..5508d115 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/__pycache__/x448.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dh.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dh.py new file mode 100644 index 00000000..31c9748a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dh.py @@ -0,0 +1,135 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization + +generate_parameters = rust_openssl.dh.generate_parameters + + +DHPrivateNumbers = rust_openssl.dh.DHPrivateNumbers +DHPublicNumbers = rust_openssl.dh.DHPublicNumbers +DHParameterNumbers = rust_openssl.dh.DHParameterNumbers + + +class DHParameters(metaclass=abc.ABCMeta): + @abc.abstractmethod + def generate_private_key(self) -> DHPrivateKey: + """ + Generates and returns a DHPrivateKey. + """ + + @abc.abstractmethod + def parameter_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.ParameterFormat, + ) -> bytes: + """ + Returns the parameters serialized as bytes. + """ + + @abc.abstractmethod + def parameter_numbers(self) -> DHParameterNumbers: + """ + Returns a DHParameterNumbers. + """ + + +DHParametersWithSerialization = DHParameters +DHParameters.register(rust_openssl.dh.DHParameters) + + +class DHPublicKey(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the prime modulus. + """ + + @abc.abstractmethod + def parameters(self) -> DHParameters: + """ + The DHParameters object associated with this public key. + """ + + @abc.abstractmethod + def public_numbers(self) -> DHPublicNumbers: + """ + Returns a DHPublicNumbers. + """ + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +DHPublicKeyWithSerialization = DHPublicKey +DHPublicKey.register(rust_openssl.dh.DHPublicKey) + + +class DHPrivateKey(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the prime modulus. + """ + + @abc.abstractmethod + def public_key(self) -> DHPublicKey: + """ + The DHPublicKey associated with this private key. + """ + + @abc.abstractmethod + def parameters(self) -> DHParameters: + """ + The DHParameters object associated with this private key. + """ + + @abc.abstractmethod + def exchange(self, peer_public_key: DHPublicKey) -> bytes: + """ + Given peer's DHPublicKey, carry out the key exchange and + return shared key as bytes. + """ + + @abc.abstractmethod + def private_numbers(self) -> DHPrivateNumbers: + """ + Returns a DHPrivateNumbers. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + +DHPrivateKeyWithSerialization = DHPrivateKey +DHPrivateKey.register(rust_openssl.dh.DHPrivateKey) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dsa.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dsa.py new file mode 100644 index 00000000..6dd34c0e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/dsa.py @@ -0,0 +1,154 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import typing + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization, hashes +from cryptography.hazmat.primitives.asymmetric import utils as asym_utils + + +class DSAParameters(metaclass=abc.ABCMeta): + @abc.abstractmethod + def generate_private_key(self) -> DSAPrivateKey: + """ + Generates and returns a DSAPrivateKey. + """ + + @abc.abstractmethod + def parameter_numbers(self) -> DSAParameterNumbers: + """ + Returns a DSAParameterNumbers. + """ + + +DSAParametersWithNumbers = DSAParameters +DSAParameters.register(rust_openssl.dsa.DSAParameters) + + +class DSAPrivateKey(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the prime modulus. + """ + + @abc.abstractmethod + def public_key(self) -> DSAPublicKey: + """ + The DSAPublicKey associated with this private key. + """ + + @abc.abstractmethod + def parameters(self) -> DSAParameters: + """ + The DSAParameters object associated with this private key. + """ + + @abc.abstractmethod + def sign( + self, + data: bytes, + algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, + ) -> bytes: + """ + Signs the data + """ + + @abc.abstractmethod + def private_numbers(self) -> DSAPrivateNumbers: + """ + Returns a DSAPrivateNumbers. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + +DSAPrivateKeyWithSerialization = DSAPrivateKey +DSAPrivateKey.register(rust_openssl.dsa.DSAPrivateKey) + + +class DSAPublicKey(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the prime modulus. + """ + + @abc.abstractmethod + def parameters(self) -> DSAParameters: + """ + The DSAParameters object associated with this public key. + """ + + @abc.abstractmethod + def public_numbers(self) -> DSAPublicNumbers: + """ + Returns a DSAPublicNumbers. + """ + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + @abc.abstractmethod + def verify( + self, + signature: bytes, + data: bytes, + algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, + ) -> None: + """ + Verifies the signature of the data. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +DSAPublicKeyWithSerialization = DSAPublicKey +DSAPublicKey.register(rust_openssl.dsa.DSAPublicKey) + +DSAPrivateNumbers = rust_openssl.dsa.DSAPrivateNumbers +DSAPublicNumbers = rust_openssl.dsa.DSAPublicNumbers +DSAParameterNumbers = rust_openssl.dsa.DSAParameterNumbers + + +def generate_parameters( + key_size: int, backend: typing.Any = None +) -> DSAParameters: + if key_size not in (1024, 2048, 3072, 4096): + raise ValueError("Key size must be 1024, 2048, 3072, or 4096 bits.") + + return rust_openssl.dsa.generate_parameters(key_size) + + +def generate_private_key( + key_size: int, backend: typing.Any = None +) -> DSAPrivateKey: + parameters = generate_parameters(key_size) + return parameters.generate_private_key() diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ec.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ec.py new file mode 100644 index 00000000..b612b401 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ec.py @@ -0,0 +1,383 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import typing + +from cryptography import utils +from cryptography.hazmat._oid import ObjectIdentifier +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization, hashes +from cryptography.hazmat.primitives.asymmetric import utils as asym_utils + + +class EllipticCurveOID: + SECP192R1 = ObjectIdentifier("1.2.840.10045.3.1.1") + SECP224R1 = ObjectIdentifier("1.3.132.0.33") + SECP256K1 = ObjectIdentifier("1.3.132.0.10") + SECP256R1 = ObjectIdentifier("1.2.840.10045.3.1.7") + SECP384R1 = ObjectIdentifier("1.3.132.0.34") + SECP521R1 = ObjectIdentifier("1.3.132.0.35") + BRAINPOOLP256R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.7") + BRAINPOOLP384R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.11") + BRAINPOOLP512R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.13") + SECT163K1 = ObjectIdentifier("1.3.132.0.1") + SECT163R2 = ObjectIdentifier("1.3.132.0.15") + SECT233K1 = ObjectIdentifier("1.3.132.0.26") + SECT233R1 = ObjectIdentifier("1.3.132.0.27") + SECT283K1 = ObjectIdentifier("1.3.132.0.16") + SECT283R1 = ObjectIdentifier("1.3.132.0.17") + SECT409K1 = ObjectIdentifier("1.3.132.0.36") + SECT409R1 = ObjectIdentifier("1.3.132.0.37") + SECT571K1 = ObjectIdentifier("1.3.132.0.38") + SECT571R1 = ObjectIdentifier("1.3.132.0.39") + + +class EllipticCurve(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def name(self) -> str: + """ + The name of the curve. e.g. secp256r1. + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + Bit size of a secret scalar for the curve. + """ + + +class EllipticCurveSignatureAlgorithm(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def algorithm( + self, + ) -> asym_utils.Prehashed | hashes.HashAlgorithm: + """ + The digest algorithm used with this signature. + """ + + +class EllipticCurvePrivateKey(metaclass=abc.ABCMeta): + @abc.abstractmethod + def exchange( + self, algorithm: ECDH, peer_public_key: EllipticCurvePublicKey + ) -> bytes: + """ + Performs a key exchange operation using the provided algorithm with the + provided peer's public key. + """ + + @abc.abstractmethod + def public_key(self) -> EllipticCurvePublicKey: + """ + The EllipticCurvePublicKey for this private key. + """ + + @property + @abc.abstractmethod + def curve(self) -> EllipticCurve: + """ + The EllipticCurve that this key is on. + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + Bit size of a secret scalar for the curve. + """ + + @abc.abstractmethod + def sign( + self, + data: bytes, + signature_algorithm: EllipticCurveSignatureAlgorithm, + ) -> bytes: + """ + Signs the data + """ + + @abc.abstractmethod + def private_numbers(self) -> EllipticCurvePrivateNumbers: + """ + Returns an EllipticCurvePrivateNumbers. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + +EllipticCurvePrivateKeyWithSerialization = EllipticCurvePrivateKey +EllipticCurvePrivateKey.register(rust_openssl.ec.ECPrivateKey) + + +class EllipticCurvePublicKey(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def curve(self) -> EllipticCurve: + """ + The EllipticCurve that this key is on. + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + Bit size of a secret scalar for the curve. + """ + + @abc.abstractmethod + def public_numbers(self) -> EllipticCurvePublicNumbers: + """ + Returns an EllipticCurvePublicNumbers. + """ + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + @abc.abstractmethod + def verify( + self, + signature: bytes, + data: bytes, + signature_algorithm: EllipticCurveSignatureAlgorithm, + ) -> None: + """ + Verifies the signature of the data. + """ + + @classmethod + def from_encoded_point( + cls, curve: EllipticCurve, data: bytes + ) -> EllipticCurvePublicKey: + utils._check_bytes("data", data) + + if len(data) == 0: + raise ValueError("data must not be an empty byte string") + + if data[0] not in [0x02, 0x03, 0x04]: + raise ValueError("Unsupported elliptic curve point type") + + return rust_openssl.ec.from_public_bytes(curve, data) + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +EllipticCurvePublicKeyWithSerialization = EllipticCurvePublicKey +EllipticCurvePublicKey.register(rust_openssl.ec.ECPublicKey) + +EllipticCurvePrivateNumbers = rust_openssl.ec.EllipticCurvePrivateNumbers +EllipticCurvePublicNumbers = rust_openssl.ec.EllipticCurvePublicNumbers + + +class SECT571R1(EllipticCurve): + name = "sect571r1" + key_size = 570 + + +class SECT409R1(EllipticCurve): + name = "sect409r1" + key_size = 409 + + +class SECT283R1(EllipticCurve): + name = "sect283r1" + key_size = 283 + + +class SECT233R1(EllipticCurve): + name = "sect233r1" + key_size = 233 + + +class SECT163R2(EllipticCurve): + name = "sect163r2" + key_size = 163 + + +class SECT571K1(EllipticCurve): + name = "sect571k1" + key_size = 571 + + +class SECT409K1(EllipticCurve): + name = "sect409k1" + key_size = 409 + + +class SECT283K1(EllipticCurve): + name = "sect283k1" + key_size = 283 + + +class SECT233K1(EllipticCurve): + name = "sect233k1" + key_size = 233 + + +class SECT163K1(EllipticCurve): + name = "sect163k1" + key_size = 163 + + +class SECP521R1(EllipticCurve): + name = "secp521r1" + key_size = 521 + + +class SECP384R1(EllipticCurve): + name = "secp384r1" + key_size = 384 + + +class SECP256R1(EllipticCurve): + name = "secp256r1" + key_size = 256 + + +class SECP256K1(EllipticCurve): + name = "secp256k1" + key_size = 256 + + +class SECP224R1(EllipticCurve): + name = "secp224r1" + key_size = 224 + + +class SECP192R1(EllipticCurve): + name = "secp192r1" + key_size = 192 + + +class BrainpoolP256R1(EllipticCurve): + name = "brainpoolP256r1" + key_size = 256 + + +class BrainpoolP384R1(EllipticCurve): + name = "brainpoolP384r1" + key_size = 384 + + +class BrainpoolP512R1(EllipticCurve): + name = "brainpoolP512r1" + key_size = 512 + + +_CURVE_TYPES: dict[str, EllipticCurve] = { + "prime192v1": SECP192R1(), + "prime256v1": SECP256R1(), + "secp192r1": SECP192R1(), + "secp224r1": SECP224R1(), + "secp256r1": SECP256R1(), + "secp384r1": SECP384R1(), + "secp521r1": SECP521R1(), + "secp256k1": SECP256K1(), + "sect163k1": SECT163K1(), + "sect233k1": SECT233K1(), + "sect283k1": SECT283K1(), + "sect409k1": SECT409K1(), + "sect571k1": SECT571K1(), + "sect163r2": SECT163R2(), + "sect233r1": SECT233R1(), + "sect283r1": SECT283R1(), + "sect409r1": SECT409R1(), + "sect571r1": SECT571R1(), + "brainpoolP256r1": BrainpoolP256R1(), + "brainpoolP384r1": BrainpoolP384R1(), + "brainpoolP512r1": BrainpoolP512R1(), +} + + +class ECDSA(EllipticCurveSignatureAlgorithm): + def __init__( + self, + algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, + ): + self._algorithm = algorithm + + @property + def algorithm( + self, + ) -> asym_utils.Prehashed | hashes.HashAlgorithm: + return self._algorithm + + +generate_private_key = rust_openssl.ec.generate_private_key + + +def derive_private_key( + private_value: int, + curve: EllipticCurve, + backend: typing.Any = None, +) -> EllipticCurvePrivateKey: + if not isinstance(private_value, int): + raise TypeError("private_value must be an integer type.") + + if private_value <= 0: + raise ValueError("private_value must be a positive integer.") + + return rust_openssl.ec.derive_private_key(private_value, curve) + + +class ECDH: + pass + + +_OID_TO_CURVE = { + EllipticCurveOID.SECP192R1: SECP192R1, + EllipticCurveOID.SECP224R1: SECP224R1, + EllipticCurveOID.SECP256K1: SECP256K1, + EllipticCurveOID.SECP256R1: SECP256R1, + EllipticCurveOID.SECP384R1: SECP384R1, + EllipticCurveOID.SECP521R1: SECP521R1, + EllipticCurveOID.BRAINPOOLP256R1: BrainpoolP256R1, + EllipticCurveOID.BRAINPOOLP384R1: BrainpoolP384R1, + EllipticCurveOID.BRAINPOOLP512R1: BrainpoolP512R1, + EllipticCurveOID.SECT163K1: SECT163K1, + EllipticCurveOID.SECT163R2: SECT163R2, + EllipticCurveOID.SECT233K1: SECT233K1, + EllipticCurveOID.SECT233R1: SECT233R1, + EllipticCurveOID.SECT283K1: SECT283K1, + EllipticCurveOID.SECT283R1: SECT283R1, + EllipticCurveOID.SECT409K1: SECT409K1, + EllipticCurveOID.SECT409R1: SECT409R1, + EllipticCurveOID.SECT571K1: SECT571K1, + EllipticCurveOID.SECT571R1: SECT571R1, +} + + +def get_curve_for_oid(oid: ObjectIdentifier) -> type[EllipticCurve]: + try: + return _OID_TO_CURVE[oid] + except KeyError: + raise LookupError( + "The provided object identifier has no matching elliptic " + "curve class" + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed25519.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed25519.py new file mode 100644 index 00000000..3a26185d --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed25519.py @@ -0,0 +1,116 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization + + +class Ed25519PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> Ed25519PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed25519_supported(): + raise UnsupportedAlgorithm( + "ed25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed25519.from_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + """ + + @abc.abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """ + Verify the signature. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +Ed25519PublicKey.register(rust_openssl.ed25519.Ed25519PublicKey) + + +class Ed25519PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> Ed25519PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed25519_supported(): + raise UnsupportedAlgorithm( + "ed25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed25519.generate_key() + + @classmethod + def from_private_bytes(cls, data: bytes) -> Ed25519PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed25519_supported(): + raise UnsupportedAlgorithm( + "ed25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed25519.from_private_bytes(data) + + @abc.abstractmethod + def public_key(self) -> Ed25519PublicKey: + """ + The Ed25519PublicKey derived from the private key. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + """ + + @abc.abstractmethod + def sign(self, data: bytes) -> bytes: + """ + Signs the data. + """ + + +Ed25519PrivateKey.register(rust_openssl.ed25519.Ed25519PrivateKey) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed448.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed448.py new file mode 100644 index 00000000..78c82c4a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/ed448.py @@ -0,0 +1,118 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization + + +class Ed448PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> Ed448PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed448_supported(): + raise UnsupportedAlgorithm( + "ed448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed448.from_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + """ + + @abc.abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """ + Verify the signature. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +if hasattr(rust_openssl, "ed448"): + Ed448PublicKey.register(rust_openssl.ed448.Ed448PublicKey) + + +class Ed448PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> Ed448PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed448_supported(): + raise UnsupportedAlgorithm( + "ed448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed448.generate_key() + + @classmethod + def from_private_bytes(cls, data: bytes) -> Ed448PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.ed448_supported(): + raise UnsupportedAlgorithm( + "ed448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.ed448.from_private_bytes(data) + + @abc.abstractmethod + def public_key(self) -> Ed448PublicKey: + """ + The Ed448PublicKey derived from the private key. + """ + + @abc.abstractmethod + def sign(self, data: bytes) -> bytes: + """ + Signs the data. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + """ + + +if hasattr(rust_openssl, "x448"): + Ed448PrivateKey.register(rust_openssl.ed448.Ed448PrivateKey) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/padding.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/padding.py new file mode 100644 index 00000000..b4babf44 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/padding.py @@ -0,0 +1,113 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives._asymmetric import ( + AsymmetricPadding as AsymmetricPadding, +) +from cryptography.hazmat.primitives.asymmetric import rsa + + +class PKCS1v15(AsymmetricPadding): + name = "EMSA-PKCS1-v1_5" + + +class _MaxLength: + "Sentinel value for `MAX_LENGTH`." + + +class _Auto: + "Sentinel value for `AUTO`." + + +class _DigestLength: + "Sentinel value for `DIGEST_LENGTH`." + + +class PSS(AsymmetricPadding): + MAX_LENGTH = _MaxLength() + AUTO = _Auto() + DIGEST_LENGTH = _DigestLength() + name = "EMSA-PSS" + _salt_length: int | _MaxLength | _Auto | _DigestLength + + def __init__( + self, + mgf: MGF, + salt_length: int | _MaxLength | _Auto | _DigestLength, + ) -> None: + self._mgf = mgf + + if not isinstance( + salt_length, (int, _MaxLength, _Auto, _DigestLength) + ): + raise TypeError( + "salt_length must be an integer, MAX_LENGTH, " + "DIGEST_LENGTH, or AUTO" + ) + + if isinstance(salt_length, int) and salt_length < 0: + raise ValueError("salt_length must be zero or greater.") + + self._salt_length = salt_length + + @property + def mgf(self) -> MGF: + return self._mgf + + +class OAEP(AsymmetricPadding): + name = "EME-OAEP" + + def __init__( + self, + mgf: MGF, + algorithm: hashes.HashAlgorithm, + label: bytes | None, + ): + if not isinstance(algorithm, hashes.HashAlgorithm): + raise TypeError("Expected instance of hashes.HashAlgorithm.") + + self._mgf = mgf + self._algorithm = algorithm + self._label = label + + @property + def algorithm(self) -> hashes.HashAlgorithm: + return self._algorithm + + @property + def mgf(self) -> MGF: + return self._mgf + + +class MGF(metaclass=abc.ABCMeta): + _algorithm: hashes.HashAlgorithm + + +class MGF1(MGF): + MAX_LENGTH = _MaxLength() + + def __init__(self, algorithm: hashes.HashAlgorithm): + if not isinstance(algorithm, hashes.HashAlgorithm): + raise TypeError("Expected instance of hashes.HashAlgorithm.") + + self._algorithm = algorithm + + +def calculate_max_pss_salt_length( + key: rsa.RSAPrivateKey | rsa.RSAPublicKey, + hash_algorithm: hashes.HashAlgorithm, +) -> int: + if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): + raise TypeError("key must be an RSA public or private key") + # bit length - 1 per RFC 3447 + emlen = (key.key_size + 6) // 8 + salt_length = emlen - hash_algorithm.digest_size - 2 + assert salt_length >= 0 + return salt_length diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/rsa.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/rsa.py new file mode 100644 index 00000000..6420434d --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/rsa.py @@ -0,0 +1,239 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import typing +from math import gcd + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization, hashes +from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding +from cryptography.hazmat.primitives.asymmetric import utils as asym_utils + + +class RSAPrivateKey(metaclass=abc.ABCMeta): + @abc.abstractmethod + def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes: + """ + Decrypts the provided ciphertext. + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the public modulus. + """ + + @abc.abstractmethod + def public_key(self) -> RSAPublicKey: + """ + The RSAPublicKey associated with this private key. + """ + + @abc.abstractmethod + def sign( + self, + data: bytes, + padding: AsymmetricPadding, + algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, + ) -> bytes: + """ + Signs the data. + """ + + @abc.abstractmethod + def private_numbers(self) -> RSAPrivateNumbers: + """ + Returns an RSAPrivateNumbers. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + +RSAPrivateKeyWithSerialization = RSAPrivateKey +RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey) + + +class RSAPublicKey(metaclass=abc.ABCMeta): + @abc.abstractmethod + def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes: + """ + Encrypts the given plaintext. + """ + + @property + @abc.abstractmethod + def key_size(self) -> int: + """ + The bit length of the public modulus. + """ + + @abc.abstractmethod + def public_numbers(self) -> RSAPublicNumbers: + """ + Returns an RSAPublicNumbers + """ + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + Returns the key serialized as bytes. + """ + + @abc.abstractmethod + def verify( + self, + signature: bytes, + data: bytes, + padding: AsymmetricPadding, + algorithm: asym_utils.Prehashed | hashes.HashAlgorithm, + ) -> None: + """ + Verifies the signature of the data. + """ + + @abc.abstractmethod + def recover_data_from_signature( + self, + signature: bytes, + padding: AsymmetricPadding, + algorithm: hashes.HashAlgorithm | None, + ) -> bytes: + """ + Recovers the original data from the signature. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +RSAPublicKeyWithSerialization = RSAPublicKey +RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey) + +RSAPrivateNumbers = rust_openssl.rsa.RSAPrivateNumbers +RSAPublicNumbers = rust_openssl.rsa.RSAPublicNumbers + + +def generate_private_key( + public_exponent: int, + key_size: int, + backend: typing.Any = None, +) -> RSAPrivateKey: + _verify_rsa_parameters(public_exponent, key_size) + return rust_openssl.rsa.generate_private_key(public_exponent, key_size) + + +def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None: + if public_exponent not in (3, 65537): + raise ValueError( + "public_exponent must be either 3 (for legacy compatibility) or " + "65537. Almost everyone should choose 65537 here!" + ) + + if key_size < 512: + raise ValueError("key_size must be at least 512-bits.") + + +def _modinv(e: int, m: int) -> int: + """ + Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1 + """ + x1, x2 = 1, 0 + a, b = e, m + while b > 0: + q, r = divmod(a, b) + xn = x1 - q * x2 + a, b, x1, x2 = b, r, x2, xn + return x1 % m + + +def rsa_crt_iqmp(p: int, q: int) -> int: + """ + Compute the CRT (q ** -1) % p value from RSA primes p and q. + """ + return _modinv(q, p) + + +def rsa_crt_dmp1(private_exponent: int, p: int) -> int: + """ + Compute the CRT private_exponent % (p - 1) value from the RSA + private_exponent (d) and p. + """ + return private_exponent % (p - 1) + + +def rsa_crt_dmq1(private_exponent: int, q: int) -> int: + """ + Compute the CRT private_exponent % (q - 1) value from the RSA + private_exponent (d) and q. + """ + return private_exponent % (q - 1) + + +# Controls the number of iterations rsa_recover_prime_factors will perform +# to obtain the prime factors. Each iteration increments by 2 so the actual +# maximum attempts is half this number. +_MAX_RECOVERY_ATTEMPTS = 1000 + + +def rsa_recover_prime_factors(n: int, e: int, d: int) -> tuple[int, int]: + """ + Compute factors p and q from the private exponent d. We assume that n has + no more than two factors. This function is adapted from code in PyCrypto. + """ + # See 8.2.2(i) in Handbook of Applied Cryptography. + ktot = d * e - 1 + # The quantity d*e-1 is a multiple of phi(n), even, + # and can be represented as t*2^s. + t = ktot + while t % 2 == 0: + t = t // 2 + # Cycle through all multiplicative inverses in Zn. + # The algorithm is non-deterministic, but there is a 50% chance + # any candidate a leads to successful factoring. + # See "Digitalized Signatures and Public Key Functions as Intractable + # as Factorization", M. Rabin, 1979 + spotted = False + a = 2 + while not spotted and a < _MAX_RECOVERY_ATTEMPTS: + k = t + # Cycle through all values a^{t*2^i}=a^k + while k < ktot: + cand = pow(a, k, n) + # Check if a^k is a non-trivial root of unity (mod n) + if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: + # We have found a number such that (cand-1)(cand+1)=0 (mod n). + # Either of the terms divides n. + p = gcd(cand + 1, n) + spotted = True + break + k *= 2 + # This value was not any good... let's try another! + a += 2 + if not spotted: + raise ValueError("Unable to compute factors p and q from exponent d.") + # Found ! + q, r = divmod(n, p) + assert r == 0 + p, q = sorted((p, q), reverse=True) + return (p, q) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/types.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/types.py new file mode 100644 index 00000000..1fe4eaf5 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/types.py @@ -0,0 +1,111 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.hazmat.primitives.asymmetric import ( + dh, + dsa, + ec, + ed448, + ed25519, + rsa, + x448, + x25519, +) + +# Every asymmetric key type +PublicKeyTypes = typing.Union[ + dh.DHPublicKey, + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + x25519.X25519PublicKey, + x448.X448PublicKey, +] +PUBLIC_KEY_TYPES = PublicKeyTypes +utils.deprecated( + PUBLIC_KEY_TYPES, + __name__, + "Use PublicKeyTypes instead", + utils.DeprecatedIn40, + name="PUBLIC_KEY_TYPES", +) +# Every asymmetric key type +PrivateKeyTypes = typing.Union[ + dh.DHPrivateKey, + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, + x25519.X25519PrivateKey, + x448.X448PrivateKey, +] +PRIVATE_KEY_TYPES = PrivateKeyTypes +utils.deprecated( + PRIVATE_KEY_TYPES, + __name__, + "Use PrivateKeyTypes instead", + utils.DeprecatedIn40, + name="PRIVATE_KEY_TYPES", +) +# Just the key types we allow to be used for x509 signing. This mirrors +# the certificate public key types +CertificateIssuerPrivateKeyTypes = typing.Union[ + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, +] +CERTIFICATE_PRIVATE_KEY_TYPES = CertificateIssuerPrivateKeyTypes +utils.deprecated( + CERTIFICATE_PRIVATE_KEY_TYPES, + __name__, + "Use CertificateIssuerPrivateKeyTypes instead", + utils.DeprecatedIn40, + name="CERTIFICATE_PRIVATE_KEY_TYPES", +) +# Just the key types we allow to be used for x509 signing. This mirrors +# the certificate private key types +CertificateIssuerPublicKeyTypes = typing.Union[ + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, +] +CERTIFICATE_ISSUER_PUBLIC_KEY_TYPES = CertificateIssuerPublicKeyTypes +utils.deprecated( + CERTIFICATE_ISSUER_PUBLIC_KEY_TYPES, + __name__, + "Use CertificateIssuerPublicKeyTypes instead", + utils.DeprecatedIn40, + name="CERTIFICATE_ISSUER_PUBLIC_KEY_TYPES", +) +# This type removes DHPublicKey. x448/x25519 can be a public key +# but cannot be used in signing so they are allowed here. +CertificatePublicKeyTypes = typing.Union[ + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + x25519.X25519PublicKey, + x448.X448PublicKey, +] +CERTIFICATE_PUBLIC_KEY_TYPES = CertificatePublicKeyTypes +utils.deprecated( + CERTIFICATE_PUBLIC_KEY_TYPES, + __name__, + "Use CertificatePublicKeyTypes instead", + utils.DeprecatedIn40, + name="CERTIFICATE_PUBLIC_KEY_TYPES", +) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/utils.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/utils.py new file mode 100644 index 00000000..826b9567 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/utils.py @@ -0,0 +1,24 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.bindings._rust import asn1 +from cryptography.hazmat.primitives import hashes + +decode_dss_signature = asn1.decode_dss_signature +encode_dss_signature = asn1.encode_dss_signature + + +class Prehashed: + def __init__(self, algorithm: hashes.HashAlgorithm): + if not isinstance(algorithm, hashes.HashAlgorithm): + raise TypeError("Expected instance of HashAlgorithm.") + + self._algorithm = algorithm + self._digest_size = algorithm.digest_size + + @property + def digest_size(self) -> int: + return self._digest_size diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x25519.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x25519.py new file mode 100644 index 00000000..0cfa36e3 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x25519.py @@ -0,0 +1,109 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization + + +class X25519PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> X25519PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x25519_supported(): + raise UnsupportedAlgorithm( + "X25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + + return rust_openssl.x25519.from_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +X25519PublicKey.register(rust_openssl.x25519.X25519PublicKey) + + +class X25519PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> X25519PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x25519_supported(): + raise UnsupportedAlgorithm( + "X25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + return rust_openssl.x25519.generate_key() + + @classmethod + def from_private_bytes(cls, data: bytes) -> X25519PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x25519_supported(): + raise UnsupportedAlgorithm( + "X25519 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + + return rust_openssl.x25519.from_private_bytes(data) + + @abc.abstractmethod + def public_key(self) -> X25519PublicKey: + """ + Returns the public key associated with this private key + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + """ + + @abc.abstractmethod + def exchange(self, peer_public_key: X25519PublicKey) -> bytes: + """ + Performs a key exchange operation using the provided peer's public key. + """ + + +X25519PrivateKey.register(rust_openssl.x25519.X25519PrivateKey) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x448.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x448.py new file mode 100644 index 00000000..86086ab4 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/asymmetric/x448.py @@ -0,0 +1,112 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import _serialization + + +class X448PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> X448PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x448_supported(): + raise UnsupportedAlgorithm( + "X448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + + return rust_openssl.x448.from_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + +if hasattr(rust_openssl, "x448"): + X448PublicKey.register(rust_openssl.x448.X448PublicKey) + + +class X448PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> X448PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x448_supported(): + raise UnsupportedAlgorithm( + "X448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + + return rust_openssl.x448.generate_key() + + @classmethod + def from_private_bytes(cls, data: bytes) -> X448PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.x448_supported(): + raise UnsupportedAlgorithm( + "X448 is not supported by this version of OpenSSL.", + _Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM, + ) + + return rust_openssl.x448.from_private_bytes(data) + + @abc.abstractmethod + def public_key(self) -> X448PublicKey: + """ + Returns the public key associated with this private key + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + """ + + @abc.abstractmethod + def exchange(self, peer_public_key: X448PublicKey) -> bytes: + """ + Performs a key exchange operation using the provided peer's public key. + """ + + +if hasattr(rust_openssl, "x448"): + X448PrivateKey.register(rust_openssl.x448.X448PrivateKey) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__init__.py new file mode 100644 index 00000000..cc88fbf2 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__init__.py @@ -0,0 +1,27 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.primitives._cipheralgorithm import ( + BlockCipherAlgorithm, + CipherAlgorithm, +) +from cryptography.hazmat.primitives.ciphers.base import ( + AEADCipherContext, + AEADDecryptionContext, + AEADEncryptionContext, + Cipher, + CipherContext, +) + +__all__ = [ + "Cipher", + "CipherAlgorithm", + "BlockCipherAlgorithm", + "CipherContext", + "AEADCipherContext", + "AEADDecryptionContext", + "AEADEncryptionContext", +] diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..edac06b5 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/aead.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/aead.cpython-311.pyc new file mode 100644 index 00000000..c97a2e97 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/aead.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/algorithms.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/algorithms.cpython-311.pyc new file mode 100644 index 00000000..ab81c9e7 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/algorithms.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..ab7a6863 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/modes.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/modes.cpython-311.pyc new file mode 100644 index 00000000..b2c209c1 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/__pycache__/modes.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/aead.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/aead.py new file mode 100644 index 00000000..40f1b9b7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/aead.py @@ -0,0 +1,174 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import os + +from cryptography import exceptions, utils +from cryptography.hazmat.backends.openssl import aead +from cryptography.hazmat.backends.openssl.backend import backend +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +__all__ = [ + "ChaCha20Poly1305", + "AESCCM", + "AESGCM", + "AESGCMSIV", + "AESOCB3", + "AESSIV", +] + +ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305 +AESSIV = rust_openssl.aead.AESSIV +AESOCB3 = rust_openssl.aead.AESOCB3 +AESGCMSIV = rust_openssl.aead.AESGCMSIV + + +class AESCCM: + _MAX_SIZE = 2**31 - 1 + + def __init__(self, key: bytes, tag_length: int = 16): + utils._check_byteslike("key", key) + if len(key) not in (16, 24, 32): + raise ValueError("AESCCM key must be 128, 192, or 256 bits.") + + self._key = key + if not isinstance(tag_length, int): + raise TypeError("tag_length must be an integer") + + if tag_length not in (4, 6, 8, 10, 12, 14, 16): + raise ValueError("Invalid tag_length") + + self._tag_length = tag_length + + if not backend.aead_cipher_supported(self): + raise exceptions.UnsupportedAlgorithm( + "AESCCM is not supported by this version of OpenSSL", + exceptions._Reasons.UNSUPPORTED_CIPHER, + ) + + @classmethod + def generate_key(cls, bit_length: int) -> bytes: + if not isinstance(bit_length, int): + raise TypeError("bit_length must be an integer") + + if bit_length not in (128, 192, 256): + raise ValueError("bit_length must be 128, 192, or 256") + + return os.urandom(bit_length // 8) + + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: + if associated_data is None: + associated_data = b"" + + if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: + # This is OverflowError to match what cffi would raise + raise OverflowError( + "Data or associated data too long. Max 2**31 - 1 bytes" + ) + + self._check_params(nonce, data, associated_data) + self._validate_lengths(nonce, len(data)) + return aead._encrypt( + backend, self, nonce, data, [associated_data], self._tag_length + ) + + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: + if associated_data is None: + associated_data = b"" + + self._check_params(nonce, data, associated_data) + return aead._decrypt( + backend, self, nonce, data, [associated_data], self._tag_length + ) + + def _validate_lengths(self, nonce: bytes, data_len: int) -> None: + # For information about computing this, see + # https://tools.ietf.org/html/rfc3610#section-2.1 + l_val = 15 - len(nonce) + if 2 ** (8 * l_val) < data_len: + raise ValueError("Data too long for nonce") + + def _check_params( + self, nonce: bytes, data: bytes, associated_data: bytes + ) -> None: + utils._check_byteslike("nonce", nonce) + utils._check_byteslike("data", data) + utils._check_byteslike("associated_data", associated_data) + if not 7 <= len(nonce) <= 13: + raise ValueError("Nonce must be between 7 and 13 bytes") + + +class AESGCM: + _MAX_SIZE = 2**31 - 1 + + def __init__(self, key: bytes): + utils._check_byteslike("key", key) + if len(key) not in (16, 24, 32): + raise ValueError("AESGCM key must be 128, 192, or 256 bits.") + + self._key = key + + @classmethod + def generate_key(cls, bit_length: int) -> bytes: + if not isinstance(bit_length, int): + raise TypeError("bit_length must be an integer") + + if bit_length not in (128, 192, 256): + raise ValueError("bit_length must be 128, 192, or 256") + + return os.urandom(bit_length // 8) + + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: + if associated_data is None: + associated_data = b"" + + if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: + # This is OverflowError to match what cffi would raise + raise OverflowError( + "Data or associated data too long. Max 2**31 - 1 bytes" + ) + + self._check_params(nonce, data, associated_data) + return aead._encrypt(backend, self, nonce, data, [associated_data], 16) + + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: + if associated_data is None: + associated_data = b"" + + self._check_params(nonce, data, associated_data) + return aead._decrypt(backend, self, nonce, data, [associated_data], 16) + + def _check_params( + self, + nonce: bytes, + data: bytes, + associated_data: bytes, + ) -> None: + utils._check_byteslike("nonce", nonce) + utils._check_byteslike("data", data) + utils._check_byteslike("associated_data", associated_data) + if len(nonce) < 8 or len(nonce) > 128: + raise ValueError("Nonce must be between 8 and 128 bytes") diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/algorithms.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/algorithms.py new file mode 100644 index 00000000..000bdcba --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/algorithms.py @@ -0,0 +1,226 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography import utils +from cryptography.hazmat.primitives.ciphers import ( + BlockCipherAlgorithm, + CipherAlgorithm, +) + + +def _verify_key_size(algorithm: CipherAlgorithm, key: bytes) -> bytes: + # Verify that the key is instance of bytes + utils._check_byteslike("key", key) + + # Verify that the key size matches the expected key size + if len(key) * 8 not in algorithm.key_sizes: + raise ValueError( + f"Invalid key size ({len(key) * 8}) for {algorithm.name}." + ) + return key + + +class AES(BlockCipherAlgorithm): + name = "AES" + block_size = 128 + # 512 added to support AES-256-XTS, which uses 512-bit keys + key_sizes = frozenset([128, 192, 256, 512]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +class AES128(BlockCipherAlgorithm): + name = "AES" + block_size = 128 + key_sizes = frozenset([128]) + key_size = 128 + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + +class AES256(BlockCipherAlgorithm): + name = "AES" + block_size = 128 + key_sizes = frozenset([256]) + key_size = 256 + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + +class Camellia(BlockCipherAlgorithm): + name = "camellia" + block_size = 128 + key_sizes = frozenset([128, 192, 256]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +class TripleDES(BlockCipherAlgorithm): + name = "3DES" + block_size = 64 + key_sizes = frozenset([64, 128, 192]) + + def __init__(self, key: bytes): + if len(key) == 8: + key += key + key + elif len(key) == 16: + key += key[:8] + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +class Blowfish(BlockCipherAlgorithm): + name = "Blowfish" + block_size = 64 + key_sizes = frozenset(range(32, 449, 8)) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +_BlowfishInternal = Blowfish +utils.deprecated( + Blowfish, + __name__, + "Blowfish has been deprecated and will be removed in a future release", + utils.DeprecatedIn37, + name="Blowfish", +) + + +class CAST5(BlockCipherAlgorithm): + name = "CAST5" + block_size = 64 + key_sizes = frozenset(range(40, 129, 8)) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +_CAST5Internal = CAST5 +utils.deprecated( + CAST5, + __name__, + "CAST5 has been deprecated and will be removed in a future release", + utils.DeprecatedIn37, + name="CAST5", +) + + +class ARC4(CipherAlgorithm): + name = "RC4" + key_sizes = frozenset([40, 56, 64, 80, 128, 160, 192, 256]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +class IDEA(BlockCipherAlgorithm): + name = "IDEA" + block_size = 64 + key_sizes = frozenset([128]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +_IDEAInternal = IDEA +utils.deprecated( + IDEA, + __name__, + "IDEA has been deprecated and will be removed in a future release", + utils.DeprecatedIn37, + name="IDEA", +) + + +class SEED(BlockCipherAlgorithm): + name = "SEED" + block_size = 128 + key_sizes = frozenset([128]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +_SEEDInternal = SEED +utils.deprecated( + SEED, + __name__, + "SEED has been deprecated and will be removed in a future release", + utils.DeprecatedIn37, + name="SEED", +) + + +class ChaCha20(CipherAlgorithm): + name = "ChaCha20" + key_sizes = frozenset([256]) + + def __init__(self, key: bytes, nonce: bytes): + self.key = _verify_key_size(self, key) + utils._check_byteslike("nonce", nonce) + + if len(nonce) != 16: + raise ValueError("nonce must be 128-bits (16 bytes)") + + self._nonce = nonce + + @property + def nonce(self) -> bytes: + return self._nonce + + @property + def key_size(self) -> int: + return len(self.key) * 8 + + +class SM4(BlockCipherAlgorithm): + name = "SM4" + block_size = 128 + key_sizes = frozenset([128]) + + def __init__(self, key: bytes): + self.key = _verify_key_size(self, key) + + @property + def key_size(self) -> int: + return len(self.key) * 8 diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/base.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/base.py new file mode 100644 index 00000000..2082df66 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/base.py @@ -0,0 +1,272 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import typing + +from cryptography.exceptions import ( + AlreadyFinalized, + AlreadyUpdated, + NotYetFinalized, +) +from cryptography.hazmat.primitives._cipheralgorithm import CipherAlgorithm +from cryptography.hazmat.primitives.ciphers import modes + +if typing.TYPE_CHECKING: + from cryptography.hazmat.backends.openssl.ciphers import ( + _CipherContext as _BackendCipherContext, + ) + + +class CipherContext(metaclass=abc.ABCMeta): + @abc.abstractmethod + def update(self, data: bytes) -> bytes: + """ + Processes the provided bytes through the cipher and returns the results + as bytes. + """ + + @abc.abstractmethod + def update_into(self, data: bytes, buf: bytes) -> int: + """ + Processes the provided bytes and writes the resulting data into the + provided buffer. Returns the number of bytes written. + """ + + @abc.abstractmethod + def finalize(self) -> bytes: + """ + Returns the results of processing the final block as bytes. + """ + + +class AEADCipherContext(CipherContext, metaclass=abc.ABCMeta): + @abc.abstractmethod + def authenticate_additional_data(self, data: bytes) -> None: + """ + Authenticates the provided bytes. + """ + + +class AEADDecryptionContext(AEADCipherContext, metaclass=abc.ABCMeta): + @abc.abstractmethod + def finalize_with_tag(self, tag: bytes) -> bytes: + """ + Returns the results of processing the final block as bytes and allows + delayed passing of the authentication tag. + """ + + +class AEADEncryptionContext(AEADCipherContext, metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def tag(self) -> bytes: + """ + Returns tag bytes. This is only available after encryption is + finalized. + """ + + +Mode = typing.TypeVar( + "Mode", bound=typing.Optional[modes.Mode], covariant=True +) + + +class Cipher(typing.Generic[Mode]): + def __init__( + self, + algorithm: CipherAlgorithm, + mode: Mode, + backend: typing.Any = None, + ) -> None: + if not isinstance(algorithm, CipherAlgorithm): + raise TypeError("Expected interface of CipherAlgorithm.") + + if mode is not None: + # mypy needs this assert to narrow the type from our generic + # type. Maybe it won't some time in the future. + assert isinstance(mode, modes.Mode) + mode.validate_for_algorithm(algorithm) + + self.algorithm = algorithm + self.mode = mode + + @typing.overload + def encryptor( + self: Cipher[modes.ModeWithAuthenticationTag], + ) -> AEADEncryptionContext: + ... + + @typing.overload + def encryptor( + self: _CIPHER_TYPE, + ) -> CipherContext: + ... + + def encryptor(self): + if isinstance(self.mode, modes.ModeWithAuthenticationTag): + if self.mode.tag is not None: + raise ValueError( + "Authentication tag must be None when encrypting." + ) + from cryptography.hazmat.backends.openssl.backend import backend + + ctx = backend.create_symmetric_encryption_ctx( + self.algorithm, self.mode + ) + return self._wrap_ctx(ctx, encrypt=True) + + @typing.overload + def decryptor( + self: Cipher[modes.ModeWithAuthenticationTag], + ) -> AEADDecryptionContext: + ... + + @typing.overload + def decryptor( + self: _CIPHER_TYPE, + ) -> CipherContext: + ... + + def decryptor(self): + from cryptography.hazmat.backends.openssl.backend import backend + + ctx = backend.create_symmetric_decryption_ctx( + self.algorithm, self.mode + ) + return self._wrap_ctx(ctx, encrypt=False) + + def _wrap_ctx( + self, ctx: _BackendCipherContext, encrypt: bool + ) -> AEADEncryptionContext | AEADDecryptionContext | CipherContext: + if isinstance(self.mode, modes.ModeWithAuthenticationTag): + if encrypt: + return _AEADEncryptionContext(ctx) + else: + return _AEADDecryptionContext(ctx) + else: + return _CipherContext(ctx) + + +_CIPHER_TYPE = Cipher[ + typing.Union[ + modes.ModeWithNonce, + modes.ModeWithTweak, + None, + modes.ECB, + modes.ModeWithInitializationVector, + ] +] + + +class _CipherContext(CipherContext): + _ctx: _BackendCipherContext | None + + def __init__(self, ctx: _BackendCipherContext) -> None: + self._ctx = ctx + + def update(self, data: bytes) -> bytes: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + return self._ctx.update(data) + + def update_into(self, data: bytes, buf: bytes) -> int: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + return self._ctx.update_into(data, buf) + + def finalize(self) -> bytes: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + data = self._ctx.finalize() + self._ctx = None + return data + + +class _AEADCipherContext(AEADCipherContext): + _ctx: _BackendCipherContext | None + _tag: bytes | None + + def __init__(self, ctx: _BackendCipherContext) -> None: + self._ctx = ctx + self._bytes_processed = 0 + self._aad_bytes_processed = 0 + self._tag = None + self._updated = False + + def _check_limit(self, data_size: int) -> None: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + self._updated = True + self._bytes_processed += data_size + if self._bytes_processed > self._ctx._mode._MAX_ENCRYPTED_BYTES: + raise ValueError( + "{} has a maximum encrypted byte limit of {}".format( + self._ctx._mode.name, self._ctx._mode._MAX_ENCRYPTED_BYTES + ) + ) + + def update(self, data: bytes) -> bytes: + self._check_limit(len(data)) + # mypy needs this assert even though _check_limit already checked + assert self._ctx is not None + return self._ctx.update(data) + + def update_into(self, data: bytes, buf: bytes) -> int: + self._check_limit(len(data)) + # mypy needs this assert even though _check_limit already checked + assert self._ctx is not None + return self._ctx.update_into(data, buf) + + def finalize(self) -> bytes: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + data = self._ctx.finalize() + self._tag = self._ctx.tag + self._ctx = None + return data + + def authenticate_additional_data(self, data: bytes) -> None: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + if self._updated: + raise AlreadyUpdated("Update has been called on this context.") + + self._aad_bytes_processed += len(data) + if self._aad_bytes_processed > self._ctx._mode._MAX_AAD_BYTES: + raise ValueError( + "{} has a maximum AAD byte limit of {}".format( + self._ctx._mode.name, self._ctx._mode._MAX_AAD_BYTES + ) + ) + + self._ctx.authenticate_additional_data(data) + + +class _AEADDecryptionContext(_AEADCipherContext, AEADDecryptionContext): + def finalize_with_tag(self, tag: bytes) -> bytes: + if self._ctx is None: + raise AlreadyFinalized("Context was already finalized.") + if self._ctx._tag is not None: + raise ValueError( + "tag provided both in mode and in call with finalize_with_tag:" + " tag should only be provided once" + ) + data = self._ctx.finalize_with_tag(tag) + self._tag = self._ctx.tag + self._ctx = None + return data + + +class _AEADEncryptionContext(_AEADCipherContext, AEADEncryptionContext): + @property + def tag(self) -> bytes: + if self._ctx is not None: + raise NotYetFinalized( + "You must finalize encryption before " "getting the tag." + ) + assert self._tag is not None + return self._tag diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/modes.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/modes.py new file mode 100644 index 00000000..712ccd3f --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/ciphers/modes.py @@ -0,0 +1,273 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography import utils +from cryptography.exceptions import UnsupportedAlgorithm, _Reasons +from cryptography.hazmat.primitives._cipheralgorithm import ( + BlockCipherAlgorithm, + CipherAlgorithm, +) +from cryptography.hazmat.primitives.ciphers import algorithms + + +class Mode(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def name(self) -> str: + """ + A string naming this mode (e.g. "ECB", "CBC"). + """ + + @abc.abstractmethod + def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None: + """ + Checks that all the necessary invariants of this (mode, algorithm) + combination are met. + """ + + +class ModeWithInitializationVector(Mode, metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def initialization_vector(self) -> bytes: + """ + The value of the initialization vector for this mode as bytes. + """ + + +class ModeWithTweak(Mode, metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def tweak(self) -> bytes: + """ + The value of the tweak for this mode as bytes. + """ + + +class ModeWithNonce(Mode, metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def nonce(self) -> bytes: + """ + The value of the nonce for this mode as bytes. + """ + + +class ModeWithAuthenticationTag(Mode, metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def tag(self) -> bytes | None: + """ + The value of the tag supplied to the constructor of this mode. + """ + + +def _check_aes_key_length(self: Mode, algorithm: CipherAlgorithm) -> None: + if algorithm.key_size > 256 and algorithm.name == "AES": + raise ValueError( + "Only 128, 192, and 256 bit keys are allowed for this AES mode" + ) + + +def _check_iv_length( + self: ModeWithInitializationVector, algorithm: BlockCipherAlgorithm +) -> None: + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError( + "Invalid IV size ({}) for {}.".format( + len(self.initialization_vector), self.name + ) + ) + + +def _check_nonce_length( + nonce: bytes, name: str, algorithm: CipherAlgorithm +) -> None: + if not isinstance(algorithm, BlockCipherAlgorithm): + raise UnsupportedAlgorithm( + f"{name} requires a block cipher algorithm", + _Reasons.UNSUPPORTED_CIPHER, + ) + if len(nonce) * 8 != algorithm.block_size: + raise ValueError(f"Invalid nonce size ({len(nonce)}) for {name}.") + + +def _check_iv_and_key_length( + self: ModeWithInitializationVector, algorithm: CipherAlgorithm +) -> None: + if not isinstance(algorithm, BlockCipherAlgorithm): + raise UnsupportedAlgorithm( + f"{self} requires a block cipher algorithm", + _Reasons.UNSUPPORTED_CIPHER, + ) + _check_aes_key_length(self, algorithm) + _check_iv_length(self, algorithm) + + +class CBC(ModeWithInitializationVector): + name = "CBC" + + def __init__(self, initialization_vector: bytes): + utils._check_byteslike("initialization_vector", initialization_vector) + self._initialization_vector = initialization_vector + + @property + def initialization_vector(self) -> bytes: + return self._initialization_vector + + validate_for_algorithm = _check_iv_and_key_length + + +class XTS(ModeWithTweak): + name = "XTS" + + def __init__(self, tweak: bytes): + utils._check_byteslike("tweak", tweak) + + if len(tweak) != 16: + raise ValueError("tweak must be 128-bits (16 bytes)") + + self._tweak = tweak + + @property + def tweak(self) -> bytes: + return self._tweak + + def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None: + if isinstance(algorithm, (algorithms.AES128, algorithms.AES256)): + raise TypeError( + "The AES128 and AES256 classes do not support XTS, please use " + "the standard AES class instead." + ) + + if algorithm.key_size not in (256, 512): + raise ValueError( + "The XTS specification requires a 256-bit key for AES-128-XTS" + " and 512-bit key for AES-256-XTS" + ) + + +class ECB(Mode): + name = "ECB" + + validate_for_algorithm = _check_aes_key_length + + +class OFB(ModeWithInitializationVector): + name = "OFB" + + def __init__(self, initialization_vector: bytes): + utils._check_byteslike("initialization_vector", initialization_vector) + self._initialization_vector = initialization_vector + + @property + def initialization_vector(self) -> bytes: + return self._initialization_vector + + validate_for_algorithm = _check_iv_and_key_length + + +class CFB(ModeWithInitializationVector): + name = "CFB" + + def __init__(self, initialization_vector: bytes): + utils._check_byteslike("initialization_vector", initialization_vector) + self._initialization_vector = initialization_vector + + @property + def initialization_vector(self) -> bytes: + return self._initialization_vector + + validate_for_algorithm = _check_iv_and_key_length + + +class CFB8(ModeWithInitializationVector): + name = "CFB8" + + def __init__(self, initialization_vector: bytes): + utils._check_byteslike("initialization_vector", initialization_vector) + self._initialization_vector = initialization_vector + + @property + def initialization_vector(self) -> bytes: + return self._initialization_vector + + validate_for_algorithm = _check_iv_and_key_length + + +class CTR(ModeWithNonce): + name = "CTR" + + def __init__(self, nonce: bytes): + utils._check_byteslike("nonce", nonce) + self._nonce = nonce + + @property + def nonce(self) -> bytes: + return self._nonce + + def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None: + _check_aes_key_length(self, algorithm) + _check_nonce_length(self.nonce, self.name, algorithm) + + +class GCM(ModeWithInitializationVector, ModeWithAuthenticationTag): + name = "GCM" + _MAX_ENCRYPTED_BYTES = (2**39 - 256) // 8 + _MAX_AAD_BYTES = (2**64) // 8 + + def __init__( + self, + initialization_vector: bytes, + tag: bytes | None = None, + min_tag_length: int = 16, + ): + # OpenSSL 3.0.0 constrains GCM IVs to [64, 1024] bits inclusive + # This is a sane limit anyway so we'll enforce it here. + utils._check_byteslike("initialization_vector", initialization_vector) + if len(initialization_vector) < 8 or len(initialization_vector) > 128: + raise ValueError( + "initialization_vector must be between 8 and 128 bytes (64 " + "and 1024 bits)." + ) + self._initialization_vector = initialization_vector + if tag is not None: + utils._check_bytes("tag", tag) + if min_tag_length < 4: + raise ValueError("min_tag_length must be >= 4") + if len(tag) < min_tag_length: + raise ValueError( + "Authentication tag must be {} bytes or longer.".format( + min_tag_length + ) + ) + self._tag = tag + self._min_tag_length = min_tag_length + + @property + def tag(self) -> bytes | None: + return self._tag + + @property + def initialization_vector(self) -> bytes: + return self._initialization_vector + + def validate_for_algorithm(self, algorithm: CipherAlgorithm) -> None: + _check_aes_key_length(self, algorithm) + if not isinstance(algorithm, BlockCipherAlgorithm): + raise UnsupportedAlgorithm( + "GCM requires a block cipher algorithm", + _Reasons.UNSUPPORTED_CIPHER, + ) + block_size_bytes = algorithm.block_size // 8 + if self._tag is not None and len(self._tag) > block_size_bytes: + raise ValueError( + "Authentication tag cannot be more than {} bytes.".format( + block_size_bytes + ) + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/cmac.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/cmac.py new file mode 100644 index 00000000..2c67ce22 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/cmac.py @@ -0,0 +1,10 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +__all__ = ["CMAC"] +CMAC = rust_openssl.cmac.CMAC diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/constant_time.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/constant_time.py new file mode 100644 index 00000000..3975c714 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/constant_time.py @@ -0,0 +1,14 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import hmac + + +def bytes_eq(a: bytes, b: bytes) -> bool: + if not isinstance(a, bytes) or not isinstance(b, bytes): + raise TypeError("a and b must be bytes.") + + return hmac.compare_digest(a, b) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/hashes.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/hashes.py new file mode 100644 index 00000000..c5be0c8e --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/hashes.py @@ -0,0 +1,242 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +__all__ = [ + "HashAlgorithm", + "HashContext", + "Hash", + "ExtendableOutputFunction", + "SHA1", + "SHA512_224", + "SHA512_256", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "SHA3_224", + "SHA3_256", + "SHA3_384", + "SHA3_512", + "SHAKE128", + "SHAKE256", + "MD5", + "BLAKE2b", + "BLAKE2s", + "SM3", +] + + +class HashAlgorithm(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def name(self) -> str: + """ + A string naming this algorithm (e.g. "sha256", "md5"). + """ + + @property + @abc.abstractmethod + def digest_size(self) -> int: + """ + The size of the resulting digest in bytes. + """ + + @property + @abc.abstractmethod + def block_size(self) -> int | None: + """ + The internal block size of the hash function, or None if the hash + function does not use blocks internally (e.g. SHA3). + """ + + +class HashContext(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def algorithm(self) -> HashAlgorithm: + """ + A HashAlgorithm that will be used by this context. + """ + + @abc.abstractmethod + def update(self, data: bytes) -> None: + """ + Processes the provided bytes through the hash. + """ + + @abc.abstractmethod + def finalize(self) -> bytes: + """ + Finalizes the hash context and returns the hash digest as bytes. + """ + + @abc.abstractmethod + def copy(self) -> HashContext: + """ + Return a HashContext that is a copy of the current context. + """ + + +Hash = rust_openssl.hashes.Hash +HashContext.register(Hash) + + +class ExtendableOutputFunction(metaclass=abc.ABCMeta): + """ + An interface for extendable output functions. + """ + + +class SHA1(HashAlgorithm): + name = "sha1" + digest_size = 20 + block_size = 64 + + +class SHA512_224(HashAlgorithm): # noqa: N801 + name = "sha512-224" + digest_size = 28 + block_size = 128 + + +class SHA512_256(HashAlgorithm): # noqa: N801 + name = "sha512-256" + digest_size = 32 + block_size = 128 + + +class SHA224(HashAlgorithm): + name = "sha224" + digest_size = 28 + block_size = 64 + + +class SHA256(HashAlgorithm): + name = "sha256" + digest_size = 32 + block_size = 64 + + +class SHA384(HashAlgorithm): + name = "sha384" + digest_size = 48 + block_size = 128 + + +class SHA512(HashAlgorithm): + name = "sha512" + digest_size = 64 + block_size = 128 + + +class SHA3_224(HashAlgorithm): # noqa: N801 + name = "sha3-224" + digest_size = 28 + block_size = None + + +class SHA3_256(HashAlgorithm): # noqa: N801 + name = "sha3-256" + digest_size = 32 + block_size = None + + +class SHA3_384(HashAlgorithm): # noqa: N801 + name = "sha3-384" + digest_size = 48 + block_size = None + + +class SHA3_512(HashAlgorithm): # noqa: N801 + name = "sha3-512" + digest_size = 64 + block_size = None + + +class SHAKE128(HashAlgorithm, ExtendableOutputFunction): + name = "shake128" + block_size = None + + def __init__(self, digest_size: int): + if not isinstance(digest_size, int): + raise TypeError("digest_size must be an integer") + + if digest_size < 1: + raise ValueError("digest_size must be a positive integer") + + self._digest_size = digest_size + + @property + def digest_size(self) -> int: + return self._digest_size + + +class SHAKE256(HashAlgorithm, ExtendableOutputFunction): + name = "shake256" + block_size = None + + def __init__(self, digest_size: int): + if not isinstance(digest_size, int): + raise TypeError("digest_size must be an integer") + + if digest_size < 1: + raise ValueError("digest_size must be a positive integer") + + self._digest_size = digest_size + + @property + def digest_size(self) -> int: + return self._digest_size + + +class MD5(HashAlgorithm): + name = "md5" + digest_size = 16 + block_size = 64 + + +class BLAKE2b(HashAlgorithm): + name = "blake2b" + _max_digest_size = 64 + _min_digest_size = 1 + block_size = 128 + + def __init__(self, digest_size: int): + if digest_size != 64: + raise ValueError("Digest size must be 64") + + self._digest_size = digest_size + + @property + def digest_size(self) -> int: + return self._digest_size + + +class BLAKE2s(HashAlgorithm): + name = "blake2s" + block_size = 64 + _max_digest_size = 32 + _min_digest_size = 1 + + def __init__(self, digest_size: int): + if digest_size != 32: + raise ValueError("Digest size must be 32") + + self._digest_size = digest_size + + @property + def digest_size(self) -> int: + return self._digest_size + + +class SM3(HashAlgorithm): + name = "sm3" + digest_size = 32 + block_size = 64 diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/hmac.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/hmac.py new file mode 100644 index 00000000..a9442d59 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/hmac.py @@ -0,0 +1,13 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import hashes + +__all__ = ["HMAC"] + +HMAC = rust_openssl.hmac.HMAC +hashes.HashContext.register(HMAC) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__init__.py new file mode 100644 index 00000000..79bb459f --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__init__.py @@ -0,0 +1,23 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc + + +class KeyDerivationFunction(metaclass=abc.ABCMeta): + @abc.abstractmethod + def derive(self, key_material: bytes) -> bytes: + """ + Deterministically generates and returns a new key based on the existing + key material. + """ + + @abc.abstractmethod + def verify(self, key_material: bytes, expected_key: bytes) -> None: + """ + Checks whether the key generated by the key material matches the + expected derived key. Raises an exception if they do not match. + """ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..1358ce1d Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/concatkdf.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/concatkdf.cpython-311.pyc new file mode 100644 index 00000000..9c2540f3 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/concatkdf.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/hkdf.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/hkdf.cpython-311.pyc new file mode 100644 index 00000000..9a56831f Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/hkdf.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/kbkdf.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/kbkdf.cpython-311.pyc new file mode 100644 index 00000000..0c13596c Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/kbkdf.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/pbkdf2.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/pbkdf2.cpython-311.pyc new file mode 100644 index 00000000..d8319a19 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/pbkdf2.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/scrypt.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/scrypt.cpython-311.pyc new file mode 100644 index 00000000..26cff800 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/scrypt.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/x963kdf.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/x963kdf.cpython-311.pyc new file mode 100644 index 00000000..6bc21210 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/__pycache__/x963kdf.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/concatkdf.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/concatkdf.py new file mode 100644 index 00000000..96d9d4c0 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/concatkdf.py @@ -0,0 +1,124 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.exceptions import AlreadyFinalized, InvalidKey +from cryptography.hazmat.primitives import constant_time, hashes, hmac +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + + +def _int_to_u32be(n: int) -> bytes: + return n.to_bytes(length=4, byteorder="big") + + +def _common_args_checks( + algorithm: hashes.HashAlgorithm, + length: int, + otherinfo: bytes | None, +) -> None: + max_length = algorithm.digest_size * (2**32 - 1) + if length > max_length: + raise ValueError(f"Cannot derive keys larger than {max_length} bits.") + if otherinfo is not None: + utils._check_bytes("otherinfo", otherinfo) + + +def _concatkdf_derive( + key_material: bytes, + length: int, + auxfn: typing.Callable[[], hashes.HashContext], + otherinfo: bytes, +) -> bytes: + utils._check_byteslike("key_material", key_material) + output = [b""] + outlen = 0 + counter = 1 + + while length > outlen: + h = auxfn() + h.update(_int_to_u32be(counter)) + h.update(key_material) + h.update(otherinfo) + output.append(h.finalize()) + outlen += len(output[-1]) + counter += 1 + + return b"".join(output)[:length] + + +class ConcatKDFHash(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + otherinfo: bytes | None, + backend: typing.Any = None, + ): + _common_args_checks(algorithm, length, otherinfo) + self._algorithm = algorithm + self._length = length + self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" + + self._used = False + + def _hash(self) -> hashes.Hash: + return hashes.Hash(self._algorithm) + + def derive(self, key_material: bytes) -> bytes: + if self._used: + raise AlreadyFinalized + self._used = True + return _concatkdf_derive( + key_material, self._length, self._hash, self._otherinfo + ) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey + + +class ConcatKDFHMAC(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + salt: bytes | None, + otherinfo: bytes | None, + backend: typing.Any = None, + ): + _common_args_checks(algorithm, length, otherinfo) + self._algorithm = algorithm + self._length = length + self._otherinfo: bytes = otherinfo if otherinfo is not None else b"" + + if algorithm.block_size is None: + raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF") + + if salt is None: + salt = b"\x00" * algorithm.block_size + else: + utils._check_bytes("salt", salt) + + self._salt = salt + + self._used = False + + def _hmac(self) -> hmac.HMAC: + return hmac.HMAC(self._salt, self._algorithm) + + def derive(self, key_material: bytes) -> bytes: + if self._used: + raise AlreadyFinalized + self._used = True + return _concatkdf_derive( + key_material, self._length, self._hmac, self._otherinfo + ) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/hkdf.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/hkdf.py new file mode 100644 index 00000000..ee562d2f --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/hkdf.py @@ -0,0 +1,101 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.exceptions import AlreadyFinalized, InvalidKey +from cryptography.hazmat.primitives import constant_time, hashes, hmac +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + + +class HKDF(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + salt: bytes | None, + info: bytes | None, + backend: typing.Any = None, + ): + self._algorithm = algorithm + + if salt is None: + salt = b"\x00" * self._algorithm.digest_size + else: + utils._check_bytes("salt", salt) + + self._salt = salt + + self._hkdf_expand = HKDFExpand(self._algorithm, length, info) + + def _extract(self, key_material: bytes) -> bytes: + h = hmac.HMAC(self._salt, self._algorithm) + h.update(key_material) + return h.finalize() + + def derive(self, key_material: bytes) -> bytes: + utils._check_byteslike("key_material", key_material) + return self._hkdf_expand.derive(self._extract(key_material)) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey + + +class HKDFExpand(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + info: bytes | None, + backend: typing.Any = None, + ): + self._algorithm = algorithm + + max_length = 255 * algorithm.digest_size + + if length > max_length: + raise ValueError( + f"Cannot derive keys larger than {max_length} octets." + ) + + self._length = length + + if info is None: + info = b"" + else: + utils._check_bytes("info", info) + + self._info = info + + self._used = False + + def _expand(self, key_material: bytes) -> bytes: + output = [b""] + counter = 1 + + while self._algorithm.digest_size * (len(output) - 1) < self._length: + h = hmac.HMAC(key_material, self._algorithm) + h.update(output[-1]) + h.update(self._info) + h.update(bytes([counter])) + output.append(h.finalize()) + counter += 1 + + return b"".join(output)[: self._length] + + def derive(self, key_material: bytes) -> bytes: + utils._check_byteslike("key_material", key_material) + if self._used: + raise AlreadyFinalized + + self._used = True + return self._expand(key_material) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/kbkdf.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/kbkdf.py new file mode 100644 index 00000000..2f41db92 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/kbkdf.py @@ -0,0 +1,299 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.exceptions import ( + AlreadyFinalized, + InvalidKey, + UnsupportedAlgorithm, + _Reasons, +) +from cryptography.hazmat.primitives import ( + ciphers, + cmac, + constant_time, + hashes, + hmac, +) +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + + +class Mode(utils.Enum): + CounterMode = "ctr" + + +class CounterLocation(utils.Enum): + BeforeFixed = "before_fixed" + AfterFixed = "after_fixed" + MiddleFixed = "middle_fixed" + + +class _KBKDFDeriver: + def __init__( + self, + prf: typing.Callable, + mode: Mode, + length: int, + rlen: int, + llen: int | None, + location: CounterLocation, + break_location: int | None, + label: bytes | None, + context: bytes | None, + fixed: bytes | None, + ): + assert callable(prf) + + if not isinstance(mode, Mode): + raise TypeError("mode must be of type Mode") + + if not isinstance(location, CounterLocation): + raise TypeError("location must be of type CounterLocation") + + if break_location is None and location is CounterLocation.MiddleFixed: + raise ValueError("Please specify a break_location") + + if ( + break_location is not None + and location != CounterLocation.MiddleFixed + ): + raise ValueError( + "break_location is ignored when location is not" + " CounterLocation.MiddleFixed" + ) + + if break_location is not None and not isinstance(break_location, int): + raise TypeError("break_location must be an integer") + + if break_location is not None and break_location < 0: + raise ValueError("break_location must be a positive integer") + + if (label or context) and fixed: + raise ValueError( + "When supplying fixed data, " "label and context are ignored." + ) + + if rlen is None or not self._valid_byte_length(rlen): + raise ValueError("rlen must be between 1 and 4") + + if llen is None and fixed is None: + raise ValueError("Please specify an llen") + + if llen is not None and not isinstance(llen, int): + raise TypeError("llen must be an integer") + + if label is None: + label = b"" + + if context is None: + context = b"" + + utils._check_bytes("label", label) + utils._check_bytes("context", context) + self._prf = prf + self._mode = mode + self._length = length + self._rlen = rlen + self._llen = llen + self._location = location + self._break_location = break_location + self._label = label + self._context = context + self._used = False + self._fixed_data = fixed + + @staticmethod + def _valid_byte_length(value: int) -> bool: + if not isinstance(value, int): + raise TypeError("value must be of type int") + + value_bin = utils.int_to_bytes(1, value) + if not 1 <= len(value_bin) <= 4: + return False + return True + + def derive(self, key_material: bytes, prf_output_size: int) -> bytes: + if self._used: + raise AlreadyFinalized + + utils._check_byteslike("key_material", key_material) + self._used = True + + # inverse floor division (equivalent to ceiling) + rounds = -(-self._length // prf_output_size) + + output = [b""] + + # For counter mode, the number of iterations shall not be + # larger than 2^r-1, where r <= 32 is the binary length of the counter + # This ensures that the counter values used as an input to the + # PRF will not repeat during a particular call to the KDF function. + r_bin = utils.int_to_bytes(1, self._rlen) + if rounds > pow(2, len(r_bin) * 8) - 1: + raise ValueError("There are too many iterations.") + + fixed = self._generate_fixed_input() + + if self._location == CounterLocation.BeforeFixed: + data_before_ctr = b"" + data_after_ctr = fixed + elif self._location == CounterLocation.AfterFixed: + data_before_ctr = fixed + data_after_ctr = b"" + else: + if isinstance( + self._break_location, int + ) and self._break_location > len(fixed): + raise ValueError("break_location offset > len(fixed)") + data_before_ctr = fixed[: self._break_location] + data_after_ctr = fixed[self._break_location :] + + for i in range(1, rounds + 1): + h = self._prf(key_material) + + counter = utils.int_to_bytes(i, self._rlen) + input_data = data_before_ctr + counter + data_after_ctr + + h.update(input_data) + + output.append(h.finalize()) + + return b"".join(output)[: self._length] + + def _generate_fixed_input(self) -> bytes: + if self._fixed_data and isinstance(self._fixed_data, bytes): + return self._fixed_data + + l_val = utils.int_to_bytes(self._length * 8, self._llen) + + return b"".join([self._label, b"\x00", self._context, l_val]) + + +class KBKDFHMAC(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + mode: Mode, + length: int, + rlen: int, + llen: int | None, + location: CounterLocation, + label: bytes | None, + context: bytes | None, + fixed: bytes | None, + backend: typing.Any = None, + *, + break_location: int | None = None, + ): + if not isinstance(algorithm, hashes.HashAlgorithm): + raise UnsupportedAlgorithm( + "Algorithm supplied is not a supported hash algorithm.", + _Reasons.UNSUPPORTED_HASH, + ) + + from cryptography.hazmat.backends.openssl.backend import ( + backend as ossl, + ) + + if not ossl.hmac_supported(algorithm): + raise UnsupportedAlgorithm( + "Algorithm supplied is not a supported hmac algorithm.", + _Reasons.UNSUPPORTED_HASH, + ) + + self._algorithm = algorithm + + self._deriver = _KBKDFDeriver( + self._prf, + mode, + length, + rlen, + llen, + location, + break_location, + label, + context, + fixed, + ) + + def _prf(self, key_material: bytes) -> hmac.HMAC: + return hmac.HMAC(key_material, self._algorithm) + + def derive(self, key_material: bytes) -> bytes: + return self._deriver.derive(key_material, self._algorithm.digest_size) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey + + +class KBKDFCMAC(KeyDerivationFunction): + def __init__( + self, + algorithm, + mode: Mode, + length: int, + rlen: int, + llen: int | None, + location: CounterLocation, + label: bytes | None, + context: bytes | None, + fixed: bytes | None, + backend: typing.Any = None, + *, + break_location: int | None = None, + ): + if not issubclass( + algorithm, ciphers.BlockCipherAlgorithm + ) or not issubclass(algorithm, ciphers.CipherAlgorithm): + raise UnsupportedAlgorithm( + "Algorithm supplied is not a supported cipher algorithm.", + _Reasons.UNSUPPORTED_CIPHER, + ) + + self._algorithm = algorithm + self._cipher: ciphers.BlockCipherAlgorithm | None = None + + self._deriver = _KBKDFDeriver( + self._prf, + mode, + length, + rlen, + llen, + location, + break_location, + label, + context, + fixed, + ) + + def _prf(self, _: bytes) -> cmac.CMAC: + assert self._cipher is not None + + return cmac.CMAC(self._cipher) + + def derive(self, key_material: bytes) -> bytes: + self._cipher = self._algorithm(key_material) + + assert self._cipher is not None + + from cryptography.hazmat.backends.openssl.backend import ( + backend as ossl, + ) + + if not ossl.cmac_algorithm_supported(self._cipher): + raise UnsupportedAlgorithm( + "Algorithm supplied is not a supported cipher algorithm.", + _Reasons.UNSUPPORTED_CIPHER, + ) + + return self._deriver.derive(key_material, self._cipher.block_size // 8) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/pbkdf2.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/pbkdf2.py new file mode 100644 index 00000000..623e1ca7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/pbkdf2.py @@ -0,0 +1,64 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.exceptions import ( + AlreadyFinalized, + InvalidKey, + UnsupportedAlgorithm, + _Reasons, +) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import constant_time, hashes +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + + +class PBKDF2HMAC(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + salt: bytes, + iterations: int, + backend: typing.Any = None, + ): + from cryptography.hazmat.backends.openssl.backend import ( + backend as ossl, + ) + + if not ossl.pbkdf2_hmac_supported(algorithm): + raise UnsupportedAlgorithm( + "{} is not supported for PBKDF2 by this backend.".format( + algorithm.name + ), + _Reasons.UNSUPPORTED_HASH, + ) + self._used = False + self._algorithm = algorithm + self._length = length + utils._check_bytes("salt", salt) + self._salt = salt + self._iterations = iterations + + def derive(self, key_material: bytes) -> bytes: + if self._used: + raise AlreadyFinalized("PBKDF2 instances can only be used once.") + self._used = True + + return rust_openssl.kdf.derive_pbkdf2_hmac( + key_material, + self._algorithm, + self._salt, + self._iterations, + self._length, + ) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + derived_key = self.derive(key_material) + if not constant_time.bytes_eq(derived_key, expected_key): + raise InvalidKey("Keys do not match.") diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/scrypt.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/scrypt.py new file mode 100644 index 00000000..05a4f675 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/scrypt.py @@ -0,0 +1,80 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import sys +import typing + +from cryptography import utils +from cryptography.exceptions import ( + AlreadyFinalized, + InvalidKey, + UnsupportedAlgorithm, +) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +from cryptography.hazmat.primitives import constant_time +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + +# This is used by the scrypt tests to skip tests that require more memory +# than the MEM_LIMIT +_MEM_LIMIT = sys.maxsize // 2 + + +class Scrypt(KeyDerivationFunction): + def __init__( + self, + salt: bytes, + length: int, + n: int, + r: int, + p: int, + backend: typing.Any = None, + ): + from cryptography.hazmat.backends.openssl.backend import ( + backend as ossl, + ) + + if not ossl.scrypt_supported(): + raise UnsupportedAlgorithm( + "This version of OpenSSL does not support scrypt" + ) + self._length = length + utils._check_bytes("salt", salt) + if n < 2 or (n & (n - 1)) != 0: + raise ValueError("n must be greater than 1 and be a power of 2.") + + if r < 1: + raise ValueError("r must be greater than or equal to 1.") + + if p < 1: + raise ValueError("p must be greater than or equal to 1.") + + self._used = False + self._salt = salt + self._n = n + self._r = r + self._p = p + + def derive(self, key_material: bytes) -> bytes: + if self._used: + raise AlreadyFinalized("Scrypt instances can only be used once.") + self._used = True + + utils._check_byteslike("key_material", key_material) + + return rust_openssl.kdf.derive_scrypt( + key_material, + self._salt, + self._n, + self._r, + self._p, + _MEM_LIMIT, + self._length, + ) + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + derived_key = self.derive(key_material) + if not constant_time.bytes_eq(derived_key, expected_key): + raise InvalidKey("Keys do not match.") diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/x963kdf.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/x963kdf.py new file mode 100644 index 00000000..6e38366a --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/kdf/x963kdf.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import utils +from cryptography.exceptions import AlreadyFinalized, InvalidKey +from cryptography.hazmat.primitives import constant_time, hashes +from cryptography.hazmat.primitives.kdf import KeyDerivationFunction + + +def _int_to_u32be(n: int) -> bytes: + return n.to_bytes(length=4, byteorder="big") + + +class X963KDF(KeyDerivationFunction): + def __init__( + self, + algorithm: hashes.HashAlgorithm, + length: int, + sharedinfo: bytes | None, + backend: typing.Any = None, + ): + max_len = algorithm.digest_size * (2**32 - 1) + if length > max_len: + raise ValueError(f"Cannot derive keys larger than {max_len} bits.") + if sharedinfo is not None: + utils._check_bytes("sharedinfo", sharedinfo) + + self._algorithm = algorithm + self._length = length + self._sharedinfo = sharedinfo + self._used = False + + def derive(self, key_material: bytes) -> bytes: + if self._used: + raise AlreadyFinalized + self._used = True + utils._check_byteslike("key_material", key_material) + output = [b""] + outlen = 0 + counter = 1 + + while self._length > outlen: + h = hashes.Hash(self._algorithm) + h.update(key_material) + h.update(_int_to_u32be(counter)) + if self._sharedinfo is not None: + h.update(self._sharedinfo) + output.append(h.finalize()) + outlen += len(output[-1]) + counter += 1 + + return b"".join(output)[: self._length] + + def verify(self, key_material: bytes, expected_key: bytes) -> None: + if not constant_time.bytes_eq(self.derive(key_material), expected_key): + raise InvalidKey diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/keywrap.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/keywrap.py new file mode 100644 index 00000000..3ee152b7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/keywrap.py @@ -0,0 +1,177 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import ECB +from cryptography.hazmat.primitives.constant_time import bytes_eq + + +def _wrap_core( + wrapping_key: bytes, + a: bytes, + r: list[bytes], +) -> bytes: + # RFC 3394 Key Wrap - 2.2.1 (index method) + encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() + n = len(r) + for j in range(6): + for i in range(n): + # every encryption operation is a discrete 16 byte chunk (because + # AES has a 128-bit block size) and since we're using ECB it is + # safe to reuse the encryptor for the entire operation + b = encryptor.update(a + r[i]) + a = ( + int.from_bytes(b[:8], byteorder="big") ^ ((n * j) + i + 1) + ).to_bytes(length=8, byteorder="big") + r[i] = b[-8:] + + assert encryptor.finalize() == b"" + + return a + b"".join(r) + + +def aes_key_wrap( + wrapping_key: bytes, + key_to_wrap: bytes, + backend: typing.Any = None, +) -> bytes: + if len(wrapping_key) not in [16, 24, 32]: + raise ValueError("The wrapping key must be a valid AES key length") + + if len(key_to_wrap) < 16: + raise ValueError("The key to wrap must be at least 16 bytes") + + if len(key_to_wrap) % 8 != 0: + raise ValueError("The key to wrap must be a multiple of 8 bytes") + + a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" + r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] + return _wrap_core(wrapping_key, a, r) + + +def _unwrap_core( + wrapping_key: bytes, + a: bytes, + r: list[bytes], +) -> tuple[bytes, list[bytes]]: + # Implement RFC 3394 Key Unwrap - 2.2.2 (index method) + decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() + n = len(r) + for j in reversed(range(6)): + for i in reversed(range(n)): + atr = ( + int.from_bytes(a, byteorder="big") ^ ((n * j) + i + 1) + ).to_bytes(length=8, byteorder="big") + r[i] + # every decryption operation is a discrete 16 byte chunk so + # it is safe to reuse the decryptor for the entire operation + b = decryptor.update(atr) + a = b[:8] + r[i] = b[-8:] + + assert decryptor.finalize() == b"" + return a, r + + +def aes_key_wrap_with_padding( + wrapping_key: bytes, + key_to_wrap: bytes, + backend: typing.Any = None, +) -> bytes: + if len(wrapping_key) not in [16, 24, 32]: + raise ValueError("The wrapping key must be a valid AES key length") + + aiv = b"\xA6\x59\x59\xA6" + len(key_to_wrap).to_bytes( + length=4, byteorder="big" + ) + # pad the key to wrap if necessary + pad = (8 - (len(key_to_wrap) % 8)) % 8 + key_to_wrap = key_to_wrap + b"\x00" * pad + if len(key_to_wrap) == 8: + # RFC 5649 - 4.1 - exactly 8 octets after padding + encryptor = Cipher(AES(wrapping_key), ECB()).encryptor() + b = encryptor.update(aiv + key_to_wrap) + assert encryptor.finalize() == b"" + return b + else: + r = [key_to_wrap[i : i + 8] for i in range(0, len(key_to_wrap), 8)] + return _wrap_core(wrapping_key, aiv, r) + + +def aes_key_unwrap_with_padding( + wrapping_key: bytes, + wrapped_key: bytes, + backend: typing.Any = None, +) -> bytes: + if len(wrapped_key) < 16: + raise InvalidUnwrap("Must be at least 16 bytes") + + if len(wrapping_key) not in [16, 24, 32]: + raise ValueError("The wrapping key must be a valid AES key length") + + if len(wrapped_key) == 16: + # RFC 5649 - 4.2 - exactly two 64-bit blocks + decryptor = Cipher(AES(wrapping_key), ECB()).decryptor() + out = decryptor.update(wrapped_key) + assert decryptor.finalize() == b"" + a = out[:8] + data = out[8:] + n = 1 + else: + r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] + encrypted_aiv = r.pop(0) + n = len(r) + a, r = _unwrap_core(wrapping_key, encrypted_aiv, r) + data = b"".join(r) + + # 1) Check that MSB(32,A) = A65959A6. + # 2) Check that 8*(n-1) < LSB(32,A) <= 8*n. If so, let + # MLI = LSB(32,A). + # 3) Let b = (8*n)-MLI, and then check that the rightmost b octets of + # the output data are zero. + mli = int.from_bytes(a[4:], byteorder="big") + b = (8 * n) - mli + if ( + not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") + or not 8 * (n - 1) < mli <= 8 * n + or (b != 0 and not bytes_eq(data[-b:], b"\x00" * b)) + ): + raise InvalidUnwrap() + + if b == 0: + return data + else: + return data[:-b] + + +def aes_key_unwrap( + wrapping_key: bytes, + wrapped_key: bytes, + backend: typing.Any = None, +) -> bytes: + if len(wrapped_key) < 24: + raise InvalidUnwrap("Must be at least 24 bytes") + + if len(wrapped_key) % 8 != 0: + raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes") + + if len(wrapping_key) not in [16, 24, 32]: + raise ValueError("The wrapping key must be a valid AES key length") + + aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" + r = [wrapped_key[i : i + 8] for i in range(0, len(wrapped_key), 8)] + a = r.pop(0) + a, r = _unwrap_core(wrapping_key, a, r) + if not bytes_eq(a, aiv): + raise InvalidUnwrap() + + return b"".join(r) + + +class InvalidUnwrap(Exception): + pass diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/padding.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/padding.py new file mode 100644 index 00000000..baceaf38 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/padding.py @@ -0,0 +1,225 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import typing + +from cryptography import utils +from cryptography.exceptions import AlreadyFinalized +from cryptography.hazmat.bindings._rust import ( + check_ansix923_padding, + check_pkcs7_padding, +) + + +class PaddingContext(metaclass=abc.ABCMeta): + @abc.abstractmethod + def update(self, data: bytes) -> bytes: + """ + Pads the provided bytes and returns any available data as bytes. + """ + + @abc.abstractmethod + def finalize(self) -> bytes: + """ + Finalize the padding, returns bytes. + """ + + +def _byte_padding_check(block_size: int) -> None: + if not (0 <= block_size <= 2040): + raise ValueError("block_size must be in range(0, 2041).") + + if block_size % 8 != 0: + raise ValueError("block_size must be a multiple of 8.") + + +def _byte_padding_update( + buffer_: bytes | None, data: bytes, block_size: int +) -> tuple[bytes, bytes]: + if buffer_ is None: + raise AlreadyFinalized("Context was already finalized.") + + utils._check_byteslike("data", data) + + buffer_ += bytes(data) + + finished_blocks = len(buffer_) // (block_size // 8) + + result = buffer_[: finished_blocks * (block_size // 8)] + buffer_ = buffer_[finished_blocks * (block_size // 8) :] + + return buffer_, result + + +def _byte_padding_pad( + buffer_: bytes | None, + block_size: int, + paddingfn: typing.Callable[[int], bytes], +) -> bytes: + if buffer_ is None: + raise AlreadyFinalized("Context was already finalized.") + + pad_size = block_size // 8 - len(buffer_) + return buffer_ + paddingfn(pad_size) + + +def _byte_unpadding_update( + buffer_: bytes | None, data: bytes, block_size: int +) -> tuple[bytes, bytes]: + if buffer_ is None: + raise AlreadyFinalized("Context was already finalized.") + + utils._check_byteslike("data", data) + + buffer_ += bytes(data) + + finished_blocks = max(len(buffer_) // (block_size // 8) - 1, 0) + + result = buffer_[: finished_blocks * (block_size // 8)] + buffer_ = buffer_[finished_blocks * (block_size // 8) :] + + return buffer_, result + + +def _byte_unpadding_check( + buffer_: bytes | None, + block_size: int, + checkfn: typing.Callable[[bytes], int], +) -> bytes: + if buffer_ is None: + raise AlreadyFinalized("Context was already finalized.") + + if len(buffer_) != block_size // 8: + raise ValueError("Invalid padding bytes.") + + valid = checkfn(buffer_) + + if not valid: + raise ValueError("Invalid padding bytes.") + + pad_size = buffer_[-1] + return buffer_[:-pad_size] + + +class PKCS7: + def __init__(self, block_size: int): + _byte_padding_check(block_size) + self.block_size = block_size + + def padder(self) -> PaddingContext: + return _PKCS7PaddingContext(self.block_size) + + def unpadder(self) -> PaddingContext: + return _PKCS7UnpaddingContext(self.block_size) + + +class _PKCS7PaddingContext(PaddingContext): + _buffer: bytes | None + + def __init__(self, block_size: int): + self.block_size = block_size + # TODO: more copies than necessary, we should use zero-buffer (#193) + self._buffer = b"" + + def update(self, data: bytes) -> bytes: + self._buffer, result = _byte_padding_update( + self._buffer, data, self.block_size + ) + return result + + def _padding(self, size: int) -> bytes: + return bytes([size]) * size + + def finalize(self) -> bytes: + result = _byte_padding_pad( + self._buffer, self.block_size, self._padding + ) + self._buffer = None + return result + + +class _PKCS7UnpaddingContext(PaddingContext): + _buffer: bytes | None + + def __init__(self, block_size: int): + self.block_size = block_size + # TODO: more copies than necessary, we should use zero-buffer (#193) + self._buffer = b"" + + def update(self, data: bytes) -> bytes: + self._buffer, result = _byte_unpadding_update( + self._buffer, data, self.block_size + ) + return result + + def finalize(self) -> bytes: + result = _byte_unpadding_check( + self._buffer, self.block_size, check_pkcs7_padding + ) + self._buffer = None + return result + + +class ANSIX923: + def __init__(self, block_size: int): + _byte_padding_check(block_size) + self.block_size = block_size + + def padder(self) -> PaddingContext: + return _ANSIX923PaddingContext(self.block_size) + + def unpadder(self) -> PaddingContext: + return _ANSIX923UnpaddingContext(self.block_size) + + +class _ANSIX923PaddingContext(PaddingContext): + _buffer: bytes | None + + def __init__(self, block_size: int): + self.block_size = block_size + # TODO: more copies than necessary, we should use zero-buffer (#193) + self._buffer = b"" + + def update(self, data: bytes) -> bytes: + self._buffer, result = _byte_padding_update( + self._buffer, data, self.block_size + ) + return result + + def _padding(self, size: int) -> bytes: + return bytes([0]) * (size - 1) + bytes([size]) + + def finalize(self) -> bytes: + result = _byte_padding_pad( + self._buffer, self.block_size, self._padding + ) + self._buffer = None + return result + + +class _ANSIX923UnpaddingContext(PaddingContext): + _buffer: bytes | None + + def __init__(self, block_size: int): + self.block_size = block_size + # TODO: more copies than necessary, we should use zero-buffer (#193) + self._buffer = b"" + + def update(self, data: bytes) -> bytes: + self._buffer, result = _byte_unpadding_update( + self._buffer, data, self.block_size + ) + return result + + def finalize(self) -> bytes: + result = _byte_unpadding_check( + self._buffer, + self.block_size, + check_ansix923_padding, + ) + self._buffer = None + return result diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/poly1305.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/poly1305.py new file mode 100644 index 00000000..7f5a77a5 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/poly1305.py @@ -0,0 +1,11 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +__all__ = ["Poly1305"] + +Poly1305 = rust_openssl.poly1305.Poly1305 diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__init__.py new file mode 100644 index 00000000..b6c9a5cd --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__init__.py @@ -0,0 +1,63 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat.primitives._serialization import ( + BestAvailableEncryption, + Encoding, + KeySerializationEncryption, + NoEncryption, + ParameterFormat, + PrivateFormat, + PublicFormat, + _KeySerializationEncryption, +) +from cryptography.hazmat.primitives.serialization.base import ( + load_der_parameters, + load_der_private_key, + load_der_public_key, + load_pem_parameters, + load_pem_private_key, + load_pem_public_key, +) +from cryptography.hazmat.primitives.serialization.ssh import ( + SSHCertificate, + SSHCertificateBuilder, + SSHCertificateType, + SSHCertPrivateKeyTypes, + SSHCertPublicKeyTypes, + SSHPrivateKeyTypes, + SSHPublicKeyTypes, + load_ssh_private_key, + load_ssh_public_identity, + load_ssh_public_key, +) + +__all__ = [ + "load_der_parameters", + "load_der_private_key", + "load_der_public_key", + "load_pem_parameters", + "load_pem_private_key", + "load_pem_public_key", + "load_ssh_private_key", + "load_ssh_public_identity", + "load_ssh_public_key", + "Encoding", + "PrivateFormat", + "PublicFormat", + "ParameterFormat", + "KeySerializationEncryption", + "BestAvailableEncryption", + "NoEncryption", + "_KeySerializationEncryption", + "SSHCertificateBuilder", + "SSHCertificate", + "SSHCertificateType", + "SSHCertPublicKeyTypes", + "SSHCertPrivateKeyTypes", + "SSHPrivateKeyTypes", + "SSHPublicKeyTypes", +] diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..100e8090 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..e8c94182 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs12.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs12.cpython-311.pyc new file mode 100644 index 00000000..93d099c3 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs12.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs7.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs7.cpython-311.pyc new file mode 100644 index 00000000..59226b0a Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/pkcs7.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/ssh.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/ssh.cpython-311.pyc new file mode 100644 index 00000000..3eaae811 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/__pycache__/ssh.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/base.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/base.py new file mode 100644 index 00000000..e7c998b7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/base.py @@ -0,0 +1,14 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.bindings._rust import openssl as rust_openssl + +load_pem_private_key = rust_openssl.keys.load_pem_private_key +load_der_private_key = rust_openssl.keys.load_der_private_key + +load_pem_public_key = rust_openssl.keys.load_pem_public_key +load_der_public_key = rust_openssl.keys.load_der_public_key + +load_pem_parameters = rust_openssl.dh.from_pem_parameters +load_der_parameters = rust_openssl.dh.from_der_parameters diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs12.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs12.py new file mode 100644 index 00000000..006a248b --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs12.py @@ -0,0 +1,229 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives._serialization import PBES as PBES +from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + ed448, + ed25519, + rsa, +) +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + +__all__ = [ + "PBES", + "PKCS12PrivateKeyTypes", + "PKCS12Certificate", + "PKCS12KeyAndCertificates", + "load_key_and_certificates", + "load_pkcs12", + "serialize_key_and_certificates", +] + +PKCS12PrivateKeyTypes = typing.Union[ + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, +] + + +class PKCS12Certificate: + def __init__( + self, + cert: x509.Certificate, + friendly_name: bytes | None, + ): + if not isinstance(cert, x509.Certificate): + raise TypeError("Expecting x509.Certificate object") + if friendly_name is not None and not isinstance(friendly_name, bytes): + raise TypeError("friendly_name must be bytes or None") + self._cert = cert + self._friendly_name = friendly_name + + @property + def friendly_name(self) -> bytes | None: + return self._friendly_name + + @property + def certificate(self) -> x509.Certificate: + return self._cert + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PKCS12Certificate): + return NotImplemented + + return ( + self.certificate == other.certificate + and self.friendly_name == other.friendly_name + ) + + def __hash__(self) -> int: + return hash((self.certificate, self.friendly_name)) + + def __repr__(self) -> str: + return "".format( + self.certificate, self.friendly_name + ) + + +class PKCS12KeyAndCertificates: + def __init__( + self, + key: PrivateKeyTypes | None, + cert: PKCS12Certificate | None, + additional_certs: list[PKCS12Certificate], + ): + if key is not None and not isinstance( + key, + ( + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, + ), + ): + raise TypeError( + "Key must be RSA, DSA, EllipticCurve, ED25519, or ED448" + " private key, or None." + ) + if cert is not None and not isinstance(cert, PKCS12Certificate): + raise TypeError("cert must be a PKCS12Certificate object or None") + if not all( + isinstance(add_cert, PKCS12Certificate) + for add_cert in additional_certs + ): + raise TypeError( + "all values in additional_certs must be PKCS12Certificate" + " objects" + ) + self._key = key + self._cert = cert + self._additional_certs = additional_certs + + @property + def key(self) -> PrivateKeyTypes | None: + return self._key + + @property + def cert(self) -> PKCS12Certificate | None: + return self._cert + + @property + def additional_certs(self) -> list[PKCS12Certificate]: + return self._additional_certs + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PKCS12KeyAndCertificates): + return NotImplemented + + return ( + self.key == other.key + and self.cert == other.cert + and self.additional_certs == other.additional_certs + ) + + def __hash__(self) -> int: + return hash((self.key, self.cert, tuple(self.additional_certs))) + + def __repr__(self) -> str: + fmt = ( + "" + ) + return fmt.format(self.key, self.cert, self.additional_certs) + + +def load_key_and_certificates( + data: bytes, + password: bytes | None, + backend: typing.Any = None, +) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], +]: + from cryptography.hazmat.backends.openssl.backend import backend as ossl + + return ossl.load_key_and_certificates_from_pkcs12(data, password) + + +def load_pkcs12( + data: bytes, + password: bytes | None, + backend: typing.Any = None, +) -> PKCS12KeyAndCertificates: + from cryptography.hazmat.backends.openssl.backend import backend as ossl + + return ossl.load_pkcs12(data, password) + + +_PKCS12CATypes = typing.Union[ + x509.Certificate, + PKCS12Certificate, +] + + +def serialize_key_and_certificates( + name: bytes | None, + key: PKCS12PrivateKeyTypes | None, + cert: x509.Certificate | None, + cas: typing.Iterable[_PKCS12CATypes] | None, + encryption_algorithm: serialization.KeySerializationEncryption, +) -> bytes: + if key is not None and not isinstance( + key, + ( + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ec.EllipticCurvePrivateKey, + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, + ), + ): + raise TypeError( + "Key must be RSA, DSA, EllipticCurve, ED25519, or ED448" + " private key, or None." + ) + if cert is not None and not isinstance(cert, x509.Certificate): + raise TypeError("cert must be a certificate or None") + + if cas is not None: + cas = list(cas) + if not all( + isinstance( + val, + ( + x509.Certificate, + PKCS12Certificate, + ), + ) + for val in cas + ): + raise TypeError("all values in cas must be certificates") + + if not isinstance( + encryption_algorithm, serialization.KeySerializationEncryption + ): + raise TypeError( + "Key encryption algorithm must be a " + "KeySerializationEncryption instance" + ) + + if key is None and cert is None and not cas: + raise ValueError("You must supply at least one of key, cert, or cas") + + from cryptography.hazmat.backends.openssl.backend import backend + + return backend.serialize_key_and_certificates_to_pkcs12( + name, key, cert, cas, encryption_algorithm + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs7.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs7.py new file mode 100644 index 00000000..bae35c5f --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/pkcs7.py @@ -0,0 +1,233 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import email.base64mime +import email.generator +import email.message +import email.policy +import io +import typing + +from cryptography import utils, x509 +from cryptography.hazmat.bindings._rust import pkcs7 as rust_pkcs7 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa +from cryptography.utils import _check_byteslike + +load_pem_pkcs7_certificates = rust_pkcs7.load_pem_pkcs7_certificates + +load_der_pkcs7_certificates = rust_pkcs7.load_der_pkcs7_certificates + +serialize_certificates = rust_pkcs7.serialize_certificates + +PKCS7HashTypes = typing.Union[ + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, +] + +PKCS7PrivateKeyTypes = typing.Union[ + rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey +] + + +class PKCS7Options(utils.Enum): + Text = "Add text/plain MIME type" + Binary = "Don't translate input data into canonical MIME format" + DetachedSignature = "Don't embed data in the PKCS7 structure" + NoCapabilities = "Don't embed SMIME capabilities" + NoAttributes = "Don't embed authenticatedAttributes" + NoCerts = "Don't embed signer certificate" + + +class PKCS7SignatureBuilder: + def __init__( + self, + data: bytes | None = None, + signers: list[ + tuple[ + x509.Certificate, + PKCS7PrivateKeyTypes, + PKCS7HashTypes, + padding.PSS | padding.PKCS1v15 | None, + ] + ] = [], + additional_certs: list[x509.Certificate] = [], + ): + self._data = data + self._signers = signers + self._additional_certs = additional_certs + + def set_data(self, data: bytes) -> PKCS7SignatureBuilder: + _check_byteslike("data", data) + if self._data is not None: + raise ValueError("data may only be set once") + + return PKCS7SignatureBuilder(data, self._signers) + + def add_signer( + self, + certificate: x509.Certificate, + private_key: PKCS7PrivateKeyTypes, + hash_algorithm: PKCS7HashTypes, + *, + rsa_padding: padding.PSS | padding.PKCS1v15 | None = None, + ) -> PKCS7SignatureBuilder: + if not isinstance( + hash_algorithm, + ( + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, + ), + ): + raise TypeError( + "hash_algorithm must be one of hashes.SHA224, " + "SHA256, SHA384, or SHA512" + ) + if not isinstance(certificate, x509.Certificate): + raise TypeError("certificate must be a x509.Certificate") + + if not isinstance( + private_key, (rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey) + ): + raise TypeError("Only RSA & EC keys are supported at this time.") + + if rsa_padding is not None: + if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)): + raise TypeError("Padding must be PSS or PKCS1v15") + if not isinstance(private_key, rsa.RSAPrivateKey): + raise TypeError("Padding is only supported for RSA keys") + + return PKCS7SignatureBuilder( + self._data, + [ + *self._signers, + (certificate, private_key, hash_algorithm, rsa_padding), + ], + ) + + def add_certificate( + self, certificate: x509.Certificate + ) -> PKCS7SignatureBuilder: + if not isinstance(certificate, x509.Certificate): + raise TypeError("certificate must be a x509.Certificate") + + return PKCS7SignatureBuilder( + self._data, self._signers, [*self._additional_certs, certificate] + ) + + def sign( + self, + encoding: serialization.Encoding, + options: typing.Iterable[PKCS7Options], + backend: typing.Any = None, + ) -> bytes: + if len(self._signers) == 0: + raise ValueError("Must have at least one signer") + if self._data is None: + raise ValueError("You must add data to sign") + options = list(options) + if not all(isinstance(x, PKCS7Options) for x in options): + raise ValueError("options must be from the PKCS7Options enum") + if encoding not in ( + serialization.Encoding.PEM, + serialization.Encoding.DER, + serialization.Encoding.SMIME, + ): + raise ValueError( + "Must be PEM, DER, or SMIME from the Encoding enum" + ) + + # Text is a meaningless option unless it is accompanied by + # DetachedSignature + if ( + PKCS7Options.Text in options + and PKCS7Options.DetachedSignature not in options + ): + raise ValueError( + "When passing the Text option you must also pass " + "DetachedSignature" + ) + + if PKCS7Options.Text in options and encoding in ( + serialization.Encoding.DER, + serialization.Encoding.PEM, + ): + raise ValueError( + "The Text option is only available for SMIME serialization" + ) + + # No attributes implies no capabilities so we'll error if you try to + # pass both. + if ( + PKCS7Options.NoAttributes in options + and PKCS7Options.NoCapabilities in options + ): + raise ValueError( + "NoAttributes is a superset of NoCapabilities. Do not pass " + "both values." + ) + + return rust_pkcs7.sign_and_serialize(self, encoding, options) + + +def _smime_encode( + data: bytes, signature: bytes, micalg: str, text_mode: bool +) -> bytes: + # This function works pretty hard to replicate what OpenSSL does + # precisely. For good and for ill. + + m = email.message.Message() + m.add_header("MIME-Version", "1.0") + m.add_header( + "Content-Type", + "multipart/signed", + protocol="application/x-pkcs7-signature", + micalg=micalg, + ) + + m.preamble = "This is an S/MIME signed message\n" + + msg_part = OpenSSLMimePart() + msg_part.set_payload(data) + if text_mode: + msg_part.add_header("Content-Type", "text/plain") + m.attach(msg_part) + + sig_part = email.message.MIMEPart() + sig_part.add_header( + "Content-Type", "application/x-pkcs7-signature", name="smime.p7s" + ) + sig_part.add_header("Content-Transfer-Encoding", "base64") + sig_part.add_header( + "Content-Disposition", "attachment", filename="smime.p7s" + ) + sig_part.set_payload( + email.base64mime.body_encode(signature, maxlinelen=65) + ) + del sig_part["MIME-Version"] + m.attach(sig_part) + + fp = io.BytesIO() + g = email.generator.BytesGenerator( + fp, + maxheaderlen=0, + mangle_from_=False, + policy=m.policy.clone(linesep="\r\n"), + ) + g.flatten(m) + return fp.getvalue() + + +class OpenSSLMimePart(email.message.MIMEPart): + # A MIMEPart subclass that replicates OpenSSL's behavior of not including + # a newline if there are no headers. + def _write_headers(self, generator) -> None: + if list(self.raw_items()): + generator._write_headers(self) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/ssh.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/ssh.py new file mode 100644 index 00000000..f33edd55 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/serialization/ssh.py @@ -0,0 +1,1507 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import binascii +import enum +import os +import re +import typing +import warnings +from base64 import encodebytes as _base64_encode +from dataclasses import dataclass + +from cryptography import utils +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + ed25519, + padding, + rsa, +) +from cryptography.hazmat.primitives.asymmetric import utils as asym_utils +from cryptography.hazmat.primitives.ciphers import ( + AEADDecryptionContext, + Cipher, + algorithms, + modes, +) +from cryptography.hazmat.primitives.serialization import ( + Encoding, + KeySerializationEncryption, + NoEncryption, + PrivateFormat, + PublicFormat, + _KeySerializationEncryption, +) + +try: + from bcrypt import kdf as _bcrypt_kdf + + _bcrypt_supported = True +except ImportError: + _bcrypt_supported = False + + def _bcrypt_kdf( + password: bytes, + salt: bytes, + desired_key_bytes: int, + rounds: int, + ignore_few_rounds: bool = False, + ) -> bytes: + raise UnsupportedAlgorithm("Need bcrypt module") + + +_SSH_ED25519 = b"ssh-ed25519" +_SSH_RSA = b"ssh-rsa" +_SSH_DSA = b"ssh-dss" +_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" +_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" +_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" +_CERT_SUFFIX = b"-cert-v01@openssh.com" + +# These are not key types, only algorithms, so they cannot appear +# as a public key type +_SSH_RSA_SHA256 = b"rsa-sha2-256" +_SSH_RSA_SHA512 = b"rsa-sha2-512" + +_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") +_SK_MAGIC = b"openssh-key-v1\0" +_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" +_SK_END = b"-----END OPENSSH PRIVATE KEY-----" +_BCRYPT = b"bcrypt" +_NONE = b"none" +_DEFAULT_CIPHER = b"aes256-ctr" +_DEFAULT_ROUNDS = 16 + +# re is only way to work on bytes-like data +_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) + +# padding for max blocksize +_PADDING = memoryview(bytearray(range(1, 1 + 16))) + + +@dataclass +class _SSHCipher: + alg: type[algorithms.AES] + key_len: int + mode: type[modes.CTR] | type[modes.CBC] | type[modes.GCM] + block_len: int + iv_len: int + tag_len: int | None + is_aead: bool + + +# ciphers that are actually used in key wrapping +_SSH_CIPHERS: dict[bytes, _SSHCipher] = { + b"aes256-ctr": _SSHCipher( + alg=algorithms.AES, + key_len=32, + mode=modes.CTR, + block_len=16, + iv_len=16, + tag_len=None, + is_aead=False, + ), + b"aes256-cbc": _SSHCipher( + alg=algorithms.AES, + key_len=32, + mode=modes.CBC, + block_len=16, + iv_len=16, + tag_len=None, + is_aead=False, + ), + b"aes256-gcm@openssh.com": _SSHCipher( + alg=algorithms.AES, + key_len=32, + mode=modes.GCM, + block_len=16, + iv_len=12, + tag_len=16, + is_aead=True, + ), +} + +# map local curve name to key type +_ECDSA_KEY_TYPE = { + "secp256r1": _ECDSA_NISTP256, + "secp384r1": _ECDSA_NISTP384, + "secp521r1": _ECDSA_NISTP521, +} + + +def _get_ssh_key_type(key: SSHPrivateKeyTypes | SSHPublicKeyTypes) -> bytes: + if isinstance(key, ec.EllipticCurvePrivateKey): + key_type = _ecdsa_key_type(key.public_key()) + elif isinstance(key, ec.EllipticCurvePublicKey): + key_type = _ecdsa_key_type(key) + elif isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)): + key_type = _SSH_RSA + elif isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)): + key_type = _SSH_DSA + elif isinstance( + key, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PublicKey) + ): + key_type = _SSH_ED25519 + else: + raise ValueError("Unsupported key type") + + return key_type + + +def _ecdsa_key_type(public_key: ec.EllipticCurvePublicKey) -> bytes: + """Return SSH key_type and curve_name for private key.""" + curve = public_key.curve + if curve.name not in _ECDSA_KEY_TYPE: + raise ValueError( + f"Unsupported curve for ssh private key: {curve.name!r}" + ) + return _ECDSA_KEY_TYPE[curve.name] + + +def _ssh_pem_encode( + data: bytes, + prefix: bytes = _SK_START + b"\n", + suffix: bytes = _SK_END + b"\n", +) -> bytes: + return b"".join([prefix, _base64_encode(data), suffix]) + + +def _check_block_size(data: bytes, block_len: int) -> None: + """Require data to be full blocks""" + if not data or len(data) % block_len != 0: + raise ValueError("Corrupt data: missing padding") + + +def _check_empty(data: bytes) -> None: + """All data should have been parsed.""" + if data: + raise ValueError("Corrupt data: unparsed data") + + +def _init_cipher( + ciphername: bytes, + password: bytes | None, + salt: bytes, + rounds: int, +) -> Cipher[modes.CBC | modes.CTR | modes.GCM]: + """Generate key + iv and return cipher.""" + if not password: + raise ValueError("Key is password-protected.") + + ciph = _SSH_CIPHERS[ciphername] + seed = _bcrypt_kdf( + password, salt, ciph.key_len + ciph.iv_len, rounds, True + ) + return Cipher( + ciph.alg(seed[: ciph.key_len]), + ciph.mode(seed[ciph.key_len :]), + ) + + +def _get_u32(data: memoryview) -> tuple[int, memoryview]: + """Uint32""" + if len(data) < 4: + raise ValueError("Invalid data") + return int.from_bytes(data[:4], byteorder="big"), data[4:] + + +def _get_u64(data: memoryview) -> tuple[int, memoryview]: + """Uint64""" + if len(data) < 8: + raise ValueError("Invalid data") + return int.from_bytes(data[:8], byteorder="big"), data[8:] + + +def _get_sshstr(data: memoryview) -> tuple[memoryview, memoryview]: + """Bytes with u32 length prefix""" + n, data = _get_u32(data) + if n > len(data): + raise ValueError("Invalid data") + return data[:n], data[n:] + + +def _get_mpint(data: memoryview) -> tuple[int, memoryview]: + """Big integer.""" + val, data = _get_sshstr(data) + if val and val[0] > 0x7F: + raise ValueError("Invalid data") + return int.from_bytes(val, "big"), data + + +def _to_mpint(val: int) -> bytes: + """Storage format for signed bigint.""" + if val < 0: + raise ValueError("negative mpint not allowed") + if not val: + return b"" + nbytes = (val.bit_length() + 8) // 8 + return utils.int_to_bytes(val, nbytes) + + +class _FragList: + """Build recursive structure without data copy.""" + + flist: list[bytes] + + def __init__(self, init: list[bytes] | None = None) -> None: + self.flist = [] + if init: + self.flist.extend(init) + + def put_raw(self, val: bytes) -> None: + """Add plain bytes""" + self.flist.append(val) + + def put_u32(self, val: int) -> None: + """Big-endian uint32""" + self.flist.append(val.to_bytes(length=4, byteorder="big")) + + def put_u64(self, val: int) -> None: + """Big-endian uint64""" + self.flist.append(val.to_bytes(length=8, byteorder="big")) + + def put_sshstr(self, val: bytes | _FragList) -> None: + """Bytes prefixed with u32 length""" + if isinstance(val, (bytes, memoryview, bytearray)): + self.put_u32(len(val)) + self.flist.append(val) + else: + self.put_u32(val.size()) + self.flist.extend(val.flist) + + def put_mpint(self, val: int) -> None: + """Big-endian bigint prefixed with u32 length""" + self.put_sshstr(_to_mpint(val)) + + def size(self) -> int: + """Current number of bytes""" + return sum(map(len, self.flist)) + + def render(self, dstbuf: memoryview, pos: int = 0) -> int: + """Write into bytearray""" + for frag in self.flist: + flen = len(frag) + start, pos = pos, pos + flen + dstbuf[start:pos] = frag + return pos + + def tobytes(self) -> bytes: + """Return as bytes""" + buf = memoryview(bytearray(self.size())) + self.render(buf) + return buf.tobytes() + + +class _SSHFormatRSA: + """Format for RSA keys. + + Public: + mpint e, n + Private: + mpint n, e, d, iqmp, p, q + """ + + def get_public(self, data: memoryview): + """RSA public fields""" + e, data = _get_mpint(data) + n, data = _get_mpint(data) + return (e, n), data + + def load_public( + self, data: memoryview + ) -> tuple[rsa.RSAPublicKey, memoryview]: + """Make RSA public key from data.""" + (e, n), data = self.get_public(data) + public_numbers = rsa.RSAPublicNumbers(e, n) + public_key = public_numbers.public_key() + return public_key, data + + def load_private( + self, data: memoryview, pubfields + ) -> tuple[rsa.RSAPrivateKey, memoryview]: + """Make RSA private key from data.""" + n, data = _get_mpint(data) + e, data = _get_mpint(data) + d, data = _get_mpint(data) + iqmp, data = _get_mpint(data) + p, data = _get_mpint(data) + q, data = _get_mpint(data) + + if (e, n) != pubfields: + raise ValueError("Corrupt data: rsa field mismatch") + dmp1 = rsa.rsa_crt_dmp1(d, p) + dmq1 = rsa.rsa_crt_dmq1(d, q) + public_numbers = rsa.RSAPublicNumbers(e, n) + private_numbers = rsa.RSAPrivateNumbers( + p, q, d, dmp1, dmq1, iqmp, public_numbers + ) + private_key = private_numbers.private_key() + return private_key, data + + def encode_public( + self, public_key: rsa.RSAPublicKey, f_pub: _FragList + ) -> None: + """Write RSA public key""" + pubn = public_key.public_numbers() + f_pub.put_mpint(pubn.e) + f_pub.put_mpint(pubn.n) + + def encode_private( + self, private_key: rsa.RSAPrivateKey, f_priv: _FragList + ) -> None: + """Write RSA private key""" + private_numbers = private_key.private_numbers() + public_numbers = private_numbers.public_numbers + + f_priv.put_mpint(public_numbers.n) + f_priv.put_mpint(public_numbers.e) + + f_priv.put_mpint(private_numbers.d) + f_priv.put_mpint(private_numbers.iqmp) + f_priv.put_mpint(private_numbers.p) + f_priv.put_mpint(private_numbers.q) + + +class _SSHFormatDSA: + """Format for DSA keys. + + Public: + mpint p, q, g, y + Private: + mpint p, q, g, y, x + """ + + def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: + """DSA public fields""" + p, data = _get_mpint(data) + q, data = _get_mpint(data) + g, data = _get_mpint(data) + y, data = _get_mpint(data) + return (p, q, g, y), data + + def load_public( + self, data: memoryview + ) -> tuple[dsa.DSAPublicKey, memoryview]: + """Make DSA public key from data.""" + (p, q, g, y), data = self.get_public(data) + parameter_numbers = dsa.DSAParameterNumbers(p, q, g) + public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) + self._validate(public_numbers) + public_key = public_numbers.public_key() + return public_key, data + + def load_private( + self, data: memoryview, pubfields + ) -> tuple[dsa.DSAPrivateKey, memoryview]: + """Make DSA private key from data.""" + (p, q, g, y), data = self.get_public(data) + x, data = _get_mpint(data) + + if (p, q, g, y) != pubfields: + raise ValueError("Corrupt data: dsa field mismatch") + parameter_numbers = dsa.DSAParameterNumbers(p, q, g) + public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) + self._validate(public_numbers) + private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) + private_key = private_numbers.private_key() + return private_key, data + + def encode_public( + self, public_key: dsa.DSAPublicKey, f_pub: _FragList + ) -> None: + """Write DSA public key""" + public_numbers = public_key.public_numbers() + parameter_numbers = public_numbers.parameter_numbers + self._validate(public_numbers) + + f_pub.put_mpint(parameter_numbers.p) + f_pub.put_mpint(parameter_numbers.q) + f_pub.put_mpint(parameter_numbers.g) + f_pub.put_mpint(public_numbers.y) + + def encode_private( + self, private_key: dsa.DSAPrivateKey, f_priv: _FragList + ) -> None: + """Write DSA private key""" + self.encode_public(private_key.public_key(), f_priv) + f_priv.put_mpint(private_key.private_numbers().x) + + def _validate(self, public_numbers: dsa.DSAPublicNumbers) -> None: + parameter_numbers = public_numbers.parameter_numbers + if parameter_numbers.p.bit_length() != 1024: + raise ValueError("SSH supports only 1024 bit DSA keys") + + +class _SSHFormatECDSA: + """Format for ECDSA keys. + + Public: + str curve + bytes point + Private: + str curve + bytes point + mpint secret + """ + + def __init__(self, ssh_curve_name: bytes, curve: ec.EllipticCurve): + self.ssh_curve_name = ssh_curve_name + self.curve = curve + + def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: + """ECDSA public fields""" + curve, data = _get_sshstr(data) + point, data = _get_sshstr(data) + if curve != self.ssh_curve_name: + raise ValueError("Curve name mismatch") + if point[0] != 4: + raise NotImplementedError("Need uncompressed point") + return (curve, point), data + + def load_public( + self, data: memoryview + ) -> tuple[ec.EllipticCurvePublicKey, memoryview]: + """Make ECDSA public key from data.""" + (_, point), data = self.get_public(data) + public_key = ec.EllipticCurvePublicKey.from_encoded_point( + self.curve, point.tobytes() + ) + return public_key, data + + def load_private( + self, data: memoryview, pubfields + ) -> tuple[ec.EllipticCurvePrivateKey, memoryview]: + """Make ECDSA private key from data.""" + (curve_name, point), data = self.get_public(data) + secret, data = _get_mpint(data) + + if (curve_name, point) != pubfields: + raise ValueError("Corrupt data: ecdsa field mismatch") + private_key = ec.derive_private_key(secret, self.curve) + return private_key, data + + def encode_public( + self, public_key: ec.EllipticCurvePublicKey, f_pub: _FragList + ) -> None: + """Write ECDSA public key""" + point = public_key.public_bytes( + Encoding.X962, PublicFormat.UncompressedPoint + ) + f_pub.put_sshstr(self.ssh_curve_name) + f_pub.put_sshstr(point) + + def encode_private( + self, private_key: ec.EllipticCurvePrivateKey, f_priv: _FragList + ) -> None: + """Write ECDSA private key""" + public_key = private_key.public_key() + private_numbers = private_key.private_numbers() + + self.encode_public(public_key, f_priv) + f_priv.put_mpint(private_numbers.private_value) + + +class _SSHFormatEd25519: + """Format for Ed25519 keys. + + Public: + bytes point + Private: + bytes point + bytes secret_and_point + """ + + def get_public(self, data: memoryview) -> tuple[tuple, memoryview]: + """Ed25519 public fields""" + point, data = _get_sshstr(data) + return (point,), data + + def load_public( + self, data: memoryview + ) -> tuple[ed25519.Ed25519PublicKey, memoryview]: + """Make Ed25519 public key from data.""" + (point,), data = self.get_public(data) + public_key = ed25519.Ed25519PublicKey.from_public_bytes( + point.tobytes() + ) + return public_key, data + + def load_private( + self, data: memoryview, pubfields + ) -> tuple[ed25519.Ed25519PrivateKey, memoryview]: + """Make Ed25519 private key from data.""" + (point,), data = self.get_public(data) + keypair, data = _get_sshstr(data) + + secret = keypair[:32] + point2 = keypair[32:] + if point != point2 or (point,) != pubfields: + raise ValueError("Corrupt data: ed25519 field mismatch") + private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) + return private_key, data + + def encode_public( + self, public_key: ed25519.Ed25519PublicKey, f_pub: _FragList + ) -> None: + """Write Ed25519 public key""" + raw_public_key = public_key.public_bytes( + Encoding.Raw, PublicFormat.Raw + ) + f_pub.put_sshstr(raw_public_key) + + def encode_private( + self, private_key: ed25519.Ed25519PrivateKey, f_priv: _FragList + ) -> None: + """Write Ed25519 private key""" + public_key = private_key.public_key() + raw_private_key = private_key.private_bytes( + Encoding.Raw, PrivateFormat.Raw, NoEncryption() + ) + raw_public_key = public_key.public_bytes( + Encoding.Raw, PublicFormat.Raw + ) + f_keypair = _FragList([raw_private_key, raw_public_key]) + + self.encode_public(public_key, f_priv) + f_priv.put_sshstr(f_keypair) + + +_KEY_FORMATS = { + _SSH_RSA: _SSHFormatRSA(), + _SSH_DSA: _SSHFormatDSA(), + _SSH_ED25519: _SSHFormatEd25519(), + _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), + _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), + _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), +} + + +def _lookup_kformat(key_type: bytes): + """Return valid format or throw error""" + if not isinstance(key_type, bytes): + key_type = memoryview(key_type).tobytes() + if key_type in _KEY_FORMATS: + return _KEY_FORMATS[key_type] + raise UnsupportedAlgorithm(f"Unsupported key type: {key_type!r}") + + +SSHPrivateKeyTypes = typing.Union[ + ec.EllipticCurvePrivateKey, + rsa.RSAPrivateKey, + dsa.DSAPrivateKey, + ed25519.Ed25519PrivateKey, +] + + +def load_ssh_private_key( + data: bytes, + password: bytes | None, + backend: typing.Any = None, +) -> SSHPrivateKeyTypes: + """Load private key from OpenSSH custom encoding.""" + utils._check_byteslike("data", data) + if password is not None: + utils._check_bytes("password", password) + + m = _PEM_RC.search(data) + if not m: + raise ValueError("Not OpenSSH private key format") + p1 = m.start(1) + p2 = m.end(1) + data = binascii.a2b_base64(memoryview(data)[p1:p2]) + if not data.startswith(_SK_MAGIC): + raise ValueError("Not OpenSSH private key format") + data = memoryview(data)[len(_SK_MAGIC) :] + + # parse header + ciphername, data = _get_sshstr(data) + kdfname, data = _get_sshstr(data) + kdfoptions, data = _get_sshstr(data) + nkeys, data = _get_u32(data) + if nkeys != 1: + raise ValueError("Only one key supported") + + # load public key data + pubdata, data = _get_sshstr(data) + pub_key_type, pubdata = _get_sshstr(pubdata) + kformat = _lookup_kformat(pub_key_type) + pubfields, pubdata = kformat.get_public(pubdata) + _check_empty(pubdata) + + if (ciphername, kdfname) != (_NONE, _NONE): + ciphername_bytes = ciphername.tobytes() + if ciphername_bytes not in _SSH_CIPHERS: + raise UnsupportedAlgorithm( + f"Unsupported cipher: {ciphername_bytes!r}" + ) + if kdfname != _BCRYPT: + raise UnsupportedAlgorithm(f"Unsupported KDF: {kdfname!r}") + blklen = _SSH_CIPHERS[ciphername_bytes].block_len + tag_len = _SSH_CIPHERS[ciphername_bytes].tag_len + # load secret data + edata, data = _get_sshstr(data) + # see https://bugzilla.mindrot.org/show_bug.cgi?id=3553 for + # information about how OpenSSH handles AEAD tags + if _SSH_CIPHERS[ciphername_bytes].is_aead: + tag = bytes(data) + if len(tag) != tag_len: + raise ValueError("Corrupt data: invalid tag length for cipher") + else: + _check_empty(data) + _check_block_size(edata, blklen) + salt, kbuf = _get_sshstr(kdfoptions) + rounds, kbuf = _get_u32(kbuf) + _check_empty(kbuf) + ciph = _init_cipher(ciphername_bytes, password, salt.tobytes(), rounds) + dec = ciph.decryptor() + edata = memoryview(dec.update(edata)) + if _SSH_CIPHERS[ciphername_bytes].is_aead: + assert isinstance(dec, AEADDecryptionContext) + _check_empty(dec.finalize_with_tag(tag)) + else: + # _check_block_size requires data to be a full block so there + # should be no output from finalize + _check_empty(dec.finalize()) + else: + # load secret data + edata, data = _get_sshstr(data) + _check_empty(data) + blklen = 8 + _check_block_size(edata, blklen) + ck1, edata = _get_u32(edata) + ck2, edata = _get_u32(edata) + if ck1 != ck2: + raise ValueError("Corrupt data: broken checksum") + + # load per-key struct + key_type, edata = _get_sshstr(edata) + if key_type != pub_key_type: + raise ValueError("Corrupt data: key type mismatch") + private_key, edata = kformat.load_private(edata, pubfields) + # We don't use the comment + _, edata = _get_sshstr(edata) + + # yes, SSH does padding check *after* all other parsing is done. + # need to follow as it writes zero-byte padding too. + if edata != _PADDING[: len(edata)]: + raise ValueError("Corrupt data: invalid padding") + + if isinstance(private_key, dsa.DSAPrivateKey): + warnings.warn( + "SSH DSA keys are deprecated and will be removed in a future " + "release.", + utils.DeprecatedIn40, + stacklevel=2, + ) + + return private_key + + +def _serialize_ssh_private_key( + private_key: SSHPrivateKeyTypes, + password: bytes, + encryption_algorithm: KeySerializationEncryption, +) -> bytes: + """Serialize private key with OpenSSH custom encoding.""" + utils._check_bytes("password", password) + if isinstance(private_key, dsa.DSAPrivateKey): + warnings.warn( + "SSH DSA key support is deprecated and will be " + "removed in a future release", + utils.DeprecatedIn40, + stacklevel=4, + ) + + key_type = _get_ssh_key_type(private_key) + kformat = _lookup_kformat(key_type) + + # setup parameters + f_kdfoptions = _FragList() + if password: + ciphername = _DEFAULT_CIPHER + blklen = _SSH_CIPHERS[ciphername].block_len + kdfname = _BCRYPT + rounds = _DEFAULT_ROUNDS + if ( + isinstance(encryption_algorithm, _KeySerializationEncryption) + and encryption_algorithm._kdf_rounds is not None + ): + rounds = encryption_algorithm._kdf_rounds + salt = os.urandom(16) + f_kdfoptions.put_sshstr(salt) + f_kdfoptions.put_u32(rounds) + ciph = _init_cipher(ciphername, password, salt, rounds) + else: + ciphername = kdfname = _NONE + blklen = 8 + ciph = None + nkeys = 1 + checkval = os.urandom(4) + comment = b"" + + # encode public and private parts together + f_public_key = _FragList() + f_public_key.put_sshstr(key_type) + kformat.encode_public(private_key.public_key(), f_public_key) + + f_secrets = _FragList([checkval, checkval]) + f_secrets.put_sshstr(key_type) + kformat.encode_private(private_key, f_secrets) + f_secrets.put_sshstr(comment) + f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) + + # top-level structure + f_main = _FragList() + f_main.put_raw(_SK_MAGIC) + f_main.put_sshstr(ciphername) + f_main.put_sshstr(kdfname) + f_main.put_sshstr(f_kdfoptions) + f_main.put_u32(nkeys) + f_main.put_sshstr(f_public_key) + f_main.put_sshstr(f_secrets) + + # copy result info bytearray + slen = f_secrets.size() + mlen = f_main.size() + buf = memoryview(bytearray(mlen + blklen)) + f_main.render(buf) + ofs = mlen - slen + + # encrypt in-place + if ciph is not None: + ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) + + return _ssh_pem_encode(buf[:mlen]) + + +SSHPublicKeyTypes = typing.Union[ + ec.EllipticCurvePublicKey, + rsa.RSAPublicKey, + dsa.DSAPublicKey, + ed25519.Ed25519PublicKey, +] + +SSHCertPublicKeyTypes = typing.Union[ + ec.EllipticCurvePublicKey, + rsa.RSAPublicKey, + ed25519.Ed25519PublicKey, +] + + +class SSHCertificateType(enum.Enum): + USER = 1 + HOST = 2 + + +class SSHCertificate: + def __init__( + self, + _nonce: memoryview, + _public_key: SSHPublicKeyTypes, + _serial: int, + _cctype: int, + _key_id: memoryview, + _valid_principals: list[bytes], + _valid_after: int, + _valid_before: int, + _critical_options: dict[bytes, bytes], + _extensions: dict[bytes, bytes], + _sig_type: memoryview, + _sig_key: memoryview, + _inner_sig_type: memoryview, + _signature: memoryview, + _tbs_cert_body: memoryview, + _cert_key_type: bytes, + _cert_body: memoryview, + ): + self._nonce = _nonce + self._public_key = _public_key + self._serial = _serial + try: + self._type = SSHCertificateType(_cctype) + except ValueError: + raise ValueError("Invalid certificate type") + self._key_id = _key_id + self._valid_principals = _valid_principals + self._valid_after = _valid_after + self._valid_before = _valid_before + self._critical_options = _critical_options + self._extensions = _extensions + self._sig_type = _sig_type + self._sig_key = _sig_key + self._inner_sig_type = _inner_sig_type + self._signature = _signature + self._cert_key_type = _cert_key_type + self._cert_body = _cert_body + self._tbs_cert_body = _tbs_cert_body + + @property + def nonce(self) -> bytes: + return bytes(self._nonce) + + def public_key(self) -> SSHCertPublicKeyTypes: + # make mypy happy until we remove DSA support entirely and + # the underlying union won't have a disallowed type + return typing.cast(SSHCertPublicKeyTypes, self._public_key) + + @property + def serial(self) -> int: + return self._serial + + @property + def type(self) -> SSHCertificateType: + return self._type + + @property + def key_id(self) -> bytes: + return bytes(self._key_id) + + @property + def valid_principals(self) -> list[bytes]: + return self._valid_principals + + @property + def valid_before(self) -> int: + return self._valid_before + + @property + def valid_after(self) -> int: + return self._valid_after + + @property + def critical_options(self) -> dict[bytes, bytes]: + return self._critical_options + + @property + def extensions(self) -> dict[bytes, bytes]: + return self._extensions + + def signature_key(self) -> SSHCertPublicKeyTypes: + sigformat = _lookup_kformat(self._sig_type) + signature_key, sigkey_rest = sigformat.load_public(self._sig_key) + _check_empty(sigkey_rest) + return signature_key + + def public_bytes(self) -> bytes: + return ( + bytes(self._cert_key_type) + + b" " + + binascii.b2a_base64(bytes(self._cert_body), newline=False) + ) + + def verify_cert_signature(self) -> None: + signature_key = self.signature_key() + if isinstance(signature_key, ed25519.Ed25519PublicKey): + signature_key.verify( + bytes(self._signature), bytes(self._tbs_cert_body) + ) + elif isinstance(signature_key, ec.EllipticCurvePublicKey): + # The signature is encoded as a pair of big-endian integers + r, data = _get_mpint(self._signature) + s, data = _get_mpint(data) + _check_empty(data) + computed_sig = asym_utils.encode_dss_signature(r, s) + hash_alg = _get_ec_hash_alg(signature_key.curve) + signature_key.verify( + computed_sig, bytes(self._tbs_cert_body), ec.ECDSA(hash_alg) + ) + else: + assert isinstance(signature_key, rsa.RSAPublicKey) + if self._inner_sig_type == _SSH_RSA: + hash_alg = hashes.SHA1() + elif self._inner_sig_type == _SSH_RSA_SHA256: + hash_alg = hashes.SHA256() + else: + assert self._inner_sig_type == _SSH_RSA_SHA512 + hash_alg = hashes.SHA512() + signature_key.verify( + bytes(self._signature), + bytes(self._tbs_cert_body), + padding.PKCS1v15(), + hash_alg, + ) + + +def _get_ec_hash_alg(curve: ec.EllipticCurve) -> hashes.HashAlgorithm: + if isinstance(curve, ec.SECP256R1): + return hashes.SHA256() + elif isinstance(curve, ec.SECP384R1): + return hashes.SHA384() + else: + assert isinstance(curve, ec.SECP521R1) + return hashes.SHA512() + + +def _load_ssh_public_identity( + data: bytes, + _legacy_dsa_allowed=False, +) -> SSHCertificate | SSHPublicKeyTypes: + utils._check_byteslike("data", data) + + m = _SSH_PUBKEY_RC.match(data) + if not m: + raise ValueError("Invalid line format") + key_type = orig_key_type = m.group(1) + key_body = m.group(2) + with_cert = False + if key_type.endswith(_CERT_SUFFIX): + with_cert = True + key_type = key_type[: -len(_CERT_SUFFIX)] + if key_type == _SSH_DSA and not _legacy_dsa_allowed: + raise UnsupportedAlgorithm( + "DSA keys aren't supported in SSH certificates" + ) + kformat = _lookup_kformat(key_type) + + try: + rest = memoryview(binascii.a2b_base64(key_body)) + except (TypeError, binascii.Error): + raise ValueError("Invalid format") + + if with_cert: + cert_body = rest + inner_key_type, rest = _get_sshstr(rest) + if inner_key_type != orig_key_type: + raise ValueError("Invalid key format") + if with_cert: + nonce, rest = _get_sshstr(rest) + public_key, rest = kformat.load_public(rest) + if with_cert: + serial, rest = _get_u64(rest) + cctype, rest = _get_u32(rest) + key_id, rest = _get_sshstr(rest) + principals, rest = _get_sshstr(rest) + valid_principals = [] + while principals: + principal, principals = _get_sshstr(principals) + valid_principals.append(bytes(principal)) + valid_after, rest = _get_u64(rest) + valid_before, rest = _get_u64(rest) + crit_options, rest = _get_sshstr(rest) + critical_options = _parse_exts_opts(crit_options) + exts, rest = _get_sshstr(rest) + extensions = _parse_exts_opts(exts) + # Get the reserved field, which is unused. + _, rest = _get_sshstr(rest) + sig_key_raw, rest = _get_sshstr(rest) + sig_type, sig_key = _get_sshstr(sig_key_raw) + if sig_type == _SSH_DSA and not _legacy_dsa_allowed: + raise UnsupportedAlgorithm( + "DSA signatures aren't supported in SSH certificates" + ) + # Get the entire cert body and subtract the signature + tbs_cert_body = cert_body[: -len(rest)] + signature_raw, rest = _get_sshstr(rest) + _check_empty(rest) + inner_sig_type, sig_rest = _get_sshstr(signature_raw) + # RSA certs can have multiple algorithm types + if ( + sig_type == _SSH_RSA + and inner_sig_type + not in [_SSH_RSA_SHA256, _SSH_RSA_SHA512, _SSH_RSA] + ) or (sig_type != _SSH_RSA and inner_sig_type != sig_type): + raise ValueError("Signature key type does not match") + signature, sig_rest = _get_sshstr(sig_rest) + _check_empty(sig_rest) + return SSHCertificate( + nonce, + public_key, + serial, + cctype, + key_id, + valid_principals, + valid_after, + valid_before, + critical_options, + extensions, + sig_type, + sig_key, + inner_sig_type, + signature, + tbs_cert_body, + orig_key_type, + cert_body, + ) + else: + _check_empty(rest) + return public_key + + +def load_ssh_public_identity( + data: bytes, +) -> SSHCertificate | SSHPublicKeyTypes: + return _load_ssh_public_identity(data) + + +def _parse_exts_opts(exts_opts: memoryview) -> dict[bytes, bytes]: + result: dict[bytes, bytes] = {} + last_name = None + while exts_opts: + name, exts_opts = _get_sshstr(exts_opts) + bname: bytes = bytes(name) + if bname in result: + raise ValueError("Duplicate name") + if last_name is not None and bname < last_name: + raise ValueError("Fields not lexically sorted") + value, exts_opts = _get_sshstr(exts_opts) + if len(value) > 0: + value, extra = _get_sshstr(value) + if len(extra) > 0: + raise ValueError("Unexpected extra data after value") + result[bname] = bytes(value) + last_name = bname + return result + + +def load_ssh_public_key( + data: bytes, backend: typing.Any = None +) -> SSHPublicKeyTypes: + cert_or_key = _load_ssh_public_identity(data, _legacy_dsa_allowed=True) + public_key: SSHPublicKeyTypes + if isinstance(cert_or_key, SSHCertificate): + public_key = cert_or_key.public_key() + else: + public_key = cert_or_key + + if isinstance(public_key, dsa.DSAPublicKey): + warnings.warn( + "SSH DSA keys are deprecated and will be removed in a future " + "release.", + utils.DeprecatedIn40, + stacklevel=2, + ) + return public_key + + +def serialize_ssh_public_key(public_key: SSHPublicKeyTypes) -> bytes: + """One-line public key format for OpenSSH""" + if isinstance(public_key, dsa.DSAPublicKey): + warnings.warn( + "SSH DSA key support is deprecated and will be " + "removed in a future release", + utils.DeprecatedIn40, + stacklevel=4, + ) + key_type = _get_ssh_key_type(public_key) + kformat = _lookup_kformat(key_type) + + f_pub = _FragList() + f_pub.put_sshstr(key_type) + kformat.encode_public(public_key, f_pub) + + pub = binascii.b2a_base64(f_pub.tobytes()).strip() + return b"".join([key_type, b" ", pub]) + + +SSHCertPrivateKeyTypes = typing.Union[ + ec.EllipticCurvePrivateKey, + rsa.RSAPrivateKey, + ed25519.Ed25519PrivateKey, +] + + +# This is an undocumented limit enforced in the openssh codebase for sshd and +# ssh-keygen, but it is undefined in the ssh certificates spec. +_SSHKEY_CERT_MAX_PRINCIPALS = 256 + + +class SSHCertificateBuilder: + def __init__( + self, + _public_key: SSHCertPublicKeyTypes | None = None, + _serial: int | None = None, + _type: SSHCertificateType | None = None, + _key_id: bytes | None = None, + _valid_principals: list[bytes] = [], + _valid_for_all_principals: bool = False, + _valid_before: int | None = None, + _valid_after: int | None = None, + _critical_options: list[tuple[bytes, bytes]] = [], + _extensions: list[tuple[bytes, bytes]] = [], + ): + self._public_key = _public_key + self._serial = _serial + self._type = _type + self._key_id = _key_id + self._valid_principals = _valid_principals + self._valid_for_all_principals = _valid_for_all_principals + self._valid_before = _valid_before + self._valid_after = _valid_after + self._critical_options = _critical_options + self._extensions = _extensions + + def public_key( + self, public_key: SSHCertPublicKeyTypes + ) -> SSHCertificateBuilder: + if not isinstance( + public_key, + ( + ec.EllipticCurvePublicKey, + rsa.RSAPublicKey, + ed25519.Ed25519PublicKey, + ), + ): + raise TypeError("Unsupported key type") + if self._public_key is not None: + raise ValueError("public_key already set") + + return SSHCertificateBuilder( + _public_key=public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def serial(self, serial: int) -> SSHCertificateBuilder: + if not isinstance(serial, int): + raise TypeError("serial must be an integer") + if not 0 <= serial < 2**64: + raise ValueError("serial must be between 0 and 2**64") + if self._serial is not None: + raise ValueError("serial already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def type(self, type: SSHCertificateType) -> SSHCertificateBuilder: + if not isinstance(type, SSHCertificateType): + raise TypeError("type must be an SSHCertificateType") + if self._type is not None: + raise ValueError("type already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def key_id(self, key_id: bytes) -> SSHCertificateBuilder: + if not isinstance(key_id, bytes): + raise TypeError("key_id must be bytes") + if self._key_id is not None: + raise ValueError("key_id already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def valid_principals( + self, valid_principals: list[bytes] + ) -> SSHCertificateBuilder: + if self._valid_for_all_principals: + raise ValueError( + "Principals can't be set because the cert is valid " + "for all principals" + ) + if ( + not all(isinstance(x, bytes) for x in valid_principals) + or not valid_principals + ): + raise TypeError( + "principals must be a list of bytes and can't be empty" + ) + if self._valid_principals: + raise ValueError("valid_principals already set") + + if len(valid_principals) > _SSHKEY_CERT_MAX_PRINCIPALS: + raise ValueError( + "Reached or exceeded the maximum number of valid_principals" + ) + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def valid_for_all_principals(self): + if self._valid_principals: + raise ValueError( + "valid_principals already set, can't set " + "valid_for_all_principals" + ) + if self._valid_for_all_principals: + raise ValueError("valid_for_all_principals already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=True, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def valid_before(self, valid_before: int | float) -> SSHCertificateBuilder: + if not isinstance(valid_before, (int, float)): + raise TypeError("valid_before must be an int or float") + valid_before = int(valid_before) + if valid_before < 0 or valid_before >= 2**64: + raise ValueError("valid_before must [0, 2**64)") + if self._valid_before is not None: + raise ValueError("valid_before already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def valid_after(self, valid_after: int | float) -> SSHCertificateBuilder: + if not isinstance(valid_after, (int, float)): + raise TypeError("valid_after must be an int or float") + valid_after = int(valid_after) + if valid_after < 0 or valid_after >= 2**64: + raise ValueError("valid_after must [0, 2**64)") + if self._valid_after is not None: + raise ValueError("valid_after already set") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=valid_after, + _critical_options=self._critical_options, + _extensions=self._extensions, + ) + + def add_critical_option( + self, name: bytes, value: bytes + ) -> SSHCertificateBuilder: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise TypeError("name and value must be bytes") + # This is O(n**2) + if name in [name for name, _ in self._critical_options]: + raise ValueError("Duplicate critical option name") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=[*self._critical_options, (name, value)], + _extensions=self._extensions, + ) + + def add_extension( + self, name: bytes, value: bytes + ) -> SSHCertificateBuilder: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise TypeError("name and value must be bytes") + # This is O(n**2) + if name in [name for name, _ in self._extensions]: + raise ValueError("Duplicate extension name") + + return SSHCertificateBuilder( + _public_key=self._public_key, + _serial=self._serial, + _type=self._type, + _key_id=self._key_id, + _valid_principals=self._valid_principals, + _valid_for_all_principals=self._valid_for_all_principals, + _valid_before=self._valid_before, + _valid_after=self._valid_after, + _critical_options=self._critical_options, + _extensions=[*self._extensions, (name, value)], + ) + + def sign(self, private_key: SSHCertPrivateKeyTypes) -> SSHCertificate: + if not isinstance( + private_key, + ( + ec.EllipticCurvePrivateKey, + rsa.RSAPrivateKey, + ed25519.Ed25519PrivateKey, + ), + ): + raise TypeError("Unsupported private key type") + + if self._public_key is None: + raise ValueError("public_key must be set") + + # Not required + serial = 0 if self._serial is None else self._serial + + if self._type is None: + raise ValueError("type must be set") + + # Not required + key_id = b"" if self._key_id is None else self._key_id + + # A zero length list is valid, but means the certificate + # is valid for any principal of the specified type. We require + # the user to explicitly set valid_for_all_principals to get + # that behavior. + if not self._valid_principals and not self._valid_for_all_principals: + raise ValueError( + "valid_principals must be set if valid_for_all_principals " + "is False" + ) + + if self._valid_before is None: + raise ValueError("valid_before must be set") + + if self._valid_after is None: + raise ValueError("valid_after must be set") + + if self._valid_after > self._valid_before: + raise ValueError("valid_after must be earlier than valid_before") + + # lexically sort our byte strings + self._critical_options.sort(key=lambda x: x[0]) + self._extensions.sort(key=lambda x: x[0]) + + key_type = _get_ssh_key_type(self._public_key) + cert_prefix = key_type + _CERT_SUFFIX + + # Marshal the bytes to be signed + nonce = os.urandom(32) + kformat = _lookup_kformat(key_type) + f = _FragList() + f.put_sshstr(cert_prefix) + f.put_sshstr(nonce) + kformat.encode_public(self._public_key, f) + f.put_u64(serial) + f.put_u32(self._type.value) + f.put_sshstr(key_id) + fprincipals = _FragList() + for p in self._valid_principals: + fprincipals.put_sshstr(p) + f.put_sshstr(fprincipals.tobytes()) + f.put_u64(self._valid_after) + f.put_u64(self._valid_before) + fcrit = _FragList() + for name, value in self._critical_options: + fcrit.put_sshstr(name) + if len(value) > 0: + foptval = _FragList() + foptval.put_sshstr(value) + fcrit.put_sshstr(foptval.tobytes()) + else: + fcrit.put_sshstr(value) + f.put_sshstr(fcrit.tobytes()) + fext = _FragList() + for name, value in self._extensions: + fext.put_sshstr(name) + if len(value) > 0: + fextval = _FragList() + fextval.put_sshstr(value) + fext.put_sshstr(fextval.tobytes()) + else: + fext.put_sshstr(value) + f.put_sshstr(fext.tobytes()) + f.put_sshstr(b"") # RESERVED FIELD + # encode CA public key + ca_type = _get_ssh_key_type(private_key) + caformat = _lookup_kformat(ca_type) + caf = _FragList() + caf.put_sshstr(ca_type) + caformat.encode_public(private_key.public_key(), caf) + f.put_sshstr(caf.tobytes()) + # Sigs according to the rules defined for the CA's public key + # (RFC4253 section 6.6 for ssh-rsa, RFC5656 for ECDSA, + # and RFC8032 for Ed25519). + if isinstance(private_key, ed25519.Ed25519PrivateKey): + signature = private_key.sign(f.tobytes()) + fsig = _FragList() + fsig.put_sshstr(ca_type) + fsig.put_sshstr(signature) + f.put_sshstr(fsig.tobytes()) + elif isinstance(private_key, ec.EllipticCurvePrivateKey): + hash_alg = _get_ec_hash_alg(private_key.curve) + signature = private_key.sign(f.tobytes(), ec.ECDSA(hash_alg)) + r, s = asym_utils.decode_dss_signature(signature) + fsig = _FragList() + fsig.put_sshstr(ca_type) + fsigblob = _FragList() + fsigblob.put_mpint(r) + fsigblob.put_mpint(s) + fsig.put_sshstr(fsigblob.tobytes()) + f.put_sshstr(fsig.tobytes()) + + else: + assert isinstance(private_key, rsa.RSAPrivateKey) + # Just like Golang, we're going to use SHA512 for RSA + # https://cs.opensource.google/go/x/crypto/+/refs/tags/ + # v0.4.0:ssh/certs.go;l=445 + # RFC 8332 defines SHA256 and 512 as options + fsig = _FragList() + fsig.put_sshstr(_SSH_RSA_SHA512) + signature = private_key.sign( + f.tobytes(), padding.PKCS1v15(), hashes.SHA512() + ) + fsig.put_sshstr(signature) + f.put_sshstr(fsig.tobytes()) + + cert_data = binascii.b2a_base64(f.tobytes()).strip() + # load_ssh_public_identity returns a union, but this is + # guaranteed to be an SSHCertificate, so we cast to make + # mypy happy. + return typing.cast( + SSHCertificate, + load_ssh_public_identity(b"".join([cert_prefix, b" ", cert_data])), + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__init__.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__init__.py new file mode 100644 index 00000000..c1af4230 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__init__.py @@ -0,0 +1,9 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + + +class InvalidToken(Exception): + pass diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..65416fde Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/hotp.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/hotp.cpython-311.pyc new file mode 100644 index 00000000..66684e1f Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/hotp.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/totp.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/totp.cpython-311.pyc new file mode 100644 index 00000000..d709bbdf Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/__pycache__/totp.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/hotp.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/hotp.py new file mode 100644 index 00000000..af5ab6ef --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/hotp.py @@ -0,0 +1,92 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import base64 +import typing +from urllib.parse import quote, urlencode + +from cryptography.hazmat.primitives import constant_time, hmac +from cryptography.hazmat.primitives.hashes import SHA1, SHA256, SHA512 +from cryptography.hazmat.primitives.twofactor import InvalidToken + +HOTPHashTypes = typing.Union[SHA1, SHA256, SHA512] + + +def _generate_uri( + hotp: HOTP, + type_name: str, + account_name: str, + issuer: str | None, + extra_parameters: list[tuple[str, int]], +) -> str: + parameters = [ + ("digits", hotp._length), + ("secret", base64.b32encode(hotp._key)), + ("algorithm", hotp._algorithm.name.upper()), + ] + + if issuer is not None: + parameters.append(("issuer", issuer)) + + parameters.extend(extra_parameters) + + label = ( + f"{quote(issuer)}:{quote(account_name)}" + if issuer + else quote(account_name) + ) + return f"otpauth://{type_name}/{label}?{urlencode(parameters)}" + + +class HOTP: + def __init__( + self, + key: bytes, + length: int, + algorithm: HOTPHashTypes, + backend: typing.Any = None, + enforce_key_length: bool = True, + ) -> None: + if len(key) < 16 and enforce_key_length is True: + raise ValueError("Key length has to be at least 128 bits.") + + if not isinstance(length, int): + raise TypeError("Length parameter must be an integer type.") + + if length < 6 or length > 8: + raise ValueError("Length of HOTP has to be between 6 and 8.") + + if not isinstance(algorithm, (SHA1, SHA256, SHA512)): + raise TypeError("Algorithm must be SHA1, SHA256 or SHA512.") + + self._key = key + self._length = length + self._algorithm = algorithm + + def generate(self, counter: int) -> bytes: + truncated_value = self._dynamic_truncate(counter) + hotp = truncated_value % (10**self._length) + return "{0:0{1}}".format(hotp, self._length).encode() + + def verify(self, hotp: bytes, counter: int) -> None: + if not constant_time.bytes_eq(self.generate(counter), hotp): + raise InvalidToken("Supplied HOTP value does not match.") + + def _dynamic_truncate(self, counter: int) -> int: + ctx = hmac.HMAC(self._key, self._algorithm) + ctx.update(counter.to_bytes(length=8, byteorder="big")) + hmac_value = ctx.finalize() + + offset = hmac_value[len(hmac_value) - 1] & 0b1111 + p = hmac_value[offset : offset + 4] + return int.from_bytes(p, byteorder="big") & 0x7FFFFFFF + + def get_provisioning_uri( + self, account_name: str, counter: int, issuer: str | None + ) -> str: + return _generate_uri( + self, "hotp", account_name, issuer, [("counter", int(counter))] + ) diff --git a/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/totp.py b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/totp.py new file mode 100644 index 00000000..68a50774 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/hazmat/primitives/twofactor/totp.py @@ -0,0 +1,50 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.hazmat.primitives import constant_time +from cryptography.hazmat.primitives.twofactor import InvalidToken +from cryptography.hazmat.primitives.twofactor.hotp import ( + HOTP, + HOTPHashTypes, + _generate_uri, +) + + +class TOTP: + def __init__( + self, + key: bytes, + length: int, + algorithm: HOTPHashTypes, + time_step: int, + backend: typing.Any = None, + enforce_key_length: bool = True, + ): + self._time_step = time_step + self._hotp = HOTP( + key, length, algorithm, enforce_key_length=enforce_key_length + ) + + def generate(self, time: int | float) -> bytes: + counter = int(time / self._time_step) + return self._hotp.generate(counter) + + def verify(self, totp: bytes, time: int) -> None: + if not constant_time.bytes_eq(self.generate(time), totp): + raise InvalidToken("Supplied TOTP value does not match.") + + def get_provisioning_uri( + self, account_name: str, issuer: str | None + ) -> str: + return _generate_uri( + self._hotp, + "totp", + account_name, + issuer, + [("period", int(self._time_step))], + ) diff --git a/.venv/Lib/site-packages/cryptography/py.typed b/.venv/Lib/site-packages/cryptography/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/cryptography/utils.py b/.venv/Lib/site-packages/cryptography/utils.py new file mode 100644 index 00000000..a0ec7a3c --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/utils.py @@ -0,0 +1,131 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import enum +import sys +import types +import typing +import warnings + + +# We use a UserWarning subclass, instead of DeprecationWarning, because CPython +# decided deprecation warnings should be invisible by default. +class CryptographyDeprecationWarning(UserWarning): + pass + + +# Several APIs were deprecated with no specific end-of-life date because of the +# ubiquity of their use. They should not be removed until we agree on when that +# cycle ends. +DeprecatedIn36 = CryptographyDeprecationWarning +DeprecatedIn37 = CryptographyDeprecationWarning +DeprecatedIn40 = CryptographyDeprecationWarning +DeprecatedIn41 = CryptographyDeprecationWarning +DeprecatedIn42 = CryptographyDeprecationWarning + + +def _check_bytes(name: str, value: bytes) -> None: + if not isinstance(value, bytes): + raise TypeError(f"{name} must be bytes") + + +def _check_byteslike(name: str, value: bytes) -> None: + try: + memoryview(value) + except TypeError: + raise TypeError(f"{name} must be bytes-like") + + +def int_to_bytes(integer: int, length: int | None = None) -> bytes: + return integer.to_bytes( + length or (integer.bit_length() + 7) // 8 or 1, "big" + ) + + +def _extract_buffer_length(obj: typing.Any) -> tuple[typing.Any, int]: + from cryptography.hazmat.bindings._rust import _openssl + + buf = _openssl.ffi.from_buffer(obj) + return buf, int(_openssl.ffi.cast("uintptr_t", buf)) + + +class InterfaceNotImplemented(Exception): + pass + + +class _DeprecatedValue: + def __init__(self, value: object, message: str, warning_class): + self.value = value + self.message = message + self.warning_class = warning_class + + +class _ModuleWithDeprecations(types.ModuleType): + def __init__(self, module: types.ModuleType): + super().__init__(module.__name__) + self.__dict__["_module"] = module + + def __getattr__(self, attr: str) -> object: + obj = getattr(self._module, attr) + if isinstance(obj, _DeprecatedValue): + warnings.warn(obj.message, obj.warning_class, stacklevel=2) + obj = obj.value + return obj + + def __setattr__(self, attr: str, value: object) -> None: + setattr(self._module, attr, value) + + def __delattr__(self, attr: str) -> None: + obj = getattr(self._module, attr) + if isinstance(obj, _DeprecatedValue): + warnings.warn(obj.message, obj.warning_class, stacklevel=2) + + delattr(self._module, attr) + + def __dir__(self) -> typing.Sequence[str]: + return ["_module", *dir(self._module)] + + +def deprecated( + value: object, + module_name: str, + message: str, + warning_class: type[Warning], + name: str | None = None, +) -> _DeprecatedValue: + module = sys.modules[module_name] + if not isinstance(module, _ModuleWithDeprecations): + sys.modules[module_name] = module = _ModuleWithDeprecations(module) + dv = _DeprecatedValue(value, message, warning_class) + # Maintain backwards compatibility with `name is None` for pyOpenSSL. + if name is not None: + setattr(module, name, dv) + return dv + + +def cached_property(func: typing.Callable) -> property: + cached_name = f"_cached_{func}" + sentinel = object() + + def inner(instance: object): + cache = getattr(instance, cached_name, sentinel) + if cache is not sentinel: + return cache + result = func(instance) + setattr(instance, cached_name, result) + return result + + return property(inner) + + +# Python 3.10 changed representation of enums. We use well-defined object +# representation and string representation from Python 3.9. +class Enum(enum.Enum): + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self._name_}: {self._value_!r}>" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self._name_}" diff --git a/.venv/Lib/site-packages/cryptography/x509/__init__.py b/.venv/Lib/site-packages/cryptography/x509/__init__.py new file mode 100644 index 00000000..931618aa --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/__init__.py @@ -0,0 +1,257 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.x509 import certificate_transparency, verification +from cryptography.x509.base import ( + Attribute, + AttributeNotFound, + Attributes, + Certificate, + CertificateBuilder, + CertificateRevocationList, + CertificateRevocationListBuilder, + CertificateSigningRequest, + CertificateSigningRequestBuilder, + InvalidVersion, + RevokedCertificate, + RevokedCertificateBuilder, + Version, + load_der_x509_certificate, + load_der_x509_crl, + load_der_x509_csr, + load_pem_x509_certificate, + load_pem_x509_certificates, + load_pem_x509_crl, + load_pem_x509_csr, + random_serial_number, +) +from cryptography.x509.extensions import ( + AccessDescription, + AuthorityInformationAccess, + AuthorityKeyIdentifier, + BasicConstraints, + CertificateIssuer, + CertificatePolicies, + CRLDistributionPoints, + CRLNumber, + CRLReason, + DeltaCRLIndicator, + DistributionPoint, + DuplicateExtension, + ExtendedKeyUsage, + Extension, + ExtensionNotFound, + Extensions, + ExtensionType, + FreshestCRL, + GeneralNames, + InhibitAnyPolicy, + InvalidityDate, + IssuerAlternativeName, + IssuingDistributionPoint, + KeyUsage, + MSCertificateTemplate, + NameConstraints, + NoticeReference, + OCSPAcceptableResponses, + OCSPNoCheck, + OCSPNonce, + PolicyConstraints, + PolicyInformation, + PrecertificateSignedCertificateTimestamps, + PrecertPoison, + ReasonFlags, + SignedCertificateTimestamps, + SubjectAlternativeName, + SubjectInformationAccess, + SubjectKeyIdentifier, + TLSFeature, + TLSFeatureType, + UnrecognizedExtension, + UserNotice, +) +from cryptography.x509.general_name import ( + DirectoryName, + DNSName, + GeneralName, + IPAddress, + OtherName, + RegisteredID, + RFC822Name, + UniformResourceIdentifier, + UnsupportedGeneralNameType, +) +from cryptography.x509.name import ( + Name, + NameAttribute, + RelativeDistinguishedName, +) +from cryptography.x509.oid import ( + AuthorityInformationAccessOID, + CertificatePoliciesOID, + CRLEntryExtensionOID, + ExtendedKeyUsageOID, + ExtensionOID, + NameOID, + ObjectIdentifier, + SignatureAlgorithmOID, +) + +OID_AUTHORITY_INFORMATION_ACCESS = ExtensionOID.AUTHORITY_INFORMATION_ACCESS +OID_AUTHORITY_KEY_IDENTIFIER = ExtensionOID.AUTHORITY_KEY_IDENTIFIER +OID_BASIC_CONSTRAINTS = ExtensionOID.BASIC_CONSTRAINTS +OID_CERTIFICATE_POLICIES = ExtensionOID.CERTIFICATE_POLICIES +OID_CRL_DISTRIBUTION_POINTS = ExtensionOID.CRL_DISTRIBUTION_POINTS +OID_EXTENDED_KEY_USAGE = ExtensionOID.EXTENDED_KEY_USAGE +OID_FRESHEST_CRL = ExtensionOID.FRESHEST_CRL +OID_INHIBIT_ANY_POLICY = ExtensionOID.INHIBIT_ANY_POLICY +OID_ISSUER_ALTERNATIVE_NAME = ExtensionOID.ISSUER_ALTERNATIVE_NAME +OID_KEY_USAGE = ExtensionOID.KEY_USAGE +OID_NAME_CONSTRAINTS = ExtensionOID.NAME_CONSTRAINTS +OID_OCSP_NO_CHECK = ExtensionOID.OCSP_NO_CHECK +OID_POLICY_CONSTRAINTS = ExtensionOID.POLICY_CONSTRAINTS +OID_POLICY_MAPPINGS = ExtensionOID.POLICY_MAPPINGS +OID_SUBJECT_ALTERNATIVE_NAME = ExtensionOID.SUBJECT_ALTERNATIVE_NAME +OID_SUBJECT_DIRECTORY_ATTRIBUTES = ExtensionOID.SUBJECT_DIRECTORY_ATTRIBUTES +OID_SUBJECT_INFORMATION_ACCESS = ExtensionOID.SUBJECT_INFORMATION_ACCESS +OID_SUBJECT_KEY_IDENTIFIER = ExtensionOID.SUBJECT_KEY_IDENTIFIER + +OID_DSA_WITH_SHA1 = SignatureAlgorithmOID.DSA_WITH_SHA1 +OID_DSA_WITH_SHA224 = SignatureAlgorithmOID.DSA_WITH_SHA224 +OID_DSA_WITH_SHA256 = SignatureAlgorithmOID.DSA_WITH_SHA256 +OID_ECDSA_WITH_SHA1 = SignatureAlgorithmOID.ECDSA_WITH_SHA1 +OID_ECDSA_WITH_SHA224 = SignatureAlgorithmOID.ECDSA_WITH_SHA224 +OID_ECDSA_WITH_SHA256 = SignatureAlgorithmOID.ECDSA_WITH_SHA256 +OID_ECDSA_WITH_SHA384 = SignatureAlgorithmOID.ECDSA_WITH_SHA384 +OID_ECDSA_WITH_SHA512 = SignatureAlgorithmOID.ECDSA_WITH_SHA512 +OID_RSA_WITH_MD5 = SignatureAlgorithmOID.RSA_WITH_MD5 +OID_RSA_WITH_SHA1 = SignatureAlgorithmOID.RSA_WITH_SHA1 +OID_RSA_WITH_SHA224 = SignatureAlgorithmOID.RSA_WITH_SHA224 +OID_RSA_WITH_SHA256 = SignatureAlgorithmOID.RSA_WITH_SHA256 +OID_RSA_WITH_SHA384 = SignatureAlgorithmOID.RSA_WITH_SHA384 +OID_RSA_WITH_SHA512 = SignatureAlgorithmOID.RSA_WITH_SHA512 +OID_RSASSA_PSS = SignatureAlgorithmOID.RSASSA_PSS + +OID_COMMON_NAME = NameOID.COMMON_NAME +OID_COUNTRY_NAME = NameOID.COUNTRY_NAME +OID_DOMAIN_COMPONENT = NameOID.DOMAIN_COMPONENT +OID_DN_QUALIFIER = NameOID.DN_QUALIFIER +OID_EMAIL_ADDRESS = NameOID.EMAIL_ADDRESS +OID_GENERATION_QUALIFIER = NameOID.GENERATION_QUALIFIER +OID_GIVEN_NAME = NameOID.GIVEN_NAME +OID_LOCALITY_NAME = NameOID.LOCALITY_NAME +OID_ORGANIZATIONAL_UNIT_NAME = NameOID.ORGANIZATIONAL_UNIT_NAME +OID_ORGANIZATION_NAME = NameOID.ORGANIZATION_NAME +OID_PSEUDONYM = NameOID.PSEUDONYM +OID_SERIAL_NUMBER = NameOID.SERIAL_NUMBER +OID_STATE_OR_PROVINCE_NAME = NameOID.STATE_OR_PROVINCE_NAME +OID_SURNAME = NameOID.SURNAME +OID_TITLE = NameOID.TITLE + +OID_CLIENT_AUTH = ExtendedKeyUsageOID.CLIENT_AUTH +OID_CODE_SIGNING = ExtendedKeyUsageOID.CODE_SIGNING +OID_EMAIL_PROTECTION = ExtendedKeyUsageOID.EMAIL_PROTECTION +OID_OCSP_SIGNING = ExtendedKeyUsageOID.OCSP_SIGNING +OID_SERVER_AUTH = ExtendedKeyUsageOID.SERVER_AUTH +OID_TIME_STAMPING = ExtendedKeyUsageOID.TIME_STAMPING + +OID_ANY_POLICY = CertificatePoliciesOID.ANY_POLICY +OID_CPS_QUALIFIER = CertificatePoliciesOID.CPS_QUALIFIER +OID_CPS_USER_NOTICE = CertificatePoliciesOID.CPS_USER_NOTICE + +OID_CERTIFICATE_ISSUER = CRLEntryExtensionOID.CERTIFICATE_ISSUER +OID_CRL_REASON = CRLEntryExtensionOID.CRL_REASON +OID_INVALIDITY_DATE = CRLEntryExtensionOID.INVALIDITY_DATE + +OID_CA_ISSUERS = AuthorityInformationAccessOID.CA_ISSUERS +OID_OCSP = AuthorityInformationAccessOID.OCSP + +__all__ = [ + "certificate_transparency", + "verification", + "load_pem_x509_certificate", + "load_pem_x509_certificates", + "load_der_x509_certificate", + "load_pem_x509_csr", + "load_der_x509_csr", + "load_pem_x509_crl", + "load_der_x509_crl", + "random_serial_number", + "verification", + "Attribute", + "AttributeNotFound", + "Attributes", + "InvalidVersion", + "DeltaCRLIndicator", + "DuplicateExtension", + "ExtensionNotFound", + "UnsupportedGeneralNameType", + "NameAttribute", + "Name", + "RelativeDistinguishedName", + "ObjectIdentifier", + "ExtensionType", + "Extensions", + "Extension", + "ExtendedKeyUsage", + "FreshestCRL", + "IssuingDistributionPoint", + "TLSFeature", + "TLSFeatureType", + "OCSPAcceptableResponses", + "OCSPNoCheck", + "BasicConstraints", + "CRLNumber", + "KeyUsage", + "AuthorityInformationAccess", + "SubjectInformationAccess", + "AccessDescription", + "CertificatePolicies", + "PolicyInformation", + "UserNotice", + "NoticeReference", + "SubjectKeyIdentifier", + "NameConstraints", + "CRLDistributionPoints", + "DistributionPoint", + "ReasonFlags", + "InhibitAnyPolicy", + "SubjectAlternativeName", + "IssuerAlternativeName", + "AuthorityKeyIdentifier", + "GeneralNames", + "GeneralName", + "RFC822Name", + "DNSName", + "UniformResourceIdentifier", + "RegisteredID", + "DirectoryName", + "IPAddress", + "OtherName", + "Certificate", + "CertificateRevocationList", + "CertificateRevocationListBuilder", + "CertificateSigningRequest", + "RevokedCertificate", + "RevokedCertificateBuilder", + "CertificateSigningRequestBuilder", + "CertificateBuilder", + "Version", + "OID_CA_ISSUERS", + "OID_OCSP", + "CertificateIssuer", + "CRLReason", + "InvalidityDate", + "UnrecognizedExtension", + "PolicyConstraints", + "PrecertificateSignedCertificateTimestamps", + "PrecertPoison", + "OCSPNonce", + "SignedCertificateTimestamps", + "SignatureAlgorithmOID", + "NameOID", + "MSCertificateTemplate", +] diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..df2d67e9 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..6e4d41ed Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/certificate_transparency.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/certificate_transparency.cpython-311.pyc new file mode 100644 index 00000000..6bd6a326 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/certificate_transparency.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/extensions.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/extensions.cpython-311.pyc new file mode 100644 index 00000000..fe56c7de Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/extensions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/general_name.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/general_name.cpython-311.pyc new file mode 100644 index 00000000..3fed4e1b Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/general_name.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/name.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/name.cpython-311.pyc new file mode 100644 index 00000000..528b371b Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/name.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/ocsp.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/ocsp.cpython-311.pyc new file mode 100644 index 00000000..65f2833d Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/ocsp.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/oid.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/oid.cpython-311.pyc new file mode 100644 index 00000000..612e0637 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/oid.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/__pycache__/verification.cpython-311.pyc b/.venv/Lib/site-packages/cryptography/x509/__pycache__/verification.cpython-311.pyc new file mode 100644 index 00000000..c5066299 Binary files /dev/null and b/.venv/Lib/site-packages/cryptography/x509/__pycache__/verification.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/cryptography/x509/base.py b/.venv/Lib/site-packages/cryptography/x509/base.py new file mode 100644 index 00000000..89a75a23 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/base.py @@ -0,0 +1,1221 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import datetime +import os +import typing +import warnings + +from cryptography import utils +from cryptography.hazmat.bindings._rust import x509 as rust_x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + ed448, + ed25519, + padding, + rsa, + x448, + x25519, +) +from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, + CertificateIssuerPublicKeyTypes, + CertificatePublicKeyTypes, +) +from cryptography.x509.extensions import ( + Extension, + Extensions, + ExtensionType, + _make_sequence_methods, +) +from cryptography.x509.name import Name, _ASN1Type +from cryptography.x509.oid import ObjectIdentifier + +_EARLIEST_UTC_TIME = datetime.datetime(1950, 1, 1) + +# This must be kept in sync with sign.rs's list of allowable types in +# identify_hash_type +_AllowedHashTypes = typing.Union[ + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, + hashes.SHA3_224, + hashes.SHA3_256, + hashes.SHA3_384, + hashes.SHA3_512, +] + + +class AttributeNotFound(Exception): + def __init__(self, msg: str, oid: ObjectIdentifier) -> None: + super().__init__(msg) + self.oid = oid + + +def _reject_duplicate_extension( + extension: Extension[ExtensionType], + extensions: list[Extension[ExtensionType]], +) -> None: + # This is quadratic in the number of extensions + for e in extensions: + if e.oid == extension.oid: + raise ValueError("This extension has already been set.") + + +def _reject_duplicate_attribute( + oid: ObjectIdentifier, + attributes: list[tuple[ObjectIdentifier, bytes, int | None]], +) -> None: + # This is quadratic in the number of attributes + for attr_oid, _, _ in attributes: + if attr_oid == oid: + raise ValueError("This attribute has already been set.") + + +def _convert_to_naive_utc_time(time: datetime.datetime) -> datetime.datetime: + """Normalizes a datetime to a naive datetime in UTC. + + time -- datetime to normalize. Assumed to be in UTC if not timezone + aware. + """ + if time.tzinfo is not None: + offset = time.utcoffset() + offset = offset if offset else datetime.timedelta() + return time.replace(tzinfo=None) - offset + else: + return time + + +class Attribute: + def __init__( + self, + oid: ObjectIdentifier, + value: bytes, + _type: int = _ASN1Type.UTF8String.value, + ) -> None: + self._oid = oid + self._value = value + self._type = _type + + @property + def oid(self) -> ObjectIdentifier: + return self._oid + + @property + def value(self) -> bytes: + return self._value + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Attribute): + return NotImplemented + + return ( + self.oid == other.oid + and self.value == other.value + and self._type == other._type + ) + + def __hash__(self) -> int: + return hash((self.oid, self.value, self._type)) + + +class Attributes: + def __init__( + self, + attributes: typing.Iterable[Attribute], + ) -> None: + self._attributes = list(attributes) + + __len__, __iter__, __getitem__ = _make_sequence_methods("_attributes") + + def __repr__(self) -> str: + return f"" + + def get_attribute_for_oid(self, oid: ObjectIdentifier) -> Attribute: + for attr in self: + if attr.oid == oid: + return attr + + raise AttributeNotFound(f"No {oid} attribute was found", oid) + + +class Version(utils.Enum): + v1 = 0 + v3 = 2 + + +class InvalidVersion(Exception): + def __init__(self, msg: str, parsed_version: int) -> None: + super().__init__(msg) + self.parsed_version = parsed_version + + +class Certificate(metaclass=abc.ABCMeta): + @abc.abstractmethod + def fingerprint(self, algorithm: hashes.HashAlgorithm) -> bytes: + """ + Returns bytes using digest passed. + """ + + @property + @abc.abstractmethod + def serial_number(self) -> int: + """ + Returns certificate serial number + """ + + @property + @abc.abstractmethod + def version(self) -> Version: + """ + Returns the certificate version + """ + + @abc.abstractmethod + def public_key(self) -> CertificatePublicKeyTypes: + """ + Returns the public key + """ + + @property + @abc.abstractmethod + def not_valid_before(self) -> datetime.datetime: + """ + Not before time (represented as UTC datetime) + """ + + @property + @abc.abstractmethod + def not_valid_before_utc(self) -> datetime.datetime: + """ + Not before time (represented as a non-naive UTC datetime) + """ + + @property + @abc.abstractmethod + def not_valid_after(self) -> datetime.datetime: + """ + Not after time (represented as UTC datetime) + """ + + @property + @abc.abstractmethod + def not_valid_after_utc(self) -> datetime.datetime: + """ + Not after time (represented as a non-naive UTC datetime) + """ + + @property + @abc.abstractmethod + def issuer(self) -> Name: + """ + Returns the issuer name object. + """ + + @property + @abc.abstractmethod + def subject(self) -> Name: + """ + Returns the subject name object. + """ + + @property + @abc.abstractmethod + def signature_hash_algorithm( + self, + ) -> hashes.HashAlgorithm | None: + """ + Returns a HashAlgorithm corresponding to the type of the digest signed + in the certificate. + """ + + @property + @abc.abstractmethod + def signature_algorithm_oid(self) -> ObjectIdentifier: + """ + Returns the ObjectIdentifier of the signature algorithm. + """ + + @property + @abc.abstractmethod + def signature_algorithm_parameters( + self, + ) -> None | padding.PSS | padding.PKCS1v15 | ec.ECDSA: + """ + Returns the signature algorithm parameters. + """ + + @property + @abc.abstractmethod + def extensions(self) -> Extensions: + """ + Returns an Extensions object. + """ + + @property + @abc.abstractmethod + def signature(self) -> bytes: + """ + Returns the signature bytes. + """ + + @property + @abc.abstractmethod + def tbs_certificate_bytes(self) -> bytes: + """ + Returns the tbsCertificate payload bytes as defined in RFC 5280. + """ + + @property + @abc.abstractmethod + def tbs_precertificate_bytes(self) -> bytes: + """ + Returns the tbsCertificate payload bytes with the SCT list extension + stripped. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + @abc.abstractmethod + def __hash__(self) -> int: + """ + Computes a hash. + """ + + @abc.abstractmethod + def public_bytes(self, encoding: serialization.Encoding) -> bytes: + """ + Serializes the certificate to PEM or DER format. + """ + + @abc.abstractmethod + def verify_directly_issued_by(self, issuer: Certificate) -> None: + """ + This method verifies that certificate issuer name matches the + issuer subject name and that the certificate is signed by the + issuer's private key. No other validation is performed. + """ + + +# Runtime isinstance checks need this since the rust class is not a subclass. +Certificate.register(rust_x509.Certificate) + + +class RevokedCertificate(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def serial_number(self) -> int: + """ + Returns the serial number of the revoked certificate. + """ + + @property + @abc.abstractmethod + def revocation_date(self) -> datetime.datetime: + """ + Returns the date of when this certificate was revoked. + """ + + @property + @abc.abstractmethod + def revocation_date_utc(self) -> datetime.datetime: + """ + Returns the date of when this certificate was revoked as a non-naive + UTC datetime. + """ + + @property + @abc.abstractmethod + def extensions(self) -> Extensions: + """ + Returns an Extensions object containing a list of Revoked extensions. + """ + + +# Runtime isinstance checks need this since the rust class is not a subclass. +RevokedCertificate.register(rust_x509.RevokedCertificate) + + +class _RawRevokedCertificate(RevokedCertificate): + def __init__( + self, + serial_number: int, + revocation_date: datetime.datetime, + extensions: Extensions, + ): + self._serial_number = serial_number + self._revocation_date = revocation_date + self._extensions = extensions + + @property + def serial_number(self) -> int: + return self._serial_number + + @property + def revocation_date(self) -> datetime.datetime: + warnings.warn( + "Properties that return a naïve datetime object have been " + "deprecated. Please switch to revocation_date_utc.", + utils.DeprecatedIn42, + stacklevel=2, + ) + return self._revocation_date + + @property + def revocation_date_utc(self) -> datetime.datetime: + return self._revocation_date.replace(tzinfo=datetime.timezone.utc) + + @property + def extensions(self) -> Extensions: + return self._extensions + + +class CertificateRevocationList(metaclass=abc.ABCMeta): + @abc.abstractmethod + def public_bytes(self, encoding: serialization.Encoding) -> bytes: + """ + Serializes the CRL to PEM or DER format. + """ + + @abc.abstractmethod + def fingerprint(self, algorithm: hashes.HashAlgorithm) -> bytes: + """ + Returns bytes using digest passed. + """ + + @abc.abstractmethod + def get_revoked_certificate_by_serial_number( + self, serial_number: int + ) -> RevokedCertificate | None: + """ + Returns an instance of RevokedCertificate or None if the serial_number + is not in the CRL. + """ + + @property + @abc.abstractmethod + def signature_hash_algorithm( + self, + ) -> hashes.HashAlgorithm | None: + """ + Returns a HashAlgorithm corresponding to the type of the digest signed + in the certificate. + """ + + @property + @abc.abstractmethod + def signature_algorithm_oid(self) -> ObjectIdentifier: + """ + Returns the ObjectIdentifier of the signature algorithm. + """ + + @property + @abc.abstractmethod + def signature_algorithm_parameters( + self, + ) -> None | padding.PSS | padding.PKCS1v15 | ec.ECDSA: + """ + Returns the signature algorithm parameters. + """ + + @property + @abc.abstractmethod + def issuer(self) -> Name: + """ + Returns the X509Name with the issuer of this CRL. + """ + + @property + @abc.abstractmethod + def next_update(self) -> datetime.datetime | None: + """ + Returns the date of next update for this CRL. + """ + + @property + @abc.abstractmethod + def next_update_utc(self) -> datetime.datetime | None: + """ + Returns the date of next update for this CRL as a non-naive UTC + datetime. + """ + + @property + @abc.abstractmethod + def last_update(self) -> datetime.datetime: + """ + Returns the date of last update for this CRL. + """ + + @property + @abc.abstractmethod + def last_update_utc(self) -> datetime.datetime: + """ + Returns the date of last update for this CRL as a non-naive UTC + datetime. + """ + + @property + @abc.abstractmethod + def extensions(self) -> Extensions: + """ + Returns an Extensions object containing a list of CRL extensions. + """ + + @property + @abc.abstractmethod + def signature(self) -> bytes: + """ + Returns the signature bytes. + """ + + @property + @abc.abstractmethod + def tbs_certlist_bytes(self) -> bytes: + """ + Returns the tbsCertList payload bytes as defined in RFC 5280. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + @abc.abstractmethod + def __len__(self) -> int: + """ + Number of revoked certificates in the CRL. + """ + + @typing.overload + def __getitem__(self, idx: int) -> RevokedCertificate: + ... + + @typing.overload + def __getitem__(self, idx: slice) -> list[RevokedCertificate]: + ... + + @abc.abstractmethod + def __getitem__( + self, idx: int | slice + ) -> RevokedCertificate | list[RevokedCertificate]: + """ + Returns a revoked certificate (or slice of revoked certificates). + """ + + @abc.abstractmethod + def __iter__(self) -> typing.Iterator[RevokedCertificate]: + """ + Iterator over the revoked certificates + """ + + @abc.abstractmethod + def is_signature_valid( + self, public_key: CertificateIssuerPublicKeyTypes + ) -> bool: + """ + Verifies signature of revocation list against given public key. + """ + + +CertificateRevocationList.register(rust_x509.CertificateRevocationList) + + +class CertificateSigningRequest(metaclass=abc.ABCMeta): + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + @abc.abstractmethod + def __hash__(self) -> int: + """ + Computes a hash. + """ + + @abc.abstractmethod + def public_key(self) -> CertificatePublicKeyTypes: + """ + Returns the public key + """ + + @property + @abc.abstractmethod + def subject(self) -> Name: + """ + Returns the subject name object. + """ + + @property + @abc.abstractmethod + def signature_hash_algorithm( + self, + ) -> hashes.HashAlgorithm | None: + """ + Returns a HashAlgorithm corresponding to the type of the digest signed + in the certificate. + """ + + @property + @abc.abstractmethod + def signature_algorithm_oid(self) -> ObjectIdentifier: + """ + Returns the ObjectIdentifier of the signature algorithm. + """ + + @property + @abc.abstractmethod + def signature_algorithm_parameters( + self, + ) -> None | padding.PSS | padding.PKCS1v15 | ec.ECDSA: + """ + Returns the signature algorithm parameters. + """ + + @property + @abc.abstractmethod + def extensions(self) -> Extensions: + """ + Returns the extensions in the signing request. + """ + + @property + @abc.abstractmethod + def attributes(self) -> Attributes: + """ + Returns an Attributes object. + """ + + @abc.abstractmethod + def public_bytes(self, encoding: serialization.Encoding) -> bytes: + """ + Encodes the request to PEM or DER format. + """ + + @property + @abc.abstractmethod + def signature(self) -> bytes: + """ + Returns the signature bytes. + """ + + @property + @abc.abstractmethod + def tbs_certrequest_bytes(self) -> bytes: + """ + Returns the PKCS#10 CertificationRequestInfo bytes as defined in RFC + 2986. + """ + + @property + @abc.abstractmethod + def is_signature_valid(self) -> bool: + """ + Verifies signature of signing request. + """ + + @abc.abstractmethod + def get_attribute_for_oid(self, oid: ObjectIdentifier) -> bytes: + """ + Get the attribute value for a given OID. + """ + + +# Runtime isinstance checks need this since the rust class is not a subclass. +CertificateSigningRequest.register(rust_x509.CertificateSigningRequest) + + +load_pem_x509_certificate = rust_x509.load_pem_x509_certificate +load_der_x509_certificate = rust_x509.load_der_x509_certificate + +load_pem_x509_certificates = rust_x509.load_pem_x509_certificates + +load_pem_x509_csr = rust_x509.load_pem_x509_csr +load_der_x509_csr = rust_x509.load_der_x509_csr + +load_pem_x509_crl = rust_x509.load_pem_x509_crl +load_der_x509_crl = rust_x509.load_der_x509_crl + + +class CertificateSigningRequestBuilder: + def __init__( + self, + subject_name: Name | None = None, + extensions: list[Extension[ExtensionType]] = [], + attributes: list[tuple[ObjectIdentifier, bytes, int | None]] = [], + ): + """ + Creates an empty X.509 certificate request (v1). + """ + self._subject_name = subject_name + self._extensions = extensions + self._attributes = attributes + + def subject_name(self, name: Name) -> CertificateSigningRequestBuilder: + """ + Sets the certificate requestor's distinguished name. + """ + if not isinstance(name, Name): + raise TypeError("Expecting x509.Name object.") + if self._subject_name is not None: + raise ValueError("The subject name may only be set once.") + return CertificateSigningRequestBuilder( + name, self._extensions, self._attributes + ) + + def add_extension( + self, extval: ExtensionType, critical: bool + ) -> CertificateSigningRequestBuilder: + """ + Adds an X.509 extension to the certificate request. + """ + if not isinstance(extval, ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + + return CertificateSigningRequestBuilder( + self._subject_name, + [*self._extensions, extension], + self._attributes, + ) + + def add_attribute( + self, + oid: ObjectIdentifier, + value: bytes, + *, + _tag: _ASN1Type | None = None, + ) -> CertificateSigningRequestBuilder: + """ + Adds an X.509 attribute with an OID and associated value. + """ + if not isinstance(oid, ObjectIdentifier): + raise TypeError("oid must be an ObjectIdentifier") + + if not isinstance(value, bytes): + raise TypeError("value must be bytes") + + if _tag is not None and not isinstance(_tag, _ASN1Type): + raise TypeError("tag must be _ASN1Type") + + _reject_duplicate_attribute(oid, self._attributes) + + if _tag is not None: + tag = _tag.value + else: + tag = None + + return CertificateSigningRequestBuilder( + self._subject_name, + self._extensions, + [*self._attributes, (oid, value, tag)], + ) + + def sign( + self, + private_key: CertificateIssuerPrivateKeyTypes, + algorithm: _AllowedHashTypes | None, + backend: typing.Any = None, + *, + rsa_padding: padding.PSS | padding.PKCS1v15 | None = None, + ) -> CertificateSigningRequest: + """ + Signs the request using the requestor's private key. + """ + if self._subject_name is None: + raise ValueError("A CertificateSigningRequest must have a subject") + + if rsa_padding is not None: + if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)): + raise TypeError("Padding must be PSS or PKCS1v15") + if not isinstance(private_key, rsa.RSAPrivateKey): + raise TypeError("Padding is only supported for RSA keys") + + return rust_x509.create_x509_csr( + self, private_key, algorithm, rsa_padding + ) + + +class CertificateBuilder: + _extensions: list[Extension[ExtensionType]] + + def __init__( + self, + issuer_name: Name | None = None, + subject_name: Name | None = None, + public_key: CertificatePublicKeyTypes | None = None, + serial_number: int | None = None, + not_valid_before: datetime.datetime | None = None, + not_valid_after: datetime.datetime | None = None, + extensions: list[Extension[ExtensionType]] = [], + ) -> None: + self._version = Version.v3 + self._issuer_name = issuer_name + self._subject_name = subject_name + self._public_key = public_key + self._serial_number = serial_number + self._not_valid_before = not_valid_before + self._not_valid_after = not_valid_after + self._extensions = extensions + + def issuer_name(self, name: Name) -> CertificateBuilder: + """ + Sets the CA's distinguished name. + """ + if not isinstance(name, Name): + raise TypeError("Expecting x509.Name object.") + if self._issuer_name is not None: + raise ValueError("The issuer name may only be set once.") + return CertificateBuilder( + name, + self._subject_name, + self._public_key, + self._serial_number, + self._not_valid_before, + self._not_valid_after, + self._extensions, + ) + + def subject_name(self, name: Name) -> CertificateBuilder: + """ + Sets the requestor's distinguished name. + """ + if not isinstance(name, Name): + raise TypeError("Expecting x509.Name object.") + if self._subject_name is not None: + raise ValueError("The subject name may only be set once.") + return CertificateBuilder( + self._issuer_name, + name, + self._public_key, + self._serial_number, + self._not_valid_before, + self._not_valid_after, + self._extensions, + ) + + def public_key( + self, + key: CertificatePublicKeyTypes, + ) -> CertificateBuilder: + """ + Sets the requestor's public key (as found in the signing request). + """ + if not isinstance( + key, + ( + dsa.DSAPublicKey, + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + x25519.X25519PublicKey, + x448.X448PublicKey, + ), + ): + raise TypeError( + "Expecting one of DSAPublicKey, RSAPublicKey," + " EllipticCurvePublicKey, Ed25519PublicKey," + " Ed448PublicKey, X25519PublicKey, or " + "X448PublicKey." + ) + if self._public_key is not None: + raise ValueError("The public key may only be set once.") + return CertificateBuilder( + self._issuer_name, + self._subject_name, + key, + self._serial_number, + self._not_valid_before, + self._not_valid_after, + self._extensions, + ) + + def serial_number(self, number: int) -> CertificateBuilder: + """ + Sets the certificate serial number. + """ + if not isinstance(number, int): + raise TypeError("Serial number must be of integral type.") + if self._serial_number is not None: + raise ValueError("The serial number may only be set once.") + if number <= 0: + raise ValueError("The serial number should be positive.") + + # ASN.1 integers are always signed, so most significant bit must be + # zero. + if number.bit_length() >= 160: # As defined in RFC 5280 + raise ValueError( + "The serial number should not be more than 159 " "bits." + ) + return CertificateBuilder( + self._issuer_name, + self._subject_name, + self._public_key, + number, + self._not_valid_before, + self._not_valid_after, + self._extensions, + ) + + def not_valid_before(self, time: datetime.datetime) -> CertificateBuilder: + """ + Sets the certificate activation time. + """ + if not isinstance(time, datetime.datetime): + raise TypeError("Expecting datetime object.") + if self._not_valid_before is not None: + raise ValueError("The not valid before may only be set once.") + time = _convert_to_naive_utc_time(time) + if time < _EARLIEST_UTC_TIME: + raise ValueError( + "The not valid before date must be on or after" + " 1950 January 1)." + ) + if self._not_valid_after is not None and time > self._not_valid_after: + raise ValueError( + "The not valid before date must be before the not valid after " + "date." + ) + return CertificateBuilder( + self._issuer_name, + self._subject_name, + self._public_key, + self._serial_number, + time, + self._not_valid_after, + self._extensions, + ) + + def not_valid_after(self, time: datetime.datetime) -> CertificateBuilder: + """ + Sets the certificate expiration time. + """ + if not isinstance(time, datetime.datetime): + raise TypeError("Expecting datetime object.") + if self._not_valid_after is not None: + raise ValueError("The not valid after may only be set once.") + time = _convert_to_naive_utc_time(time) + if time < _EARLIEST_UTC_TIME: + raise ValueError( + "The not valid after date must be on or after" + " 1950 January 1." + ) + if ( + self._not_valid_before is not None + and time < self._not_valid_before + ): + raise ValueError( + "The not valid after date must be after the not valid before " + "date." + ) + return CertificateBuilder( + self._issuer_name, + self._subject_name, + self._public_key, + self._serial_number, + self._not_valid_before, + time, + self._extensions, + ) + + def add_extension( + self, extval: ExtensionType, critical: bool + ) -> CertificateBuilder: + """ + Adds an X.509 extension to the certificate. + """ + if not isinstance(extval, ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + + return CertificateBuilder( + self._issuer_name, + self._subject_name, + self._public_key, + self._serial_number, + self._not_valid_before, + self._not_valid_after, + [*self._extensions, extension], + ) + + def sign( + self, + private_key: CertificateIssuerPrivateKeyTypes, + algorithm: _AllowedHashTypes | None, + backend: typing.Any = None, + *, + rsa_padding: padding.PSS | padding.PKCS1v15 | None = None, + ) -> Certificate: + """ + Signs the certificate using the CA's private key. + """ + if self._subject_name is None: + raise ValueError("A certificate must have a subject name") + + if self._issuer_name is None: + raise ValueError("A certificate must have an issuer name") + + if self._serial_number is None: + raise ValueError("A certificate must have a serial number") + + if self._not_valid_before is None: + raise ValueError("A certificate must have a not valid before time") + + if self._not_valid_after is None: + raise ValueError("A certificate must have a not valid after time") + + if self._public_key is None: + raise ValueError("A certificate must have a public key") + + if rsa_padding is not None: + if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)): + raise TypeError("Padding must be PSS or PKCS1v15") + if not isinstance(private_key, rsa.RSAPrivateKey): + raise TypeError("Padding is only supported for RSA keys") + + return rust_x509.create_x509_certificate( + self, private_key, algorithm, rsa_padding + ) + + +class CertificateRevocationListBuilder: + _extensions: list[Extension[ExtensionType]] + _revoked_certificates: list[RevokedCertificate] + + def __init__( + self, + issuer_name: Name | None = None, + last_update: datetime.datetime | None = None, + next_update: datetime.datetime | None = None, + extensions: list[Extension[ExtensionType]] = [], + revoked_certificates: list[RevokedCertificate] = [], + ): + self._issuer_name = issuer_name + self._last_update = last_update + self._next_update = next_update + self._extensions = extensions + self._revoked_certificates = revoked_certificates + + def issuer_name( + self, issuer_name: Name + ) -> CertificateRevocationListBuilder: + if not isinstance(issuer_name, Name): + raise TypeError("Expecting x509.Name object.") + if self._issuer_name is not None: + raise ValueError("The issuer name may only be set once.") + return CertificateRevocationListBuilder( + issuer_name, + self._last_update, + self._next_update, + self._extensions, + self._revoked_certificates, + ) + + def last_update( + self, last_update: datetime.datetime + ) -> CertificateRevocationListBuilder: + if not isinstance(last_update, datetime.datetime): + raise TypeError("Expecting datetime object.") + if self._last_update is not None: + raise ValueError("Last update may only be set once.") + last_update = _convert_to_naive_utc_time(last_update) + if last_update < _EARLIEST_UTC_TIME: + raise ValueError( + "The last update date must be on or after" " 1950 January 1." + ) + if self._next_update is not None and last_update > self._next_update: + raise ValueError( + "The last update date must be before the next update date." + ) + return CertificateRevocationListBuilder( + self._issuer_name, + last_update, + self._next_update, + self._extensions, + self._revoked_certificates, + ) + + def next_update( + self, next_update: datetime.datetime + ) -> CertificateRevocationListBuilder: + if not isinstance(next_update, datetime.datetime): + raise TypeError("Expecting datetime object.") + if self._next_update is not None: + raise ValueError("Last update may only be set once.") + next_update = _convert_to_naive_utc_time(next_update) + if next_update < _EARLIEST_UTC_TIME: + raise ValueError( + "The last update date must be on or after" " 1950 January 1." + ) + if self._last_update is not None and next_update < self._last_update: + raise ValueError( + "The next update date must be after the last update date." + ) + return CertificateRevocationListBuilder( + self._issuer_name, + self._last_update, + next_update, + self._extensions, + self._revoked_certificates, + ) + + def add_extension( + self, extval: ExtensionType, critical: bool + ) -> CertificateRevocationListBuilder: + """ + Adds an X.509 extension to the certificate revocation list. + """ + if not isinstance(extval, ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + return CertificateRevocationListBuilder( + self._issuer_name, + self._last_update, + self._next_update, + [*self._extensions, extension], + self._revoked_certificates, + ) + + def add_revoked_certificate( + self, revoked_certificate: RevokedCertificate + ) -> CertificateRevocationListBuilder: + """ + Adds a revoked certificate to the CRL. + """ + if not isinstance(revoked_certificate, RevokedCertificate): + raise TypeError("Must be an instance of RevokedCertificate") + + return CertificateRevocationListBuilder( + self._issuer_name, + self._last_update, + self._next_update, + self._extensions, + [*self._revoked_certificates, revoked_certificate], + ) + + def sign( + self, + private_key: CertificateIssuerPrivateKeyTypes, + algorithm: _AllowedHashTypes | None, + backend: typing.Any = None, + *, + rsa_padding: padding.PSS | padding.PKCS1v15 | None = None, + ) -> CertificateRevocationList: + if self._issuer_name is None: + raise ValueError("A CRL must have an issuer name") + + if self._last_update is None: + raise ValueError("A CRL must have a last update time") + + if self._next_update is None: + raise ValueError("A CRL must have a next update time") + + if rsa_padding is not None: + if not isinstance(rsa_padding, (padding.PSS, padding.PKCS1v15)): + raise TypeError("Padding must be PSS or PKCS1v15") + if not isinstance(private_key, rsa.RSAPrivateKey): + raise TypeError("Padding is only supported for RSA keys") + + return rust_x509.create_x509_crl( + self, private_key, algorithm, rsa_padding + ) + + +class RevokedCertificateBuilder: + def __init__( + self, + serial_number: int | None = None, + revocation_date: datetime.datetime | None = None, + extensions: list[Extension[ExtensionType]] = [], + ): + self._serial_number = serial_number + self._revocation_date = revocation_date + self._extensions = extensions + + def serial_number(self, number: int) -> RevokedCertificateBuilder: + if not isinstance(number, int): + raise TypeError("Serial number must be of integral type.") + if self._serial_number is not None: + raise ValueError("The serial number may only be set once.") + if number <= 0: + raise ValueError("The serial number should be positive") + + # ASN.1 integers are always signed, so most significant bit must be + # zero. + if number.bit_length() >= 160: # As defined in RFC 5280 + raise ValueError( + "The serial number should not be more than 159 " "bits." + ) + return RevokedCertificateBuilder( + number, self._revocation_date, self._extensions + ) + + def revocation_date( + self, time: datetime.datetime + ) -> RevokedCertificateBuilder: + if not isinstance(time, datetime.datetime): + raise TypeError("Expecting datetime object.") + if self._revocation_date is not None: + raise ValueError("The revocation date may only be set once.") + time = _convert_to_naive_utc_time(time) + if time < _EARLIEST_UTC_TIME: + raise ValueError( + "The revocation date must be on or after" " 1950 January 1." + ) + return RevokedCertificateBuilder( + self._serial_number, time, self._extensions + ) + + def add_extension( + self, extval: ExtensionType, critical: bool + ) -> RevokedCertificateBuilder: + if not isinstance(extval, ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + return RevokedCertificateBuilder( + self._serial_number, + self._revocation_date, + [*self._extensions, extension], + ) + + def build(self, backend: typing.Any = None) -> RevokedCertificate: + if self._serial_number is None: + raise ValueError("A revoked certificate must have a serial number") + if self._revocation_date is None: + raise ValueError( + "A revoked certificate must have a revocation date" + ) + return _RawRevokedCertificate( + self._serial_number, + self._revocation_date, + Extensions(self._extensions), + ) + + +def random_serial_number() -> int: + return int.from_bytes(os.urandom(20), "big") >> 1 diff --git a/.venv/Lib/site-packages/cryptography/x509/certificate_transparency.py b/.venv/Lib/site-packages/cryptography/x509/certificate_transparency.py new file mode 100644 index 00000000..73647ee7 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/certificate_transparency.py @@ -0,0 +1,97 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import datetime + +from cryptography import utils +from cryptography.hazmat.bindings._rust import x509 as rust_x509 +from cryptography.hazmat.primitives.hashes import HashAlgorithm + + +class LogEntryType(utils.Enum): + X509_CERTIFICATE = 0 + PRE_CERTIFICATE = 1 + + +class Version(utils.Enum): + v1 = 0 + + +class SignatureAlgorithm(utils.Enum): + """ + Signature algorithms that are valid for SCTs. + + These are exactly the same as SignatureAlgorithm in RFC 5246 (TLS 1.2). + + See: + """ + + ANONYMOUS = 0 + RSA = 1 + DSA = 2 + ECDSA = 3 + + +class SignedCertificateTimestamp(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def version(self) -> Version: + """ + Returns the SCT version. + """ + + @property + @abc.abstractmethod + def log_id(self) -> bytes: + """ + Returns an identifier indicating which log this SCT is for. + """ + + @property + @abc.abstractmethod + def timestamp(self) -> datetime.datetime: + """ + Returns the timestamp for this SCT. + """ + + @property + @abc.abstractmethod + def entry_type(self) -> LogEntryType: + """ + Returns whether this is an SCT for a certificate or pre-certificate. + """ + + @property + @abc.abstractmethod + def signature_hash_algorithm(self) -> HashAlgorithm: + """ + Returns the hash algorithm used for the SCT's signature. + """ + + @property + @abc.abstractmethod + def signature_algorithm(self) -> SignatureAlgorithm: + """ + Returns the signing algorithm used for the SCT's signature. + """ + + @property + @abc.abstractmethod + def signature(self) -> bytes: + """ + Returns the signature for this SCT. + """ + + @property + @abc.abstractmethod + def extension_bytes(self) -> bytes: + """ + Returns the raw bytes of any extensions for this SCT. + """ + + +SignedCertificateTimestamp.register(rust_x509.Sct) diff --git a/.venv/Lib/site-packages/cryptography/x509/extensions.py b/.venv/Lib/site-packages/cryptography/x509/extensions.py new file mode 100644 index 00000000..c61c1f48 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/extensions.py @@ -0,0 +1,2175 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import datetime +import hashlib +import ipaddress +import typing + +from cryptography import utils +from cryptography.hazmat.bindings._rust import asn1 +from cryptography.hazmat.bindings._rust import x509 as rust_x509 +from cryptography.hazmat.primitives import constant_time, serialization +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPublicKeyTypes, + CertificatePublicKeyTypes, +) +from cryptography.x509.certificate_transparency import ( + SignedCertificateTimestamp, +) +from cryptography.x509.general_name import ( + DirectoryName, + DNSName, + GeneralName, + IPAddress, + OtherName, + RegisteredID, + RFC822Name, + UniformResourceIdentifier, + _IPAddressTypes, +) +from cryptography.x509.name import Name, RelativeDistinguishedName +from cryptography.x509.oid import ( + CRLEntryExtensionOID, + ExtensionOID, + ObjectIdentifier, + OCSPExtensionOID, +) + +ExtensionTypeVar = typing.TypeVar( + "ExtensionTypeVar", bound="ExtensionType", covariant=True +) + + +def _key_identifier_from_public_key( + public_key: CertificatePublicKeyTypes, +) -> bytes: + if isinstance(public_key, RSAPublicKey): + data = public_key.public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.PKCS1, + ) + elif isinstance(public_key, EllipticCurvePublicKey): + data = public_key.public_bytes( + serialization.Encoding.X962, + serialization.PublicFormat.UncompressedPoint, + ) + else: + # This is a very slow way to do this. + serialized = public_key.public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + data = asn1.parse_spki_for_data(serialized) + + return hashlib.sha1(data).digest() + + +def _make_sequence_methods(field_name: str): + def len_method(self) -> int: + return len(getattr(self, field_name)) + + def iter_method(self): + return iter(getattr(self, field_name)) + + def getitem_method(self, idx): + return getattr(self, field_name)[idx] + + return len_method, iter_method, getitem_method + + +class DuplicateExtension(Exception): + def __init__(self, msg: str, oid: ObjectIdentifier) -> None: + super().__init__(msg) + self.oid = oid + + +class ExtensionNotFound(Exception): + def __init__(self, msg: str, oid: ObjectIdentifier) -> None: + super().__init__(msg) + self.oid = oid + + +class ExtensionType(metaclass=abc.ABCMeta): + oid: typing.ClassVar[ObjectIdentifier] + + def public_bytes(self) -> bytes: + """ + Serializes the extension type to DER. + """ + raise NotImplementedError( + f"public_bytes is not implemented for extension type {self!r}" + ) + + +class Extensions: + def __init__( + self, extensions: typing.Iterable[Extension[ExtensionType]] + ) -> None: + self._extensions = list(extensions) + + def get_extension_for_oid( + self, oid: ObjectIdentifier + ) -> Extension[ExtensionType]: + for ext in self: + if ext.oid == oid: + return ext + + raise ExtensionNotFound(f"No {oid} extension was found", oid) + + def get_extension_for_class( + self, extclass: type[ExtensionTypeVar] + ) -> Extension[ExtensionTypeVar]: + if extclass is UnrecognizedExtension: + raise TypeError( + "UnrecognizedExtension can't be used with " + "get_extension_for_class because more than one instance of the" + " class may be present." + ) + + for ext in self: + if isinstance(ext.value, extclass): + return ext + + raise ExtensionNotFound( + f"No {extclass} extension was found", extclass.oid + ) + + __len__, __iter__, __getitem__ = _make_sequence_methods("_extensions") + + def __repr__(self) -> str: + return f"" + + +class CRLNumber(ExtensionType): + oid = ExtensionOID.CRL_NUMBER + + def __init__(self, crl_number: int) -> None: + if not isinstance(crl_number, int): + raise TypeError("crl_number must be an integer") + + self._crl_number = crl_number + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CRLNumber): + return NotImplemented + + return self.crl_number == other.crl_number + + def __hash__(self) -> int: + return hash(self.crl_number) + + def __repr__(self) -> str: + return f"" + + @property + def crl_number(self) -> int: + return self._crl_number + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class AuthorityKeyIdentifier(ExtensionType): + oid = ExtensionOID.AUTHORITY_KEY_IDENTIFIER + + def __init__( + self, + key_identifier: bytes | None, + authority_cert_issuer: typing.Iterable[GeneralName] | None, + authority_cert_serial_number: int | None, + ) -> None: + if (authority_cert_issuer is None) != ( + authority_cert_serial_number is None + ): + raise ValueError( + "authority_cert_issuer and authority_cert_serial_number " + "must both be present or both None" + ) + + if authority_cert_issuer is not None: + authority_cert_issuer = list(authority_cert_issuer) + if not all( + isinstance(x, GeneralName) for x in authority_cert_issuer + ): + raise TypeError( + "authority_cert_issuer must be a list of GeneralName " + "objects" + ) + + if authority_cert_serial_number is not None and not isinstance( + authority_cert_serial_number, int + ): + raise TypeError("authority_cert_serial_number must be an integer") + + self._key_identifier = key_identifier + self._authority_cert_issuer = authority_cert_issuer + self._authority_cert_serial_number = authority_cert_serial_number + + # This takes a subset of CertificatePublicKeyTypes because an issuer + # cannot have an X25519/X448 key. This introduces some unfortunate + # asymmetry that requires typing users to explicitly + # narrow their type, but we should make this accurate and not just + # convenient. + @classmethod + def from_issuer_public_key( + cls, public_key: CertificateIssuerPublicKeyTypes + ) -> AuthorityKeyIdentifier: + digest = _key_identifier_from_public_key(public_key) + return cls( + key_identifier=digest, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ) + + @classmethod + def from_issuer_subject_key_identifier( + cls, ski: SubjectKeyIdentifier + ) -> AuthorityKeyIdentifier: + return cls( + key_identifier=ski.digest, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AuthorityKeyIdentifier): + return NotImplemented + + return ( + self.key_identifier == other.key_identifier + and self.authority_cert_issuer == other.authority_cert_issuer + and self.authority_cert_serial_number + == other.authority_cert_serial_number + ) + + def __hash__(self) -> int: + if self.authority_cert_issuer is None: + aci = None + else: + aci = tuple(self.authority_cert_issuer) + return hash( + (self.key_identifier, aci, self.authority_cert_serial_number) + ) + + @property + def key_identifier(self) -> bytes | None: + return self._key_identifier + + @property + def authority_cert_issuer( + self, + ) -> list[GeneralName] | None: + return self._authority_cert_issuer + + @property + def authority_cert_serial_number(self) -> int | None: + return self._authority_cert_serial_number + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class SubjectKeyIdentifier(ExtensionType): + oid = ExtensionOID.SUBJECT_KEY_IDENTIFIER + + def __init__(self, digest: bytes) -> None: + self._digest = digest + + @classmethod + def from_public_key( + cls, public_key: CertificatePublicKeyTypes + ) -> SubjectKeyIdentifier: + return cls(_key_identifier_from_public_key(public_key)) + + @property + def digest(self) -> bytes: + return self._digest + + @property + def key_identifier(self) -> bytes: + return self._digest + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SubjectKeyIdentifier): + return NotImplemented + + return constant_time.bytes_eq(self.digest, other.digest) + + def __hash__(self) -> int: + return hash(self.digest) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class AuthorityInformationAccess(ExtensionType): + oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS + + def __init__( + self, descriptions: typing.Iterable[AccessDescription] + ) -> None: + descriptions = list(descriptions) + if not all(isinstance(x, AccessDescription) for x in descriptions): + raise TypeError( + "Every item in the descriptions list must be an " + "AccessDescription" + ) + + self._descriptions = descriptions + + __len__, __iter__, __getitem__ = _make_sequence_methods("_descriptions") + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AuthorityInformationAccess): + return NotImplemented + + return self._descriptions == other._descriptions + + def __hash__(self) -> int: + return hash(tuple(self._descriptions)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class SubjectInformationAccess(ExtensionType): + oid = ExtensionOID.SUBJECT_INFORMATION_ACCESS + + def __init__( + self, descriptions: typing.Iterable[AccessDescription] + ) -> None: + descriptions = list(descriptions) + if not all(isinstance(x, AccessDescription) for x in descriptions): + raise TypeError( + "Every item in the descriptions list must be an " + "AccessDescription" + ) + + self._descriptions = descriptions + + __len__, __iter__, __getitem__ = _make_sequence_methods("_descriptions") + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SubjectInformationAccess): + return NotImplemented + + return self._descriptions == other._descriptions + + def __hash__(self) -> int: + return hash(tuple(self._descriptions)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class AccessDescription: + def __init__( + self, access_method: ObjectIdentifier, access_location: GeneralName + ) -> None: + if not isinstance(access_method, ObjectIdentifier): + raise TypeError("access_method must be an ObjectIdentifier") + + if not isinstance(access_location, GeneralName): + raise TypeError("access_location must be a GeneralName") + + self._access_method = access_method + self._access_location = access_location + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AccessDescription): + return NotImplemented + + return ( + self.access_method == other.access_method + and self.access_location == other.access_location + ) + + def __hash__(self) -> int: + return hash((self.access_method, self.access_location)) + + @property + def access_method(self) -> ObjectIdentifier: + return self._access_method + + @property + def access_location(self) -> GeneralName: + return self._access_location + + +class BasicConstraints(ExtensionType): + oid = ExtensionOID.BASIC_CONSTRAINTS + + def __init__(self, ca: bool, path_length: int | None) -> None: + if not isinstance(ca, bool): + raise TypeError("ca must be a boolean value") + + if path_length is not None and not ca: + raise ValueError("path_length must be None when ca is False") + + if path_length is not None and ( + not isinstance(path_length, int) or path_length < 0 + ): + raise TypeError( + "path_length must be a non-negative integer or None" + ) + + self._ca = ca + self._path_length = path_length + + @property + def ca(self) -> bool: + return self._ca + + @property + def path_length(self) -> int | None: + return self._path_length + + def __repr__(self) -> str: + return ( + "" + ).format(self) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BasicConstraints): + return NotImplemented + + return self.ca == other.ca and self.path_length == other.path_length + + def __hash__(self) -> int: + return hash((self.ca, self.path_length)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class DeltaCRLIndicator(ExtensionType): + oid = ExtensionOID.DELTA_CRL_INDICATOR + + def __init__(self, crl_number: int) -> None: + if not isinstance(crl_number, int): + raise TypeError("crl_number must be an integer") + + self._crl_number = crl_number + + @property + def crl_number(self) -> int: + return self._crl_number + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeltaCRLIndicator): + return NotImplemented + + return self.crl_number == other.crl_number + + def __hash__(self) -> int: + return hash(self.crl_number) + + def __repr__(self) -> str: + return f"" + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class CRLDistributionPoints(ExtensionType): + oid = ExtensionOID.CRL_DISTRIBUTION_POINTS + + def __init__( + self, distribution_points: typing.Iterable[DistributionPoint] + ) -> None: + distribution_points = list(distribution_points) + if not all( + isinstance(x, DistributionPoint) for x in distribution_points + ): + raise TypeError( + "distribution_points must be a list of DistributionPoint " + "objects" + ) + + self._distribution_points = distribution_points + + __len__, __iter__, __getitem__ = _make_sequence_methods( + "_distribution_points" + ) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CRLDistributionPoints): + return NotImplemented + + return self._distribution_points == other._distribution_points + + def __hash__(self) -> int: + return hash(tuple(self._distribution_points)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class FreshestCRL(ExtensionType): + oid = ExtensionOID.FRESHEST_CRL + + def __init__( + self, distribution_points: typing.Iterable[DistributionPoint] + ) -> None: + distribution_points = list(distribution_points) + if not all( + isinstance(x, DistributionPoint) for x in distribution_points + ): + raise TypeError( + "distribution_points must be a list of DistributionPoint " + "objects" + ) + + self._distribution_points = distribution_points + + __len__, __iter__, __getitem__ = _make_sequence_methods( + "_distribution_points" + ) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, FreshestCRL): + return NotImplemented + + return self._distribution_points == other._distribution_points + + def __hash__(self) -> int: + return hash(tuple(self._distribution_points)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class DistributionPoint: + def __init__( + self, + full_name: typing.Iterable[GeneralName] | None, + relative_name: RelativeDistinguishedName | None, + reasons: frozenset[ReasonFlags] | None, + crl_issuer: typing.Iterable[GeneralName] | None, + ) -> None: + if full_name and relative_name: + raise ValueError( + "You cannot provide both full_name and relative_name, at " + "least one must be None." + ) + if not full_name and not relative_name and not crl_issuer: + raise ValueError( + "Either full_name, relative_name or crl_issuer must be " + "provided." + ) + + if full_name is not None: + full_name = list(full_name) + if not all(isinstance(x, GeneralName) for x in full_name): + raise TypeError( + "full_name must be a list of GeneralName objects" + ) + + if relative_name: + if not isinstance(relative_name, RelativeDistinguishedName): + raise TypeError( + "relative_name must be a RelativeDistinguishedName" + ) + + if crl_issuer is not None: + crl_issuer = list(crl_issuer) + if not all(isinstance(x, GeneralName) for x in crl_issuer): + raise TypeError( + "crl_issuer must be None or a list of general names" + ) + + if reasons and ( + not isinstance(reasons, frozenset) + or not all(isinstance(x, ReasonFlags) for x in reasons) + ): + raise TypeError("reasons must be None or frozenset of ReasonFlags") + + if reasons and ( + ReasonFlags.unspecified in reasons + or ReasonFlags.remove_from_crl in reasons + ): + raise ValueError( + "unspecified and remove_from_crl are not valid reasons in a " + "DistributionPoint" + ) + + self._full_name = full_name + self._relative_name = relative_name + self._reasons = reasons + self._crl_issuer = crl_issuer + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DistributionPoint): + return NotImplemented + + return ( + self.full_name == other.full_name + and self.relative_name == other.relative_name + and self.reasons == other.reasons + and self.crl_issuer == other.crl_issuer + ) + + def __hash__(self) -> int: + if self.full_name is not None: + fn: tuple[GeneralName, ...] | None = tuple(self.full_name) + else: + fn = None + + if self.crl_issuer is not None: + crl_issuer: tuple[GeneralName, ...] | None = tuple(self.crl_issuer) + else: + crl_issuer = None + + return hash((fn, self.relative_name, self.reasons, crl_issuer)) + + @property + def full_name(self) -> list[GeneralName] | None: + return self._full_name + + @property + def relative_name(self) -> RelativeDistinguishedName | None: + return self._relative_name + + @property + def reasons(self) -> frozenset[ReasonFlags] | None: + return self._reasons + + @property + def crl_issuer(self) -> list[GeneralName] | None: + return self._crl_issuer + + +class ReasonFlags(utils.Enum): + unspecified = "unspecified" + key_compromise = "keyCompromise" + ca_compromise = "cACompromise" + affiliation_changed = "affiliationChanged" + superseded = "superseded" + cessation_of_operation = "cessationOfOperation" + certificate_hold = "certificateHold" + privilege_withdrawn = "privilegeWithdrawn" + aa_compromise = "aACompromise" + remove_from_crl = "removeFromCRL" + + +# These are distribution point bit string mappings. Not to be confused with +# CRLReason reason flags bit string mappings. +# ReasonFlags ::= BIT STRING { +# unused (0), +# keyCompromise (1), +# cACompromise (2), +# affiliationChanged (3), +# superseded (4), +# cessationOfOperation (5), +# certificateHold (6), +# privilegeWithdrawn (7), +# aACompromise (8) } +_REASON_BIT_MAPPING = { + 1: ReasonFlags.key_compromise, + 2: ReasonFlags.ca_compromise, + 3: ReasonFlags.affiliation_changed, + 4: ReasonFlags.superseded, + 5: ReasonFlags.cessation_of_operation, + 6: ReasonFlags.certificate_hold, + 7: ReasonFlags.privilege_withdrawn, + 8: ReasonFlags.aa_compromise, +} + +_CRLREASONFLAGS = { + ReasonFlags.key_compromise: 1, + ReasonFlags.ca_compromise: 2, + ReasonFlags.affiliation_changed: 3, + ReasonFlags.superseded: 4, + ReasonFlags.cessation_of_operation: 5, + ReasonFlags.certificate_hold: 6, + ReasonFlags.privilege_withdrawn: 7, + ReasonFlags.aa_compromise: 8, +} + + +class PolicyConstraints(ExtensionType): + oid = ExtensionOID.POLICY_CONSTRAINTS + + def __init__( + self, + require_explicit_policy: int | None, + inhibit_policy_mapping: int | None, + ) -> None: + if require_explicit_policy is not None and not isinstance( + require_explicit_policy, int + ): + raise TypeError( + "require_explicit_policy must be a non-negative integer or " + "None" + ) + + if inhibit_policy_mapping is not None and not isinstance( + inhibit_policy_mapping, int + ): + raise TypeError( + "inhibit_policy_mapping must be a non-negative integer or None" + ) + + if inhibit_policy_mapping is None and require_explicit_policy is None: + raise ValueError( + "At least one of require_explicit_policy and " + "inhibit_policy_mapping must not be None" + ) + + self._require_explicit_policy = require_explicit_policy + self._inhibit_policy_mapping = inhibit_policy_mapping + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PolicyConstraints): + return NotImplemented + + return ( + self.require_explicit_policy == other.require_explicit_policy + and self.inhibit_policy_mapping == other.inhibit_policy_mapping + ) + + def __hash__(self) -> int: + return hash( + (self.require_explicit_policy, self.inhibit_policy_mapping) + ) + + @property + def require_explicit_policy(self) -> int | None: + return self._require_explicit_policy + + @property + def inhibit_policy_mapping(self) -> int | None: + return self._inhibit_policy_mapping + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class CertificatePolicies(ExtensionType): + oid = ExtensionOID.CERTIFICATE_POLICIES + + def __init__(self, policies: typing.Iterable[PolicyInformation]) -> None: + policies = list(policies) + if not all(isinstance(x, PolicyInformation) for x in policies): + raise TypeError( + "Every item in the policies list must be a " + "PolicyInformation" + ) + + self._policies = policies + + __len__, __iter__, __getitem__ = _make_sequence_methods("_policies") + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CertificatePolicies): + return NotImplemented + + return self._policies == other._policies + + def __hash__(self) -> int: + return hash(tuple(self._policies)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class PolicyInformation: + def __init__( + self, + policy_identifier: ObjectIdentifier, + policy_qualifiers: typing.Iterable[str | UserNotice] | None, + ) -> None: + if not isinstance(policy_identifier, ObjectIdentifier): + raise TypeError("policy_identifier must be an ObjectIdentifier") + + self._policy_identifier = policy_identifier + + if policy_qualifiers is not None: + policy_qualifiers = list(policy_qualifiers) + if not all( + isinstance(x, (str, UserNotice)) for x in policy_qualifiers + ): + raise TypeError( + "policy_qualifiers must be a list of strings and/or " + "UserNotice objects or None" + ) + + self._policy_qualifiers = policy_qualifiers + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PolicyInformation): + return NotImplemented + + return ( + self.policy_identifier == other.policy_identifier + and self.policy_qualifiers == other.policy_qualifiers + ) + + def __hash__(self) -> int: + if self.policy_qualifiers is not None: + pq: tuple[str | UserNotice, ...] | None = tuple( + self.policy_qualifiers + ) + else: + pq = None + + return hash((self.policy_identifier, pq)) + + @property + def policy_identifier(self) -> ObjectIdentifier: + return self._policy_identifier + + @property + def policy_qualifiers( + self, + ) -> list[str | UserNotice] | None: + return self._policy_qualifiers + + +class UserNotice: + def __init__( + self, + notice_reference: NoticeReference | None, + explicit_text: str | None, + ) -> None: + if notice_reference and not isinstance( + notice_reference, NoticeReference + ): + raise TypeError( + "notice_reference must be None or a NoticeReference" + ) + + self._notice_reference = notice_reference + self._explicit_text = explicit_text + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UserNotice): + return NotImplemented + + return ( + self.notice_reference == other.notice_reference + and self.explicit_text == other.explicit_text + ) + + def __hash__(self) -> int: + return hash((self.notice_reference, self.explicit_text)) + + @property + def notice_reference(self) -> NoticeReference | None: + return self._notice_reference + + @property + def explicit_text(self) -> str | None: + return self._explicit_text + + +class NoticeReference: + def __init__( + self, + organization: str | None, + notice_numbers: typing.Iterable[int], + ) -> None: + self._organization = organization + notice_numbers = list(notice_numbers) + if not all(isinstance(x, int) for x in notice_numbers): + raise TypeError("notice_numbers must be a list of integers") + + self._notice_numbers = notice_numbers + + def __repr__(self) -> str: + return ( + "".format(self) + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NoticeReference): + return NotImplemented + + return ( + self.organization == other.organization + and self.notice_numbers == other.notice_numbers + ) + + def __hash__(self) -> int: + return hash((self.organization, tuple(self.notice_numbers))) + + @property + def organization(self) -> str | None: + return self._organization + + @property + def notice_numbers(self) -> list[int]: + return self._notice_numbers + + +class ExtendedKeyUsage(ExtensionType): + oid = ExtensionOID.EXTENDED_KEY_USAGE + + def __init__(self, usages: typing.Iterable[ObjectIdentifier]) -> None: + usages = list(usages) + if not all(isinstance(x, ObjectIdentifier) for x in usages): + raise TypeError( + "Every item in the usages list must be an ObjectIdentifier" + ) + + self._usages = usages + + __len__, __iter__, __getitem__ = _make_sequence_methods("_usages") + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ExtendedKeyUsage): + return NotImplemented + + return self._usages == other._usages + + def __hash__(self) -> int: + return hash(tuple(self._usages)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class OCSPNoCheck(ExtensionType): + oid = ExtensionOID.OCSP_NO_CHECK + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OCSPNoCheck): + return NotImplemented + + return True + + def __hash__(self) -> int: + return hash(OCSPNoCheck) + + def __repr__(self) -> str: + return "" + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class PrecertPoison(ExtensionType): + oid = ExtensionOID.PRECERT_POISON + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PrecertPoison): + return NotImplemented + + return True + + def __hash__(self) -> int: + return hash(PrecertPoison) + + def __repr__(self) -> str: + return "" + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class TLSFeature(ExtensionType): + oid = ExtensionOID.TLS_FEATURE + + def __init__(self, features: typing.Iterable[TLSFeatureType]) -> None: + features = list(features) + if ( + not all(isinstance(x, TLSFeatureType) for x in features) + or len(features) == 0 + ): + raise TypeError( + "features must be a list of elements from the TLSFeatureType " + "enum" + ) + + self._features = features + + __len__, __iter__, __getitem__ = _make_sequence_methods("_features") + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TLSFeature): + return NotImplemented + + return self._features == other._features + + def __hash__(self) -> int: + return hash(tuple(self._features)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class TLSFeatureType(utils.Enum): + # status_request is defined in RFC 6066 and is used for what is commonly + # called OCSP Must-Staple when present in the TLS Feature extension in an + # X.509 certificate. + status_request = 5 + # status_request_v2 is defined in RFC 6961 and allows multiple OCSP + # responses to be provided. It is not currently in use by clients or + # servers. + status_request_v2 = 17 + + +_TLS_FEATURE_TYPE_TO_ENUM = {x.value: x for x in TLSFeatureType} + + +class InhibitAnyPolicy(ExtensionType): + oid = ExtensionOID.INHIBIT_ANY_POLICY + + def __init__(self, skip_certs: int) -> None: + if not isinstance(skip_certs, int): + raise TypeError("skip_certs must be an integer") + + if skip_certs < 0: + raise ValueError("skip_certs must be a non-negative integer") + + self._skip_certs = skip_certs + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InhibitAnyPolicy): + return NotImplemented + + return self.skip_certs == other.skip_certs + + def __hash__(self) -> int: + return hash(self.skip_certs) + + @property + def skip_certs(self) -> int: + return self._skip_certs + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class KeyUsage(ExtensionType): + oid = ExtensionOID.KEY_USAGE + + def __init__( + self, + digital_signature: bool, + content_commitment: bool, + key_encipherment: bool, + data_encipherment: bool, + key_agreement: bool, + key_cert_sign: bool, + crl_sign: bool, + encipher_only: bool, + decipher_only: bool, + ) -> None: + if not key_agreement and (encipher_only or decipher_only): + raise ValueError( + "encipher_only and decipher_only can only be true when " + "key_agreement is true" + ) + + self._digital_signature = digital_signature + self._content_commitment = content_commitment + self._key_encipherment = key_encipherment + self._data_encipherment = data_encipherment + self._key_agreement = key_agreement + self._key_cert_sign = key_cert_sign + self._crl_sign = crl_sign + self._encipher_only = encipher_only + self._decipher_only = decipher_only + + @property + def digital_signature(self) -> bool: + return self._digital_signature + + @property + def content_commitment(self) -> bool: + return self._content_commitment + + @property + def key_encipherment(self) -> bool: + return self._key_encipherment + + @property + def data_encipherment(self) -> bool: + return self._data_encipherment + + @property + def key_agreement(self) -> bool: + return self._key_agreement + + @property + def key_cert_sign(self) -> bool: + return self._key_cert_sign + + @property + def crl_sign(self) -> bool: + return self._crl_sign + + @property + def encipher_only(self) -> bool: + if not self.key_agreement: + raise ValueError( + "encipher_only is undefined unless key_agreement is true" + ) + else: + return self._encipher_only + + @property + def decipher_only(self) -> bool: + if not self.key_agreement: + raise ValueError( + "decipher_only is undefined unless key_agreement is true" + ) + else: + return self._decipher_only + + def __repr__(self) -> str: + try: + encipher_only = self.encipher_only + decipher_only = self.decipher_only + except ValueError: + # Users found None confusing because even though encipher/decipher + # have no meaning unless key_agreement is true, to construct an + # instance of the class you still need to pass False. + encipher_only = False + decipher_only = False + + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, KeyUsage): + return NotImplemented + + return ( + self.digital_signature == other.digital_signature + and self.content_commitment == other.content_commitment + and self.key_encipherment == other.key_encipherment + and self.data_encipherment == other.data_encipherment + and self.key_agreement == other.key_agreement + and self.key_cert_sign == other.key_cert_sign + and self.crl_sign == other.crl_sign + and self._encipher_only == other._encipher_only + and self._decipher_only == other._decipher_only + ) + + def __hash__(self) -> int: + return hash( + ( + self.digital_signature, + self.content_commitment, + self.key_encipherment, + self.data_encipherment, + self.key_agreement, + self.key_cert_sign, + self.crl_sign, + self._encipher_only, + self._decipher_only, + ) + ) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class NameConstraints(ExtensionType): + oid = ExtensionOID.NAME_CONSTRAINTS + + def __init__( + self, + permitted_subtrees: typing.Iterable[GeneralName] | None, + excluded_subtrees: typing.Iterable[GeneralName] | None, + ) -> None: + if permitted_subtrees is not None: + permitted_subtrees = list(permitted_subtrees) + if not permitted_subtrees: + raise ValueError( + "permitted_subtrees must be a non-empty list or None" + ) + if not all(isinstance(x, GeneralName) for x in permitted_subtrees): + raise TypeError( + "permitted_subtrees must be a list of GeneralName objects " + "or None" + ) + + self._validate_tree(permitted_subtrees) + + if excluded_subtrees is not None: + excluded_subtrees = list(excluded_subtrees) + if not excluded_subtrees: + raise ValueError( + "excluded_subtrees must be a non-empty list or None" + ) + if not all(isinstance(x, GeneralName) for x in excluded_subtrees): + raise TypeError( + "excluded_subtrees must be a list of GeneralName objects " + "or None" + ) + + self._validate_tree(excluded_subtrees) + + if permitted_subtrees is None and excluded_subtrees is None: + raise ValueError( + "At least one of permitted_subtrees and excluded_subtrees " + "must not be None" + ) + + self._permitted_subtrees = permitted_subtrees + self._excluded_subtrees = excluded_subtrees + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NameConstraints): + return NotImplemented + + return ( + self.excluded_subtrees == other.excluded_subtrees + and self.permitted_subtrees == other.permitted_subtrees + ) + + def _validate_tree(self, tree: typing.Iterable[GeneralName]) -> None: + self._validate_ip_name(tree) + self._validate_dns_name(tree) + + def _validate_ip_name(self, tree: typing.Iterable[GeneralName]) -> None: + if any( + isinstance(name, IPAddress) + and not isinstance( + name.value, (ipaddress.IPv4Network, ipaddress.IPv6Network) + ) + for name in tree + ): + raise TypeError( + "IPAddress name constraints must be an IPv4Network or" + " IPv6Network object" + ) + + def _validate_dns_name(self, tree: typing.Iterable[GeneralName]) -> None: + if any( + isinstance(name, DNSName) and "*" in name.value for name in tree + ): + raise ValueError( + "DNSName name constraints must not contain the '*' wildcard" + " character" + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + def __hash__(self) -> int: + if self.permitted_subtrees is not None: + ps: tuple[GeneralName, ...] | None = tuple(self.permitted_subtrees) + else: + ps = None + + if self.excluded_subtrees is not None: + es: tuple[GeneralName, ...] | None = tuple(self.excluded_subtrees) + else: + es = None + + return hash((ps, es)) + + @property + def permitted_subtrees( + self, + ) -> list[GeneralName] | None: + return self._permitted_subtrees + + @property + def excluded_subtrees( + self, + ) -> list[GeneralName] | None: + return self._excluded_subtrees + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class Extension(typing.Generic[ExtensionTypeVar]): + def __init__( + self, oid: ObjectIdentifier, critical: bool, value: ExtensionTypeVar + ) -> None: + if not isinstance(oid, ObjectIdentifier): + raise TypeError( + "oid argument must be an ObjectIdentifier instance." + ) + + if not isinstance(critical, bool): + raise TypeError("critical must be a boolean value") + + self._oid = oid + self._critical = critical + self._value = value + + @property + def oid(self) -> ObjectIdentifier: + return self._oid + + @property + def critical(self) -> bool: + return self._critical + + @property + def value(self) -> ExtensionTypeVar: + return self._value + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Extension): + return NotImplemented + + return ( + self.oid == other.oid + and self.critical == other.critical + and self.value == other.value + ) + + def __hash__(self) -> int: + return hash((self.oid, self.critical, self.value)) + + +class GeneralNames: + def __init__(self, general_names: typing.Iterable[GeneralName]) -> None: + general_names = list(general_names) + if not all(isinstance(x, GeneralName) for x in general_names): + raise TypeError( + "Every item in the general_names list must be an " + "object conforming to the GeneralName interface" + ) + + self._general_names = general_names + + __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") + + @typing.overload + def get_values_for_type( + self, + type: type[DNSName] + | type[UniformResourceIdentifier] + | type[RFC822Name], + ) -> list[str]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[DirectoryName], + ) -> list[Name]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[RegisteredID], + ) -> list[ObjectIdentifier]: + ... + + @typing.overload + def get_values_for_type( + self, type: type[IPAddress] + ) -> list[_IPAddressTypes]: + ... + + @typing.overload + def get_values_for_type(self, type: type[OtherName]) -> list[OtherName]: + ... + + def get_values_for_type( + self, + type: type[DNSName] + | type[DirectoryName] + | type[IPAddress] + | type[OtherName] + | type[RFC822Name] + | type[RegisteredID] + | type[UniformResourceIdentifier], + ) -> ( + list[_IPAddressTypes] + | list[str] + | list[OtherName] + | list[Name] + | list[ObjectIdentifier] + ): + # Return the value of each GeneralName, except for OtherName instances + # which we return directly because it has two important properties not + # just one value. + objs = (i for i in self if isinstance(i, type)) + if type != OtherName: + return [i.value for i in objs] + return list(objs) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GeneralNames): + return NotImplemented + + return self._general_names == other._general_names + + def __hash__(self) -> int: + return hash(tuple(self._general_names)) + + +class SubjectAlternativeName(ExtensionType): + oid = ExtensionOID.SUBJECT_ALTERNATIVE_NAME + + def __init__(self, general_names: typing.Iterable[GeneralName]) -> None: + self._general_names = GeneralNames(general_names) + + __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") + + @typing.overload + def get_values_for_type( + self, + type: type[DNSName] + | type[UniformResourceIdentifier] + | type[RFC822Name], + ) -> list[str]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[DirectoryName], + ) -> list[Name]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[RegisteredID], + ) -> list[ObjectIdentifier]: + ... + + @typing.overload + def get_values_for_type( + self, type: type[IPAddress] + ) -> list[_IPAddressTypes]: + ... + + @typing.overload + def get_values_for_type(self, type: type[OtherName]) -> list[OtherName]: + ... + + def get_values_for_type( + self, + type: type[DNSName] + | type[DirectoryName] + | type[IPAddress] + | type[OtherName] + | type[RFC822Name] + | type[RegisteredID] + | type[UniformResourceIdentifier], + ) -> ( + list[_IPAddressTypes] + | list[str] + | list[OtherName] + | list[Name] + | list[ObjectIdentifier] + ): + return self._general_names.get_values_for_type(type) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SubjectAlternativeName): + return NotImplemented + + return self._general_names == other._general_names + + def __hash__(self) -> int: + return hash(self._general_names) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class IssuerAlternativeName(ExtensionType): + oid = ExtensionOID.ISSUER_ALTERNATIVE_NAME + + def __init__(self, general_names: typing.Iterable[GeneralName]) -> None: + self._general_names = GeneralNames(general_names) + + __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") + + @typing.overload + def get_values_for_type( + self, + type: type[DNSName] + | type[UniformResourceIdentifier] + | type[RFC822Name], + ) -> list[str]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[DirectoryName], + ) -> list[Name]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[RegisteredID], + ) -> list[ObjectIdentifier]: + ... + + @typing.overload + def get_values_for_type( + self, type: type[IPAddress] + ) -> list[_IPAddressTypes]: + ... + + @typing.overload + def get_values_for_type(self, type: type[OtherName]) -> list[OtherName]: + ... + + def get_values_for_type( + self, + type: type[DNSName] + | type[DirectoryName] + | type[IPAddress] + | type[OtherName] + | type[RFC822Name] + | type[RegisteredID] + | type[UniformResourceIdentifier], + ) -> ( + list[_IPAddressTypes] + | list[str] + | list[OtherName] + | list[Name] + | list[ObjectIdentifier] + ): + return self._general_names.get_values_for_type(type) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IssuerAlternativeName): + return NotImplemented + + return self._general_names == other._general_names + + def __hash__(self) -> int: + return hash(self._general_names) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class CertificateIssuer(ExtensionType): + oid = CRLEntryExtensionOID.CERTIFICATE_ISSUER + + def __init__(self, general_names: typing.Iterable[GeneralName]) -> None: + self._general_names = GeneralNames(general_names) + + __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") + + @typing.overload + def get_values_for_type( + self, + type: type[DNSName] + | type[UniformResourceIdentifier] + | type[RFC822Name], + ) -> list[str]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[DirectoryName], + ) -> list[Name]: + ... + + @typing.overload + def get_values_for_type( + self, + type: type[RegisteredID], + ) -> list[ObjectIdentifier]: + ... + + @typing.overload + def get_values_for_type( + self, type: type[IPAddress] + ) -> list[_IPAddressTypes]: + ... + + @typing.overload + def get_values_for_type(self, type: type[OtherName]) -> list[OtherName]: + ... + + def get_values_for_type( + self, + type: type[DNSName] + | type[DirectoryName] + | type[IPAddress] + | type[OtherName] + | type[RFC822Name] + | type[RegisteredID] + | type[UniformResourceIdentifier], + ) -> ( + list[_IPAddressTypes] + | list[str] + | list[OtherName] + | list[Name] + | list[ObjectIdentifier] + ): + return self._general_names.get_values_for_type(type) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CertificateIssuer): + return NotImplemented + + return self._general_names == other._general_names + + def __hash__(self) -> int: + return hash(self._general_names) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class CRLReason(ExtensionType): + oid = CRLEntryExtensionOID.CRL_REASON + + def __init__(self, reason: ReasonFlags) -> None: + if not isinstance(reason, ReasonFlags): + raise TypeError("reason must be an element from ReasonFlags") + + self._reason = reason + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CRLReason): + return NotImplemented + + return self.reason == other.reason + + def __hash__(self) -> int: + return hash(self.reason) + + @property + def reason(self) -> ReasonFlags: + return self._reason + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class InvalidityDate(ExtensionType): + oid = CRLEntryExtensionOID.INVALIDITY_DATE + + def __init__(self, invalidity_date: datetime.datetime) -> None: + if not isinstance(invalidity_date, datetime.datetime): + raise TypeError("invalidity_date must be a datetime.datetime") + + self._invalidity_date = invalidity_date + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InvalidityDate): + return NotImplemented + + return self.invalidity_date == other.invalidity_date + + def __hash__(self) -> int: + return hash(self.invalidity_date) + + @property + def invalidity_date(self) -> datetime.datetime: + return self._invalidity_date + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class PrecertificateSignedCertificateTimestamps(ExtensionType): + oid = ExtensionOID.PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS + + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ) -> None: + signed_certificate_timestamps = list(signed_certificate_timestamps) + if not all( + isinstance(sct, SignedCertificateTimestamp) + for sct in signed_certificate_timestamps + ): + raise TypeError( + "Every item in the signed_certificate_timestamps list must be " + "a SignedCertificateTimestamp" + ) + self._signed_certificate_timestamps = signed_certificate_timestamps + + __len__, __iter__, __getitem__ = _make_sequence_methods( + "_signed_certificate_timestamps" + ) + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(tuple(self._signed_certificate_timestamps)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PrecertificateSignedCertificateTimestamps): + return NotImplemented + + return ( + self._signed_certificate_timestamps + == other._signed_certificate_timestamps + ) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class SignedCertificateTimestamps(ExtensionType): + oid = ExtensionOID.SIGNED_CERTIFICATE_TIMESTAMPS + + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ) -> None: + signed_certificate_timestamps = list(signed_certificate_timestamps) + if not all( + isinstance(sct, SignedCertificateTimestamp) + for sct in signed_certificate_timestamps + ): + raise TypeError( + "Every item in the signed_certificate_timestamps list must be " + "a SignedCertificateTimestamp" + ) + self._signed_certificate_timestamps = signed_certificate_timestamps + + __len__, __iter__, __getitem__ = _make_sequence_methods( + "_signed_certificate_timestamps" + ) + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(tuple(self._signed_certificate_timestamps)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SignedCertificateTimestamps): + return NotImplemented + + return ( + self._signed_certificate_timestamps + == other._signed_certificate_timestamps + ) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class OCSPNonce(ExtensionType): + oid = OCSPExtensionOID.NONCE + + def __init__(self, nonce: bytes) -> None: + if not isinstance(nonce, bytes): + raise TypeError("nonce must be bytes") + + self._nonce = nonce + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OCSPNonce): + return NotImplemented + + return self.nonce == other.nonce + + def __hash__(self) -> int: + return hash(self.nonce) + + def __repr__(self) -> str: + return f"" + + @property + def nonce(self) -> bytes: + return self._nonce + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class OCSPAcceptableResponses(ExtensionType): + oid = OCSPExtensionOID.ACCEPTABLE_RESPONSES + + def __init__(self, responses: typing.Iterable[ObjectIdentifier]) -> None: + responses = list(responses) + if any(not isinstance(r, ObjectIdentifier) for r in responses): + raise TypeError("All responses must be ObjectIdentifiers") + + self._responses = responses + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OCSPAcceptableResponses): + return NotImplemented + + return self._responses == other._responses + + def __hash__(self) -> int: + return hash(tuple(self._responses)) + + def __repr__(self) -> str: + return f"" + + def __iter__(self) -> typing.Iterator[ObjectIdentifier]: + return iter(self._responses) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class IssuingDistributionPoint(ExtensionType): + oid = ExtensionOID.ISSUING_DISTRIBUTION_POINT + + def __init__( + self, + full_name: typing.Iterable[GeneralName] | None, + relative_name: RelativeDistinguishedName | None, + only_contains_user_certs: bool, + only_contains_ca_certs: bool, + only_some_reasons: frozenset[ReasonFlags] | None, + indirect_crl: bool, + only_contains_attribute_certs: bool, + ) -> None: + if full_name is not None: + full_name = list(full_name) + + if only_some_reasons and ( + not isinstance(only_some_reasons, frozenset) + or not all(isinstance(x, ReasonFlags) for x in only_some_reasons) + ): + raise TypeError( + "only_some_reasons must be None or frozenset of ReasonFlags" + ) + + if only_some_reasons and ( + ReasonFlags.unspecified in only_some_reasons + or ReasonFlags.remove_from_crl in only_some_reasons + ): + raise ValueError( + "unspecified and remove_from_crl are not valid reasons in an " + "IssuingDistributionPoint" + ) + + if not ( + isinstance(only_contains_user_certs, bool) + and isinstance(only_contains_ca_certs, bool) + and isinstance(indirect_crl, bool) + and isinstance(only_contains_attribute_certs, bool) + ): + raise TypeError( + "only_contains_user_certs, only_contains_ca_certs, " + "indirect_crl and only_contains_attribute_certs " + "must all be boolean." + ) + + crl_constraints = [ + only_contains_user_certs, + only_contains_ca_certs, + indirect_crl, + only_contains_attribute_certs, + ] + + if len([x for x in crl_constraints if x]) > 1: + raise ValueError( + "Only one of the following can be set to True: " + "only_contains_user_certs, only_contains_ca_certs, " + "indirect_crl, only_contains_attribute_certs" + ) + + if not any( + [ + only_contains_user_certs, + only_contains_ca_certs, + indirect_crl, + only_contains_attribute_certs, + full_name, + relative_name, + only_some_reasons, + ] + ): + raise ValueError( + "Cannot create empty extension: " + "if only_contains_user_certs, only_contains_ca_certs, " + "indirect_crl, and only_contains_attribute_certs are all False" + ", then either full_name, relative_name, or only_some_reasons " + "must have a value." + ) + + self._only_contains_user_certs = only_contains_user_certs + self._only_contains_ca_certs = only_contains_ca_certs + self._indirect_crl = indirect_crl + self._only_contains_attribute_certs = only_contains_attribute_certs + self._only_some_reasons = only_some_reasons + self._full_name = full_name + self._relative_name = relative_name + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IssuingDistributionPoint): + return NotImplemented + + return ( + self.full_name == other.full_name + and self.relative_name == other.relative_name + and self.only_contains_user_certs == other.only_contains_user_certs + and self.only_contains_ca_certs == other.only_contains_ca_certs + and self.only_some_reasons == other.only_some_reasons + and self.indirect_crl == other.indirect_crl + and self.only_contains_attribute_certs + == other.only_contains_attribute_certs + ) + + def __hash__(self) -> int: + return hash( + ( + self.full_name, + self.relative_name, + self.only_contains_user_certs, + self.only_contains_ca_certs, + self.only_some_reasons, + self.indirect_crl, + self.only_contains_attribute_certs, + ) + ) + + @property + def full_name(self) -> list[GeneralName] | None: + return self._full_name + + @property + def relative_name(self) -> RelativeDistinguishedName | None: + return self._relative_name + + @property + def only_contains_user_certs(self) -> bool: + return self._only_contains_user_certs + + @property + def only_contains_ca_certs(self) -> bool: + return self._only_contains_ca_certs + + @property + def only_some_reasons( + self, + ) -> frozenset[ReasonFlags] | None: + return self._only_some_reasons + + @property + def indirect_crl(self) -> bool: + return self._indirect_crl + + @property + def only_contains_attribute_certs(self) -> bool: + return self._only_contains_attribute_certs + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class MSCertificateTemplate(ExtensionType): + oid = ExtensionOID.MS_CERTIFICATE_TEMPLATE + + def __init__( + self, + template_id: ObjectIdentifier, + major_version: int | None, + minor_version: int | None, + ) -> None: + if not isinstance(template_id, ObjectIdentifier): + raise TypeError("oid must be an ObjectIdentifier") + self._template_id = template_id + if ( + major_version is not None and not isinstance(major_version, int) + ) or ( + minor_version is not None and not isinstance(minor_version, int) + ): + raise TypeError( + "major_version and minor_version must be integers or None" + ) + self._major_version = major_version + self._minor_version = minor_version + + @property + def template_id(self) -> ObjectIdentifier: + return self._template_id + + @property + def major_version(self) -> int | None: + return self._major_version + + @property + def minor_version(self) -> int | None: + return self._minor_version + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MSCertificateTemplate): + return NotImplemented + + return ( + self.template_id == other.template_id + and self.major_version == other.major_version + and self.minor_version == other.minor_version + ) + + def __hash__(self) -> int: + return hash((self.template_id, self.major_version, self.minor_version)) + + def public_bytes(self) -> bytes: + return rust_x509.encode_extension_value(self) + + +class UnrecognizedExtension(ExtensionType): + def __init__(self, oid: ObjectIdentifier, value: bytes) -> None: + if not isinstance(oid, ObjectIdentifier): + raise TypeError("oid must be an ObjectIdentifier") + self._oid = oid + self._value = value + + @property + def oid(self) -> ObjectIdentifier: # type: ignore[override] + return self._oid + + @property + def value(self) -> bytes: + return self._value + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UnrecognizedExtension): + return NotImplemented + + return self.oid == other.oid and self.value == other.value + + def __hash__(self) -> int: + return hash((self.oid, self.value)) + + def public_bytes(self) -> bytes: + return self.value diff --git a/.venv/Lib/site-packages/cryptography/x509/general_name.py b/.venv/Lib/site-packages/cryptography/x509/general_name.py new file mode 100644 index 00000000..672f2875 --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/general_name.py @@ -0,0 +1,281 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import ipaddress +import typing +from email.utils import parseaddr + +from cryptography.x509.name import Name +from cryptography.x509.oid import ObjectIdentifier + +_IPAddressTypes = typing.Union[ + ipaddress.IPv4Address, + ipaddress.IPv6Address, + ipaddress.IPv4Network, + ipaddress.IPv6Network, +] + + +class UnsupportedGeneralNameType(Exception): + pass + + +class GeneralName(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def value(self) -> typing.Any: + """ + Return the value of the object + """ + + +class RFC822Name(GeneralName): + def __init__(self, value: str) -> None: + if isinstance(value, str): + try: + value.encode("ascii") + except UnicodeEncodeError: + raise ValueError( + "RFC822Name values should be passed as an A-label string. " + "This means unicode characters should be encoded via " + "a library like idna." + ) + else: + raise TypeError("value must be string") + + name, address = parseaddr(value) + if name or not address: + # parseaddr has found a name (e.g. Name ) or the entire + # value is an empty string. + raise ValueError("Invalid rfc822name value") + + self._value = value + + @property + def value(self) -> str: + return self._value + + @classmethod + def _init_without_validation(cls, value: str) -> RFC822Name: + instance = cls.__new__(cls) + instance._value = value + return instance + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RFC822Name): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class DNSName(GeneralName): + def __init__(self, value: str) -> None: + if isinstance(value, str): + try: + value.encode("ascii") + except UnicodeEncodeError: + raise ValueError( + "DNSName values should be passed as an A-label string. " + "This means unicode characters should be encoded via " + "a library like idna." + ) + else: + raise TypeError("value must be string") + + self._value = value + + @property + def value(self) -> str: + return self._value + + @classmethod + def _init_without_validation(cls, value: str) -> DNSName: + instance = cls.__new__(cls) + instance._value = value + return instance + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DNSName): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class UniformResourceIdentifier(GeneralName): + def __init__(self, value: str) -> None: + if isinstance(value, str): + try: + value.encode("ascii") + except UnicodeEncodeError: + raise ValueError( + "URI values should be passed as an A-label string. " + "This means unicode characters should be encoded via " + "a library like idna." + ) + else: + raise TypeError("value must be string") + + self._value = value + + @property + def value(self) -> str: + return self._value + + @classmethod + def _init_without_validation(cls, value: str) -> UniformResourceIdentifier: + instance = cls.__new__(cls) + instance._value = value + return instance + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, UniformResourceIdentifier): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class DirectoryName(GeneralName): + def __init__(self, value: Name) -> None: + if not isinstance(value, Name): + raise TypeError("value must be a Name") + + self._value = value + + @property + def value(self) -> Name: + return self._value + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DirectoryName): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class RegisteredID(GeneralName): + def __init__(self, value: ObjectIdentifier) -> None: + if not isinstance(value, ObjectIdentifier): + raise TypeError("value must be an ObjectIdentifier") + + self._value = value + + @property + def value(self) -> ObjectIdentifier: + return self._value + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RegisteredID): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class IPAddress(GeneralName): + def __init__(self, value: _IPAddressTypes) -> None: + if not isinstance( + value, + ( + ipaddress.IPv4Address, + ipaddress.IPv6Address, + ipaddress.IPv4Network, + ipaddress.IPv6Network, + ), + ): + raise TypeError( + "value must be an instance of ipaddress.IPv4Address, " + "ipaddress.IPv6Address, ipaddress.IPv4Network, or " + "ipaddress.IPv6Network" + ) + + self._value = value + + @property + def value(self) -> _IPAddressTypes: + return self._value + + def _packed(self) -> bytes: + if isinstance( + self.value, (ipaddress.IPv4Address, ipaddress.IPv6Address) + ): + return self.value.packed + else: + return ( + self.value.network_address.packed + self.value.netmask.packed + ) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IPAddress): + return NotImplemented + + return self.value == other.value + + def __hash__(self) -> int: + return hash(self.value) + + +class OtherName(GeneralName): + def __init__(self, type_id: ObjectIdentifier, value: bytes) -> None: + if not isinstance(type_id, ObjectIdentifier): + raise TypeError("type_id must be an ObjectIdentifier") + if not isinstance(value, bytes): + raise TypeError("value must be a binary string") + + self._type_id = type_id + self._value = value + + @property + def type_id(self) -> ObjectIdentifier: + return self._type_id + + @property + def value(self) -> bytes: + return self._value + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OtherName): + return NotImplemented + + return self.type_id == other.type_id and self.value == other.value + + def __hash__(self) -> int: + return hash((self.type_id, self.value)) diff --git a/.venv/Lib/site-packages/cryptography/x509/name.py b/.venv/Lib/site-packages/cryptography/x509/name.py new file mode 100644 index 00000000..5e8ccfff --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/name.py @@ -0,0 +1,456 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import binascii +import re +import sys +import typing +import warnings + +from cryptography import utils +from cryptography.hazmat.bindings._rust import x509 as rust_x509 +from cryptography.x509.oid import NameOID, ObjectIdentifier + + +class _ASN1Type(utils.Enum): + BitString = 3 + OctetString = 4 + UTF8String = 12 + NumericString = 18 + PrintableString = 19 + T61String = 20 + IA5String = 22 + UTCTime = 23 + GeneralizedTime = 24 + VisibleString = 26 + UniversalString = 28 + BMPString = 30 + + +_ASN1_TYPE_TO_ENUM = {i.value: i for i in _ASN1Type} +_NAMEOID_DEFAULT_TYPE: dict[ObjectIdentifier, _ASN1Type] = { + NameOID.COUNTRY_NAME: _ASN1Type.PrintableString, + NameOID.JURISDICTION_COUNTRY_NAME: _ASN1Type.PrintableString, + NameOID.SERIAL_NUMBER: _ASN1Type.PrintableString, + NameOID.DN_QUALIFIER: _ASN1Type.PrintableString, + NameOID.EMAIL_ADDRESS: _ASN1Type.IA5String, + NameOID.DOMAIN_COMPONENT: _ASN1Type.IA5String, +} + +# Type alias +_OidNameMap = typing.Mapping[ObjectIdentifier, str] +_NameOidMap = typing.Mapping[str, ObjectIdentifier] + +#: Short attribute names from RFC 4514: +#: https://tools.ietf.org/html/rfc4514#page-7 +_NAMEOID_TO_NAME: _OidNameMap = { + NameOID.COMMON_NAME: "CN", + NameOID.LOCALITY_NAME: "L", + NameOID.STATE_OR_PROVINCE_NAME: "ST", + NameOID.ORGANIZATION_NAME: "O", + NameOID.ORGANIZATIONAL_UNIT_NAME: "OU", + NameOID.COUNTRY_NAME: "C", + NameOID.STREET_ADDRESS: "STREET", + NameOID.DOMAIN_COMPONENT: "DC", + NameOID.USER_ID: "UID", +} +_NAME_TO_NAMEOID = {v: k for k, v in _NAMEOID_TO_NAME.items()} + + +def _escape_dn_value(val: str | bytes) -> str: + """Escape special characters in RFC4514 Distinguished Name value.""" + + if not val: + return "" + + # RFC 4514 Section 2.4 defines the value as being the # (U+0023) character + # followed by the hexadecimal encoding of the octets. + if isinstance(val, bytes): + return "#" + binascii.hexlify(val).decode("utf8") + + # See https://tools.ietf.org/html/rfc4514#section-2.4 + val = val.replace("\\", "\\\\") + val = val.replace('"', '\\"') + val = val.replace("+", "\\+") + val = val.replace(",", "\\,") + val = val.replace(";", "\\;") + val = val.replace("<", "\\<") + val = val.replace(">", "\\>") + val = val.replace("\0", "\\00") + + if val[0] in ("#", " "): + val = "\\" + val + if val[-1] == " ": + val = val[:-1] + "\\ " + + return val + + +def _unescape_dn_value(val: str) -> str: + if not val: + return "" + + # See https://tools.ietf.org/html/rfc4514#section-3 + + # special = escaped / SPACE / SHARP / EQUALS + # escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE + def sub(m): + val = m.group(1) + # Regular escape + if len(val) == 1: + return val + # Hex-value scape + return chr(int(val, 16)) + + return _RFC4514NameParser._PAIR_RE.sub(sub, val) + + +class NameAttribute: + def __init__( + self, + oid: ObjectIdentifier, + value: str | bytes, + _type: _ASN1Type | None = None, + *, + _validate: bool = True, + ) -> None: + if not isinstance(oid, ObjectIdentifier): + raise TypeError( + "oid argument must be an ObjectIdentifier instance." + ) + if _type == _ASN1Type.BitString: + if oid != NameOID.X500_UNIQUE_IDENTIFIER: + raise TypeError( + "oid must be X500_UNIQUE_IDENTIFIER for BitString type." + ) + if not isinstance(value, bytes): + raise TypeError("value must be bytes for BitString") + else: + if not isinstance(value, str): + raise TypeError("value argument must be a str") + + if oid in (NameOID.COUNTRY_NAME, NameOID.JURISDICTION_COUNTRY_NAME): + assert isinstance(value, str) + c_len = len(value.encode("utf8")) + if c_len != 2 and _validate is True: + raise ValueError( + "Country name must be a 2 character country code" + ) + elif c_len != 2: + warnings.warn( + "Country names should be two characters, but the " + f"attribute is {c_len} characters in length.", + stacklevel=2, + ) + + # The appropriate ASN1 string type varies by OID and is defined across + # multiple RFCs including 2459, 3280, and 5280. In general UTF8String + # is preferred (2459), but 3280 and 5280 specify several OIDs with + # alternate types. This means when we see the sentinel value we need + # to look up whether the OID has a non-UTF8 type. If it does, set it + # to that. Otherwise, UTF8! + if _type is None: + _type = _NAMEOID_DEFAULT_TYPE.get(oid, _ASN1Type.UTF8String) + + if not isinstance(_type, _ASN1Type): + raise TypeError("_type must be from the _ASN1Type enum") + + self._oid = oid + self._value = value + self._type = _type + + @property + def oid(self) -> ObjectIdentifier: + return self._oid + + @property + def value(self) -> str | bytes: + return self._value + + @property + def rfc4514_attribute_name(self) -> str: + """ + The short attribute name (for example "CN") if available, + otherwise the OID dotted string. + """ + return _NAMEOID_TO_NAME.get(self.oid, self.oid.dotted_string) + + def rfc4514_string( + self, attr_name_overrides: _OidNameMap | None = None + ) -> str: + """ + Format as RFC4514 Distinguished Name string. + + Use short attribute name if available, otherwise fall back to OID + dotted string. + """ + attr_name = ( + attr_name_overrides.get(self.oid) if attr_name_overrides else None + ) + if attr_name is None: + attr_name = self.rfc4514_attribute_name + + return f"{attr_name}={_escape_dn_value(self.value)}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NameAttribute): + return NotImplemented + + return self.oid == other.oid and self.value == other.value + + def __hash__(self) -> int: + return hash((self.oid, self.value)) + + def __repr__(self) -> str: + return f"" + + +class RelativeDistinguishedName: + def __init__(self, attributes: typing.Iterable[NameAttribute]): + attributes = list(attributes) + if not attributes: + raise ValueError("a relative distinguished name cannot be empty") + if not all(isinstance(x, NameAttribute) for x in attributes): + raise TypeError("attributes must be an iterable of NameAttribute") + + # Keep list and frozenset to preserve attribute order where it matters + self._attributes = attributes + self._attribute_set = frozenset(attributes) + + if len(self._attribute_set) != len(attributes): + raise ValueError("duplicate attributes are not allowed") + + def get_attributes_for_oid( + self, oid: ObjectIdentifier + ) -> list[NameAttribute]: + return [i for i in self if i.oid == oid] + + def rfc4514_string( + self, attr_name_overrides: _OidNameMap | None = None + ) -> str: + """ + Format as RFC4514 Distinguished Name string. + + Within each RDN, attributes are joined by '+', although that is rarely + used in certificates. + """ + return "+".join( + attr.rfc4514_string(attr_name_overrides) + for attr in self._attributes + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RelativeDistinguishedName): + return NotImplemented + + return self._attribute_set == other._attribute_set + + def __hash__(self) -> int: + return hash(self._attribute_set) + + def __iter__(self) -> typing.Iterator[NameAttribute]: + return iter(self._attributes) + + def __len__(self) -> int: + return len(self._attributes) + + def __repr__(self) -> str: + return f"" + + +class Name: + @typing.overload + def __init__(self, attributes: typing.Iterable[NameAttribute]) -> None: + ... + + @typing.overload + def __init__( + self, attributes: typing.Iterable[RelativeDistinguishedName] + ) -> None: + ... + + def __init__( + self, + attributes: typing.Iterable[NameAttribute | RelativeDistinguishedName], + ) -> None: + attributes = list(attributes) + if all(isinstance(x, NameAttribute) for x in attributes): + self._attributes = [ + RelativeDistinguishedName([typing.cast(NameAttribute, x)]) + for x in attributes + ] + elif all(isinstance(x, RelativeDistinguishedName) for x in attributes): + self._attributes = typing.cast( + typing.List[RelativeDistinguishedName], attributes + ) + else: + raise TypeError( + "attributes must be a list of NameAttribute" + " or a list RelativeDistinguishedName" + ) + + @classmethod + def from_rfc4514_string( + cls, + data: str, + attr_name_overrides: _NameOidMap | None = None, + ) -> Name: + return _RFC4514NameParser(data, attr_name_overrides or {}).parse() + + def rfc4514_string( + self, attr_name_overrides: _OidNameMap | None = None + ) -> str: + """ + Format as RFC4514 Distinguished Name string. + For example 'CN=foobar.com,O=Foo Corp,C=US' + + An X.509 name is a two-level structure: a list of sets of attributes. + Each list element is separated by ',' and within each list element, set + elements are separated by '+'. The latter is almost never used in + real world certificates. According to RFC4514 section 2.1 the + RDNSequence must be reversed when converting to string representation. + """ + return ",".join( + attr.rfc4514_string(attr_name_overrides) + for attr in reversed(self._attributes) + ) + + def get_attributes_for_oid( + self, oid: ObjectIdentifier + ) -> list[NameAttribute]: + return [i for i in self if i.oid == oid] + + @property + def rdns(self) -> list[RelativeDistinguishedName]: + return self._attributes + + def public_bytes(self, backend: typing.Any = None) -> bytes: + return rust_x509.encode_name_bytes(self) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Name): + return NotImplemented + + return self._attributes == other._attributes + + def __hash__(self) -> int: + # TODO: this is relatively expensive, if this looks like a bottleneck + # for you, consider optimizing! + return hash(tuple(self._attributes)) + + def __iter__(self) -> typing.Iterator[NameAttribute]: + for rdn in self._attributes: + yield from rdn + + def __len__(self) -> int: + return sum(len(rdn) for rdn in self._attributes) + + def __repr__(self) -> str: + rdns = ",".join(attr.rfc4514_string() for attr in self._attributes) + return f"" + + +class _RFC4514NameParser: + _OID_RE = re.compile(r"(0|([1-9]\d*))(\.(0|([1-9]\d*)))+") + _DESCR_RE = re.compile(r"[a-zA-Z][a-zA-Z\d-]*") + + _PAIR = r"\\([\\ #=\"\+,;<>]|[\da-zA-Z]{2})" + _PAIR_RE = re.compile(_PAIR) + _LUTF1 = r"[\x01-\x1f\x21\x24-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" + _SUTF1 = r"[\x01-\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" + _TUTF1 = r"[\x01-\x1F\x21\x23-\x2A\x2D-\x3A\x3D\x3F-\x5B\x5D-\x7F]" + _UTFMB = rf"[\x80-{chr(sys.maxunicode)}]" + _LEADCHAR = rf"{_LUTF1}|{_UTFMB}" + _STRINGCHAR = rf"{_SUTF1}|{_UTFMB}" + _TRAILCHAR = rf"{_TUTF1}|{_UTFMB}" + _STRING_RE = re.compile( + rf""" + ( + ({_LEADCHAR}|{_PAIR}) + ( + ({_STRINGCHAR}|{_PAIR})* + ({_TRAILCHAR}|{_PAIR}) + )? + )? + """, + re.VERBOSE, + ) + _HEXSTRING_RE = re.compile(r"#([\da-zA-Z]{2})+") + + def __init__(self, data: str, attr_name_overrides: _NameOidMap) -> None: + self._data = data + self._idx = 0 + + self._attr_name_overrides = attr_name_overrides + + def _has_data(self) -> bool: + return self._idx < len(self._data) + + def _peek(self) -> str | None: + if self._has_data(): + return self._data[self._idx] + return None + + def _read_char(self, ch: str) -> None: + if self._peek() != ch: + raise ValueError + self._idx += 1 + + def _read_re(self, pat) -> str: + match = pat.match(self._data, pos=self._idx) + if match is None: + raise ValueError + val = match.group() + self._idx += len(val) + return val + + def parse(self) -> Name: + """ + Parses the `data` string and converts it to a Name. + + According to RFC4514 section 2.1 the RDNSequence must be + reversed when converting to string representation. So, when + we parse it, we need to reverse again to get the RDNs on the + correct order. + """ + rdns = [self._parse_rdn()] + + while self._has_data(): + self._read_char(",") + rdns.append(self._parse_rdn()) + + return Name(reversed(rdns)) + + def _parse_rdn(self) -> RelativeDistinguishedName: + nas = [self._parse_na()] + while self._peek() == "+": + self._read_char("+") + nas.append(self._parse_na()) + + return RelativeDistinguishedName(nas) + + def _parse_na(self) -> NameAttribute: + try: + oid_value = self._read_re(self._OID_RE) + except ValueError: + name = self._read_re(self._DESCR_RE) + oid = self._attr_name_overrides.get( + name, _NAME_TO_NAMEOID.get(name) + ) + if oid is None: + raise ValueError + else: + oid = ObjectIdentifier(oid_value) + + self._read_char("=") + if self._peek() == "#": + value = self._read_re(self._HEXSTRING_RE) + value = binascii.unhexlify(value[1:]).decode() + else: + raw_value = self._read_re(self._STRING_RE) + value = _unescape_dn_value(raw_value) + + return NameAttribute(oid, value) diff --git a/.venv/Lib/site-packages/cryptography/x509/ocsp.py b/.venv/Lib/site-packages/cryptography/x509/ocsp.py new file mode 100644 index 00000000..9751ceaf --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/ocsp.py @@ -0,0 +1,615 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import abc +import datetime +import typing + +from cryptography import utils, x509 +from cryptography.hazmat.bindings._rust import ocsp +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric.types import ( + CertificateIssuerPrivateKeyTypes, +) +from cryptography.x509.base import ( + _EARLIEST_UTC_TIME, + _convert_to_naive_utc_time, + _reject_duplicate_extension, +) + + +class OCSPResponderEncoding(utils.Enum): + HASH = "By Hash" + NAME = "By Name" + + +class OCSPResponseStatus(utils.Enum): + SUCCESSFUL = 0 + MALFORMED_REQUEST = 1 + INTERNAL_ERROR = 2 + TRY_LATER = 3 + SIG_REQUIRED = 5 + UNAUTHORIZED = 6 + + +_ALLOWED_HASHES = ( + hashes.SHA1, + hashes.SHA224, + hashes.SHA256, + hashes.SHA384, + hashes.SHA512, +) + + +def _verify_algorithm(algorithm: hashes.HashAlgorithm) -> None: + if not isinstance(algorithm, _ALLOWED_HASHES): + raise ValueError( + "Algorithm must be SHA1, SHA224, SHA256, SHA384, or SHA512" + ) + + +class OCSPCertStatus(utils.Enum): + GOOD = 0 + REVOKED = 1 + UNKNOWN = 2 + + +class _SingleResponse: + def __init__( + self, + cert: x509.Certificate, + issuer: x509.Certificate, + algorithm: hashes.HashAlgorithm, + cert_status: OCSPCertStatus, + this_update: datetime.datetime, + next_update: datetime.datetime | None, + revocation_time: datetime.datetime | None, + revocation_reason: x509.ReasonFlags | None, + ): + if not isinstance(cert, x509.Certificate) or not isinstance( + issuer, x509.Certificate + ): + raise TypeError("cert and issuer must be a Certificate") + + _verify_algorithm(algorithm) + if not isinstance(this_update, datetime.datetime): + raise TypeError("this_update must be a datetime object") + if next_update is not None and not isinstance( + next_update, datetime.datetime + ): + raise TypeError("next_update must be a datetime object or None") + + self._cert = cert + self._issuer = issuer + self._algorithm = algorithm + self._this_update = this_update + self._next_update = next_update + + if not isinstance(cert_status, OCSPCertStatus): + raise TypeError( + "cert_status must be an item from the OCSPCertStatus enum" + ) + if cert_status is not OCSPCertStatus.REVOKED: + if revocation_time is not None: + raise ValueError( + "revocation_time can only be provided if the certificate " + "is revoked" + ) + if revocation_reason is not None: + raise ValueError( + "revocation_reason can only be provided if the certificate" + " is revoked" + ) + else: + if not isinstance(revocation_time, datetime.datetime): + raise TypeError("revocation_time must be a datetime object") + + revocation_time = _convert_to_naive_utc_time(revocation_time) + if revocation_time < _EARLIEST_UTC_TIME: + raise ValueError( + "The revocation_time must be on or after" + " 1950 January 1." + ) + + if revocation_reason is not None and not isinstance( + revocation_reason, x509.ReasonFlags + ): + raise TypeError( + "revocation_reason must be an item from the ReasonFlags " + "enum or None" + ) + + self._cert_status = cert_status + self._revocation_time = revocation_time + self._revocation_reason = revocation_reason + + +class OCSPRequest(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def issuer_key_hash(self) -> bytes: + """ + The hash of the issuer public key + """ + + @property + @abc.abstractmethod + def issuer_name_hash(self) -> bytes: + """ + The hash of the issuer name + """ + + @property + @abc.abstractmethod + def hash_algorithm(self) -> hashes.HashAlgorithm: + """ + The hash algorithm used in the issuer name and key hashes + """ + + @property + @abc.abstractmethod + def serial_number(self) -> int: + """ + The serial number of the cert whose status is being checked + """ + + @abc.abstractmethod + def public_bytes(self, encoding: serialization.Encoding) -> bytes: + """ + Serializes the request to DER + """ + + @property + @abc.abstractmethod + def extensions(self) -> x509.Extensions: + """ + The list of request extensions. Not single request extensions. + """ + + +class OCSPSingleResponse(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def certificate_status(self) -> OCSPCertStatus: + """ + The status of the certificate (an element from the OCSPCertStatus enum) + """ + + @property + @abc.abstractmethod + def revocation_time(self) -> datetime.datetime | None: + """ + The date of when the certificate was revoked or None if not + revoked. + """ + + @property + @abc.abstractmethod + def revocation_reason(self) -> x509.ReasonFlags | None: + """ + The reason the certificate was revoked or None if not specified or + not revoked. + """ + + @property + @abc.abstractmethod + def this_update(self) -> datetime.datetime: + """ + The most recent time at which the status being indicated is known by + the responder to have been correct + """ + + @property + @abc.abstractmethod + def next_update(self) -> datetime.datetime | None: + """ + The time when newer information will be available + """ + + @property + @abc.abstractmethod + def issuer_key_hash(self) -> bytes: + """ + The hash of the issuer public key + """ + + @property + @abc.abstractmethod + def issuer_name_hash(self) -> bytes: + """ + The hash of the issuer name + """ + + @property + @abc.abstractmethod + def hash_algorithm(self) -> hashes.HashAlgorithm: + """ + The hash algorithm used in the issuer name and key hashes + """ + + @property + @abc.abstractmethod + def serial_number(self) -> int: + """ + The serial number of the cert whose status is being checked + """ + + +class OCSPResponse(metaclass=abc.ABCMeta): + @property + @abc.abstractmethod + def responses(self) -> typing.Iterator[OCSPSingleResponse]: + """ + An iterator over the individual SINGLERESP structures in the + response + """ + + @property + @abc.abstractmethod + def response_status(self) -> OCSPResponseStatus: + """ + The status of the response. This is a value from the OCSPResponseStatus + enumeration + """ + + @property + @abc.abstractmethod + def signature_algorithm_oid(self) -> x509.ObjectIdentifier: + """ + The ObjectIdentifier of the signature algorithm + """ + + @property + @abc.abstractmethod + def signature_hash_algorithm( + self, + ) -> hashes.HashAlgorithm | None: + """ + Returns a HashAlgorithm corresponding to the type of the digest signed + """ + + @property + @abc.abstractmethod + def signature(self) -> bytes: + """ + The signature bytes + """ + + @property + @abc.abstractmethod + def tbs_response_bytes(self) -> bytes: + """ + The tbsResponseData bytes + """ + + @property + @abc.abstractmethod + def certificates(self) -> list[x509.Certificate]: + """ + A list of certificates used to help build a chain to verify the OCSP + response. This situation occurs when the OCSP responder uses a delegate + certificate. + """ + + @property + @abc.abstractmethod + def responder_key_hash(self) -> bytes | None: + """ + The responder's key hash or None + """ + + @property + @abc.abstractmethod + def responder_name(self) -> x509.Name | None: + """ + The responder's Name or None + """ + + @property + @abc.abstractmethod + def produced_at(self) -> datetime.datetime: + """ + The time the response was produced + """ + + @property + @abc.abstractmethod + def certificate_status(self) -> OCSPCertStatus: + """ + The status of the certificate (an element from the OCSPCertStatus enum) + """ + + @property + @abc.abstractmethod + def revocation_time(self) -> datetime.datetime | None: + """ + The date of when the certificate was revoked or None if not + revoked. + """ + + @property + @abc.abstractmethod + def revocation_reason(self) -> x509.ReasonFlags | None: + """ + The reason the certificate was revoked or None if not specified or + not revoked. + """ + + @property + @abc.abstractmethod + def this_update(self) -> datetime.datetime: + """ + The most recent time at which the status being indicated is known by + the responder to have been correct + """ + + @property + @abc.abstractmethod + def next_update(self) -> datetime.datetime | None: + """ + The time when newer information will be available + """ + + @property + @abc.abstractmethod + def issuer_key_hash(self) -> bytes: + """ + The hash of the issuer public key + """ + + @property + @abc.abstractmethod + def issuer_name_hash(self) -> bytes: + """ + The hash of the issuer name + """ + + @property + @abc.abstractmethod + def hash_algorithm(self) -> hashes.HashAlgorithm: + """ + The hash algorithm used in the issuer name and key hashes + """ + + @property + @abc.abstractmethod + def serial_number(self) -> int: + """ + The serial number of the cert whose status is being checked + """ + + @property + @abc.abstractmethod + def extensions(self) -> x509.Extensions: + """ + The list of response extensions. Not single response extensions. + """ + + @property + @abc.abstractmethod + def single_extensions(self) -> x509.Extensions: + """ + The list of single response extensions. Not response extensions. + """ + + @abc.abstractmethod + def public_bytes(self, encoding: serialization.Encoding) -> bytes: + """ + Serializes the response to DER + """ + + +class OCSPRequestBuilder: + def __init__( + self, + request: tuple[ + x509.Certificate, x509.Certificate, hashes.HashAlgorithm + ] + | None = None, + request_hash: tuple[bytes, bytes, int, hashes.HashAlgorithm] + | None = None, + extensions: list[x509.Extension[x509.ExtensionType]] = [], + ) -> None: + self._request = request + self._request_hash = request_hash + self._extensions = extensions + + def add_certificate( + self, + cert: x509.Certificate, + issuer: x509.Certificate, + algorithm: hashes.HashAlgorithm, + ) -> OCSPRequestBuilder: + if self._request is not None or self._request_hash is not None: + raise ValueError("Only one certificate can be added to a request") + + _verify_algorithm(algorithm) + if not isinstance(cert, x509.Certificate) or not isinstance( + issuer, x509.Certificate + ): + raise TypeError("cert and issuer must be a Certificate") + + return OCSPRequestBuilder( + (cert, issuer, algorithm), self._request_hash, self._extensions + ) + + def add_certificate_by_hash( + self, + issuer_name_hash: bytes, + issuer_key_hash: bytes, + serial_number: int, + algorithm: hashes.HashAlgorithm, + ) -> OCSPRequestBuilder: + if self._request is not None or self._request_hash is not None: + raise ValueError("Only one certificate can be added to a request") + + if not isinstance(serial_number, int): + raise TypeError("serial_number must be an integer") + + _verify_algorithm(algorithm) + utils._check_bytes("issuer_name_hash", issuer_name_hash) + utils._check_bytes("issuer_key_hash", issuer_key_hash) + if algorithm.digest_size != len( + issuer_name_hash + ) or algorithm.digest_size != len(issuer_key_hash): + raise ValueError( + "issuer_name_hash and issuer_key_hash must be the same length " + "as the digest size of the algorithm" + ) + + return OCSPRequestBuilder( + self._request, + (issuer_name_hash, issuer_key_hash, serial_number, algorithm), + self._extensions, + ) + + def add_extension( + self, extval: x509.ExtensionType, critical: bool + ) -> OCSPRequestBuilder: + if not isinstance(extval, x509.ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = x509.Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + + return OCSPRequestBuilder( + self._request, self._request_hash, [*self._extensions, extension] + ) + + def build(self) -> OCSPRequest: + if self._request is None and self._request_hash is None: + raise ValueError("You must add a certificate before building") + + return ocsp.create_ocsp_request(self) + + +class OCSPResponseBuilder: + def __init__( + self, + response: _SingleResponse | None = None, + responder_id: tuple[x509.Certificate, OCSPResponderEncoding] + | None = None, + certs: list[x509.Certificate] | None = None, + extensions: list[x509.Extension[x509.ExtensionType]] = [], + ): + self._response = response + self._responder_id = responder_id + self._certs = certs + self._extensions = extensions + + def add_response( + self, + cert: x509.Certificate, + issuer: x509.Certificate, + algorithm: hashes.HashAlgorithm, + cert_status: OCSPCertStatus, + this_update: datetime.datetime, + next_update: datetime.datetime | None, + revocation_time: datetime.datetime | None, + revocation_reason: x509.ReasonFlags | None, + ) -> OCSPResponseBuilder: + if self._response is not None: + raise ValueError("Only one response per OCSPResponse.") + + singleresp = _SingleResponse( + cert, + issuer, + algorithm, + cert_status, + this_update, + next_update, + revocation_time, + revocation_reason, + ) + return OCSPResponseBuilder( + singleresp, + self._responder_id, + self._certs, + self._extensions, + ) + + def responder_id( + self, encoding: OCSPResponderEncoding, responder_cert: x509.Certificate + ) -> OCSPResponseBuilder: + if self._responder_id is not None: + raise ValueError("responder_id can only be set once") + if not isinstance(responder_cert, x509.Certificate): + raise TypeError("responder_cert must be a Certificate") + if not isinstance(encoding, OCSPResponderEncoding): + raise TypeError( + "encoding must be an element from OCSPResponderEncoding" + ) + + return OCSPResponseBuilder( + self._response, + (responder_cert, encoding), + self._certs, + self._extensions, + ) + + def certificates( + self, certs: typing.Iterable[x509.Certificate] + ) -> OCSPResponseBuilder: + if self._certs is not None: + raise ValueError("certificates may only be set once") + certs = list(certs) + if len(certs) == 0: + raise ValueError("certs must not be an empty list") + if not all(isinstance(x, x509.Certificate) for x in certs): + raise TypeError("certs must be a list of Certificates") + return OCSPResponseBuilder( + self._response, + self._responder_id, + certs, + self._extensions, + ) + + def add_extension( + self, extval: x509.ExtensionType, critical: bool + ) -> OCSPResponseBuilder: + if not isinstance(extval, x509.ExtensionType): + raise TypeError("extension must be an ExtensionType") + + extension = x509.Extension(extval.oid, critical, extval) + _reject_duplicate_extension(extension, self._extensions) + + return OCSPResponseBuilder( + self._response, + self._responder_id, + self._certs, + [*self._extensions, extension], + ) + + def sign( + self, + private_key: CertificateIssuerPrivateKeyTypes, + algorithm: hashes.HashAlgorithm | None, + ) -> OCSPResponse: + if self._response is None: + raise ValueError("You must add a response before signing") + if self._responder_id is None: + raise ValueError("You must add a responder_id before signing") + + return ocsp.create_ocsp_response( + OCSPResponseStatus.SUCCESSFUL, self, private_key, algorithm + ) + + @classmethod + def build_unsuccessful( + cls, response_status: OCSPResponseStatus + ) -> OCSPResponse: + if not isinstance(response_status, OCSPResponseStatus): + raise TypeError( + "response_status must be an item from OCSPResponseStatus" + ) + if response_status is OCSPResponseStatus.SUCCESSFUL: + raise ValueError("response_status cannot be SUCCESSFUL") + + return ocsp.create_ocsp_response(response_status, None, None, None) + + +load_der_ocsp_request = ocsp.load_der_ocsp_request +load_der_ocsp_response = ocsp.load_der_ocsp_response diff --git a/.venv/Lib/site-packages/cryptography/x509/oid.py b/.venv/Lib/site-packages/cryptography/x509/oid.py new file mode 100644 index 00000000..cda50cce --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/oid.py @@ -0,0 +1,33 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +from cryptography.hazmat._oid import ( + AttributeOID, + AuthorityInformationAccessOID, + CertificatePoliciesOID, + CRLEntryExtensionOID, + ExtendedKeyUsageOID, + ExtensionOID, + NameOID, + ObjectIdentifier, + OCSPExtensionOID, + SignatureAlgorithmOID, + SubjectInformationAccessOID, +) + +__all__ = [ + "AttributeOID", + "AuthorityInformationAccessOID", + "CRLEntryExtensionOID", + "CertificatePoliciesOID", + "ExtendedKeyUsageOID", + "ExtensionOID", + "NameOID", + "OCSPExtensionOID", + "ObjectIdentifier", + "SignatureAlgorithmOID", + "SubjectInformationAccessOID", +] diff --git a/.venv/Lib/site-packages/cryptography/x509/verification.py b/.venv/Lib/site-packages/cryptography/x509/verification.py new file mode 100644 index 00000000..ab1a37ae --- /dev/null +++ b/.venv/Lib/site-packages/cryptography/x509/verification.py @@ -0,0 +1,24 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import annotations + +import typing + +from cryptography.hazmat.bindings._rust import x509 as rust_x509 +from cryptography.x509.general_name import DNSName, IPAddress + +__all__ = [ + "Store", + "Subject", + "ServerVerifier", + "PolicyBuilder", + "VerificationError", +] + +Store = rust_x509.Store +Subject = typing.Union[DNSName, IPAddress] +ServerVerifier = rust_x509.ServerVerifier +PolicyBuilder = rust_x509.PolicyBuilder +VerificationError = rust_x509.VerificationError diff --git a/.venv/Lib/site-packages/dotenv/__init__.py b/.venv/Lib/site-packages/dotenv/__init__.py new file mode 100644 index 00000000..7f4c631b --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/__init__.py @@ -0,0 +1,49 @@ +from typing import Any, Optional + +from .main import (dotenv_values, find_dotenv, get_key, load_dotenv, set_key, + unset_key) + + +def load_ipython_extension(ipython: Any) -> None: + from .ipython import load_ipython_extension + load_ipython_extension(ipython) + + +def get_cli_string( + path: Optional[str] = None, + action: Optional[str] = None, + key: Optional[str] = None, + value: Optional[str] = None, + quote: Optional[str] = None, +): + """Returns a string suitable for running as a shell script. + + Useful for converting a arguments passed to a fabric task + to be passed to a `local` or `run` command. + """ + command = ['dotenv'] + if quote: + command.append(f'-q {quote}') + if path: + command.append(f'-f {path}') + if action: + command.append(action) + if key: + command.append(key) + if value: + if ' ' in value: + command.append(f'"{value}"') + else: + command.append(value) + + return ' '.join(command).strip() + + +__all__ = ['get_cli_string', + 'load_dotenv', + 'dotenv_values', + 'get_key', + 'set_key', + 'unset_key', + 'find_dotenv', + 'load_ipython_extension'] diff --git a/.venv/Lib/site-packages/dotenv/__main__.py b/.venv/Lib/site-packages/dotenv/__main__.py new file mode 100644 index 00000000..3977f55a --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/__main__.py @@ -0,0 +1,6 @@ +"""Entry point for cli, enables execution with `python -m dotenv`""" + +from .cli import cli + +if __name__ == "__main__": + cli() diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..76edf420 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc new file mode 100644 index 00000000..859d8966 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/__main__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/cli.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/cli.cpython-311.pyc new file mode 100644 index 00000000..9005b323 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/cli.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc new file mode 100644 index 00000000..3c2548af Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/ipython.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/main.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/main.cpython-311.pyc new file mode 100644 index 00000000..a038f592 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/main.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/parser.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/parser.cpython-311.pyc new file mode 100644 index 00000000..7a372982 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/parser.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/variables.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/variables.cpython-311.pyc new file mode 100644 index 00000000..8a8be986 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/variables.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/__pycache__/version.cpython-311.pyc b/.venv/Lib/site-packages/dotenv/__pycache__/version.cpython-311.pyc new file mode 100644 index 00000000..1303db44 Binary files /dev/null and b/.venv/Lib/site-packages/dotenv/__pycache__/version.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/dotenv/cli.py b/.venv/Lib/site-packages/dotenv/cli.py new file mode 100644 index 00000000..65ead461 --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/cli.py @@ -0,0 +1,199 @@ +import json +import os +import shlex +import sys +from contextlib import contextmanager +from subprocess import Popen +from typing import Any, Dict, IO, Iterator, List + +try: + import click +except ImportError: + sys.stderr.write('It seems python-dotenv is not installed with cli option. \n' + 'Run pip install "python-dotenv[cli]" to fix this.') + sys.exit(1) + +from .main import dotenv_values, set_key, unset_key +from .version import __version__ + + +def enumerate_env(): + """ + Return a path for the ${pwd}/.env file. + + If pwd does not exist, return None. + """ + try: + cwd = os.getcwd() + except FileNotFoundError: + return None + path = os.path.join(cwd, '.env') + return path + + +@click.group() +@click.option('-f', '--file', default=enumerate_env(), + type=click.Path(file_okay=True), + help="Location of the .env file, defaults to .env file in current working directory.") +@click.option('-q', '--quote', default='always', + type=click.Choice(['always', 'never', 'auto']), + help="Whether to quote or not the variable values. Default mode is always. This does not affect parsing.") +@click.option('-e', '--export', default=False, + type=click.BOOL, + help="Whether to write the dot file as an executable bash script.") +@click.version_option(version=__version__) +@click.pass_context +def cli(ctx: click.Context, file: Any, quote: Any, export: Any) -> None: + """This script is used to set, get or unset values from a .env file.""" + ctx.obj = {'QUOTE': quote, 'EXPORT': export, 'FILE': file} + + +@contextmanager +def stream_file(path: os.PathLike) -> Iterator[IO[str]]: + """ + Open a file and yield the corresponding (decoded) stream. + + Exits with error code 2 if the file cannot be opened. + """ + + try: + with open(path) as stream: + yield stream + except OSError as exc: + print(f"Error opening env file: {exc}", file=sys.stderr) + exit(2) + + +@cli.command() +@click.pass_context +@click.option('--format', default='simple', + type=click.Choice(['simple', 'json', 'shell', 'export']), + help="The format in which to display the list. Default format is simple, " + "which displays name=value without quotes.") +def list(ctx: click.Context, format: bool) -> None: + """Display all the stored key/value.""" + file = ctx.obj['FILE'] + + with stream_file(file) as stream: + values = dotenv_values(stream=stream) + + if format == 'json': + click.echo(json.dumps(values, indent=2, sort_keys=True)) + else: + prefix = 'export ' if format == 'export' else '' + for k in sorted(values): + v = values[k] + if v is not None: + if format in ('export', 'shell'): + v = shlex.quote(v) + click.echo(f'{prefix}{k}={v}') + + +@cli.command() +@click.pass_context +@click.argument('key', required=True) +@click.argument('value', required=True) +def set(ctx: click.Context, key: Any, value: Any) -> None: + """Store the given key/value.""" + file = ctx.obj['FILE'] + quote = ctx.obj['QUOTE'] + export = ctx.obj['EXPORT'] + success, key, value = set_key(file, key, value, quote, export) + if success: + click.echo(f'{key}={value}') + else: + exit(1) + + +@cli.command() +@click.pass_context +@click.argument('key', required=True) +def get(ctx: click.Context, key: Any) -> None: + """Retrieve the value for the given key.""" + file = ctx.obj['FILE'] + + with stream_file(file) as stream: + values = dotenv_values(stream=stream) + + stored_value = values.get(key) + if stored_value: + click.echo(stored_value) + else: + exit(1) + + +@cli.command() +@click.pass_context +@click.argument('key', required=True) +def unset(ctx: click.Context, key: Any) -> None: + """Removes the given key.""" + file = ctx.obj['FILE'] + quote = ctx.obj['QUOTE'] + success, key = unset_key(file, key, quote) + if success: + click.echo(f"Successfully removed {key}") + else: + exit(1) + + +@cli.command(context_settings={'ignore_unknown_options': True}) +@click.pass_context +@click.option( + "--override/--no-override", + default=True, + help="Override variables from the environment file with those from the .env file.", +) +@click.argument('commandline', nargs=-1, type=click.UNPROCESSED) +def run(ctx: click.Context, override: bool, commandline: List[str]) -> None: + """Run command with environment variables present.""" + file = ctx.obj['FILE'] + if not os.path.isfile(file): + raise click.BadParameter( + f'Invalid value for \'-f\' "{file}" does not exist.', + ctx=ctx + ) + dotenv_as_dict = { + k: v + for (k, v) in dotenv_values(file).items() + if v is not None and (override or k not in os.environ) + } + + if not commandline: + click.echo('No command given.') + exit(1) + ret = run_command(commandline, dotenv_as_dict) + exit(ret) + + +def run_command(command: List[str], env: Dict[str, str]) -> int: + """Run command in sub process. + + Runs the command in a sub process with the variables from `env` + added in the current environment variables. + + Parameters + ---------- + command: List[str] + The command and it's parameters + env: Dict + The additional environment variables + + Returns + ------- + int + The return code of the command + + """ + # copy the current environment variables and add the vales from + # `env` + cmd_env = os.environ.copy() + cmd_env.update(env) + + p = Popen(command, + universal_newlines=True, + bufsize=0, + shell=False, + env=cmd_env) + _, _ = p.communicate() + + return p.returncode diff --git a/.venv/Lib/site-packages/dotenv/ipython.py b/.venv/Lib/site-packages/dotenv/ipython.py new file mode 100644 index 00000000..7df727cd --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/ipython.py @@ -0,0 +1,39 @@ +from IPython.core.magic import Magics, line_magic, magics_class # type: ignore +from IPython.core.magic_arguments import (argument, magic_arguments, # type: ignore + parse_argstring) # type: ignore + +from .main import find_dotenv, load_dotenv + + +@magics_class +class IPythonDotEnv(Magics): + + @magic_arguments() + @argument( + '-o', '--override', action='store_true', + help="Indicate to override existing variables" + ) + @argument( + '-v', '--verbose', action='store_true', + help="Indicate function calls to be verbose" + ) + @argument('dotenv_path', nargs='?', type=str, default='.env', + help='Search in increasingly higher folders for the `dotenv_path`') + @line_magic + def dotenv(self, line): + args = parse_argstring(self.dotenv, line) + # Locate the .env file + dotenv_path = args.dotenv_path + try: + dotenv_path = find_dotenv(dotenv_path, True, True) + except IOError: + print("cannot find .env file") + return + + # Load the .env file + load_dotenv(dotenv_path, verbose=args.verbose, override=args.override) + + +def load_ipython_extension(ipython): + """Register the %dotenv magic.""" + ipython.register_magics(IPythonDotEnv) diff --git a/.venv/Lib/site-packages/dotenv/main.py b/.venv/Lib/site-packages/dotenv/main.py new file mode 100644 index 00000000..7bc54285 --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/main.py @@ -0,0 +1,392 @@ +import io +import logging +import os +import pathlib +import shutil +import sys +import tempfile +from collections import OrderedDict +from contextlib import contextmanager +from typing import (IO, Dict, Iterable, Iterator, Mapping, Optional, Tuple, + Union) + +from .parser import Binding, parse_stream +from .variables import parse_variables + +# A type alias for a string path to be used for the paths in this file. +# These paths may flow to `open()` and `shutil.move()`; `shutil.move()` +# only accepts string paths, not byte paths or file descriptors. See +# https://github.com/python/typeshed/pull/6832. +StrPath = Union[str, 'os.PathLike[str]'] + +logger = logging.getLogger(__name__) + + +def with_warn_for_invalid_lines(mappings: Iterator[Binding]) -> Iterator[Binding]: + for mapping in mappings: + if mapping.error: + logger.warning( + "Python-dotenv could not parse statement starting at line %s", + mapping.original.line, + ) + yield mapping + + +class DotEnv: + def __init__( + self, + dotenv_path: Optional[StrPath], + stream: Optional[IO[str]] = None, + verbose: bool = False, + encoding: Optional[str] = None, + interpolate: bool = True, + override: bool = True, + ) -> None: + self.dotenv_path: Optional[StrPath] = dotenv_path + self.stream: Optional[IO[str]] = stream + self._dict: Optional[Dict[str, Optional[str]]] = None + self.verbose: bool = verbose + self.encoding: Optional[str] = encoding + self.interpolate: bool = interpolate + self.override: bool = override + + @contextmanager + def _get_stream(self) -> Iterator[IO[str]]: + if self.dotenv_path and os.path.isfile(self.dotenv_path): + with open(self.dotenv_path, encoding=self.encoding) as stream: + yield stream + elif self.stream is not None: + yield self.stream + else: + if self.verbose: + logger.info( + "Python-dotenv could not find configuration file %s.", + self.dotenv_path or '.env', + ) + yield io.StringIO('') + + def dict(self) -> Dict[str, Optional[str]]: + """Return dotenv as dict""" + if self._dict: + return self._dict + + raw_values = self.parse() + + if self.interpolate: + self._dict = OrderedDict(resolve_variables(raw_values, override=self.override)) + else: + self._dict = OrderedDict(raw_values) + + return self._dict + + def parse(self) -> Iterator[Tuple[str, Optional[str]]]: + with self._get_stream() as stream: + for mapping in with_warn_for_invalid_lines(parse_stream(stream)): + if mapping.key is not None: + yield mapping.key, mapping.value + + def set_as_environment_variables(self) -> bool: + """ + Load the current dotenv as system environment variable. + """ + if not self.dict(): + return False + + for k, v in self.dict().items(): + if k in os.environ and not self.override: + continue + if v is not None: + os.environ[k] = v + + return True + + def get(self, key: str) -> Optional[str]: + """ + """ + data = self.dict() + + if key in data: + return data[key] + + if self.verbose: + logger.warning("Key %s not found in %s.", key, self.dotenv_path) + + return None + + +def get_key( + dotenv_path: StrPath, + key_to_get: str, + encoding: Optional[str] = "utf-8", +) -> Optional[str]: + """ + Get the value of a given key from the given .env. + + Returns `None` if the key isn't found or doesn't have a value. + """ + return DotEnv(dotenv_path, verbose=True, encoding=encoding).get(key_to_get) + + +@contextmanager +def rewrite( + path: StrPath, + encoding: Optional[str], +) -> Iterator[Tuple[IO[str], IO[str]]]: + pathlib.Path(path).touch() + + with tempfile.NamedTemporaryFile(mode="w", encoding=encoding, delete=False) as dest: + error = None + try: + with open(path, encoding=encoding) as source: + yield (source, dest) + except BaseException as err: + error = err + + if error is None: + shutil.move(dest.name, path) + else: + os.unlink(dest.name) + raise error from None + + +def set_key( + dotenv_path: StrPath, + key_to_set: str, + value_to_set: str, + quote_mode: str = "always", + export: bool = False, + encoding: Optional[str] = "utf-8", +) -> Tuple[Optional[bool], str, str]: + """ + Adds or Updates a key/value to the given .env + + If the .env path given doesn't exist, fails instead of risking creating + an orphan .env somewhere in the filesystem + """ + if quote_mode not in ("always", "auto", "never"): + raise ValueError(f"Unknown quote_mode: {quote_mode}") + + quote = ( + quote_mode == "always" + or (quote_mode == "auto" and not value_to_set.isalnum()) + ) + + if quote: + value_out = "'{}'".format(value_to_set.replace("'", "\\'")) + else: + value_out = value_to_set + if export: + line_out = f'export {key_to_set}={value_out}\n' + else: + line_out = f"{key_to_set}={value_out}\n" + + with rewrite(dotenv_path, encoding=encoding) as (source, dest): + replaced = False + missing_newline = False + for mapping in with_warn_for_invalid_lines(parse_stream(source)): + if mapping.key == key_to_set: + dest.write(line_out) + replaced = True + else: + dest.write(mapping.original.string) + missing_newline = not mapping.original.string.endswith("\n") + if not replaced: + if missing_newline: + dest.write("\n") + dest.write(line_out) + + return True, key_to_set, value_to_set + + +def unset_key( + dotenv_path: StrPath, + key_to_unset: str, + quote_mode: str = "always", + encoding: Optional[str] = "utf-8", +) -> Tuple[Optional[bool], str]: + """ + Removes a given key from the given `.env` file. + + If the .env path given doesn't exist, fails. + If the given key doesn't exist in the .env, fails. + """ + if not os.path.exists(dotenv_path): + logger.warning("Can't delete from %s - it doesn't exist.", dotenv_path) + return None, key_to_unset + + removed = False + with rewrite(dotenv_path, encoding=encoding) as (source, dest): + for mapping in with_warn_for_invalid_lines(parse_stream(source)): + if mapping.key == key_to_unset: + removed = True + else: + dest.write(mapping.original.string) + + if not removed: + logger.warning("Key %s not removed from %s - key doesn't exist.", key_to_unset, dotenv_path) + return None, key_to_unset + + return removed, key_to_unset + + +def resolve_variables( + values: Iterable[Tuple[str, Optional[str]]], + override: bool, +) -> Mapping[str, Optional[str]]: + new_values: Dict[str, Optional[str]] = {} + + for (name, value) in values: + if value is None: + result = None + else: + atoms = parse_variables(value) + env: Dict[str, Optional[str]] = {} + if override: + env.update(os.environ) # type: ignore + env.update(new_values) + else: + env.update(new_values) + env.update(os.environ) # type: ignore + result = "".join(atom.resolve(env) for atom in atoms) + + new_values[name] = result + + return new_values + + +def _walk_to_root(path: str) -> Iterator[str]: + """ + Yield directories starting from the given directory up to the root + """ + if not os.path.exists(path): + raise IOError('Starting path not found') + + if os.path.isfile(path): + path = os.path.dirname(path) + + last_dir = None + current_dir = os.path.abspath(path) + while last_dir != current_dir: + yield current_dir + parent_dir = os.path.abspath(os.path.join(current_dir, os.path.pardir)) + last_dir, current_dir = current_dir, parent_dir + + +def find_dotenv( + filename: str = '.env', + raise_error_if_not_found: bool = False, + usecwd: bool = False, +) -> str: + """ + Search in increasingly higher folders for the given file + + Returns path to the file if found, or an empty string otherwise + """ + + def _is_interactive(): + """ Decide whether this is running in a REPL or IPython notebook """ + try: + main = __import__('__main__', None, None, fromlist=['__file__']) + except ModuleNotFoundError: + return False + return not hasattr(main, '__file__') + + if usecwd or _is_interactive() or getattr(sys, 'frozen', False): + # Should work without __file__, e.g. in REPL or IPython notebook. + path = os.getcwd() + else: + # will work for .py files + frame = sys._getframe() + current_file = __file__ + + while frame.f_code.co_filename == current_file or not os.path.exists( + frame.f_code.co_filename + ): + assert frame.f_back is not None + frame = frame.f_back + frame_filename = frame.f_code.co_filename + path = os.path.dirname(os.path.abspath(frame_filename)) + + for dirname in _walk_to_root(path): + check_path = os.path.join(dirname, filename) + if os.path.isfile(check_path): + return check_path + + if raise_error_if_not_found: + raise IOError('File not found') + + return '' + + +def load_dotenv( + dotenv_path: Optional[StrPath] = None, + stream: Optional[IO[str]] = None, + verbose: bool = False, + override: bool = False, + interpolate: bool = True, + encoding: Optional[str] = "utf-8", +) -> bool: + """Parse a .env file and then load all the variables found as environment variables. + + Parameters: + dotenv_path: Absolute or relative path to .env file. + stream: Text stream (such as `io.StringIO`) with .env content, used if + `dotenv_path` is `None`. + verbose: Whether to output a warning the .env file is missing. + override: Whether to override the system environment variables with the variables + from the `.env` file. + encoding: Encoding to be used to read the file. + Returns: + Bool: True if at least one environment variable is set else False + + If both `dotenv_path` and `stream` are `None`, `find_dotenv()` is used to find the + .env file. + """ + if dotenv_path is None and stream is None: + dotenv_path = find_dotenv() + + dotenv = DotEnv( + dotenv_path=dotenv_path, + stream=stream, + verbose=verbose, + interpolate=interpolate, + override=override, + encoding=encoding, + ) + return dotenv.set_as_environment_variables() + + +def dotenv_values( + dotenv_path: Optional[StrPath] = None, + stream: Optional[IO[str]] = None, + verbose: bool = False, + interpolate: bool = True, + encoding: Optional[str] = "utf-8", +) -> Dict[str, Optional[str]]: + """ + Parse a .env file and return its content as a dict. + + The returned dict will have `None` values for keys without values in the .env file. + For example, `foo=bar` results in `{"foo": "bar"}` whereas `foo` alone results in + `{"foo": None}` + + Parameters: + dotenv_path: Absolute or relative path to the .env file. + stream: `StringIO` object with .env content, used if `dotenv_path` is `None`. + verbose: Whether to output a warning if the .env file is missing. + encoding: Encoding to be used to read the file. + + If both `dotenv_path` and `stream` are `None`, `find_dotenv()` is used to find the + .env file. + """ + if dotenv_path is None and stream is None: + dotenv_path = find_dotenv() + + return DotEnv( + dotenv_path=dotenv_path, + stream=stream, + verbose=verbose, + interpolate=interpolate, + override=True, + encoding=encoding, + ).dict() diff --git a/.venv/Lib/site-packages/dotenv/parser.py b/.venv/Lib/site-packages/dotenv/parser.py new file mode 100644 index 00000000..735f14a3 --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/parser.py @@ -0,0 +1,175 @@ +import codecs +import re +from typing import (IO, Iterator, Match, NamedTuple, Optional, # noqa:F401 + Pattern, Sequence, Tuple) + + +def make_regex(string: str, extra_flags: int = 0) -> Pattern[str]: + return re.compile(string, re.UNICODE | extra_flags) + + +_newline = make_regex(r"(\r\n|\n|\r)") +_multiline_whitespace = make_regex(r"\s*", extra_flags=re.MULTILINE) +_whitespace = make_regex(r"[^\S\r\n]*") +_export = make_regex(r"(?:export[^\S\r\n]+)?") +_single_quoted_key = make_regex(r"'([^']+)'") +_unquoted_key = make_regex(r"([^=\#\s]+)") +_equal_sign = make_regex(r"(=[^\S\r\n]*)") +_single_quoted_value = make_regex(r"'((?:\\'|[^'])*)'") +_double_quoted_value = make_regex(r'"((?:\\"|[^"])*)"') +_unquoted_value = make_regex(r"([^\r\n]*)") +_comment = make_regex(r"(?:[^\S\r\n]*#[^\r\n]*)?") +_end_of_line = make_regex(r"[^\S\r\n]*(?:\r\n|\n|\r|$)") +_rest_of_line = make_regex(r"[^\r\n]*(?:\r|\n|\r\n)?") +_double_quote_escapes = make_regex(r"\\[\\'\"abfnrtv]") +_single_quote_escapes = make_regex(r"\\[\\']") + + +class Original(NamedTuple): + string: str + line: int + + +class Binding(NamedTuple): + key: Optional[str] + value: Optional[str] + original: Original + error: bool + + +class Position: + def __init__(self, chars: int, line: int) -> None: + self.chars = chars + self.line = line + + @classmethod + def start(cls) -> "Position": + return cls(chars=0, line=1) + + def set(self, other: "Position") -> None: + self.chars = other.chars + self.line = other.line + + def advance(self, string: str) -> None: + self.chars += len(string) + self.line += len(re.findall(_newline, string)) + + +class Error(Exception): + pass + + +class Reader: + def __init__(self, stream: IO[str]) -> None: + self.string = stream.read() + self.position = Position.start() + self.mark = Position.start() + + def has_next(self) -> bool: + return self.position.chars < len(self.string) + + def set_mark(self) -> None: + self.mark.set(self.position) + + def get_marked(self) -> Original: + return Original( + string=self.string[self.mark.chars:self.position.chars], + line=self.mark.line, + ) + + def peek(self, count: int) -> str: + return self.string[self.position.chars:self.position.chars + count] + + def read(self, count: int) -> str: + result = self.string[self.position.chars:self.position.chars + count] + if len(result) < count: + raise Error("read: End of string") + self.position.advance(result) + return result + + def read_regex(self, regex: Pattern[str]) -> Sequence[str]: + match = regex.match(self.string, self.position.chars) + if match is None: + raise Error("read_regex: Pattern not found") + self.position.advance(self.string[match.start():match.end()]) + return match.groups() + + +def decode_escapes(regex: Pattern[str], string: str) -> str: + def decode_match(match: Match[str]) -> str: + return codecs.decode(match.group(0), 'unicode-escape') # type: ignore + + return regex.sub(decode_match, string) + + +def parse_key(reader: Reader) -> Optional[str]: + char = reader.peek(1) + if char == "#": + return None + elif char == "'": + (key,) = reader.read_regex(_single_quoted_key) + else: + (key,) = reader.read_regex(_unquoted_key) + return key + + +def parse_unquoted_value(reader: Reader) -> str: + (part,) = reader.read_regex(_unquoted_value) + return re.sub(r"\s+#.*", "", part).rstrip() + + +def parse_value(reader: Reader) -> str: + char = reader.peek(1) + if char == u"'": + (value,) = reader.read_regex(_single_quoted_value) + return decode_escapes(_single_quote_escapes, value) + elif char == u'"': + (value,) = reader.read_regex(_double_quoted_value) + return decode_escapes(_double_quote_escapes, value) + elif char in (u"", u"\n", u"\r"): + return u"" + else: + return parse_unquoted_value(reader) + + +def parse_binding(reader: Reader) -> Binding: + reader.set_mark() + try: + reader.read_regex(_multiline_whitespace) + if not reader.has_next(): + return Binding( + key=None, + value=None, + original=reader.get_marked(), + error=False, + ) + reader.read_regex(_export) + key = parse_key(reader) + reader.read_regex(_whitespace) + if reader.peek(1) == "=": + reader.read_regex(_equal_sign) + value: Optional[str] = parse_value(reader) + else: + value = None + reader.read_regex(_comment) + reader.read_regex(_end_of_line) + return Binding( + key=key, + value=value, + original=reader.get_marked(), + error=False, + ) + except Error: + reader.read_regex(_rest_of_line) + return Binding( + key=None, + value=None, + original=reader.get_marked(), + error=True, + ) + + +def parse_stream(stream: IO[str]) -> Iterator[Binding]: + reader = Reader(stream) + while reader.has_next(): + yield parse_binding(reader) diff --git a/.venv/Lib/site-packages/dotenv/py.typed b/.venv/Lib/site-packages/dotenv/py.typed new file mode 100644 index 00000000..7632ecf7 --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 diff --git a/.venv/Lib/site-packages/dotenv/variables.py b/.venv/Lib/site-packages/dotenv/variables.py new file mode 100644 index 00000000..667f2f26 --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/variables.py @@ -0,0 +1,86 @@ +import re +from abc import ABCMeta, abstractmethod +from typing import Iterator, Mapping, Optional, Pattern + +_posix_variable: Pattern[str] = re.compile( + r""" + \$\{ + (?P[^\}:]*) + (?::- + (?P[^\}]*) + )? + \} + """, + re.VERBOSE, +) + + +class Atom(metaclass=ABCMeta): + def __ne__(self, other: object) -> bool: + result = self.__eq__(other) + if result is NotImplemented: + return NotImplemented + return not result + + @abstractmethod + def resolve(self, env: Mapping[str, Optional[str]]) -> str: ... + + +class Literal(Atom): + def __init__(self, value: str) -> None: + self.value = value + + def __repr__(self) -> str: + return f"Literal(value={self.value})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return self.value == other.value + + def __hash__(self) -> int: + return hash((self.__class__, self.value)) + + def resolve(self, env: Mapping[str, Optional[str]]) -> str: + return self.value + + +class Variable(Atom): + def __init__(self, name: str, default: Optional[str]) -> None: + self.name = name + self.default = default + + def __repr__(self) -> str: + return f"Variable(name={self.name}, default={self.default})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return (self.name, self.default) == (other.name, other.default) + + def __hash__(self) -> int: + return hash((self.__class__, self.name, self.default)) + + def resolve(self, env: Mapping[str, Optional[str]]) -> str: + default = self.default if self.default is not None else "" + result = env.get(self.name, default) + return result if result is not None else "" + + +def parse_variables(value: str) -> Iterator[Atom]: + cursor = 0 + + for match in _posix_variable.finditer(value): + (start, end) = match.span() + name = match["name"] + default = match["default"] + + if start > cursor: + yield Literal(value=value[cursor:start]) + + yield Variable(name=name, default=default) + cursor = end + + length = len(value) + if cursor < length: + yield Literal(value=value[cursor:length]) diff --git a/.venv/Lib/site-packages/dotenv/version.py b/.venv/Lib/site-packages/dotenv/version.py new file mode 100644 index 00000000..5c4105cd --- /dev/null +++ b/.venv/Lib/site-packages/dotenv/version.py @@ -0,0 +1 @@ +__version__ = "1.0.1" diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/INSTALLER b/.venv/Lib/site-packages/pycparser-2.21.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/LICENSE b/.venv/Lib/site-packages/pycparser-2.21.dist-info/LICENSE new file mode 100644 index 00000000..ea215f2d --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/LICENSE @@ -0,0 +1,27 @@ +pycparser -- A C parser in Python + +Copyright (c) 2008-2020, Eli Bendersky +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. +* Neither the name of Eli Bendersky nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/METADATA b/.venv/Lib/site-packages/pycparser-2.21.dist-info/METADATA new file mode 100644 index 00000000..1d0fbd65 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/METADATA @@ -0,0 +1,31 @@ +Metadata-Version: 2.1 +Name: pycparser +Version: 2.21 +Summary: C parser in Python +Home-page: https://github.com/eliben/pycparser +Author: Eli Bendersky +Author-email: eliben@gmail.com +Maintainer: Eli Bendersky +License: BSD +Platform: Cross Platform +Classifier: Development Status :: 5 - Production/Stable +Classifier: License :: OSI Approved :: BSD License +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.* + + +pycparser is a complete parser of the C language, written in +pure Python using the PLY parsing library. +It parses C code into an AST and can serve as a front-end for +C compilers or analysis tools. + + diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/RECORD b/.venv/Lib/site-packages/pycparser-2.21.dist-info/RECORD new file mode 100644 index 00000000..20497073 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/RECORD @@ -0,0 +1,41 @@ +pycparser-2.21.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +pycparser-2.21.dist-info/LICENSE,sha256=Pn3yW437ZYyakVAZMNTZQ7BQh6g0fH4rQyVhavU1BHs,1536 +pycparser-2.21.dist-info/METADATA,sha256=GvTEQA9yKj0nvP4mknfoGpMvjaJXCQjQANcQHrRrAxc,1108 +pycparser-2.21.dist-info/RECORD,, +pycparser-2.21.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110 +pycparser-2.21.dist-info/top_level.txt,sha256=c-lPcS74L_8KoH7IE6PQF5ofyirRQNV4VhkbSFIPeWM,10 +pycparser/__init__.py,sha256=WUEp5D0fuHBH9Q8c1fYvR2eKWfj-CNghLf2MMlQLI1I,2815 +pycparser/__pycache__/__init__.cpython-311.pyc,, +pycparser/__pycache__/_ast_gen.cpython-311.pyc,, +pycparser/__pycache__/_build_tables.cpython-311.pyc,, +pycparser/__pycache__/ast_transforms.cpython-311.pyc,, +pycparser/__pycache__/c_ast.cpython-311.pyc,, +pycparser/__pycache__/c_generator.cpython-311.pyc,, +pycparser/__pycache__/c_lexer.cpython-311.pyc,, +pycparser/__pycache__/c_parser.cpython-311.pyc,, +pycparser/__pycache__/lextab.cpython-311.pyc,, +pycparser/__pycache__/plyparser.cpython-311.pyc,, +pycparser/__pycache__/yacctab.cpython-311.pyc,, +pycparser/_ast_gen.py,sha256=0JRVnDW-Jw-3IjVlo8je9rbAcp6Ko7toHAnB5zi7h0Q,10555 +pycparser/_build_tables.py,sha256=oZCd3Plhq-vkV-QuEsaahcf-jUI6-HgKsrAL9gvFzuU,1039 +pycparser/_c_ast.cfg,sha256=ld5ezE9yzIJFIVAUfw7ezJSlMi4nXKNCzfmqjOyQTNo,4255 +pycparser/ast_transforms.py,sha256=GTMYlUgWmXd5wJVyovXY1qzzAqjxzCpVVg0664dKGBs,5691 +pycparser/c_ast.py,sha256=HWeOrfYdCY0u5XaYhE1i60uVyE3yMWdcxzECUX-DqJw,31445 +pycparser/c_generator.py,sha256=yi6Mcqxv88J5ue8k5-mVGxh3iJ37iD4QyF-sWcGjC-8,17772 +pycparser/c_lexer.py,sha256=xCpjIb6vOUebBJpdifidb08y7XgAsO3T1gNGXJT93-w,17167 +pycparser/c_parser.py,sha256=_8y3i52bL6SUK21KmEEl0qzHxe-0eZRzjZGkWg8gQ4A,73680 +pycparser/lextab.py,sha256=fIxBAHYRC418oKF52M7xb8_KMj3K-tHx0TzZiKwxjPM,8504 +pycparser/ply/__init__.py,sha256=q4s86QwRsYRa20L9ueSxfh-hPihpftBjDOvYa2_SS2Y,102 +pycparser/ply/__pycache__/__init__.cpython-311.pyc,, +pycparser/ply/__pycache__/cpp.cpython-311.pyc,, +pycparser/ply/__pycache__/ctokens.cpython-311.pyc,, +pycparser/ply/__pycache__/lex.cpython-311.pyc,, +pycparser/ply/__pycache__/yacc.cpython-311.pyc,, +pycparser/ply/__pycache__/ygen.cpython-311.pyc,, +pycparser/ply/cpp.py,sha256=UtC3ylTWp5_1MKA-PLCuwKQR8zSOnlGuGGIdzj8xS98,33282 +pycparser/ply/ctokens.py,sha256=MKksnN40TehPhgVfxCJhjj_BjL943apreABKYz-bl0Y,3177 +pycparser/ply/lex.py,sha256=7Qol57x702HZwjA3ZLp-84CUEWq1EehW-N67Wzghi-M,42918 +pycparser/ply/yacc.py,sha256=eatSDkRLgRr6X3-hoDk_SQQv065R0BdL2K7fQ54CgVM,137323 +pycparser/ply/ygen.py,sha256=2JYNeYtrPz1JzLSLO3d4GsS8zJU8jY_I_CR1VI9gWrA,2251 +pycparser/plyparser.py,sha256=8tLOoEytcapvWrr1JfCf7Dog-wulBtS1YrDs8S7JfMo,4875 +pycparser/yacctab.py,sha256=j_fVNIyDWDRVk7eWMqQtlBw2AwUSV5JTrtT58l7zis0,205652 diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/WHEEL b/.venv/Lib/site-packages/pycparser-2.21.dist-info/WHEEL new file mode 100644 index 00000000..ef99c6cf --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.34.2) +Root-Is-Purelib: true +Tag: py2-none-any +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/pycparser-2.21.dist-info/top_level.txt b/.venv/Lib/site-packages/pycparser-2.21.dist-info/top_level.txt new file mode 100644 index 00000000..dc1c9e10 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser-2.21.dist-info/top_level.txt @@ -0,0 +1 @@ +pycparser diff --git a/.venv/Lib/site-packages/pycparser/__init__.py b/.venv/Lib/site-packages/pycparser/__init__.py new file mode 100644 index 00000000..d82eb2d6 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/__init__.py @@ -0,0 +1,90 @@ +#----------------------------------------------------------------- +# pycparser: __init__.py +# +# This package file exports some convenience functions for +# interacting with pycparser +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- +__all__ = ['c_lexer', 'c_parser', 'c_ast'] +__version__ = '2.21' + +import io +from subprocess import check_output +from .c_parser import CParser + + +def preprocess_file(filename, cpp_path='cpp', cpp_args=''): + """ Preprocess a file using cpp. + + filename: + Name of the file you want to preprocess. + + cpp_path: + cpp_args: + Refer to the documentation of parse_file for the meaning of these + arguments. + + When successful, returns the preprocessed file's contents. + Errors from cpp will be printed out. + """ + path_list = [cpp_path] + if isinstance(cpp_args, list): + path_list += cpp_args + elif cpp_args != '': + path_list += [cpp_args] + path_list += [filename] + + try: + # Note the use of universal_newlines to treat all newlines + # as \n for Python's purpose + text = check_output(path_list, universal_newlines=True) + except OSError as e: + raise RuntimeError("Unable to invoke 'cpp'. " + + 'Make sure its path was passed correctly\n' + + ('Original error: %s' % e)) + + return text + + +def parse_file(filename, use_cpp=False, cpp_path='cpp', cpp_args='', + parser=None): + """ Parse a C file using pycparser. + + filename: + Name of the file you want to parse. + + use_cpp: + Set to True if you want to execute the C pre-processor + on the file prior to parsing it. + + cpp_path: + If use_cpp is True, this is the path to 'cpp' on your + system. If no path is provided, it attempts to just + execute 'cpp', so it must be in your PATH. + + cpp_args: + If use_cpp is True, set this to the command line arguments strings + to cpp. Be careful with quotes - it's best to pass a raw string + (r'') here. For example: + r'-I../utils/fake_libc_include' + If several arguments are required, pass a list of strings. + + parser: + Optional parser object to be used instead of the default CParser + + When successful, an AST is returned. ParseError can be + thrown if the file doesn't parse successfully. + + Errors from cpp will be printed out. + """ + if use_cpp: + text = preprocess_file(filename, cpp_path, cpp_args) + else: + with io.open(filename) as f: + text = f.read() + + if parser is None: + parser = CParser() + return parser.parse(text, filename) diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b9ceb438 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/_ast_gen.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/_ast_gen.cpython-311.pyc new file mode 100644 index 00000000..42498397 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/_ast_gen.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/_build_tables.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/_build_tables.cpython-311.pyc new file mode 100644 index 00000000..7086bc0d Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/_build_tables.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/ast_transforms.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/ast_transforms.cpython-311.pyc new file mode 100644 index 00000000..66172b64 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/ast_transforms.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/c_ast.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/c_ast.cpython-311.pyc new file mode 100644 index 00000000..6e7fbbe4 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/c_ast.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/c_generator.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/c_generator.cpython-311.pyc new file mode 100644 index 00000000..30750e37 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/c_generator.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/c_lexer.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/c_lexer.cpython-311.pyc new file mode 100644 index 00000000..a1e94ea4 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/c_lexer.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/c_parser.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/c_parser.cpython-311.pyc new file mode 100644 index 00000000..6be24650 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/c_parser.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/lextab.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/lextab.cpython-311.pyc new file mode 100644 index 00000000..4681c523 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/lextab.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/plyparser.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/plyparser.cpython-311.pyc new file mode 100644 index 00000000..07d809ae Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/plyparser.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/__pycache__/yacctab.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/__pycache__/yacctab.cpython-311.pyc new file mode 100644 index 00000000..e7417b00 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/__pycache__/yacctab.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/_ast_gen.py b/.venv/Lib/site-packages/pycparser/_ast_gen.py new file mode 100644 index 00000000..0f7d330b --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/_ast_gen.py @@ -0,0 +1,336 @@ +#----------------------------------------------------------------- +# _ast_gen.py +# +# Generates the AST Node classes from a specification given in +# a configuration file +# +# The design of this module was inspired by astgen.py from the +# Python 2.5 code-base. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- +from string import Template + + +class ASTCodeGenerator(object): + def __init__(self, cfg_filename='_c_ast.cfg'): + """ Initialize the code generator from a configuration + file. + """ + self.cfg_filename = cfg_filename + self.node_cfg = [NodeCfg(name, contents) + for (name, contents) in self.parse_cfgfile(cfg_filename)] + + def generate(self, file=None): + """ Generates the code into file, an open file buffer. + """ + src = Template(_PROLOGUE_COMMENT).substitute( + cfg_filename=self.cfg_filename) + + src += _PROLOGUE_CODE + for node_cfg in self.node_cfg: + src += node_cfg.generate_source() + '\n\n' + + file.write(src) + + def parse_cfgfile(self, filename): + """ Parse the configuration file and yield pairs of + (name, contents) for each node. + """ + with open(filename, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + colon_i = line.find(':') + lbracket_i = line.find('[') + rbracket_i = line.find(']') + if colon_i < 1 or lbracket_i <= colon_i or rbracket_i <= lbracket_i: + raise RuntimeError("Invalid line in %s:\n%s\n" % (filename, line)) + + name = line[:colon_i] + val = line[lbracket_i + 1:rbracket_i] + vallist = [v.strip() for v in val.split(',')] if val else [] + yield name, vallist + + +class NodeCfg(object): + """ Node configuration. + + name: node name + contents: a list of contents - attributes and child nodes + See comment at the top of the configuration file for details. + """ + + def __init__(self, name, contents): + self.name = name + self.all_entries = [] + self.attr = [] + self.child = [] + self.seq_child = [] + + for entry in contents: + clean_entry = entry.rstrip('*') + self.all_entries.append(clean_entry) + + if entry.endswith('**'): + self.seq_child.append(clean_entry) + elif entry.endswith('*'): + self.child.append(clean_entry) + else: + self.attr.append(entry) + + def generate_source(self): + src = self._gen_init() + src += '\n' + self._gen_children() + src += '\n' + self._gen_iter() + src += '\n' + self._gen_attr_names() + return src + + def _gen_init(self): + src = "class %s(Node):\n" % self.name + + if self.all_entries: + args = ', '.join(self.all_entries) + slots = ', '.join("'{0}'".format(e) for e in self.all_entries) + slots += ", 'coord', '__weakref__'" + arglist = '(self, %s, coord=None)' % args + else: + slots = "'coord', '__weakref__'" + arglist = '(self, coord=None)' + + src += " __slots__ = (%s)\n" % slots + src += " def __init__%s:\n" % arglist + + for name in self.all_entries + ['coord']: + src += " self.%s = %s\n" % (name, name) + + return src + + def _gen_children(self): + src = ' def children(self):\n' + + if self.all_entries: + src += ' nodelist = []\n' + + for child in self.child: + src += ( + ' if self.%(child)s is not None:' + + ' nodelist.append(("%(child)s", self.%(child)s))\n') % ( + dict(child=child)) + + for seq_child in self.seq_child: + src += ( + ' for i, child in enumerate(self.%(child)s or []):\n' + ' nodelist.append(("%(child)s[%%d]" %% i, child))\n') % ( + dict(child=seq_child)) + + src += ' return tuple(nodelist)\n' + else: + src += ' return ()\n' + + return src + + def _gen_iter(self): + src = ' def __iter__(self):\n' + + if self.all_entries: + for child in self.child: + src += ( + ' if self.%(child)s is not None:\n' + + ' yield self.%(child)s\n') % (dict(child=child)) + + for seq_child in self.seq_child: + src += ( + ' for child in (self.%(child)s or []):\n' + ' yield child\n') % (dict(child=seq_child)) + + if not (self.child or self.seq_child): + # Empty generator + src += ( + ' return\n' + + ' yield\n') + else: + # Empty generator + src += ( + ' return\n' + + ' yield\n') + + return src + + def _gen_attr_names(self): + src = " attr_names = (" + ''.join("%r, " % nm for nm in self.attr) + ')' + return src + + +_PROLOGUE_COMMENT = \ +r'''#----------------------------------------------------------------- +# ** ATTENTION ** +# This code was automatically generated from the file: +# $cfg_filename +# +# Do not modify it directly. Modify the configuration file and +# run the generator again. +# ** ** *** ** ** +# +# pycparser: c_ast.py +# +# AST Node classes. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- + +''' + +_PROLOGUE_CODE = r''' +import sys + +def _repr(obj): + """ + Get the representation of an object, with dedicated pprint-like format for lists. + """ + if isinstance(obj, list): + return '[' + (',\n '.join((_repr(e).replace('\n', '\n ') for e in obj))) + '\n]' + else: + return repr(obj) + +class Node(object): + __slots__ = () + """ Abstract base class for AST nodes. + """ + def __repr__(self): + """ Generates a python representation of the current node + """ + result = self.__class__.__name__ + '(' + + indent = '' + separator = '' + for name in self.__slots__[:-2]: + result += separator + result += indent + result += name + '=' + (_repr(getattr(self, name)).replace('\n', '\n ' + (' ' * (len(name) + len(self.__class__.__name__))))) + + separator = ',' + indent = '\n ' + (' ' * len(self.__class__.__name__)) + + result += indent + ')' + + return result + + def children(self): + """ A sequence of all children that are Nodes + """ + pass + + def show(self, buf=sys.stdout, offset=0, attrnames=False, nodenames=False, showcoord=False, _my_node_name=None): + """ Pretty print the Node and all its attributes and + children (recursively) to a buffer. + + buf: + Open IO buffer into which the Node is printed. + + offset: + Initial offset (amount of leading spaces) + + attrnames: + True if you want to see the attribute names in + name=value pairs. False to only see the values. + + nodenames: + True if you want to see the actual node names + within their parents. + + showcoord: + Do you want the coordinates of each Node to be + displayed. + """ + lead = ' ' * offset + if nodenames and _my_node_name is not None: + buf.write(lead + self.__class__.__name__+ ' <' + _my_node_name + '>: ') + else: + buf.write(lead + self.__class__.__name__+ ': ') + + if self.attr_names: + if attrnames: + nvlist = [(n, getattr(self,n)) for n in self.attr_names] + attrstr = ', '.join('%s=%s' % nv for nv in nvlist) + else: + vlist = [getattr(self, n) for n in self.attr_names] + attrstr = ', '.join('%s' % v for v in vlist) + buf.write(attrstr) + + if showcoord: + buf.write(' (at %s)' % self.coord) + buf.write('\n') + + for (child_name, child) in self.children(): + child.show( + buf, + offset=offset + 2, + attrnames=attrnames, + nodenames=nodenames, + showcoord=showcoord, + _my_node_name=child_name) + + +class NodeVisitor(object): + """ A base NodeVisitor class for visiting c_ast nodes. + Subclass it and define your own visit_XXX methods, where + XXX is the class name you want to visit with these + methods. + + For example: + + class ConstantVisitor(NodeVisitor): + def __init__(self): + self.values = [] + + def visit_Constant(self, node): + self.values.append(node.value) + + Creates a list of values of all the constant nodes + encountered below the given node. To use it: + + cv = ConstantVisitor() + cv.visit(node) + + Notes: + + * generic_visit() will be called for AST nodes for which + no visit_XXX method was defined. + * The children of nodes for which a visit_XXX was + defined will not be visited - if you need this, call + generic_visit() on the node. + You can use: + NodeVisitor.generic_visit(self, node) + * Modeled after Python's own AST visiting facilities + (the ast module of Python 3.0) + """ + + _method_cache = None + + def visit(self, node): + """ Visit a node. + """ + + if self._method_cache is None: + self._method_cache = {} + + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node) + + def generic_visit(self, node): + """ Called if no explicit visitor function exists for a + node. Implements preorder visiting of the node. + """ + for c in node: + self.visit(c) + +''' diff --git a/.venv/Lib/site-packages/pycparser/_build_tables.py b/.venv/Lib/site-packages/pycparser/_build_tables.py new file mode 100644 index 00000000..958381ad --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/_build_tables.py @@ -0,0 +1,37 @@ +#----------------------------------------------------------------- +# pycparser: _build_tables.py +# +# A dummy for generating the lexing/parsing tables and and +# compiling them into .pyc for faster execution in optimized mode. +# Also generates AST code from the configuration file. +# Should be called from the pycparser directory. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- + +# Insert '.' and '..' as first entries to the search path for modules. +# Restricted environments like embeddable python do not include the +# current working directory on startup. +import sys +sys.path[0:0] = ['.', '..'] + +# Generate c_ast.py +from _ast_gen import ASTCodeGenerator +ast_gen = ASTCodeGenerator('_c_ast.cfg') +ast_gen.generate(open('c_ast.py', 'w')) + +from pycparser import c_parser + +# Generates the tables +# +c_parser.CParser( + lex_optimize=True, + yacc_debug=False, + yacc_optimize=True) + +# Load to compile into .pyc +# +import lextab +import yacctab +import c_ast diff --git a/.venv/Lib/site-packages/pycparser/_c_ast.cfg b/.venv/Lib/site-packages/pycparser/_c_ast.cfg new file mode 100644 index 00000000..0626533e --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/_c_ast.cfg @@ -0,0 +1,195 @@ +#----------------------------------------------------------------- +# pycparser: _c_ast.cfg +# +# Defines the AST Node classes used in pycparser. +# +# Each entry is a Node sub-class name, listing the attributes +# and child nodes of the class: +# * - a child node +# ** - a sequence of child nodes +# - an attribute +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- + +# ArrayDecl is a nested declaration of an array with the given type. +# dim: the dimension (for example, constant 42) +# dim_quals: list of dimension qualifiers, to support C99's allowing 'const' +# and 'static' within the array dimension in function declarations. +ArrayDecl: [type*, dim*, dim_quals] + +ArrayRef: [name*, subscript*] + +# op: =, +=, /= etc. +# +Assignment: [op, lvalue*, rvalue*] + +Alignas: [alignment*] + +BinaryOp: [op, left*, right*] + +Break: [] + +Case: [expr*, stmts**] + +Cast: [to_type*, expr*] + +# Compound statement in C99 is a list of block items (declarations or +# statements). +# +Compound: [block_items**] + +# Compound literal (anonymous aggregate) for C99. +# (type-name) {initializer_list} +# type: the typename +# init: InitList for the initializer list +# +CompoundLiteral: [type*, init*] + +# type: int, char, float, string, etc. +# +Constant: [type, value] + +Continue: [] + +# name: the variable being declared +# quals: list of qualifiers (const, volatile) +# funcspec: list function specifiers (i.e. inline in C99) +# storage: list of storage specifiers (extern, register, etc.) +# type: declaration type (probably nested with all the modifiers) +# init: initialization value, or None +# bitsize: bit field size, or None +# +Decl: [name, quals, align, storage, funcspec, type*, init*, bitsize*] + +DeclList: [decls**] + +Default: [stmts**] + +DoWhile: [cond*, stmt*] + +# Represents the ellipsis (...) parameter in a function +# declaration +# +EllipsisParam: [] + +# An empty statement (a semicolon ';' on its own) +# +EmptyStatement: [] + +# Enumeration type specifier +# name: an optional ID +# values: an EnumeratorList +# +Enum: [name, values*] + +# A name/value pair for enumeration values +# +Enumerator: [name, value*] + +# A list of enumerators +# +EnumeratorList: [enumerators**] + +# A list of expressions separated by the comma operator. +# +ExprList: [exprs**] + +# This is the top of the AST, representing a single C file (a +# translation unit in K&R jargon). It contains a list of +# "external-declaration"s, which is either declarations (Decl), +# Typedef or function definitions (FuncDef). +# +FileAST: [ext**] + +# for (init; cond; next) stmt +# +For: [init*, cond*, next*, stmt*] + +# name: Id +# args: ExprList +# +FuncCall: [name*, args*] + +# type (args) +# +FuncDecl: [args*, type*] + +# Function definition: a declarator for the function name and +# a body, which is a compound statement. +# There's an optional list of parameter declarations for old +# K&R-style definitions +# +FuncDef: [decl*, param_decls**, body*] + +Goto: [name] + +ID: [name] + +# Holder for types that are a simple identifier (e.g. the built +# ins void, char etc. and typedef-defined types) +# +IdentifierType: [names] + +If: [cond*, iftrue*, iffalse*] + +# An initialization list used for compound literals. +# +InitList: [exprs**] + +Label: [name, stmt*] + +# A named initializer for C99. +# The name of a NamedInitializer is a sequence of Nodes, because +# names can be hierarchical and contain constant expressions. +# +NamedInitializer: [name**, expr*] + +# a list of comma separated function parameter declarations +# +ParamList: [params**] + +PtrDecl: [quals, type*] + +Return: [expr*] + +StaticAssert: [cond*, message*] + +# name: struct tag name +# decls: declaration of members +# +Struct: [name, decls**] + +# type: . or -> +# name.field or name->field +# +StructRef: [name*, type, field*] + +Switch: [cond*, stmt*] + +# cond ? iftrue : iffalse +# +TernaryOp: [cond*, iftrue*, iffalse*] + +# A base type declaration +# +TypeDecl: [declname, quals, align, type*] + +# A typedef declaration. +# Very similar to Decl, but without some attributes +# +Typedef: [name, quals, storage, type*] + +Typename: [name, quals, align, type*] + +UnaryOp: [op, expr*] + +# name: union tag name +# decls: declaration of members +# +Union: [name, decls**] + +While: [cond*, stmt*] + +Pragma: [string] diff --git a/.venv/Lib/site-packages/pycparser/ast_transforms.py b/.venv/Lib/site-packages/pycparser/ast_transforms.py new file mode 100644 index 00000000..367dcf54 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ast_transforms.py @@ -0,0 +1,164 @@ +#------------------------------------------------------------------------------ +# pycparser: ast_transforms.py +# +# Some utilities used by the parser to create a friendlier AST. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#------------------------------------------------------------------------------ + +from . import c_ast + + +def fix_switch_cases(switch_node): + """ The 'case' statements in a 'switch' come out of parsing with one + child node, so subsequent statements are just tucked to the parent + Compound. Additionally, consecutive (fall-through) case statements + come out messy. This is a peculiarity of the C grammar. The following: + + switch (myvar) { + case 10: + k = 10; + p = k + 1; + return 10; + case 20: + case 30: + return 20; + default: + break; + } + + Creates this tree (pseudo-dump): + + Switch + ID: myvar + Compound: + Case 10: + k = 10 + p = k + 1 + return 10 + Case 20: + Case 30: + return 20 + Default: + break + + The goal of this transform is to fix this mess, turning it into the + following: + + Switch + ID: myvar + Compound: + Case 10: + k = 10 + p = k + 1 + return 10 + Case 20: + Case 30: + return 20 + Default: + break + + A fixed AST node is returned. The argument may be modified. + """ + assert isinstance(switch_node, c_ast.Switch) + if not isinstance(switch_node.stmt, c_ast.Compound): + return switch_node + + # The new Compound child for the Switch, which will collect children in the + # correct order + new_compound = c_ast.Compound([], switch_node.stmt.coord) + + # The last Case/Default node + last_case = None + + # Goes over the children of the Compound below the Switch, adding them + # either directly below new_compound or below the last Case as appropriate + # (for `switch(cond) {}`, block_items would have been None) + for child in (switch_node.stmt.block_items or []): + if isinstance(child, (c_ast.Case, c_ast.Default)): + # If it's a Case/Default: + # 1. Add it to the Compound and mark as "last case" + # 2. If its immediate child is also a Case or Default, promote it + # to a sibling. + new_compound.block_items.append(child) + _extract_nested_case(child, new_compound.block_items) + last_case = new_compound.block_items[-1] + else: + # Other statements are added as children to the last case, if it + # exists. + if last_case is None: + new_compound.block_items.append(child) + else: + last_case.stmts.append(child) + + switch_node.stmt = new_compound + return switch_node + + +def _extract_nested_case(case_node, stmts_list): + """ Recursively extract consecutive Case statements that are made nested + by the parser and add them to the stmts_list. + """ + if isinstance(case_node.stmts[0], (c_ast.Case, c_ast.Default)): + stmts_list.append(case_node.stmts.pop()) + _extract_nested_case(stmts_list[-1], stmts_list) + + +def fix_atomic_specifiers(decl): + """ Atomic specifiers like _Atomic(type) are unusually structured, + conferring a qualifier upon the contained type. + + This function fixes a decl with atomic specifiers to have a sane AST + structure, by removing spurious Typename->TypeDecl pairs and attaching + the _Atomic qualifier in the right place. + """ + # There can be multiple levels of _Atomic in a decl; fix them until a + # fixed point is reached. + while True: + decl, found = _fix_atomic_specifiers_once(decl) + if not found: + break + + # Make sure to add an _Atomic qual on the topmost decl if needed. Also + # restore the declname on the innermost TypeDecl (it gets placed in the + # wrong place during construction). + typ = decl + while not isinstance(typ, c_ast.TypeDecl): + try: + typ = typ.type + except AttributeError: + return decl + if '_Atomic' in typ.quals and '_Atomic' not in decl.quals: + decl.quals.append('_Atomic') + if typ.declname is None: + typ.declname = decl.name + + return decl + + +def _fix_atomic_specifiers_once(decl): + """ Performs one 'fix' round of atomic specifiers. + Returns (modified_decl, found) where found is True iff a fix was made. + """ + parent = decl + grandparent = None + node = decl.type + while node is not None: + if isinstance(node, c_ast.Typename) and '_Atomic' in node.quals: + break + try: + grandparent = parent + parent = node + node = node.type + except AttributeError: + # If we've reached a node without a `type` field, it means we won't + # find what we're looking for at this point; give up the search + # and return the original decl unmodified. + return decl, False + + assert isinstance(parent, c_ast.TypeDecl) + grandparent.type = node.type + if '_Atomic' not in node.type.quals: + node.type.quals.append('_Atomic') + return decl, True diff --git a/.venv/Lib/site-packages/pycparser/c_ast.py b/.venv/Lib/site-packages/pycparser/c_ast.py new file mode 100644 index 00000000..6575a2ad --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/c_ast.py @@ -0,0 +1,1125 @@ +#----------------------------------------------------------------- +# ** ATTENTION ** +# This code was automatically generated from the file: +# _c_ast.cfg +# +# Do not modify it directly. Modify the configuration file and +# run the generator again. +# ** ** *** ** ** +# +# pycparser: c_ast.py +# +# AST Node classes. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- + + +import sys + +def _repr(obj): + """ + Get the representation of an object, with dedicated pprint-like format for lists. + """ + if isinstance(obj, list): + return '[' + (',\n '.join((_repr(e).replace('\n', '\n ') for e in obj))) + '\n]' + else: + return repr(obj) + +class Node(object): + __slots__ = () + """ Abstract base class for AST nodes. + """ + def __repr__(self): + """ Generates a python representation of the current node + """ + result = self.__class__.__name__ + '(' + + indent = '' + separator = '' + for name in self.__slots__[:-2]: + result += separator + result += indent + result += name + '=' + (_repr(getattr(self, name)).replace('\n', '\n ' + (' ' * (len(name) + len(self.__class__.__name__))))) + + separator = ',' + indent = '\n ' + (' ' * len(self.__class__.__name__)) + + result += indent + ')' + + return result + + def children(self): + """ A sequence of all children that are Nodes + """ + pass + + def show(self, buf=sys.stdout, offset=0, attrnames=False, nodenames=False, showcoord=False, _my_node_name=None): + """ Pretty print the Node and all its attributes and + children (recursively) to a buffer. + + buf: + Open IO buffer into which the Node is printed. + + offset: + Initial offset (amount of leading spaces) + + attrnames: + True if you want to see the attribute names in + name=value pairs. False to only see the values. + + nodenames: + True if you want to see the actual node names + within their parents. + + showcoord: + Do you want the coordinates of each Node to be + displayed. + """ + lead = ' ' * offset + if nodenames and _my_node_name is not None: + buf.write(lead + self.__class__.__name__+ ' <' + _my_node_name + '>: ') + else: + buf.write(lead + self.__class__.__name__+ ': ') + + if self.attr_names: + if attrnames: + nvlist = [(n, getattr(self,n)) for n in self.attr_names] + attrstr = ', '.join('%s=%s' % nv for nv in nvlist) + else: + vlist = [getattr(self, n) for n in self.attr_names] + attrstr = ', '.join('%s' % v for v in vlist) + buf.write(attrstr) + + if showcoord: + buf.write(' (at %s)' % self.coord) + buf.write('\n') + + for (child_name, child) in self.children(): + child.show( + buf, + offset=offset + 2, + attrnames=attrnames, + nodenames=nodenames, + showcoord=showcoord, + _my_node_name=child_name) + + +class NodeVisitor(object): + """ A base NodeVisitor class for visiting c_ast nodes. + Subclass it and define your own visit_XXX methods, where + XXX is the class name you want to visit with these + methods. + + For example: + + class ConstantVisitor(NodeVisitor): + def __init__(self): + self.values = [] + + def visit_Constant(self, node): + self.values.append(node.value) + + Creates a list of values of all the constant nodes + encountered below the given node. To use it: + + cv = ConstantVisitor() + cv.visit(node) + + Notes: + + * generic_visit() will be called for AST nodes for which + no visit_XXX method was defined. + * The children of nodes for which a visit_XXX was + defined will not be visited - if you need this, call + generic_visit() on the node. + You can use: + NodeVisitor.generic_visit(self, node) + * Modeled after Python's own AST visiting facilities + (the ast module of Python 3.0) + """ + + _method_cache = None + + def visit(self, node): + """ Visit a node. + """ + + if self._method_cache is None: + self._method_cache = {} + + visitor = self._method_cache.get(node.__class__.__name__, None) + if visitor is None: + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + self._method_cache[node.__class__.__name__] = visitor + + return visitor(node) + + def generic_visit(self, node): + """ Called if no explicit visitor function exists for a + node. Implements preorder visiting of the node. + """ + for c in node: + self.visit(c) + +class ArrayDecl(Node): + __slots__ = ('type', 'dim', 'dim_quals', 'coord', '__weakref__') + def __init__(self, type, dim, dim_quals, coord=None): + self.type = type + self.dim = dim + self.dim_quals = dim_quals + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + if self.dim is not None: nodelist.append(("dim", self.dim)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + if self.dim is not None: + yield self.dim + + attr_names = ('dim_quals', ) + +class ArrayRef(Node): + __slots__ = ('name', 'subscript', 'coord', '__weakref__') + def __init__(self, name, subscript, coord=None): + self.name = name + self.subscript = subscript + self.coord = coord + + def children(self): + nodelist = [] + if self.name is not None: nodelist.append(("name", self.name)) + if self.subscript is not None: nodelist.append(("subscript", self.subscript)) + return tuple(nodelist) + + def __iter__(self): + if self.name is not None: + yield self.name + if self.subscript is not None: + yield self.subscript + + attr_names = () + +class Assignment(Node): + __slots__ = ('op', 'lvalue', 'rvalue', 'coord', '__weakref__') + def __init__(self, op, lvalue, rvalue, coord=None): + self.op = op + self.lvalue = lvalue + self.rvalue = rvalue + self.coord = coord + + def children(self): + nodelist = [] + if self.lvalue is not None: nodelist.append(("lvalue", self.lvalue)) + if self.rvalue is not None: nodelist.append(("rvalue", self.rvalue)) + return tuple(nodelist) + + def __iter__(self): + if self.lvalue is not None: + yield self.lvalue + if self.rvalue is not None: + yield self.rvalue + + attr_names = ('op', ) + +class Alignas(Node): + __slots__ = ('alignment', 'coord', '__weakref__') + def __init__(self, alignment, coord=None): + self.alignment = alignment + self.coord = coord + + def children(self): + nodelist = [] + if self.alignment is not None: nodelist.append(("alignment", self.alignment)) + return tuple(nodelist) + + def __iter__(self): + if self.alignment is not None: + yield self.alignment + + attr_names = () + +class BinaryOp(Node): + __slots__ = ('op', 'left', 'right', 'coord', '__weakref__') + def __init__(self, op, left, right, coord=None): + self.op = op + self.left = left + self.right = right + self.coord = coord + + def children(self): + nodelist = [] + if self.left is not None: nodelist.append(("left", self.left)) + if self.right is not None: nodelist.append(("right", self.right)) + return tuple(nodelist) + + def __iter__(self): + if self.left is not None: + yield self.left + if self.right is not None: + yield self.right + + attr_names = ('op', ) + +class Break(Node): + __slots__ = ('coord', '__weakref__') + def __init__(self, coord=None): + self.coord = coord + + def children(self): + return () + + def __iter__(self): + return + yield + + attr_names = () + +class Case(Node): + __slots__ = ('expr', 'stmts', 'coord', '__weakref__') + def __init__(self, expr, stmts, coord=None): + self.expr = expr + self.stmts = stmts + self.coord = coord + + def children(self): + nodelist = [] + if self.expr is not None: nodelist.append(("expr", self.expr)) + for i, child in enumerate(self.stmts or []): + nodelist.append(("stmts[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + if self.expr is not None: + yield self.expr + for child in (self.stmts or []): + yield child + + attr_names = () + +class Cast(Node): + __slots__ = ('to_type', 'expr', 'coord', '__weakref__') + def __init__(self, to_type, expr, coord=None): + self.to_type = to_type + self.expr = expr + self.coord = coord + + def children(self): + nodelist = [] + if self.to_type is not None: nodelist.append(("to_type", self.to_type)) + if self.expr is not None: nodelist.append(("expr", self.expr)) + return tuple(nodelist) + + def __iter__(self): + if self.to_type is not None: + yield self.to_type + if self.expr is not None: + yield self.expr + + attr_names = () + +class Compound(Node): + __slots__ = ('block_items', 'coord', '__weakref__') + def __init__(self, block_items, coord=None): + self.block_items = block_items + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.block_items or []): + nodelist.append(("block_items[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.block_items or []): + yield child + + attr_names = () + +class CompoundLiteral(Node): + __slots__ = ('type', 'init', 'coord', '__weakref__') + def __init__(self, type, init, coord=None): + self.type = type + self.init = init + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + if self.init is not None: nodelist.append(("init", self.init)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + if self.init is not None: + yield self.init + + attr_names = () + +class Constant(Node): + __slots__ = ('type', 'value', 'coord', '__weakref__') + def __init__(self, type, value, coord=None): + self.type = type + self.value = value + self.coord = coord + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('type', 'value', ) + +class Continue(Node): + __slots__ = ('coord', '__weakref__') + def __init__(self, coord=None): + self.coord = coord + + def children(self): + return () + + def __iter__(self): + return + yield + + attr_names = () + +class Decl(Node): + __slots__ = ('name', 'quals', 'align', 'storage', 'funcspec', 'type', 'init', 'bitsize', 'coord', '__weakref__') + def __init__(self, name, quals, align, storage, funcspec, type, init, bitsize, coord=None): + self.name = name + self.quals = quals + self.align = align + self.storage = storage + self.funcspec = funcspec + self.type = type + self.init = init + self.bitsize = bitsize + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + if self.init is not None: nodelist.append(("init", self.init)) + if self.bitsize is not None: nodelist.append(("bitsize", self.bitsize)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + if self.init is not None: + yield self.init + if self.bitsize is not None: + yield self.bitsize + + attr_names = ('name', 'quals', 'align', 'storage', 'funcspec', ) + +class DeclList(Node): + __slots__ = ('decls', 'coord', '__weakref__') + def __init__(self, decls, coord=None): + self.decls = decls + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.decls or []): + nodelist.append(("decls[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.decls or []): + yield child + + attr_names = () + +class Default(Node): + __slots__ = ('stmts', 'coord', '__weakref__') + def __init__(self, stmts, coord=None): + self.stmts = stmts + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.stmts or []): + nodelist.append(("stmts[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.stmts or []): + yield child + + attr_names = () + +class DoWhile(Node): + __slots__ = ('cond', 'stmt', 'coord', '__weakref__') + def __init__(self, cond, stmt, coord=None): + self.cond = cond + self.stmt = stmt + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.stmt is not None: nodelist.append(("stmt", self.stmt)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.stmt is not None: + yield self.stmt + + attr_names = () + +class EllipsisParam(Node): + __slots__ = ('coord', '__weakref__') + def __init__(self, coord=None): + self.coord = coord + + def children(self): + return () + + def __iter__(self): + return + yield + + attr_names = () + +class EmptyStatement(Node): + __slots__ = ('coord', '__weakref__') + def __init__(self, coord=None): + self.coord = coord + + def children(self): + return () + + def __iter__(self): + return + yield + + attr_names = () + +class Enum(Node): + __slots__ = ('name', 'values', 'coord', '__weakref__') + def __init__(self, name, values, coord=None): + self.name = name + self.values = values + self.coord = coord + + def children(self): + nodelist = [] + if self.values is not None: nodelist.append(("values", self.values)) + return tuple(nodelist) + + def __iter__(self): + if self.values is not None: + yield self.values + + attr_names = ('name', ) + +class Enumerator(Node): + __slots__ = ('name', 'value', 'coord', '__weakref__') + def __init__(self, name, value, coord=None): + self.name = name + self.value = value + self.coord = coord + + def children(self): + nodelist = [] + if self.value is not None: nodelist.append(("value", self.value)) + return tuple(nodelist) + + def __iter__(self): + if self.value is not None: + yield self.value + + attr_names = ('name', ) + +class EnumeratorList(Node): + __slots__ = ('enumerators', 'coord', '__weakref__') + def __init__(self, enumerators, coord=None): + self.enumerators = enumerators + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.enumerators or []): + nodelist.append(("enumerators[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.enumerators or []): + yield child + + attr_names = () + +class ExprList(Node): + __slots__ = ('exprs', 'coord', '__weakref__') + def __init__(self, exprs, coord=None): + self.exprs = exprs + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.exprs or []): + nodelist.append(("exprs[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.exprs or []): + yield child + + attr_names = () + +class FileAST(Node): + __slots__ = ('ext', 'coord', '__weakref__') + def __init__(self, ext, coord=None): + self.ext = ext + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.ext or []): + nodelist.append(("ext[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.ext or []): + yield child + + attr_names = () + +class For(Node): + __slots__ = ('init', 'cond', 'next', 'stmt', 'coord', '__weakref__') + def __init__(self, init, cond, next, stmt, coord=None): + self.init = init + self.cond = cond + self.next = next + self.stmt = stmt + self.coord = coord + + def children(self): + nodelist = [] + if self.init is not None: nodelist.append(("init", self.init)) + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.next is not None: nodelist.append(("next", self.next)) + if self.stmt is not None: nodelist.append(("stmt", self.stmt)) + return tuple(nodelist) + + def __iter__(self): + if self.init is not None: + yield self.init + if self.cond is not None: + yield self.cond + if self.next is not None: + yield self.next + if self.stmt is not None: + yield self.stmt + + attr_names = () + +class FuncCall(Node): + __slots__ = ('name', 'args', 'coord', '__weakref__') + def __init__(self, name, args, coord=None): + self.name = name + self.args = args + self.coord = coord + + def children(self): + nodelist = [] + if self.name is not None: nodelist.append(("name", self.name)) + if self.args is not None: nodelist.append(("args", self.args)) + return tuple(nodelist) + + def __iter__(self): + if self.name is not None: + yield self.name + if self.args is not None: + yield self.args + + attr_names = () + +class FuncDecl(Node): + __slots__ = ('args', 'type', 'coord', '__weakref__') + def __init__(self, args, type, coord=None): + self.args = args + self.type = type + self.coord = coord + + def children(self): + nodelist = [] + if self.args is not None: nodelist.append(("args", self.args)) + if self.type is not None: nodelist.append(("type", self.type)) + return tuple(nodelist) + + def __iter__(self): + if self.args is not None: + yield self.args + if self.type is not None: + yield self.type + + attr_names = () + +class FuncDef(Node): + __slots__ = ('decl', 'param_decls', 'body', 'coord', '__weakref__') + def __init__(self, decl, param_decls, body, coord=None): + self.decl = decl + self.param_decls = param_decls + self.body = body + self.coord = coord + + def children(self): + nodelist = [] + if self.decl is not None: nodelist.append(("decl", self.decl)) + if self.body is not None: nodelist.append(("body", self.body)) + for i, child in enumerate(self.param_decls or []): + nodelist.append(("param_decls[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + if self.decl is not None: + yield self.decl + if self.body is not None: + yield self.body + for child in (self.param_decls or []): + yield child + + attr_names = () + +class Goto(Node): + __slots__ = ('name', 'coord', '__weakref__') + def __init__(self, name, coord=None): + self.name = name + self.coord = coord + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('name', ) + +class ID(Node): + __slots__ = ('name', 'coord', '__weakref__') + def __init__(self, name, coord=None): + self.name = name + self.coord = coord + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('name', ) + +class IdentifierType(Node): + __slots__ = ('names', 'coord', '__weakref__') + def __init__(self, names, coord=None): + self.names = names + self.coord = coord + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('names', ) + +class If(Node): + __slots__ = ('cond', 'iftrue', 'iffalse', 'coord', '__weakref__') + def __init__(self, cond, iftrue, iffalse, coord=None): + self.cond = cond + self.iftrue = iftrue + self.iffalse = iffalse + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.iftrue is not None: nodelist.append(("iftrue", self.iftrue)) + if self.iffalse is not None: nodelist.append(("iffalse", self.iffalse)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.iftrue is not None: + yield self.iftrue + if self.iffalse is not None: + yield self.iffalse + + attr_names = () + +class InitList(Node): + __slots__ = ('exprs', 'coord', '__weakref__') + def __init__(self, exprs, coord=None): + self.exprs = exprs + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.exprs or []): + nodelist.append(("exprs[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.exprs or []): + yield child + + attr_names = () + +class Label(Node): + __slots__ = ('name', 'stmt', 'coord', '__weakref__') + def __init__(self, name, stmt, coord=None): + self.name = name + self.stmt = stmt + self.coord = coord + + def children(self): + nodelist = [] + if self.stmt is not None: nodelist.append(("stmt", self.stmt)) + return tuple(nodelist) + + def __iter__(self): + if self.stmt is not None: + yield self.stmt + + attr_names = ('name', ) + +class NamedInitializer(Node): + __slots__ = ('name', 'expr', 'coord', '__weakref__') + def __init__(self, name, expr, coord=None): + self.name = name + self.expr = expr + self.coord = coord + + def children(self): + nodelist = [] + if self.expr is not None: nodelist.append(("expr", self.expr)) + for i, child in enumerate(self.name or []): + nodelist.append(("name[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + if self.expr is not None: + yield self.expr + for child in (self.name or []): + yield child + + attr_names = () + +class ParamList(Node): + __slots__ = ('params', 'coord', '__weakref__') + def __init__(self, params, coord=None): + self.params = params + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.params or []): + nodelist.append(("params[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.params or []): + yield child + + attr_names = () + +class PtrDecl(Node): + __slots__ = ('quals', 'type', 'coord', '__weakref__') + def __init__(self, quals, type, coord=None): + self.quals = quals + self.type = type + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + + attr_names = ('quals', ) + +class Return(Node): + __slots__ = ('expr', 'coord', '__weakref__') + def __init__(self, expr, coord=None): + self.expr = expr + self.coord = coord + + def children(self): + nodelist = [] + if self.expr is not None: nodelist.append(("expr", self.expr)) + return tuple(nodelist) + + def __iter__(self): + if self.expr is not None: + yield self.expr + + attr_names = () + +class StaticAssert(Node): + __slots__ = ('cond', 'message', 'coord', '__weakref__') + def __init__(self, cond, message, coord=None): + self.cond = cond + self.message = message + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.message is not None: nodelist.append(("message", self.message)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.message is not None: + yield self.message + + attr_names = () + +class Struct(Node): + __slots__ = ('name', 'decls', 'coord', '__weakref__') + def __init__(self, name, decls, coord=None): + self.name = name + self.decls = decls + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.decls or []): + nodelist.append(("decls[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.decls or []): + yield child + + attr_names = ('name', ) + +class StructRef(Node): + __slots__ = ('name', 'type', 'field', 'coord', '__weakref__') + def __init__(self, name, type, field, coord=None): + self.name = name + self.type = type + self.field = field + self.coord = coord + + def children(self): + nodelist = [] + if self.name is not None: nodelist.append(("name", self.name)) + if self.field is not None: nodelist.append(("field", self.field)) + return tuple(nodelist) + + def __iter__(self): + if self.name is not None: + yield self.name + if self.field is not None: + yield self.field + + attr_names = ('type', ) + +class Switch(Node): + __slots__ = ('cond', 'stmt', 'coord', '__weakref__') + def __init__(self, cond, stmt, coord=None): + self.cond = cond + self.stmt = stmt + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.stmt is not None: nodelist.append(("stmt", self.stmt)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.stmt is not None: + yield self.stmt + + attr_names = () + +class TernaryOp(Node): + __slots__ = ('cond', 'iftrue', 'iffalse', 'coord', '__weakref__') + def __init__(self, cond, iftrue, iffalse, coord=None): + self.cond = cond + self.iftrue = iftrue + self.iffalse = iffalse + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.iftrue is not None: nodelist.append(("iftrue", self.iftrue)) + if self.iffalse is not None: nodelist.append(("iffalse", self.iffalse)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.iftrue is not None: + yield self.iftrue + if self.iffalse is not None: + yield self.iffalse + + attr_names = () + +class TypeDecl(Node): + __slots__ = ('declname', 'quals', 'align', 'type', 'coord', '__weakref__') + def __init__(self, declname, quals, align, type, coord=None): + self.declname = declname + self.quals = quals + self.align = align + self.type = type + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + + attr_names = ('declname', 'quals', 'align', ) + +class Typedef(Node): + __slots__ = ('name', 'quals', 'storage', 'type', 'coord', '__weakref__') + def __init__(self, name, quals, storage, type, coord=None): + self.name = name + self.quals = quals + self.storage = storage + self.type = type + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + + attr_names = ('name', 'quals', 'storage', ) + +class Typename(Node): + __slots__ = ('name', 'quals', 'align', 'type', 'coord', '__weakref__') + def __init__(self, name, quals, align, type, coord=None): + self.name = name + self.quals = quals + self.align = align + self.type = type + self.coord = coord + + def children(self): + nodelist = [] + if self.type is not None: nodelist.append(("type", self.type)) + return tuple(nodelist) + + def __iter__(self): + if self.type is not None: + yield self.type + + attr_names = ('name', 'quals', 'align', ) + +class UnaryOp(Node): + __slots__ = ('op', 'expr', 'coord', '__weakref__') + def __init__(self, op, expr, coord=None): + self.op = op + self.expr = expr + self.coord = coord + + def children(self): + nodelist = [] + if self.expr is not None: nodelist.append(("expr", self.expr)) + return tuple(nodelist) + + def __iter__(self): + if self.expr is not None: + yield self.expr + + attr_names = ('op', ) + +class Union(Node): + __slots__ = ('name', 'decls', 'coord', '__weakref__') + def __init__(self, name, decls, coord=None): + self.name = name + self.decls = decls + self.coord = coord + + def children(self): + nodelist = [] + for i, child in enumerate(self.decls or []): + nodelist.append(("decls[%d]" % i, child)) + return tuple(nodelist) + + def __iter__(self): + for child in (self.decls or []): + yield child + + attr_names = ('name', ) + +class While(Node): + __slots__ = ('cond', 'stmt', 'coord', '__weakref__') + def __init__(self, cond, stmt, coord=None): + self.cond = cond + self.stmt = stmt + self.coord = coord + + def children(self): + nodelist = [] + if self.cond is not None: nodelist.append(("cond", self.cond)) + if self.stmt is not None: nodelist.append(("stmt", self.stmt)) + return tuple(nodelist) + + def __iter__(self): + if self.cond is not None: + yield self.cond + if self.stmt is not None: + yield self.stmt + + attr_names = () + +class Pragma(Node): + __slots__ = ('string', 'coord', '__weakref__') + def __init__(self, string, coord=None): + self.string = string + self.coord = coord + + def children(self): + nodelist = [] + return tuple(nodelist) + + def __iter__(self): + return + yield + + attr_names = ('string', ) + diff --git a/.venv/Lib/site-packages/pycparser/c_generator.py b/.venv/Lib/site-packages/pycparser/c_generator.py new file mode 100644 index 00000000..1057b2c6 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/c_generator.py @@ -0,0 +1,502 @@ +#------------------------------------------------------------------------------ +# pycparser: c_generator.py +# +# C code generator from pycparser AST nodes. +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#------------------------------------------------------------------------------ +from . import c_ast + + +class CGenerator(object): + """ Uses the same visitor pattern as c_ast.NodeVisitor, but modified to + return a value from each visit method, using string accumulation in + generic_visit. + """ + def __init__(self, reduce_parentheses=False): + """ Constructs C-code generator + + reduce_parentheses: + if True, eliminates needless parentheses on binary operators + """ + # Statements start with indentation of self.indent_level spaces, using + # the _make_indent method. + self.indent_level = 0 + self.reduce_parentheses = reduce_parentheses + + def _make_indent(self): + return ' ' * self.indent_level + + def visit(self, node): + method = 'visit_' + node.__class__.__name__ + return getattr(self, method, self.generic_visit)(node) + + def generic_visit(self, node): + if node is None: + return '' + else: + return ''.join(self.visit(c) for c_name, c in node.children()) + + def visit_Constant(self, n): + return n.value + + def visit_ID(self, n): + return n.name + + def visit_Pragma(self, n): + ret = '#pragma' + if n.string: + ret += ' ' + n.string + return ret + + def visit_ArrayRef(self, n): + arrref = self._parenthesize_unless_simple(n.name) + return arrref + '[' + self.visit(n.subscript) + ']' + + def visit_StructRef(self, n): + sref = self._parenthesize_unless_simple(n.name) + return sref + n.type + self.visit(n.field) + + def visit_FuncCall(self, n): + fref = self._parenthesize_unless_simple(n.name) + return fref + '(' + self.visit(n.args) + ')' + + def visit_UnaryOp(self, n): + if n.op == 'sizeof': + # Always parenthesize the argument of sizeof since it can be + # a name. + return 'sizeof(%s)' % self.visit(n.expr) + else: + operand = self._parenthesize_unless_simple(n.expr) + if n.op == 'p++': + return '%s++' % operand + elif n.op == 'p--': + return '%s--' % operand + else: + return '%s%s' % (n.op, operand) + + # Precedence map of binary operators: + precedence_map = { + # Should be in sync with c_parser.CParser.precedence + # Higher numbers are stronger binding + '||': 0, # weakest binding + '&&': 1, + '|': 2, + '^': 3, + '&': 4, + '==': 5, '!=': 5, + '>': 6, '>=': 6, '<': 6, '<=': 6, + '>>': 7, '<<': 7, + '+': 8, '-': 8, + '*': 9, '/': 9, '%': 9 # strongest binding + } + + def visit_BinaryOp(self, n): + # Note: all binary operators are left-to-right associative + # + # If `n.left.op` has a stronger or equally binding precedence in + # comparison to `n.op`, no parenthesis are needed for the left: + # e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as + # `(a+b) - c` is equivalent to `a+b - c` (same precedence). + # If the left operator is weaker binding than the current, then + # parentheses are necessary: + # e.g., `(a+b) * c` is NOT equivalent to `a+b * c`. + lval_str = self._parenthesize_if( + n.left, + lambda d: not (self._is_simple_node(d) or + self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and + self.precedence_map[d.op] >= self.precedence_map[n.op])) + # If `n.right.op` has a stronger -but not equal- binding precedence, + # parenthesis can be omitted on the right: + # e.g., `a + (b*c)` is equivalent to `a + b*c`. + # If the right operator is weaker or equally binding, then parentheses + # are necessary: + # e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and + # `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence). + rval_str = self._parenthesize_if( + n.right, + lambda d: not (self._is_simple_node(d) or + self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and + self.precedence_map[d.op] > self.precedence_map[n.op])) + return '%s %s %s' % (lval_str, n.op, rval_str) + + def visit_Assignment(self, n): + rval_str = self._parenthesize_if( + n.rvalue, + lambda n: isinstance(n, c_ast.Assignment)) + return '%s %s %s' % (self.visit(n.lvalue), n.op, rval_str) + + def visit_IdentifierType(self, n): + return ' '.join(n.names) + + def _visit_expr(self, n): + if isinstance(n, c_ast.InitList): + return '{' + self.visit(n) + '}' + elif isinstance(n, c_ast.ExprList): + return '(' + self.visit(n) + ')' + else: + return self.visit(n) + + def visit_Decl(self, n, no_type=False): + # no_type is used when a Decl is part of a DeclList, where the type is + # explicitly only for the first declaration in a list. + # + s = n.name if no_type else self._generate_decl(n) + if n.bitsize: s += ' : ' + self.visit(n.bitsize) + if n.init: + s += ' = ' + self._visit_expr(n.init) + return s + + def visit_DeclList(self, n): + s = self.visit(n.decls[0]) + if len(n.decls) > 1: + s += ', ' + ', '.join(self.visit_Decl(decl, no_type=True) + for decl in n.decls[1:]) + return s + + def visit_Typedef(self, n): + s = '' + if n.storage: s += ' '.join(n.storage) + ' ' + s += self._generate_type(n.type) + return s + + def visit_Cast(self, n): + s = '(' + self._generate_type(n.to_type, emit_declname=False) + ')' + return s + ' ' + self._parenthesize_unless_simple(n.expr) + + def visit_ExprList(self, n): + visited_subexprs = [] + for expr in n.exprs: + visited_subexprs.append(self._visit_expr(expr)) + return ', '.join(visited_subexprs) + + def visit_InitList(self, n): + visited_subexprs = [] + for expr in n.exprs: + visited_subexprs.append(self._visit_expr(expr)) + return ', '.join(visited_subexprs) + + def visit_Enum(self, n): + return self._generate_struct_union_enum(n, name='enum') + + def visit_Alignas(self, n): + return '_Alignas({})'.format(self.visit(n.alignment)) + + def visit_Enumerator(self, n): + if not n.value: + return '{indent}{name},\n'.format( + indent=self._make_indent(), + name=n.name, + ) + else: + return '{indent}{name} = {value},\n'.format( + indent=self._make_indent(), + name=n.name, + value=self.visit(n.value), + ) + + def visit_FuncDef(self, n): + decl = self.visit(n.decl) + self.indent_level = 0 + body = self.visit(n.body) + if n.param_decls: + knrdecls = ';\n'.join(self.visit(p) for p in n.param_decls) + return decl + '\n' + knrdecls + ';\n' + body + '\n' + else: + return decl + '\n' + body + '\n' + + def visit_FileAST(self, n): + s = '' + for ext in n.ext: + if isinstance(ext, c_ast.FuncDef): + s += self.visit(ext) + elif isinstance(ext, c_ast.Pragma): + s += self.visit(ext) + '\n' + else: + s += self.visit(ext) + ';\n' + return s + + def visit_Compound(self, n): + s = self._make_indent() + '{\n' + self.indent_level += 2 + if n.block_items: + s += ''.join(self._generate_stmt(stmt) for stmt in n.block_items) + self.indent_level -= 2 + s += self._make_indent() + '}\n' + return s + + def visit_CompoundLiteral(self, n): + return '(' + self.visit(n.type) + '){' + self.visit(n.init) + '}' + + + def visit_EmptyStatement(self, n): + return ';' + + def visit_ParamList(self, n): + return ', '.join(self.visit(param) for param in n.params) + + def visit_Return(self, n): + s = 'return' + if n.expr: s += ' ' + self.visit(n.expr) + return s + ';' + + def visit_Break(self, n): + return 'break;' + + def visit_Continue(self, n): + return 'continue;' + + def visit_TernaryOp(self, n): + s = '(' + self._visit_expr(n.cond) + ') ? ' + s += '(' + self._visit_expr(n.iftrue) + ') : ' + s += '(' + self._visit_expr(n.iffalse) + ')' + return s + + def visit_If(self, n): + s = 'if (' + if n.cond: s += self.visit(n.cond) + s += ')\n' + s += self._generate_stmt(n.iftrue, add_indent=True) + if n.iffalse: + s += self._make_indent() + 'else\n' + s += self._generate_stmt(n.iffalse, add_indent=True) + return s + + def visit_For(self, n): + s = 'for (' + if n.init: s += self.visit(n.init) + s += ';' + if n.cond: s += ' ' + self.visit(n.cond) + s += ';' + if n.next: s += ' ' + self.visit(n.next) + s += ')\n' + s += self._generate_stmt(n.stmt, add_indent=True) + return s + + def visit_While(self, n): + s = 'while (' + if n.cond: s += self.visit(n.cond) + s += ')\n' + s += self._generate_stmt(n.stmt, add_indent=True) + return s + + def visit_DoWhile(self, n): + s = 'do\n' + s += self._generate_stmt(n.stmt, add_indent=True) + s += self._make_indent() + 'while (' + if n.cond: s += self.visit(n.cond) + s += ');' + return s + + def visit_StaticAssert(self, n): + s = '_Static_assert(' + s += self.visit(n.cond) + if n.message: + s += ',' + s += self.visit(n.message) + s += ')' + return s + + def visit_Switch(self, n): + s = 'switch (' + self.visit(n.cond) + ')\n' + s += self._generate_stmt(n.stmt, add_indent=True) + return s + + def visit_Case(self, n): + s = 'case ' + self.visit(n.expr) + ':\n' + for stmt in n.stmts: + s += self._generate_stmt(stmt, add_indent=True) + return s + + def visit_Default(self, n): + s = 'default:\n' + for stmt in n.stmts: + s += self._generate_stmt(stmt, add_indent=True) + return s + + def visit_Label(self, n): + return n.name + ':\n' + self._generate_stmt(n.stmt) + + def visit_Goto(self, n): + return 'goto ' + n.name + ';' + + def visit_EllipsisParam(self, n): + return '...' + + def visit_Struct(self, n): + return self._generate_struct_union_enum(n, 'struct') + + def visit_Typename(self, n): + return self._generate_type(n.type) + + def visit_Union(self, n): + return self._generate_struct_union_enum(n, 'union') + + def visit_NamedInitializer(self, n): + s = '' + for name in n.name: + if isinstance(name, c_ast.ID): + s += '.' + name.name + else: + s += '[' + self.visit(name) + ']' + s += ' = ' + self._visit_expr(n.expr) + return s + + def visit_FuncDecl(self, n): + return self._generate_type(n) + + def visit_ArrayDecl(self, n): + return self._generate_type(n, emit_declname=False) + + def visit_TypeDecl(self, n): + return self._generate_type(n, emit_declname=False) + + def visit_PtrDecl(self, n): + return self._generate_type(n, emit_declname=False) + + def _generate_struct_union_enum(self, n, name): + """ Generates code for structs, unions, and enums. name should be + 'struct', 'union', or 'enum'. + """ + if name in ('struct', 'union'): + members = n.decls + body_function = self._generate_struct_union_body + else: + assert name == 'enum' + members = None if n.values is None else n.values.enumerators + body_function = self._generate_enum_body + s = name + ' ' + (n.name or '') + if members is not None: + # None means no members + # Empty sequence means an empty list of members + s += '\n' + s += self._make_indent() + self.indent_level += 2 + s += '{\n' + s += body_function(members) + self.indent_level -= 2 + s += self._make_indent() + '}' + return s + + def _generate_struct_union_body(self, members): + return ''.join(self._generate_stmt(decl) for decl in members) + + def _generate_enum_body(self, members): + # `[:-2] + '\n'` removes the final `,` from the enumerator list + return ''.join(self.visit(value) for value in members)[:-2] + '\n' + + def _generate_stmt(self, n, add_indent=False): + """ Generation from a statement node. This method exists as a wrapper + for individual visit_* methods to handle different treatment of + some statements in this context. + """ + typ = type(n) + if add_indent: self.indent_level += 2 + indent = self._make_indent() + if add_indent: self.indent_level -= 2 + + if typ in ( + c_ast.Decl, c_ast.Assignment, c_ast.Cast, c_ast.UnaryOp, + c_ast.BinaryOp, c_ast.TernaryOp, c_ast.FuncCall, c_ast.ArrayRef, + c_ast.StructRef, c_ast.Constant, c_ast.ID, c_ast.Typedef, + c_ast.ExprList): + # These can also appear in an expression context so no semicolon + # is added to them automatically + # + return indent + self.visit(n) + ';\n' + elif typ in (c_ast.Compound,): + # No extra indentation required before the opening brace of a + # compound - because it consists of multiple lines it has to + # compute its own indentation. + # + return self.visit(n) + elif typ in (c_ast.If,): + return indent + self.visit(n) + else: + return indent + self.visit(n) + '\n' + + def _generate_decl(self, n): + """ Generation from a Decl node. + """ + s = '' + if n.funcspec: s = ' '.join(n.funcspec) + ' ' + if n.storage: s += ' '.join(n.storage) + ' ' + if n.align: s += self.visit(n.align[0]) + ' ' + s += self._generate_type(n.type) + return s + + def _generate_type(self, n, modifiers=[], emit_declname = True): + """ Recursive generation from a type node. n is the type node. + modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers + encountered on the way down to a TypeDecl, to allow proper + generation from it. + """ + typ = type(n) + #~ print(n, modifiers) + + if typ == c_ast.TypeDecl: + s = '' + if n.quals: s += ' '.join(n.quals) + ' ' + s += self.visit(n.type) + + nstr = n.declname if n.declname and emit_declname else '' + # Resolve modifiers. + # Wrap in parens to distinguish pointer to array and pointer to + # function syntax. + # + for i, modifier in enumerate(modifiers): + if isinstance(modifier, c_ast.ArrayDecl): + if (i != 0 and + isinstance(modifiers[i - 1], c_ast.PtrDecl)): + nstr = '(' + nstr + ')' + nstr += '[' + if modifier.dim_quals: + nstr += ' '.join(modifier.dim_quals) + ' ' + nstr += self.visit(modifier.dim) + ']' + elif isinstance(modifier, c_ast.FuncDecl): + if (i != 0 and + isinstance(modifiers[i - 1], c_ast.PtrDecl)): + nstr = '(' + nstr + ')' + nstr += '(' + self.visit(modifier.args) + ')' + elif isinstance(modifier, c_ast.PtrDecl): + if modifier.quals: + nstr = '* %s%s' % (' '.join(modifier.quals), + ' ' + nstr if nstr else '') + else: + nstr = '*' + nstr + if nstr: s += ' ' + nstr + return s + elif typ == c_ast.Decl: + return self._generate_decl(n.type) + elif typ == c_ast.Typename: + return self._generate_type(n.type, emit_declname = emit_declname) + elif typ == c_ast.IdentifierType: + return ' '.join(n.names) + ' ' + elif typ in (c_ast.ArrayDecl, c_ast.PtrDecl, c_ast.FuncDecl): + return self._generate_type(n.type, modifiers + [n], + emit_declname = emit_declname) + else: + return self.visit(n) + + def _parenthesize_if(self, n, condition): + """ Visits 'n' and returns its string representation, parenthesized + if the condition function applied to the node returns True. + """ + s = self._visit_expr(n) + if condition(n): + return '(' + s + ')' + else: + return s + + def _parenthesize_unless_simple(self, n): + """ Common use case for _parenthesize_if + """ + return self._parenthesize_if(n, lambda d: not self._is_simple_node(d)) + + def _is_simple_node(self, n): + """ Returns True for nodes that are "simple" - i.e. nodes that always + have higher precedence than operators. + """ + return isinstance(n, (c_ast.Constant, c_ast.ID, c_ast.ArrayRef, + c_ast.StructRef, c_ast.FuncCall)) diff --git a/.venv/Lib/site-packages/pycparser/c_lexer.py b/.venv/Lib/site-packages/pycparser/c_lexer.py new file mode 100644 index 00000000..d68d8ebf --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/c_lexer.py @@ -0,0 +1,554 @@ +#------------------------------------------------------------------------------ +# pycparser: c_lexer.py +# +# CLexer class: lexer for the C language +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#------------------------------------------------------------------------------ +import re + +from .ply import lex +from .ply.lex import TOKEN + + +class CLexer(object): + """ A lexer for the C language. After building it, set the + input text with input(), and call token() to get new + tokens. + + The public attribute filename can be set to an initial + filename, but the lexer will update it upon #line + directives. + """ + def __init__(self, error_func, on_lbrace_func, on_rbrace_func, + type_lookup_func): + """ Create a new Lexer. + + error_func: + An error function. Will be called with an error + message, line and column as arguments, in case of + an error during lexing. + + on_lbrace_func, on_rbrace_func: + Called when an LBRACE or RBRACE is encountered + (likely to push/pop type_lookup_func's scope) + + type_lookup_func: + A type lookup function. Given a string, it must + return True IFF this string is a name of a type + that was defined with a typedef earlier. + """ + self.error_func = error_func + self.on_lbrace_func = on_lbrace_func + self.on_rbrace_func = on_rbrace_func + self.type_lookup_func = type_lookup_func + self.filename = '' + + # Keeps track of the last token returned from self.token() + self.last_token = None + + # Allow either "# line" or "# " to support GCC's + # cpp output + # + self.line_pattern = re.compile(r'([ \t]*line\W)|([ \t]*\d+)') + self.pragma_pattern = re.compile(r'[ \t]*pragma\W') + + def build(self, **kwargs): + """ Builds the lexer from the specification. Must be + called after the lexer object is created. + + This method exists separately, because the PLY + manual warns against calling lex.lex inside + __init__ + """ + self.lexer = lex.lex(object=self, **kwargs) + + def reset_lineno(self): + """ Resets the internal line number counter of the lexer. + """ + self.lexer.lineno = 1 + + def input(self, text): + self.lexer.input(text) + + def token(self): + self.last_token = self.lexer.token() + return self.last_token + + def find_tok_column(self, token): + """ Find the column of the token in its line. + """ + last_cr = self.lexer.lexdata.rfind('\n', 0, token.lexpos) + return token.lexpos - last_cr + + ######################-- PRIVATE --###################### + + ## + ## Internal auxiliary methods + ## + def _error(self, msg, token): + location = self._make_tok_location(token) + self.error_func(msg, location[0], location[1]) + self.lexer.skip(1) + + def _make_tok_location(self, token): + return (token.lineno, self.find_tok_column(token)) + + ## + ## Reserved keywords + ## + keywords = ( + 'AUTO', 'BREAK', 'CASE', 'CHAR', 'CONST', + 'CONTINUE', 'DEFAULT', 'DO', 'DOUBLE', 'ELSE', 'ENUM', 'EXTERN', + 'FLOAT', 'FOR', 'GOTO', 'IF', 'INLINE', 'INT', 'LONG', + 'REGISTER', 'OFFSETOF', + 'RESTRICT', 'RETURN', 'SHORT', 'SIGNED', 'SIZEOF', 'STATIC', 'STRUCT', + 'SWITCH', 'TYPEDEF', 'UNION', 'UNSIGNED', 'VOID', + 'VOLATILE', 'WHILE', '__INT128', + ) + + keywords_new = ( + '_BOOL', '_COMPLEX', + '_NORETURN', '_THREAD_LOCAL', '_STATIC_ASSERT', + '_ATOMIC', '_ALIGNOF', '_ALIGNAS', + ) + + keyword_map = {} + + for keyword in keywords: + keyword_map[keyword.lower()] = keyword + + for keyword in keywords_new: + keyword_map[keyword[:2].upper() + keyword[2:].lower()] = keyword + + ## + ## All the tokens recognized by the lexer + ## + tokens = keywords + keywords_new + ( + # Identifiers + 'ID', + + # Type identifiers (identifiers previously defined as + # types with typedef) + 'TYPEID', + + # constants + 'INT_CONST_DEC', 'INT_CONST_OCT', 'INT_CONST_HEX', 'INT_CONST_BIN', 'INT_CONST_CHAR', + 'FLOAT_CONST', 'HEX_FLOAT_CONST', + 'CHAR_CONST', + 'WCHAR_CONST', + 'U8CHAR_CONST', + 'U16CHAR_CONST', + 'U32CHAR_CONST', + + # String literals + 'STRING_LITERAL', + 'WSTRING_LITERAL', + 'U8STRING_LITERAL', + 'U16STRING_LITERAL', + 'U32STRING_LITERAL', + + # Operators + 'PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'MOD', + 'OR', 'AND', 'NOT', 'XOR', 'LSHIFT', 'RSHIFT', + 'LOR', 'LAND', 'LNOT', + 'LT', 'LE', 'GT', 'GE', 'EQ', 'NE', + + # Assignment + 'EQUALS', 'TIMESEQUAL', 'DIVEQUAL', 'MODEQUAL', + 'PLUSEQUAL', 'MINUSEQUAL', + 'LSHIFTEQUAL','RSHIFTEQUAL', 'ANDEQUAL', 'XOREQUAL', + 'OREQUAL', + + # Increment/decrement + 'PLUSPLUS', 'MINUSMINUS', + + # Structure dereference (->) + 'ARROW', + + # Conditional operator (?) + 'CONDOP', + + # Delimiters + 'LPAREN', 'RPAREN', # ( ) + 'LBRACKET', 'RBRACKET', # [ ] + 'LBRACE', 'RBRACE', # { } + 'COMMA', 'PERIOD', # . , + 'SEMI', 'COLON', # ; : + + # Ellipsis (...) + 'ELLIPSIS', + + # pre-processor + 'PPHASH', # '#' + 'PPPRAGMA', # 'pragma' + 'PPPRAGMASTR', + ) + + ## + ## Regexes for use in tokens + ## + ## + + # valid C identifiers (K&R2: A.2.3), plus '$' (supported by some compilers) + identifier = r'[a-zA-Z_$][0-9a-zA-Z_$]*' + + hex_prefix = '0[xX]' + hex_digits = '[0-9a-fA-F]+' + bin_prefix = '0[bB]' + bin_digits = '[01]+' + + # integer constants (K&R2: A.2.5.1) + integer_suffix_opt = r'(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?' + decimal_constant = '(0'+integer_suffix_opt+')|([1-9][0-9]*'+integer_suffix_opt+')' + octal_constant = '0[0-7]*'+integer_suffix_opt + hex_constant = hex_prefix+hex_digits+integer_suffix_opt + bin_constant = bin_prefix+bin_digits+integer_suffix_opt + + bad_octal_constant = '0[0-7]*[89]' + + # character constants (K&R2: A.2.5.2) + # Note: a-zA-Z and '.-~^_!=&;,' are allowed as escape chars to support #line + # directives with Windows paths as filenames (..\..\dir\file) + # For the same reason, decimal_escape allows all digit sequences. We want to + # parse all correct code, even if it means to sometimes parse incorrect + # code. + # + # The original regexes were taken verbatim from the C syntax definition, + # and were later modified to avoid worst-case exponential running time. + # + # simple_escape = r"""([a-zA-Z._~!=&\^\-\\?'"])""" + # decimal_escape = r"""(\d+)""" + # hex_escape = r"""(x[0-9a-fA-F]+)""" + # bad_escape = r"""([\\][^a-zA-Z._~^!=&\^\-\\?'"x0-7])""" + # + # The following modifications were made to avoid the ambiguity that allowed backtracking: + # (https://github.com/eliben/pycparser/issues/61) + # + # - \x was removed from simple_escape, unless it was not followed by a hex digit, to avoid ambiguity with hex_escape. + # - hex_escape allows one or more hex characters, but requires that the next character(if any) is not hex + # - decimal_escape allows one or more decimal characters, but requires that the next character(if any) is not a decimal + # - bad_escape does not allow any decimals (8-9), to avoid conflicting with the permissive decimal_escape. + # + # Without this change, python's `re` module would recursively try parsing each ambiguous escape sequence in multiple ways. + # e.g. `\123` could be parsed as `\1`+`23`, `\12`+`3`, and `\123`. + + simple_escape = r"""([a-wyzA-Z._~!=&\^\-\\?'"]|x(?![0-9a-fA-F]))""" + decimal_escape = r"""(\d+)(?!\d)""" + hex_escape = r"""(x[0-9a-fA-F]+)(?![0-9a-fA-F])""" + bad_escape = r"""([\\][^a-zA-Z._~^!=&\^\-\\?'"x0-9])""" + + escape_sequence = r"""(\\("""+simple_escape+'|'+decimal_escape+'|'+hex_escape+'))' + + # This complicated regex with lookahead might be slow for strings, so because all of the valid escapes (including \x) allowed + # 0 or more non-escaped characters after the first character, simple_escape+decimal_escape+hex_escape got simplified to + + escape_sequence_start_in_string = r"""(\\[0-9a-zA-Z._~!=&\^\-\\?'"])""" + + cconst_char = r"""([^'\\\n]|"""+escape_sequence+')' + char_const = "'"+cconst_char+"'" + wchar_const = 'L'+char_const + u8char_const = 'u8'+char_const + u16char_const = 'u'+char_const + u32char_const = 'U'+char_const + multicharacter_constant = "'"+cconst_char+"{2,4}'" + unmatched_quote = "('"+cconst_char+"*\\n)|('"+cconst_char+"*$)" + bad_char_const = r"""('"""+cconst_char+"""[^'\n]+')|('')|('"""+bad_escape+r"""[^'\n]*')""" + + # string literals (K&R2: A.2.6) + string_char = r"""([^"\\\n]|"""+escape_sequence_start_in_string+')' + string_literal = '"'+string_char+'*"' + wstring_literal = 'L'+string_literal + u8string_literal = 'u8'+string_literal + u16string_literal = 'u'+string_literal + u32string_literal = 'U'+string_literal + bad_string_literal = '"'+string_char+'*'+bad_escape+string_char+'*"' + + # floating constants (K&R2: A.2.5.3) + exponent_part = r"""([eE][-+]?[0-9]+)""" + fractional_constant = r"""([0-9]*\.[0-9]+)|([0-9]+\.)""" + floating_constant = '(((('+fractional_constant+')'+exponent_part+'?)|([0-9]+'+exponent_part+'))[FfLl]?)' + binary_exponent_part = r'''([pP][+-]?[0-9]+)''' + hex_fractional_constant = '((('+hex_digits+r""")?\."""+hex_digits+')|('+hex_digits+r"""\.))""" + hex_floating_constant = '('+hex_prefix+'('+hex_digits+'|'+hex_fractional_constant+')'+binary_exponent_part+'[FfLl]?)' + + ## + ## Lexer states: used for preprocessor \n-terminated directives + ## + states = ( + # ppline: preprocessor line directives + # + ('ppline', 'exclusive'), + + # pppragma: pragma + # + ('pppragma', 'exclusive'), + ) + + def t_PPHASH(self, t): + r'[ \t]*\#' + if self.line_pattern.match(t.lexer.lexdata, pos=t.lexer.lexpos): + t.lexer.begin('ppline') + self.pp_line = self.pp_filename = None + elif self.pragma_pattern.match(t.lexer.lexdata, pos=t.lexer.lexpos): + t.lexer.begin('pppragma') + else: + t.type = 'PPHASH' + return t + + ## + ## Rules for the ppline state + ## + @TOKEN(string_literal) + def t_ppline_FILENAME(self, t): + if self.pp_line is None: + self._error('filename before line number in #line', t) + else: + self.pp_filename = t.value.lstrip('"').rstrip('"') + + @TOKEN(decimal_constant) + def t_ppline_LINE_NUMBER(self, t): + if self.pp_line is None: + self.pp_line = t.value + else: + # Ignore: GCC's cpp sometimes inserts a numeric flag + # after the file name + pass + + def t_ppline_NEWLINE(self, t): + r'\n' + if self.pp_line is None: + self._error('line number missing in #line', t) + else: + self.lexer.lineno = int(self.pp_line) + + if self.pp_filename is not None: + self.filename = self.pp_filename + + t.lexer.begin('INITIAL') + + def t_ppline_PPLINE(self, t): + r'line' + pass + + t_ppline_ignore = ' \t' + + def t_ppline_error(self, t): + self._error('invalid #line directive', t) + + ## + ## Rules for the pppragma state + ## + def t_pppragma_NEWLINE(self, t): + r'\n' + t.lexer.lineno += 1 + t.lexer.begin('INITIAL') + + def t_pppragma_PPPRAGMA(self, t): + r'pragma' + return t + + t_pppragma_ignore = ' \t' + + def t_pppragma_STR(self, t): + '.+' + t.type = 'PPPRAGMASTR' + return t + + def t_pppragma_error(self, t): + self._error('invalid #pragma directive', t) + + ## + ## Rules for the normal state + ## + t_ignore = ' \t' + + # Newlines + def t_NEWLINE(self, t): + r'\n+' + t.lexer.lineno += t.value.count("\n") + + # Operators + t_PLUS = r'\+' + t_MINUS = r'-' + t_TIMES = r'\*' + t_DIVIDE = r'/' + t_MOD = r'%' + t_OR = r'\|' + t_AND = r'&' + t_NOT = r'~' + t_XOR = r'\^' + t_LSHIFT = r'<<' + t_RSHIFT = r'>>' + t_LOR = r'\|\|' + t_LAND = r'&&' + t_LNOT = r'!' + t_LT = r'<' + t_GT = r'>' + t_LE = r'<=' + t_GE = r'>=' + t_EQ = r'==' + t_NE = r'!=' + + # Assignment operators + t_EQUALS = r'=' + t_TIMESEQUAL = r'\*=' + t_DIVEQUAL = r'/=' + t_MODEQUAL = r'%=' + t_PLUSEQUAL = r'\+=' + t_MINUSEQUAL = r'-=' + t_LSHIFTEQUAL = r'<<=' + t_RSHIFTEQUAL = r'>>=' + t_ANDEQUAL = r'&=' + t_OREQUAL = r'\|=' + t_XOREQUAL = r'\^=' + + # Increment/decrement + t_PLUSPLUS = r'\+\+' + t_MINUSMINUS = r'--' + + # -> + t_ARROW = r'->' + + # ? + t_CONDOP = r'\?' + + # Delimiters + t_LPAREN = r'\(' + t_RPAREN = r'\)' + t_LBRACKET = r'\[' + t_RBRACKET = r'\]' + t_COMMA = r',' + t_PERIOD = r'\.' + t_SEMI = r';' + t_COLON = r':' + t_ELLIPSIS = r'\.\.\.' + + # Scope delimiters + # To see why on_lbrace_func is needed, consider: + # typedef char TT; + # void foo(int TT) { TT = 10; } + # TT x = 5; + # Outside the function, TT is a typedef, but inside (starting and ending + # with the braces) it's a parameter. The trouble begins with yacc's + # lookahead token. If we open a new scope in brace_open, then TT has + # already been read and incorrectly interpreted as TYPEID. So, we need + # to open and close scopes from within the lexer. + # Similar for the TT immediately outside the end of the function. + # + @TOKEN(r'\{') + def t_LBRACE(self, t): + self.on_lbrace_func() + return t + @TOKEN(r'\}') + def t_RBRACE(self, t): + self.on_rbrace_func() + return t + + t_STRING_LITERAL = string_literal + + # The following floating and integer constants are defined as + # functions to impose a strict order (otherwise, decimal + # is placed before the others because its regex is longer, + # and this is bad) + # + @TOKEN(floating_constant) + def t_FLOAT_CONST(self, t): + return t + + @TOKEN(hex_floating_constant) + def t_HEX_FLOAT_CONST(self, t): + return t + + @TOKEN(hex_constant) + def t_INT_CONST_HEX(self, t): + return t + + @TOKEN(bin_constant) + def t_INT_CONST_BIN(self, t): + return t + + @TOKEN(bad_octal_constant) + def t_BAD_CONST_OCT(self, t): + msg = "Invalid octal constant" + self._error(msg, t) + + @TOKEN(octal_constant) + def t_INT_CONST_OCT(self, t): + return t + + @TOKEN(decimal_constant) + def t_INT_CONST_DEC(self, t): + return t + + # Must come before bad_char_const, to prevent it from + # catching valid char constants as invalid + # + @TOKEN(multicharacter_constant) + def t_INT_CONST_CHAR(self, t): + return t + + @TOKEN(char_const) + def t_CHAR_CONST(self, t): + return t + + @TOKEN(wchar_const) + def t_WCHAR_CONST(self, t): + return t + + @TOKEN(u8char_const) + def t_U8CHAR_CONST(self, t): + return t + + @TOKEN(u16char_const) + def t_U16CHAR_CONST(self, t): + return t + + @TOKEN(u32char_const) + def t_U32CHAR_CONST(self, t): + return t + + @TOKEN(unmatched_quote) + def t_UNMATCHED_QUOTE(self, t): + msg = "Unmatched '" + self._error(msg, t) + + @TOKEN(bad_char_const) + def t_BAD_CHAR_CONST(self, t): + msg = "Invalid char constant %s" % t.value + self._error(msg, t) + + @TOKEN(wstring_literal) + def t_WSTRING_LITERAL(self, t): + return t + + @TOKEN(u8string_literal) + def t_U8STRING_LITERAL(self, t): + return t + + @TOKEN(u16string_literal) + def t_U16STRING_LITERAL(self, t): + return t + + @TOKEN(u32string_literal) + def t_U32STRING_LITERAL(self, t): + return t + + # unmatched string literals are caught by the preprocessor + + @TOKEN(bad_string_literal) + def t_BAD_STRING_LITERAL(self, t): + msg = "String contains invalid escape code" + self._error(msg, t) + + @TOKEN(identifier) + def t_ID(self, t): + t.type = self.keyword_map.get(t.value, "ID") + if t.type == 'ID' and self.type_lookup_func(t.value): + t.type = "TYPEID" + return t + + def t_error(self, t): + msg = 'Illegal character %s' % repr(t.value[0]) + self._error(msg, t) diff --git a/.venv/Lib/site-packages/pycparser/c_parser.py b/.venv/Lib/site-packages/pycparser/c_parser.py new file mode 100644 index 00000000..640a7594 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/c_parser.py @@ -0,0 +1,1936 @@ +#------------------------------------------------------------------------------ +# pycparser: c_parser.py +# +# CParser class: Parser and AST builder for the C language +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#------------------------------------------------------------------------------ +from .ply import yacc + +from . import c_ast +from .c_lexer import CLexer +from .plyparser import PLYParser, ParseError, parameterized, template +from .ast_transforms import fix_switch_cases, fix_atomic_specifiers + + +@template +class CParser(PLYParser): + def __init__( + self, + lex_optimize=True, + lexer=CLexer, + lextab='pycparser.lextab', + yacc_optimize=True, + yacctab='pycparser.yacctab', + yacc_debug=False, + taboutputdir=''): + """ Create a new CParser. + + Some arguments for controlling the debug/optimization + level of the parser are provided. The defaults are + tuned for release/performance mode. + The simple rules for using them are: + *) When tweaking CParser/CLexer, set these to False + *) When releasing a stable parser, set to True + + lex_optimize: + Set to False when you're modifying the lexer. + Otherwise, changes in the lexer won't be used, if + some lextab.py file exists. + When releasing with a stable lexer, set to True + to save the re-generation of the lexer table on + each run. + + lexer: + Set this parameter to define the lexer to use if + you're not using the default CLexer. + + lextab: + Points to the lex table that's used for optimized + mode. Only if you're modifying the lexer and want + some tests to avoid re-generating the table, make + this point to a local lex table file (that's been + earlier generated with lex_optimize=True) + + yacc_optimize: + Set to False when you're modifying the parser. + Otherwise, changes in the parser won't be used, if + some parsetab.py file exists. + When releasing with a stable parser, set to True + to save the re-generation of the parser table on + each run. + + yacctab: + Points to the yacc table that's used for optimized + mode. Only if you're modifying the parser, make + this point to a local yacc table file + + yacc_debug: + Generate a parser.out file that explains how yacc + built the parsing table from the grammar. + + taboutputdir: + Set this parameter to control the location of generated + lextab and yacctab files. + """ + self.clex = lexer( + error_func=self._lex_error_func, + on_lbrace_func=self._lex_on_lbrace_func, + on_rbrace_func=self._lex_on_rbrace_func, + type_lookup_func=self._lex_type_lookup_func) + + self.clex.build( + optimize=lex_optimize, + lextab=lextab, + outputdir=taboutputdir) + self.tokens = self.clex.tokens + + rules_with_opt = [ + 'abstract_declarator', + 'assignment_expression', + 'declaration_list', + 'declaration_specifiers_no_type', + 'designation', + 'expression', + 'identifier_list', + 'init_declarator_list', + 'id_init_declarator_list', + 'initializer_list', + 'parameter_type_list', + 'block_item_list', + 'type_qualifier_list', + 'struct_declarator_list' + ] + + for rule in rules_with_opt: + self._create_opt_rule(rule) + + self.cparser = yacc.yacc( + module=self, + start='translation_unit_or_empty', + debug=yacc_debug, + optimize=yacc_optimize, + tabmodule=yacctab, + outputdir=taboutputdir) + + # Stack of scopes for keeping track of symbols. _scope_stack[-1] is + # the current (topmost) scope. Each scope is a dictionary that + # specifies whether a name is a type. If _scope_stack[n][name] is + # True, 'name' is currently a type in the scope. If it's False, + # 'name' is used in the scope but not as a type (for instance, if we + # saw: int name; + # If 'name' is not a key in _scope_stack[n] then 'name' was not defined + # in this scope at all. + self._scope_stack = [dict()] + + # Keeps track of the last token given to yacc (the lookahead token) + self._last_yielded_token = None + + def parse(self, text, filename='', debug=False): + """ Parses C code and returns an AST. + + text: + A string containing the C source code + + filename: + Name of the file being parsed (for meaningful + error messages) + + debug: + Debug flag to YACC + """ + self.clex.filename = filename + self.clex.reset_lineno() + self._scope_stack = [dict()] + self._last_yielded_token = None + return self.cparser.parse( + input=text, + lexer=self.clex, + debug=debug) + + ######################-- PRIVATE --###################### + + def _push_scope(self): + self._scope_stack.append(dict()) + + def _pop_scope(self): + assert len(self._scope_stack) > 1 + self._scope_stack.pop() + + def _add_typedef_name(self, name, coord): + """ Add a new typedef name (ie a TYPEID) to the current scope + """ + if not self._scope_stack[-1].get(name, True): + self._parse_error( + "Typedef %r previously declared as non-typedef " + "in this scope" % name, coord) + self._scope_stack[-1][name] = True + + def _add_identifier(self, name, coord): + """ Add a new object, function, or enum member name (ie an ID) to the + current scope + """ + if self._scope_stack[-1].get(name, False): + self._parse_error( + "Non-typedef %r previously declared as typedef " + "in this scope" % name, coord) + self._scope_stack[-1][name] = False + + def _is_type_in_scope(self, name): + """ Is *name* a typedef-name in the current scope? + """ + for scope in reversed(self._scope_stack): + # If name is an identifier in this scope it shadows typedefs in + # higher scopes. + in_scope = scope.get(name) + if in_scope is not None: return in_scope + return False + + def _lex_error_func(self, msg, line, column): + self._parse_error(msg, self._coord(line, column)) + + def _lex_on_lbrace_func(self): + self._push_scope() + + def _lex_on_rbrace_func(self): + self._pop_scope() + + def _lex_type_lookup_func(self, name): + """ Looks up types that were previously defined with + typedef. + Passed to the lexer for recognizing identifiers that + are types. + """ + is_type = self._is_type_in_scope(name) + return is_type + + def _get_yacc_lookahead_token(self): + """ We need access to yacc's lookahead token in certain cases. + This is the last token yacc requested from the lexer, so we + ask the lexer. + """ + return self.clex.last_token + + # To understand what's going on here, read sections A.8.5 and + # A.8.6 of K&R2 very carefully. + # + # A C type consists of a basic type declaration, with a list + # of modifiers. For example: + # + # int *c[5]; + # + # The basic declaration here is 'int c', and the pointer and + # the array are the modifiers. + # + # Basic declarations are represented by TypeDecl (from module c_ast) and the + # modifiers are FuncDecl, PtrDecl and ArrayDecl. + # + # The standard states that whenever a new modifier is parsed, it should be + # added to the end of the list of modifiers. For example: + # + # K&R2 A.8.6.2: Array Declarators + # + # In a declaration T D where D has the form + # D1 [constant-expression-opt] + # and the type of the identifier in the declaration T D1 is + # "type-modifier T", the type of the + # identifier of D is "type-modifier array of T" + # + # This is what this method does. The declarator it receives + # can be a list of declarators ending with TypeDecl. It + # tacks the modifier to the end of this list, just before + # the TypeDecl. + # + # Additionally, the modifier may be a list itself. This is + # useful for pointers, that can come as a chain from the rule + # p_pointer. In this case, the whole modifier list is spliced + # into the new location. + def _type_modify_decl(self, decl, modifier): + """ Tacks a type modifier on a declarator, and returns + the modified declarator. + + Note: the declarator and modifier may be modified + """ + #~ print '****' + #~ decl.show(offset=3) + #~ modifier.show(offset=3) + #~ print '****' + + modifier_head = modifier + modifier_tail = modifier + + # The modifier may be a nested list. Reach its tail. + while modifier_tail.type: + modifier_tail = modifier_tail.type + + # If the decl is a basic type, just tack the modifier onto it. + if isinstance(decl, c_ast.TypeDecl): + modifier_tail.type = decl + return modifier + else: + # Otherwise, the decl is a list of modifiers. Reach + # its tail and splice the modifier onto the tail, + # pointing to the underlying basic type. + decl_tail = decl + + while not isinstance(decl_tail.type, c_ast.TypeDecl): + decl_tail = decl_tail.type + + modifier_tail.type = decl_tail.type + decl_tail.type = modifier_head + return decl + + # Due to the order in which declarators are constructed, + # they have to be fixed in order to look like a normal AST. + # + # When a declaration arrives from syntax construction, it has + # these problems: + # * The innermost TypeDecl has no type (because the basic + # type is only known at the uppermost declaration level) + # * The declaration has no variable name, since that is saved + # in the innermost TypeDecl + # * The typename of the declaration is a list of type + # specifiers, and not a node. Here, basic identifier types + # should be separated from more complex types like enums + # and structs. + # + # This method fixes these problems. + def _fix_decl_name_type(self, decl, typename): + """ Fixes a declaration. Modifies decl. + """ + # Reach the underlying basic type + # + type = decl + while not isinstance(type, c_ast.TypeDecl): + type = type.type + + decl.name = type.declname + type.quals = decl.quals[:] + + # The typename is a list of types. If any type in this + # list isn't an IdentifierType, it must be the only + # type in the list (it's illegal to declare "int enum ..") + # If all the types are basic, they're collected in the + # IdentifierType holder. + for tn in typename: + if not isinstance(tn, c_ast.IdentifierType): + if len(typename) > 1: + self._parse_error( + "Invalid multiple types specified", tn.coord) + else: + type.type = tn + return decl + + if not typename: + # Functions default to returning int + # + if not isinstance(decl.type, c_ast.FuncDecl): + self._parse_error( + "Missing type in declaration", decl.coord) + type.type = c_ast.IdentifierType( + ['int'], + coord=decl.coord) + else: + # At this point, we know that typename is a list of IdentifierType + # nodes. Concatenate all the names into a single list. + # + type.type = c_ast.IdentifierType( + [name for id in typename for name in id.names], + coord=typename[0].coord) + return decl + + def _add_declaration_specifier(self, declspec, newspec, kind, append=False): + """ Declaration specifiers are represented by a dictionary + with the entries: + * qual: a list of type qualifiers + * storage: a list of storage type qualifiers + * type: a list of type specifiers + * function: a list of function specifiers + * alignment: a list of alignment specifiers + + This method is given a declaration specifier, and a + new specifier of a given kind. + If `append` is True, the new specifier is added to the end of + the specifiers list, otherwise it's added at the beginning. + Returns the declaration specifier, with the new + specifier incorporated. + """ + spec = declspec or dict(qual=[], storage=[], type=[], function=[], alignment=[]) + + if append: + spec[kind].append(newspec) + else: + spec[kind].insert(0, newspec) + + return spec + + def _build_declarations(self, spec, decls, typedef_namespace=False): + """ Builds a list of declarations all sharing the given specifiers. + If typedef_namespace is true, each declared name is added + to the "typedef namespace", which also includes objects, + functions, and enum constants. + """ + is_typedef = 'typedef' in spec['storage'] + declarations = [] + + # Bit-fields are allowed to be unnamed. + if decls[0].get('bitsize') is not None: + pass + + # When redeclaring typedef names as identifiers in inner scopes, a + # problem can occur where the identifier gets grouped into + # spec['type'], leaving decl as None. This can only occur for the + # first declarator. + elif decls[0]['decl'] is None: + if len(spec['type']) < 2 or len(spec['type'][-1].names) != 1 or \ + not self._is_type_in_scope(spec['type'][-1].names[0]): + coord = '?' + for t in spec['type']: + if hasattr(t, 'coord'): + coord = t.coord + break + self._parse_error('Invalid declaration', coord) + + # Make this look as if it came from "direct_declarator:ID" + decls[0]['decl'] = c_ast.TypeDecl( + declname=spec['type'][-1].names[0], + type=None, + quals=None, + align=spec['alignment'], + coord=spec['type'][-1].coord) + # Remove the "new" type's name from the end of spec['type'] + del spec['type'][-1] + + # A similar problem can occur where the declaration ends up looking + # like an abstract declarator. Give it a name if this is the case. + elif not isinstance(decls[0]['decl'], ( + c_ast.Enum, c_ast.Struct, c_ast.Union, c_ast.IdentifierType)): + decls_0_tail = decls[0]['decl'] + while not isinstance(decls_0_tail, c_ast.TypeDecl): + decls_0_tail = decls_0_tail.type + if decls_0_tail.declname is None: + decls_0_tail.declname = spec['type'][-1].names[0] + del spec['type'][-1] + + for decl in decls: + assert decl['decl'] is not None + if is_typedef: + declaration = c_ast.Typedef( + name=None, + quals=spec['qual'], + storage=spec['storage'], + type=decl['decl'], + coord=decl['decl'].coord) + else: + declaration = c_ast.Decl( + name=None, + quals=spec['qual'], + align=spec['alignment'], + storage=spec['storage'], + funcspec=spec['function'], + type=decl['decl'], + init=decl.get('init'), + bitsize=decl.get('bitsize'), + coord=decl['decl'].coord) + + if isinstance(declaration.type, ( + c_ast.Enum, c_ast.Struct, c_ast.Union, + c_ast.IdentifierType)): + fixed_decl = declaration + else: + fixed_decl = self._fix_decl_name_type(declaration, spec['type']) + + # Add the type name defined by typedef to a + # symbol table (for usage in the lexer) + if typedef_namespace: + if is_typedef: + self._add_typedef_name(fixed_decl.name, fixed_decl.coord) + else: + self._add_identifier(fixed_decl.name, fixed_decl.coord) + + fixed_decl = fix_atomic_specifiers(fixed_decl) + declarations.append(fixed_decl) + + return declarations + + def _build_function_definition(self, spec, decl, param_decls, body): + """ Builds a function definition. + """ + if 'typedef' in spec['storage']: + self._parse_error("Invalid typedef", decl.coord) + + declaration = self._build_declarations( + spec=spec, + decls=[dict(decl=decl, init=None)], + typedef_namespace=True)[0] + + return c_ast.FuncDef( + decl=declaration, + param_decls=param_decls, + body=body, + coord=decl.coord) + + def _select_struct_union_class(self, token): + """ Given a token (either STRUCT or UNION), selects the + appropriate AST class. + """ + if token == 'struct': + return c_ast.Struct + else: + return c_ast.Union + + ## + ## Precedence and associativity of operators + ## + # If this changes, c_generator.CGenerator.precedence_map needs to change as + # well + precedence = ( + ('left', 'LOR'), + ('left', 'LAND'), + ('left', 'OR'), + ('left', 'XOR'), + ('left', 'AND'), + ('left', 'EQ', 'NE'), + ('left', 'GT', 'GE', 'LT', 'LE'), + ('left', 'RSHIFT', 'LSHIFT'), + ('left', 'PLUS', 'MINUS'), + ('left', 'TIMES', 'DIVIDE', 'MOD') + ) + + ## + ## Grammar productions + ## Implementation of the BNF defined in K&R2 A.13 + ## + + # Wrapper around a translation unit, to allow for empty input. + # Not strictly part of the C99 Grammar, but useful in practice. + def p_translation_unit_or_empty(self, p): + """ translation_unit_or_empty : translation_unit + | empty + """ + if p[1] is None: + p[0] = c_ast.FileAST([]) + else: + p[0] = c_ast.FileAST(p[1]) + + def p_translation_unit_1(self, p): + """ translation_unit : external_declaration + """ + # Note: external_declaration is already a list + p[0] = p[1] + + def p_translation_unit_2(self, p): + """ translation_unit : translation_unit external_declaration + """ + p[1].extend(p[2]) + p[0] = p[1] + + # Declarations always come as lists (because they can be + # several in one line), so we wrap the function definition + # into a list as well, to make the return value of + # external_declaration homogeneous. + def p_external_declaration_1(self, p): + """ external_declaration : function_definition + """ + p[0] = [p[1]] + + def p_external_declaration_2(self, p): + """ external_declaration : declaration + """ + p[0] = p[1] + + def p_external_declaration_3(self, p): + """ external_declaration : pp_directive + | pppragma_directive + """ + p[0] = [p[1]] + + def p_external_declaration_4(self, p): + """ external_declaration : SEMI + """ + p[0] = [] + + def p_external_declaration_5(self, p): + """ external_declaration : static_assert + """ + p[0] = p[1] + + def p_static_assert_declaration(self, p): + """ static_assert : _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN + | _STATIC_ASSERT LPAREN constant_expression RPAREN + """ + if len(p) == 5: + p[0] = [c_ast.StaticAssert(p[3], None, self._token_coord(p, 1))] + else: + p[0] = [c_ast.StaticAssert(p[3], p[5], self._token_coord(p, 1))] + + def p_pp_directive(self, p): + """ pp_directive : PPHASH + """ + self._parse_error('Directives not supported yet', + self._token_coord(p, 1)) + + def p_pppragma_directive(self, p): + """ pppragma_directive : PPPRAGMA + | PPPRAGMA PPPRAGMASTR + """ + if len(p) == 3: + p[0] = c_ast.Pragma(p[2], self._token_coord(p, 2)) + else: + p[0] = c_ast.Pragma("", self._token_coord(p, 1)) + + # In function definitions, the declarator can be followed by + # a declaration list, for old "K&R style" function definitios. + def p_function_definition_1(self, p): + """ function_definition : id_declarator declaration_list_opt compound_statement + """ + # no declaration specifiers - 'int' becomes the default type + spec = dict( + qual=[], + alignment=[], + storage=[], + type=[c_ast.IdentifierType(['int'], + coord=self._token_coord(p, 1))], + function=[]) + + p[0] = self._build_function_definition( + spec=spec, + decl=p[1], + param_decls=p[2], + body=p[3]) + + def p_function_definition_2(self, p): + """ function_definition : declaration_specifiers id_declarator declaration_list_opt compound_statement + """ + spec = p[1] + + p[0] = self._build_function_definition( + spec=spec, + decl=p[2], + param_decls=p[3], + body=p[4]) + + # Note, according to C18 A.2.2 6.7.10 static_assert-declaration _Static_assert + # is a declaration, not a statement. We additionally recognise it as a statement + # to fix parsing of _Static_assert inside the functions. + # + def p_statement(self, p): + """ statement : labeled_statement + | expression_statement + | compound_statement + | selection_statement + | iteration_statement + | jump_statement + | pppragma_directive + | static_assert + """ + p[0] = p[1] + + # A pragma is generally considered a decorator rather than an actual + # statement. Still, for the purposes of analyzing an abstract syntax tree of + # C code, pragma's should not be ignored and were previously treated as a + # statement. This presents a problem for constructs that take a statement + # such as labeled_statements, selection_statements, and + # iteration_statements, causing a misleading structure in the AST. For + # example, consider the following C code. + # + # for (int i = 0; i < 3; i++) + # #pragma omp critical + # sum += 1; + # + # This code will compile and execute "sum += 1;" as the body of the for + # loop. Previous implementations of PyCParser would render the AST for this + # block of code as follows: + # + # For: + # DeclList: + # Decl: i, [], [], [] + # TypeDecl: i, [] + # IdentifierType: ['int'] + # Constant: int, 0 + # BinaryOp: < + # ID: i + # Constant: int, 3 + # UnaryOp: p++ + # ID: i + # Pragma: omp critical + # Assignment: += + # ID: sum + # Constant: int, 1 + # + # This AST misleadingly takes the Pragma as the body of the loop and the + # assignment then becomes a sibling of the loop. + # + # To solve edge cases like these, the pragmacomp_or_statement rule groups + # a pragma and its following statement (which would otherwise be orphaned) + # using a compound block, effectively turning the above code into: + # + # for (int i = 0; i < 3; i++) { + # #pragma omp critical + # sum += 1; + # } + def p_pragmacomp_or_statement(self, p): + """ pragmacomp_or_statement : pppragma_directive statement + | statement + """ + if isinstance(p[1], c_ast.Pragma) and len(p) == 3: + p[0] = c_ast.Compound( + block_items=[p[1], p[2]], + coord=self._token_coord(p, 1)) + else: + p[0] = p[1] + + # In C, declarations can come several in a line: + # int x, *px, romulo = 5; + # + # However, for the AST, we will split them to separate Decl + # nodes. + # + # This rule splits its declarations and always returns a list + # of Decl nodes, even if it's one element long. + # + def p_decl_body(self, p): + """ decl_body : declaration_specifiers init_declarator_list_opt + | declaration_specifiers_no_type id_init_declarator_list_opt + """ + spec = p[1] + + # p[2] (init_declarator_list_opt) is either a list or None + # + if p[2] is None: + # By the standard, you must have at least one declarator unless + # declaring a structure tag, a union tag, or the members of an + # enumeration. + # + ty = spec['type'] + s_u_or_e = (c_ast.Struct, c_ast.Union, c_ast.Enum) + if len(ty) == 1 and isinstance(ty[0], s_u_or_e): + decls = [c_ast.Decl( + name=None, + quals=spec['qual'], + align=spec['alignment'], + storage=spec['storage'], + funcspec=spec['function'], + type=ty[0], + init=None, + bitsize=None, + coord=ty[0].coord)] + + # However, this case can also occur on redeclared identifiers in + # an inner scope. The trouble is that the redeclared type's name + # gets grouped into declaration_specifiers; _build_declarations + # compensates for this. + # + else: + decls = self._build_declarations( + spec=spec, + decls=[dict(decl=None, init=None)], + typedef_namespace=True) + + else: + decls = self._build_declarations( + spec=spec, + decls=p[2], + typedef_namespace=True) + + p[0] = decls + + # The declaration has been split to a decl_body sub-rule and + # SEMI, because having them in a single rule created a problem + # for defining typedefs. + # + # If a typedef line was directly followed by a line using the + # type defined with the typedef, the type would not be + # recognized. This is because to reduce the declaration rule, + # the parser's lookahead asked for the token after SEMI, which + # was the type from the next line, and the lexer had no chance + # to see the updated type symbol table. + # + # Splitting solves this problem, because after seeing SEMI, + # the parser reduces decl_body, which actually adds the new + # type into the table to be seen by the lexer before the next + # line is reached. + def p_declaration(self, p): + """ declaration : decl_body SEMI + """ + p[0] = p[1] + + # Since each declaration is a list of declarations, this + # rule will combine all the declarations and return a single + # list + # + def p_declaration_list(self, p): + """ declaration_list : declaration + | declaration_list declaration + """ + p[0] = p[1] if len(p) == 2 else p[1] + p[2] + + # To know when declaration-specifiers end and declarators begin, + # we require declaration-specifiers to have at least one + # type-specifier, and disallow typedef-names after we've seen any + # type-specifier. These are both required by the spec. + # + def p_declaration_specifiers_no_type_1(self, p): + """ declaration_specifiers_no_type : type_qualifier declaration_specifiers_no_type_opt + """ + p[0] = self._add_declaration_specifier(p[2], p[1], 'qual') + + def p_declaration_specifiers_no_type_2(self, p): + """ declaration_specifiers_no_type : storage_class_specifier declaration_specifiers_no_type_opt + """ + p[0] = self._add_declaration_specifier(p[2], p[1], 'storage') + + def p_declaration_specifiers_no_type_3(self, p): + """ declaration_specifiers_no_type : function_specifier declaration_specifiers_no_type_opt + """ + p[0] = self._add_declaration_specifier(p[2], p[1], 'function') + + # Without this, `typedef _Atomic(T) U` will parse incorrectly because the + # _Atomic qualifier will match, instead of the specifier. + def p_declaration_specifiers_no_type_4(self, p): + """ declaration_specifiers_no_type : atomic_specifier declaration_specifiers_no_type_opt + """ + p[0] = self._add_declaration_specifier(p[2], p[1], 'type') + + def p_declaration_specifiers_no_type_5(self, p): + """ declaration_specifiers_no_type : alignment_specifier declaration_specifiers_no_type_opt + """ + p[0] = self._add_declaration_specifier(p[2], p[1], 'alignment') + + def p_declaration_specifiers_1(self, p): + """ declaration_specifiers : declaration_specifiers type_qualifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'qual', append=True) + + def p_declaration_specifiers_2(self, p): + """ declaration_specifiers : declaration_specifiers storage_class_specifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'storage', append=True) + + def p_declaration_specifiers_3(self, p): + """ declaration_specifiers : declaration_specifiers function_specifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'function', append=True) + + def p_declaration_specifiers_4(self, p): + """ declaration_specifiers : declaration_specifiers type_specifier_no_typeid + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True) + + def p_declaration_specifiers_5(self, p): + """ declaration_specifiers : type_specifier + """ + p[0] = self._add_declaration_specifier(None, p[1], 'type') + + def p_declaration_specifiers_6(self, p): + """ declaration_specifiers : declaration_specifiers_no_type type_specifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True) + + def p_declaration_specifiers_7(self, p): + """ declaration_specifiers : declaration_specifiers alignment_specifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'alignment', append=True) + + def p_storage_class_specifier(self, p): + """ storage_class_specifier : AUTO + | REGISTER + | STATIC + | EXTERN + | TYPEDEF + | _THREAD_LOCAL + """ + p[0] = p[1] + + def p_function_specifier(self, p): + """ function_specifier : INLINE + | _NORETURN + """ + p[0] = p[1] + + def p_type_specifier_no_typeid(self, p): + """ type_specifier_no_typeid : VOID + | _BOOL + | CHAR + | SHORT + | INT + | LONG + | FLOAT + | DOUBLE + | _COMPLEX + | SIGNED + | UNSIGNED + | __INT128 + """ + p[0] = c_ast.IdentifierType([p[1]], coord=self._token_coord(p, 1)) + + def p_type_specifier(self, p): + """ type_specifier : typedef_name + | enum_specifier + | struct_or_union_specifier + | type_specifier_no_typeid + | atomic_specifier + """ + p[0] = p[1] + + # See section 6.7.2.4 of the C11 standard. + def p_atomic_specifier(self, p): + """ atomic_specifier : _ATOMIC LPAREN type_name RPAREN + """ + typ = p[3] + typ.quals.append('_Atomic') + p[0] = typ + + def p_type_qualifier(self, p): + """ type_qualifier : CONST + | RESTRICT + | VOLATILE + | _ATOMIC + """ + p[0] = p[1] + + def p_init_declarator_list(self, p): + """ init_declarator_list : init_declarator + | init_declarator_list COMMA init_declarator + """ + p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]] + + # Returns a {decl= : init=} dictionary + # If there's no initializer, uses None + # + def p_init_declarator(self, p): + """ init_declarator : declarator + | declarator EQUALS initializer + """ + p[0] = dict(decl=p[1], init=(p[3] if len(p) > 2 else None)) + + def p_id_init_declarator_list(self, p): + """ id_init_declarator_list : id_init_declarator + | id_init_declarator_list COMMA init_declarator + """ + p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]] + + def p_id_init_declarator(self, p): + """ id_init_declarator : id_declarator + | id_declarator EQUALS initializer + """ + p[0] = dict(decl=p[1], init=(p[3] if len(p) > 2 else None)) + + # Require at least one type specifier in a specifier-qualifier-list + # + def p_specifier_qualifier_list_1(self, p): + """ specifier_qualifier_list : specifier_qualifier_list type_specifier_no_typeid + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True) + + def p_specifier_qualifier_list_2(self, p): + """ specifier_qualifier_list : specifier_qualifier_list type_qualifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'qual', append=True) + + def p_specifier_qualifier_list_3(self, p): + """ specifier_qualifier_list : type_specifier + """ + p[0] = self._add_declaration_specifier(None, p[1], 'type') + + def p_specifier_qualifier_list_4(self, p): + """ specifier_qualifier_list : type_qualifier_list type_specifier + """ + p[0] = dict(qual=p[1], alignment=[], storage=[], type=[p[2]], function=[]) + + def p_specifier_qualifier_list_5(self, p): + """ specifier_qualifier_list : alignment_specifier + """ + p[0] = dict(qual=[], alignment=[p[1]], storage=[], type=[], function=[]) + + def p_specifier_qualifier_list_6(self, p): + """ specifier_qualifier_list : specifier_qualifier_list alignment_specifier + """ + p[0] = self._add_declaration_specifier(p[1], p[2], 'alignment') + + # TYPEID is allowed here (and in other struct/enum related tag names), because + # struct/enum tags reside in their own namespace and can be named the same as types + # + def p_struct_or_union_specifier_1(self, p): + """ struct_or_union_specifier : struct_or_union ID + | struct_or_union TYPEID + """ + klass = self._select_struct_union_class(p[1]) + # None means no list of members + p[0] = klass( + name=p[2], + decls=None, + coord=self._token_coord(p, 2)) + + def p_struct_or_union_specifier_2(self, p): + """ struct_or_union_specifier : struct_or_union brace_open struct_declaration_list brace_close + | struct_or_union brace_open brace_close + """ + klass = self._select_struct_union_class(p[1]) + if len(p) == 4: + # Empty sequence means an empty list of members + p[0] = klass( + name=None, + decls=[], + coord=self._token_coord(p, 2)) + else: + p[0] = klass( + name=None, + decls=p[3], + coord=self._token_coord(p, 2)) + + + def p_struct_or_union_specifier_3(self, p): + """ struct_or_union_specifier : struct_or_union ID brace_open struct_declaration_list brace_close + | struct_or_union ID brace_open brace_close + | struct_or_union TYPEID brace_open struct_declaration_list brace_close + | struct_or_union TYPEID brace_open brace_close + """ + klass = self._select_struct_union_class(p[1]) + if len(p) == 5: + # Empty sequence means an empty list of members + p[0] = klass( + name=p[2], + decls=[], + coord=self._token_coord(p, 2)) + else: + p[0] = klass( + name=p[2], + decls=p[4], + coord=self._token_coord(p, 2)) + + def p_struct_or_union(self, p): + """ struct_or_union : STRUCT + | UNION + """ + p[0] = p[1] + + # Combine all declarations into a single list + # + def p_struct_declaration_list(self, p): + """ struct_declaration_list : struct_declaration + | struct_declaration_list struct_declaration + """ + if len(p) == 2: + p[0] = p[1] or [] + else: + p[0] = p[1] + (p[2] or []) + + def p_struct_declaration_1(self, p): + """ struct_declaration : specifier_qualifier_list struct_declarator_list_opt SEMI + """ + spec = p[1] + assert 'typedef' not in spec['storage'] + + if p[2] is not None: + decls = self._build_declarations( + spec=spec, + decls=p[2]) + + elif len(spec['type']) == 1: + # Anonymous struct/union, gcc extension, C1x feature. + # Although the standard only allows structs/unions here, I see no + # reason to disallow other types since some compilers have typedefs + # here, and pycparser isn't about rejecting all invalid code. + # + node = spec['type'][0] + if isinstance(node, c_ast.Node): + decl_type = node + else: + decl_type = c_ast.IdentifierType(node) + + decls = self._build_declarations( + spec=spec, + decls=[dict(decl=decl_type)]) + + else: + # Structure/union members can have the same names as typedefs. + # The trouble is that the member's name gets grouped into + # specifier_qualifier_list; _build_declarations compensates. + # + decls = self._build_declarations( + spec=spec, + decls=[dict(decl=None, init=None)]) + + p[0] = decls + + def p_struct_declaration_2(self, p): + """ struct_declaration : SEMI + """ + p[0] = None + + def p_struct_declaration_3(self, p): + """ struct_declaration : pppragma_directive + """ + p[0] = [p[1]] + + def p_struct_declarator_list(self, p): + """ struct_declarator_list : struct_declarator + | struct_declarator_list COMMA struct_declarator + """ + p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]] + + # struct_declarator passes up a dict with the keys: decl (for + # the underlying declarator) and bitsize (for the bitsize) + # + def p_struct_declarator_1(self, p): + """ struct_declarator : declarator + """ + p[0] = {'decl': p[1], 'bitsize': None} + + def p_struct_declarator_2(self, p): + """ struct_declarator : declarator COLON constant_expression + | COLON constant_expression + """ + if len(p) > 3: + p[0] = {'decl': p[1], 'bitsize': p[3]} + else: + p[0] = {'decl': c_ast.TypeDecl(None, None, None, None), 'bitsize': p[2]} + + def p_enum_specifier_1(self, p): + """ enum_specifier : ENUM ID + | ENUM TYPEID + """ + p[0] = c_ast.Enum(p[2], None, self._token_coord(p, 1)) + + def p_enum_specifier_2(self, p): + """ enum_specifier : ENUM brace_open enumerator_list brace_close + """ + p[0] = c_ast.Enum(None, p[3], self._token_coord(p, 1)) + + def p_enum_specifier_3(self, p): + """ enum_specifier : ENUM ID brace_open enumerator_list brace_close + | ENUM TYPEID brace_open enumerator_list brace_close + """ + p[0] = c_ast.Enum(p[2], p[4], self._token_coord(p, 1)) + + def p_enumerator_list(self, p): + """ enumerator_list : enumerator + | enumerator_list COMMA + | enumerator_list COMMA enumerator + """ + if len(p) == 2: + p[0] = c_ast.EnumeratorList([p[1]], p[1].coord) + elif len(p) == 3: + p[0] = p[1] + else: + p[1].enumerators.append(p[3]) + p[0] = p[1] + + def p_alignment_specifier(self, p): + """ alignment_specifier : _ALIGNAS LPAREN type_name RPAREN + | _ALIGNAS LPAREN constant_expression RPAREN + """ + p[0] = c_ast.Alignas(p[3], self._token_coord(p, 1)) + + def p_enumerator(self, p): + """ enumerator : ID + | ID EQUALS constant_expression + """ + if len(p) == 2: + enumerator = c_ast.Enumerator( + p[1], None, + self._token_coord(p, 1)) + else: + enumerator = c_ast.Enumerator( + p[1], p[3], + self._token_coord(p, 1)) + self._add_identifier(enumerator.name, enumerator.coord) + + p[0] = enumerator + + def p_declarator(self, p): + """ declarator : id_declarator + | typeid_declarator + """ + p[0] = p[1] + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_xxx_declarator_1(self, p): + """ xxx_declarator : direct_xxx_declarator + """ + p[0] = p[1] + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_xxx_declarator_2(self, p): + """ xxx_declarator : pointer direct_xxx_declarator + """ + p[0] = self._type_modify_decl(p[2], p[1]) + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_direct_xxx_declarator_1(self, p): + """ direct_xxx_declarator : yyy + """ + p[0] = c_ast.TypeDecl( + declname=p[1], + type=None, + quals=None, + align=None, + coord=self._token_coord(p, 1)) + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID')) + def p_direct_xxx_declarator_2(self, p): + """ direct_xxx_declarator : LPAREN xxx_declarator RPAREN + """ + p[0] = p[2] + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_direct_xxx_declarator_3(self, p): + """ direct_xxx_declarator : direct_xxx_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET + """ + quals = (p[3] if len(p) > 5 else []) or [] + # Accept dimension qualifiers + # Per C99 6.7.5.3 p7 + arr = c_ast.ArrayDecl( + type=None, + dim=p[4] if len(p) > 5 else p[3], + dim_quals=quals, + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=arr) + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_direct_xxx_declarator_4(self, p): + """ direct_xxx_declarator : direct_xxx_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET + | direct_xxx_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET + """ + # Using slice notation for PLY objects doesn't work in Python 3 for the + # version of PLY embedded with pycparser; see PLY Google Code issue 30. + # Work around that here by listing the two elements separately. + listed_quals = [item if isinstance(item, list) else [item] + for item in [p[3],p[4]]] + dim_quals = [qual for sublist in listed_quals for qual in sublist + if qual is not None] + arr = c_ast.ArrayDecl( + type=None, + dim=p[5], + dim_quals=dim_quals, + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=arr) + + # Special for VLAs + # + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_direct_xxx_declarator_5(self, p): + """ direct_xxx_declarator : direct_xxx_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET + """ + arr = c_ast.ArrayDecl( + type=None, + dim=c_ast.ID(p[4], self._token_coord(p, 4)), + dim_quals=p[3] if p[3] is not None else [], + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=arr) + + @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID')) + def p_direct_xxx_declarator_6(self, p): + """ direct_xxx_declarator : direct_xxx_declarator LPAREN parameter_type_list RPAREN + | direct_xxx_declarator LPAREN identifier_list_opt RPAREN + """ + func = c_ast.FuncDecl( + args=p[3], + type=None, + coord=p[1].coord) + + # To see why _get_yacc_lookahead_token is needed, consider: + # typedef char TT; + # void foo(int TT) { TT = 10; } + # Outside the function, TT is a typedef, but inside (starting and + # ending with the braces) it's a parameter. The trouble begins with + # yacc's lookahead token. We don't know if we're declaring or + # defining a function until we see LBRACE, but if we wait for yacc to + # trigger a rule on that token, then TT will have already been read + # and incorrectly interpreted as TYPEID. We need to add the + # parameters to the scope the moment the lexer sees LBRACE. + # + if self._get_yacc_lookahead_token().type == "LBRACE": + if func.args is not None: + for param in func.args.params: + if isinstance(param, c_ast.EllipsisParam): break + self._add_identifier(param.name, param.coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=func) + + def p_pointer(self, p): + """ pointer : TIMES type_qualifier_list_opt + | TIMES type_qualifier_list_opt pointer + """ + coord = self._token_coord(p, 1) + # Pointer decls nest from inside out. This is important when different + # levels have different qualifiers. For example: + # + # char * const * p; + # + # Means "pointer to const pointer to char" + # + # While: + # + # char ** const p; + # + # Means "const pointer to pointer to char" + # + # So when we construct PtrDecl nestings, the leftmost pointer goes in + # as the most nested type. + nested_type = c_ast.PtrDecl(quals=p[2] or [], type=None, coord=coord) + if len(p) > 3: + tail_type = p[3] + while tail_type.type is not None: + tail_type = tail_type.type + tail_type.type = nested_type + p[0] = p[3] + else: + p[0] = nested_type + + def p_type_qualifier_list(self, p): + """ type_qualifier_list : type_qualifier + | type_qualifier_list type_qualifier + """ + p[0] = [p[1]] if len(p) == 2 else p[1] + [p[2]] + + def p_parameter_type_list(self, p): + """ parameter_type_list : parameter_list + | parameter_list COMMA ELLIPSIS + """ + if len(p) > 2: + p[1].params.append(c_ast.EllipsisParam(self._token_coord(p, 3))) + + p[0] = p[1] + + def p_parameter_list(self, p): + """ parameter_list : parameter_declaration + | parameter_list COMMA parameter_declaration + """ + if len(p) == 2: # single parameter + p[0] = c_ast.ParamList([p[1]], p[1].coord) + else: + p[1].params.append(p[3]) + p[0] = p[1] + + # From ISO/IEC 9899:TC2, 6.7.5.3.11: + # "If, in a parameter declaration, an identifier can be treated either + # as a typedef name or as a parameter name, it shall be taken as a + # typedef name." + # + # Inside a parameter declaration, once we've reduced declaration specifiers, + # if we shift in an LPAREN and see a TYPEID, it could be either an abstract + # declarator or a declarator nested inside parens. This rule tells us to + # always treat it as an abstract declarator. Therefore, we only accept + # `id_declarator`s and `typeid_noparen_declarator`s. + def p_parameter_declaration_1(self, p): + """ parameter_declaration : declaration_specifiers id_declarator + | declaration_specifiers typeid_noparen_declarator + """ + spec = p[1] + if not spec['type']: + spec['type'] = [c_ast.IdentifierType(['int'], + coord=self._token_coord(p, 1))] + p[0] = self._build_declarations( + spec=spec, + decls=[dict(decl=p[2])])[0] + + def p_parameter_declaration_2(self, p): + """ parameter_declaration : declaration_specifiers abstract_declarator_opt + """ + spec = p[1] + if not spec['type']: + spec['type'] = [c_ast.IdentifierType(['int'], + coord=self._token_coord(p, 1))] + + # Parameters can have the same names as typedefs. The trouble is that + # the parameter's name gets grouped into declaration_specifiers, making + # it look like an old-style declaration; compensate. + # + if len(spec['type']) > 1 and len(spec['type'][-1].names) == 1 and \ + self._is_type_in_scope(spec['type'][-1].names[0]): + decl = self._build_declarations( + spec=spec, + decls=[dict(decl=p[2], init=None)])[0] + + # This truly is an old-style parameter declaration + # + else: + decl = c_ast.Typename( + name='', + quals=spec['qual'], + align=None, + type=p[2] or c_ast.TypeDecl(None, None, None, None), + coord=self._token_coord(p, 2)) + typename = spec['type'] + decl = self._fix_decl_name_type(decl, typename) + + p[0] = decl + + def p_identifier_list(self, p): + """ identifier_list : identifier + | identifier_list COMMA identifier + """ + if len(p) == 2: # single parameter + p[0] = c_ast.ParamList([p[1]], p[1].coord) + else: + p[1].params.append(p[3]) + p[0] = p[1] + + def p_initializer_1(self, p): + """ initializer : assignment_expression + """ + p[0] = p[1] + + def p_initializer_2(self, p): + """ initializer : brace_open initializer_list_opt brace_close + | brace_open initializer_list COMMA brace_close + """ + if p[2] is None: + p[0] = c_ast.InitList([], self._token_coord(p, 1)) + else: + p[0] = p[2] + + def p_initializer_list(self, p): + """ initializer_list : designation_opt initializer + | initializer_list COMMA designation_opt initializer + """ + if len(p) == 3: # single initializer + init = p[2] if p[1] is None else c_ast.NamedInitializer(p[1], p[2]) + p[0] = c_ast.InitList([init], p[2].coord) + else: + init = p[4] if p[3] is None else c_ast.NamedInitializer(p[3], p[4]) + p[1].exprs.append(init) + p[0] = p[1] + + def p_designation(self, p): + """ designation : designator_list EQUALS + """ + p[0] = p[1] + + # Designators are represented as a list of nodes, in the order in which + # they're written in the code. + # + def p_designator_list(self, p): + """ designator_list : designator + | designator_list designator + """ + p[0] = [p[1]] if len(p) == 2 else p[1] + [p[2]] + + def p_designator(self, p): + """ designator : LBRACKET constant_expression RBRACKET + | PERIOD identifier + """ + p[0] = p[2] + + def p_type_name(self, p): + """ type_name : specifier_qualifier_list abstract_declarator_opt + """ + typename = c_ast.Typename( + name='', + quals=p[1]['qual'][:], + align=None, + type=p[2] or c_ast.TypeDecl(None, None, None, None), + coord=self._token_coord(p, 2)) + + p[0] = self._fix_decl_name_type(typename, p[1]['type']) + + def p_abstract_declarator_1(self, p): + """ abstract_declarator : pointer + """ + dummytype = c_ast.TypeDecl(None, None, None, None) + p[0] = self._type_modify_decl( + decl=dummytype, + modifier=p[1]) + + def p_abstract_declarator_2(self, p): + """ abstract_declarator : pointer direct_abstract_declarator + """ + p[0] = self._type_modify_decl(p[2], p[1]) + + def p_abstract_declarator_3(self, p): + """ abstract_declarator : direct_abstract_declarator + """ + p[0] = p[1] + + # Creating and using direct_abstract_declarator_opt here + # instead of listing both direct_abstract_declarator and the + # lack of it in the beginning of _1 and _2 caused two + # shift/reduce errors. + # + def p_direct_abstract_declarator_1(self, p): + """ direct_abstract_declarator : LPAREN abstract_declarator RPAREN """ + p[0] = p[2] + + def p_direct_abstract_declarator_2(self, p): + """ direct_abstract_declarator : direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET + """ + arr = c_ast.ArrayDecl( + type=None, + dim=p[3], + dim_quals=[], + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=arr) + + def p_direct_abstract_declarator_3(self, p): + """ direct_abstract_declarator : LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET + """ + quals = (p[2] if len(p) > 4 else []) or [] + p[0] = c_ast.ArrayDecl( + type=c_ast.TypeDecl(None, None, None, None), + dim=p[3] if len(p) > 4 else p[2], + dim_quals=quals, + coord=self._token_coord(p, 1)) + + def p_direct_abstract_declarator_4(self, p): + """ direct_abstract_declarator : direct_abstract_declarator LBRACKET TIMES RBRACKET + """ + arr = c_ast.ArrayDecl( + type=None, + dim=c_ast.ID(p[3], self._token_coord(p, 3)), + dim_quals=[], + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=arr) + + def p_direct_abstract_declarator_5(self, p): + """ direct_abstract_declarator : LBRACKET TIMES RBRACKET + """ + p[0] = c_ast.ArrayDecl( + type=c_ast.TypeDecl(None, None, None, None), + dim=c_ast.ID(p[3], self._token_coord(p, 3)), + dim_quals=[], + coord=self._token_coord(p, 1)) + + def p_direct_abstract_declarator_6(self, p): + """ direct_abstract_declarator : direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN + """ + func = c_ast.FuncDecl( + args=p[3], + type=None, + coord=p[1].coord) + + p[0] = self._type_modify_decl(decl=p[1], modifier=func) + + def p_direct_abstract_declarator_7(self, p): + """ direct_abstract_declarator : LPAREN parameter_type_list_opt RPAREN + """ + p[0] = c_ast.FuncDecl( + args=p[2], + type=c_ast.TypeDecl(None, None, None, None), + coord=self._token_coord(p, 1)) + + # declaration is a list, statement isn't. To make it consistent, block_item + # will always be a list + # + def p_block_item(self, p): + """ block_item : declaration + | statement + """ + p[0] = p[1] if isinstance(p[1], list) else [p[1]] + + # Since we made block_item a list, this just combines lists + # + def p_block_item_list(self, p): + """ block_item_list : block_item + | block_item_list block_item + """ + # Empty block items (plain ';') produce [None], so ignore them + p[0] = p[1] if (len(p) == 2 or p[2] == [None]) else p[1] + p[2] + + def p_compound_statement_1(self, p): + """ compound_statement : brace_open block_item_list_opt brace_close """ + p[0] = c_ast.Compound( + block_items=p[2], + coord=self._token_coord(p, 1)) + + def p_labeled_statement_1(self, p): + """ labeled_statement : ID COLON pragmacomp_or_statement """ + p[0] = c_ast.Label(p[1], p[3], self._token_coord(p, 1)) + + def p_labeled_statement_2(self, p): + """ labeled_statement : CASE constant_expression COLON pragmacomp_or_statement """ + p[0] = c_ast.Case(p[2], [p[4]], self._token_coord(p, 1)) + + def p_labeled_statement_3(self, p): + """ labeled_statement : DEFAULT COLON pragmacomp_or_statement """ + p[0] = c_ast.Default([p[3]], self._token_coord(p, 1)) + + def p_selection_statement_1(self, p): + """ selection_statement : IF LPAREN expression RPAREN pragmacomp_or_statement """ + p[0] = c_ast.If(p[3], p[5], None, self._token_coord(p, 1)) + + def p_selection_statement_2(self, p): + """ selection_statement : IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement """ + p[0] = c_ast.If(p[3], p[5], p[7], self._token_coord(p, 1)) + + def p_selection_statement_3(self, p): + """ selection_statement : SWITCH LPAREN expression RPAREN pragmacomp_or_statement """ + p[0] = fix_switch_cases( + c_ast.Switch(p[3], p[5], self._token_coord(p, 1))) + + def p_iteration_statement_1(self, p): + """ iteration_statement : WHILE LPAREN expression RPAREN pragmacomp_or_statement """ + p[0] = c_ast.While(p[3], p[5], self._token_coord(p, 1)) + + def p_iteration_statement_2(self, p): + """ iteration_statement : DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI """ + p[0] = c_ast.DoWhile(p[5], p[2], self._token_coord(p, 1)) + + def p_iteration_statement_3(self, p): + """ iteration_statement : FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement """ + p[0] = c_ast.For(p[3], p[5], p[7], p[9], self._token_coord(p, 1)) + + def p_iteration_statement_4(self, p): + """ iteration_statement : FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement """ + p[0] = c_ast.For(c_ast.DeclList(p[3], self._token_coord(p, 1)), + p[4], p[6], p[8], self._token_coord(p, 1)) + + def p_jump_statement_1(self, p): + """ jump_statement : GOTO ID SEMI """ + p[0] = c_ast.Goto(p[2], self._token_coord(p, 1)) + + def p_jump_statement_2(self, p): + """ jump_statement : BREAK SEMI """ + p[0] = c_ast.Break(self._token_coord(p, 1)) + + def p_jump_statement_3(self, p): + """ jump_statement : CONTINUE SEMI """ + p[0] = c_ast.Continue(self._token_coord(p, 1)) + + def p_jump_statement_4(self, p): + """ jump_statement : RETURN expression SEMI + | RETURN SEMI + """ + p[0] = c_ast.Return(p[2] if len(p) == 4 else None, self._token_coord(p, 1)) + + def p_expression_statement(self, p): + """ expression_statement : expression_opt SEMI """ + if p[1] is None: + p[0] = c_ast.EmptyStatement(self._token_coord(p, 2)) + else: + p[0] = p[1] + + def p_expression(self, p): + """ expression : assignment_expression + | expression COMMA assignment_expression + """ + if len(p) == 2: + p[0] = p[1] + else: + if not isinstance(p[1], c_ast.ExprList): + p[1] = c_ast.ExprList([p[1]], p[1].coord) + + p[1].exprs.append(p[3]) + p[0] = p[1] + + def p_parenthesized_compound_expression(self, p): + """ assignment_expression : LPAREN compound_statement RPAREN """ + p[0] = p[2] + + def p_typedef_name(self, p): + """ typedef_name : TYPEID """ + p[0] = c_ast.IdentifierType([p[1]], coord=self._token_coord(p, 1)) + + def p_assignment_expression(self, p): + """ assignment_expression : conditional_expression + | unary_expression assignment_operator assignment_expression + """ + if len(p) == 2: + p[0] = p[1] + else: + p[0] = c_ast.Assignment(p[2], p[1], p[3], p[1].coord) + + # K&R2 defines these as many separate rules, to encode + # precedence and associativity. Why work hard ? I'll just use + # the built in precedence/associativity specification feature + # of PLY. (see precedence declaration above) + # + def p_assignment_operator(self, p): + """ assignment_operator : EQUALS + | XOREQUAL + | TIMESEQUAL + | DIVEQUAL + | MODEQUAL + | PLUSEQUAL + | MINUSEQUAL + | LSHIFTEQUAL + | RSHIFTEQUAL + | ANDEQUAL + | OREQUAL + """ + p[0] = p[1] + + def p_constant_expression(self, p): + """ constant_expression : conditional_expression """ + p[0] = p[1] + + def p_conditional_expression(self, p): + """ conditional_expression : binary_expression + | binary_expression CONDOP expression COLON conditional_expression + """ + if len(p) == 2: + p[0] = p[1] + else: + p[0] = c_ast.TernaryOp(p[1], p[3], p[5], p[1].coord) + + def p_binary_expression(self, p): + """ binary_expression : cast_expression + | binary_expression TIMES binary_expression + | binary_expression DIVIDE binary_expression + | binary_expression MOD binary_expression + | binary_expression PLUS binary_expression + | binary_expression MINUS binary_expression + | binary_expression RSHIFT binary_expression + | binary_expression LSHIFT binary_expression + | binary_expression LT binary_expression + | binary_expression LE binary_expression + | binary_expression GE binary_expression + | binary_expression GT binary_expression + | binary_expression EQ binary_expression + | binary_expression NE binary_expression + | binary_expression AND binary_expression + | binary_expression OR binary_expression + | binary_expression XOR binary_expression + | binary_expression LAND binary_expression + | binary_expression LOR binary_expression + """ + if len(p) == 2: + p[0] = p[1] + else: + p[0] = c_ast.BinaryOp(p[2], p[1], p[3], p[1].coord) + + def p_cast_expression_1(self, p): + """ cast_expression : unary_expression """ + p[0] = p[1] + + def p_cast_expression_2(self, p): + """ cast_expression : LPAREN type_name RPAREN cast_expression """ + p[0] = c_ast.Cast(p[2], p[4], self._token_coord(p, 1)) + + def p_unary_expression_1(self, p): + """ unary_expression : postfix_expression """ + p[0] = p[1] + + def p_unary_expression_2(self, p): + """ unary_expression : PLUSPLUS unary_expression + | MINUSMINUS unary_expression + | unary_operator cast_expression + """ + p[0] = c_ast.UnaryOp(p[1], p[2], p[2].coord) + + def p_unary_expression_3(self, p): + """ unary_expression : SIZEOF unary_expression + | SIZEOF LPAREN type_name RPAREN + | _ALIGNOF LPAREN type_name RPAREN + """ + p[0] = c_ast.UnaryOp( + p[1], + p[2] if len(p) == 3 else p[3], + self._token_coord(p, 1)) + + def p_unary_operator(self, p): + """ unary_operator : AND + | TIMES + | PLUS + | MINUS + | NOT + | LNOT + """ + p[0] = p[1] + + def p_postfix_expression_1(self, p): + """ postfix_expression : primary_expression """ + p[0] = p[1] + + def p_postfix_expression_2(self, p): + """ postfix_expression : postfix_expression LBRACKET expression RBRACKET """ + p[0] = c_ast.ArrayRef(p[1], p[3], p[1].coord) + + def p_postfix_expression_3(self, p): + """ postfix_expression : postfix_expression LPAREN argument_expression_list RPAREN + | postfix_expression LPAREN RPAREN + """ + p[0] = c_ast.FuncCall(p[1], p[3] if len(p) == 5 else None, p[1].coord) + + def p_postfix_expression_4(self, p): + """ postfix_expression : postfix_expression PERIOD ID + | postfix_expression PERIOD TYPEID + | postfix_expression ARROW ID + | postfix_expression ARROW TYPEID + """ + field = c_ast.ID(p[3], self._token_coord(p, 3)) + p[0] = c_ast.StructRef(p[1], p[2], field, p[1].coord) + + def p_postfix_expression_5(self, p): + """ postfix_expression : postfix_expression PLUSPLUS + | postfix_expression MINUSMINUS + """ + p[0] = c_ast.UnaryOp('p' + p[2], p[1], p[1].coord) + + def p_postfix_expression_6(self, p): + """ postfix_expression : LPAREN type_name RPAREN brace_open initializer_list brace_close + | LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close + """ + p[0] = c_ast.CompoundLiteral(p[2], p[5]) + + def p_primary_expression_1(self, p): + """ primary_expression : identifier """ + p[0] = p[1] + + def p_primary_expression_2(self, p): + """ primary_expression : constant """ + p[0] = p[1] + + def p_primary_expression_3(self, p): + """ primary_expression : unified_string_literal + | unified_wstring_literal + """ + p[0] = p[1] + + def p_primary_expression_4(self, p): + """ primary_expression : LPAREN expression RPAREN """ + p[0] = p[2] + + def p_primary_expression_5(self, p): + """ primary_expression : OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN + """ + coord = self._token_coord(p, 1) + p[0] = c_ast.FuncCall(c_ast.ID(p[1], coord), + c_ast.ExprList([p[3], p[5]], coord), + coord) + + def p_offsetof_member_designator(self, p): + """ offsetof_member_designator : identifier + | offsetof_member_designator PERIOD identifier + | offsetof_member_designator LBRACKET expression RBRACKET + """ + if len(p) == 2: + p[0] = p[1] + elif len(p) == 4: + p[0] = c_ast.StructRef(p[1], p[2], p[3], p[1].coord) + elif len(p) == 5: + p[0] = c_ast.ArrayRef(p[1], p[3], p[1].coord) + else: + raise NotImplementedError("Unexpected parsing state. len(p): %u" % len(p)) + + def p_argument_expression_list(self, p): + """ argument_expression_list : assignment_expression + | argument_expression_list COMMA assignment_expression + """ + if len(p) == 2: # single expr + p[0] = c_ast.ExprList([p[1]], p[1].coord) + else: + p[1].exprs.append(p[3]) + p[0] = p[1] + + def p_identifier(self, p): + """ identifier : ID """ + p[0] = c_ast.ID(p[1], self._token_coord(p, 1)) + + def p_constant_1(self, p): + """ constant : INT_CONST_DEC + | INT_CONST_OCT + | INT_CONST_HEX + | INT_CONST_BIN + | INT_CONST_CHAR + """ + uCount = 0 + lCount = 0 + for x in p[1][-3:]: + if x in ('l', 'L'): + lCount += 1 + elif x in ('u', 'U'): + uCount += 1 + t = '' + if uCount > 1: + raise ValueError('Constant cannot have more than one u/U suffix.') + elif lCount > 2: + raise ValueError('Constant cannot have more than two l/L suffix.') + prefix = 'unsigned ' * uCount + 'long ' * lCount + p[0] = c_ast.Constant( + prefix + 'int', p[1], self._token_coord(p, 1)) + + def p_constant_2(self, p): + """ constant : FLOAT_CONST + | HEX_FLOAT_CONST + """ + if 'x' in p[1].lower(): + t = 'float' + else: + if p[1][-1] in ('f', 'F'): + t = 'float' + elif p[1][-1] in ('l', 'L'): + t = 'long double' + else: + t = 'double' + + p[0] = c_ast.Constant( + t, p[1], self._token_coord(p, 1)) + + def p_constant_3(self, p): + """ constant : CHAR_CONST + | WCHAR_CONST + | U8CHAR_CONST + | U16CHAR_CONST + | U32CHAR_CONST + """ + p[0] = c_ast.Constant( + 'char', p[1], self._token_coord(p, 1)) + + # The "unified" string and wstring literal rules are for supporting + # concatenation of adjacent string literals. + # I.e. "hello " "world" is seen by the C compiler as a single string literal + # with the value "hello world" + # + def p_unified_string_literal(self, p): + """ unified_string_literal : STRING_LITERAL + | unified_string_literal STRING_LITERAL + """ + if len(p) == 2: # single literal + p[0] = c_ast.Constant( + 'string', p[1], self._token_coord(p, 1)) + else: + p[1].value = p[1].value[:-1] + p[2][1:] + p[0] = p[1] + + def p_unified_wstring_literal(self, p): + """ unified_wstring_literal : WSTRING_LITERAL + | U8STRING_LITERAL + | U16STRING_LITERAL + | U32STRING_LITERAL + | unified_wstring_literal WSTRING_LITERAL + | unified_wstring_literal U8STRING_LITERAL + | unified_wstring_literal U16STRING_LITERAL + | unified_wstring_literal U32STRING_LITERAL + """ + if len(p) == 2: # single literal + p[0] = c_ast.Constant( + 'string', p[1], self._token_coord(p, 1)) + else: + p[1].value = p[1].value.rstrip()[:-1] + p[2][2:] + p[0] = p[1] + + def p_brace_open(self, p): + """ brace_open : LBRACE + """ + p[0] = p[1] + p.set_lineno(0, p.lineno(1)) + + def p_brace_close(self, p): + """ brace_close : RBRACE + """ + p[0] = p[1] + p.set_lineno(0, p.lineno(1)) + + def p_empty(self, p): + 'empty : ' + p[0] = None + + def p_error(self, p): + # If error recovery is added here in the future, make sure + # _get_yacc_lookahead_token still works! + # + if p: + self._parse_error( + 'before: %s' % p.value, + self._coord(lineno=p.lineno, + column=self.clex.find_tok_column(p))) + else: + self._parse_error('At end of input', self.clex.filename) diff --git a/.venv/Lib/site-packages/pycparser/lextab.py b/.venv/Lib/site-packages/pycparser/lextab.py new file mode 100644 index 00000000..444b4656 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/lextab.py @@ -0,0 +1,10 @@ +# lextab.py. This file automatically created by PLY (version 3.10). Don't edit! +_tabversion = '3.10' +_lextokens = set(('INT_CONST_CHAR', 'VOID', 'LBRACKET', 'WCHAR_CONST', 'FLOAT_CONST', 'MINUS', 'RPAREN', 'STRUCT', 'LONG', 'PLUS', 'ELLIPSIS', 'U32STRING_LITERAL', 'GT', 'GOTO', 'ENUM', 'PERIOD', 'GE', 'INT_CONST_DEC', 'ARROW', '_STATIC_ASSERT', '__INT128', 'HEX_FLOAT_CONST', 'DOUBLE', 'MINUSEQUAL', 'INT_CONST_OCT', 'TIMESEQUAL', 'OR', 'SHORT', 'RETURN', 'RSHIFTEQUAL', '_ALIGNAS', 'RESTRICT', 'STATIC', 'SIZEOF', 'UNSIGNED', 'PLUSPLUS', 'COLON', 'WSTRING_LITERAL', 'DIVIDE', 'FOR', 'UNION', 'EQUALS', 'ELSE', 'ANDEQUAL', 'EQ', 'AND', 'TYPEID', 'LBRACE', 'PPHASH', 'INT', 'SIGNED', 'CONTINUE', 'NOT', 'OREQUAL', 'MOD', 'RSHIFT', 'DEFAULT', '_NORETURN', 'CHAR', 'WHILE', 'DIVEQUAL', '_ALIGNOF', 'EXTERN', 'LNOT', 'CASE', 'LAND', 'REGISTER', 'MODEQUAL', 'NE', 'SWITCH', 'INT_CONST_HEX', '_COMPLEX', 'PPPRAGMASTR', 'PLUSEQUAL', 'U32CHAR_CONST', 'CONDOP', 'U8STRING_LITERAL', 'BREAK', 'VOLATILE', 'PPPRAGMA', 'INLINE', 'INT_CONST_BIN', 'DO', 'U8CHAR_CONST', 'CONST', 'U16STRING_LITERAL', 'LOR', 'CHAR_CONST', 'LSHIFT', 'RBRACE', '_BOOL', 'LE', 'SEMI', '_THREAD_LOCAL', 'LT', 'COMMA', 'U16CHAR_CONST', 'OFFSETOF', '_ATOMIC', 'TYPEDEF', 'XOR', 'AUTO', 'TIMES', 'LPAREN', 'MINUSMINUS', 'ID', 'IF', 'STRING_LITERAL', 'FLOAT', 'XOREQUAL', 'LSHIFTEQUAL', 'RBRACKET')) +_lexreflags = 64 +_lexliterals = '' +_lexstateinfo = {'ppline': 'exclusive', 'pppragma': 'exclusive', 'INITIAL': 'inclusive'} +_lexstatere = {'ppline': [('(?P"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P(0(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|([1-9][0-9]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?))|(?P\\n)|(?Pline)', [None, ('t_ppline_FILENAME', 'FILENAME'), None, None, ('t_ppline_LINE_NUMBER', 'LINE_NUMBER'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_ppline_NEWLINE', 'NEWLINE'), ('t_ppline_PPLINE', 'PPLINE')])], 'pppragma': [('(?P\\n)|(?Ppragma)|(?P.+)', [None, ('t_pppragma_NEWLINE', 'NEWLINE'), ('t_pppragma_PPPRAGMA', 'PPPRAGMA'), ('t_pppragma_STR', 'STR')])], 'INITIAL': [('(?P[ \\t]*\\#)|(?P\\n+)|(?P\\{)|(?P\\})|(?P((((([0-9]*\\.[0-9]+)|([0-9]+\\.))([eE][-+]?[0-9]+)?)|([0-9]+([eE][-+]?[0-9]+)))[FfLl]?))|(?P(0[xX]([0-9a-fA-F]+|((([0-9a-fA-F]+)?\\.[0-9a-fA-F]+)|([0-9a-fA-F]+\\.)))([pP][+-]?[0-9]+)[FfLl]?))|(?P0[xX][0-9a-fA-F]+(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|(?P0[bB][01]+(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)', [None, ('t_PPHASH', 'PPHASH'), ('t_NEWLINE', 'NEWLINE'), ('t_LBRACE', 'LBRACE'), ('t_RBRACE', 'RBRACE'), ('t_FLOAT_CONST', 'FLOAT_CONST'), None, None, None, None, None, None, None, None, None, ('t_HEX_FLOAT_CONST', 'HEX_FLOAT_CONST'), None, None, None, None, None, None, None, ('t_INT_CONST_HEX', 'INT_CONST_HEX'), None, None, None, None, None, None, None, ('t_INT_CONST_BIN', 'INT_CONST_BIN')]), ('(?P0[0-7]*[89])|(?P0[0-7]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|(?P(0(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|([1-9][0-9]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?))|(?P\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F])))){2,4}\')|(?P\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?PL\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?Pu8\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?Pu\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?PU\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')', [None, ('t_BAD_CONST_OCT', 'BAD_CONST_OCT'), ('t_INT_CONST_OCT', 'INT_CONST_OCT'), None, None, None, None, None, None, None, ('t_INT_CONST_DEC', 'INT_CONST_DEC'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_INT_CONST_CHAR', 'INT_CONST_CHAR'), None, None, None, None, None, None, ('t_CHAR_CONST', 'CHAR_CONST'), None, None, None, None, None, None, ('t_WCHAR_CONST', 'WCHAR_CONST'), None, None, None, None, None, None, ('t_U8CHAR_CONST', 'U8CHAR_CONST'), None, None, None, None, None, None, ('t_U16CHAR_CONST', 'U16CHAR_CONST'), None, None, None, None, None, None, ('t_U32CHAR_CONST', 'U32CHAR_CONST')]), ('(?P(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))*\\n)|(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))*$))|(?P(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))[^\'\n]+\')|(\'\')|(\'([\\\\][^a-zA-Z._~^!=&\\^\\-\\\\?\'"x0-9])[^\'\\n]*\'))|(?PL"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?Pu8"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?Pu"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?PU"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*([\\\\][^a-zA-Z._~^!=&\\^\\-\\\\?\'"x0-9])([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P[a-zA-Z_$][0-9a-zA-Z_$]*)|(?P"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P\\.\\.\\.)|(?P\\+\\+)|(?P\\|\\|)|(?P\\^=)|(?P\\|=)|(?P<<=)|(?P>>=)|(?P\\+=)|(?P\\*=)', [None, ('t_UNMATCHED_QUOTE', 'UNMATCHED_QUOTE'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_BAD_CHAR_CONST', 'BAD_CHAR_CONST'), None, None, None, None, None, None, None, None, None, None, ('t_WSTRING_LITERAL', 'WSTRING_LITERAL'), None, None, ('t_U8STRING_LITERAL', 'U8STRING_LITERAL'), None, None, ('t_U16STRING_LITERAL', 'U16STRING_LITERAL'), None, None, ('t_U32STRING_LITERAL', 'U32STRING_LITERAL'), None, None, ('t_BAD_STRING_LITERAL', 'BAD_STRING_LITERAL'), None, None, None, None, None, ('t_ID', 'ID'), (None, 'STRING_LITERAL'), None, None, (None, 'ELLIPSIS'), (None, 'PLUSPLUS'), (None, 'LOR'), (None, 'XOREQUAL'), (None, 'OREQUAL'), (None, 'LSHIFTEQUAL'), (None, 'RSHIFTEQUAL'), (None, 'PLUSEQUAL'), (None, 'TIMESEQUAL')]), ('(?P\\+)|(?P%=)|(?P/=)|(?P\\])|(?P\\?)|(?P\\^)|(?P<<)|(?P<=)|(?P\\()|(?P->)|(?P==)|(?P!=)|(?P--)|(?P\\|)|(?P\\*)|(?P\\[)|(?P>=)|(?P\\))|(?P&&)|(?P>>)|(?P-=)|(?P\\.)|(?P&=)|(?P=)|(?P<)|(?P,)|(?P/)|(?P&)|(?P%)|(?P;)|(?P-)|(?P>)|(?P:)|(?P~)|(?P!)', [None, (None, 'PLUS'), (None, 'MODEQUAL'), (None, 'DIVEQUAL'), (None, 'RBRACKET'), (None, 'CONDOP'), (None, 'XOR'), (None, 'LSHIFT'), (None, 'LE'), (None, 'LPAREN'), (None, 'ARROW'), (None, 'EQ'), (None, 'NE'), (None, 'MINUSMINUS'), (None, 'OR'), (None, 'TIMES'), (None, 'LBRACKET'), (None, 'GE'), (None, 'RPAREN'), (None, 'LAND'), (None, 'RSHIFT'), (None, 'MINUSEQUAL'), (None, 'PERIOD'), (None, 'ANDEQUAL'), (None, 'EQUALS'), (None, 'LT'), (None, 'COMMA'), (None, 'DIVIDE'), (None, 'AND'), (None, 'MOD'), (None, 'SEMI'), (None, 'MINUS'), (None, 'GT'), (None, 'COLON'), (None, 'NOT'), (None, 'LNOT')])]} +_lexstateignore = {'ppline': ' \t', 'pppragma': ' \t', 'INITIAL': ' \t'} +_lexstateerrorf = {'ppline': 't_ppline_error', 'pppragma': 't_pppragma_error', 'INITIAL': 't_error'} +_lexstateeoff = {} diff --git a/.venv/Lib/site-packages/pycparser/ply/__init__.py b/.venv/Lib/site-packages/pycparser/ply/__init__.py new file mode 100644 index 00000000..6e53cddc --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/__init__.py @@ -0,0 +1,5 @@ +# PLY package +# Author: David Beazley (dave@dabeaz.com) + +__version__ = '3.9' +__all__ = ['lex','yacc'] diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..30ef9b28 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/cpp.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/cpp.cpython-311.pyc new file mode 100644 index 00000000..ddb41a2c Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/cpp.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/ctokens.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/ctokens.cpython-311.pyc new file mode 100644 index 00000000..e9c078f9 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/ctokens.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/lex.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/lex.cpython-311.pyc new file mode 100644 index 00000000..5b2173b9 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/lex.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc new file mode 100644 index 00000000..6dad032f Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/yacc.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/__pycache__/ygen.cpython-311.pyc b/.venv/Lib/site-packages/pycparser/ply/__pycache__/ygen.cpython-311.pyc new file mode 100644 index 00000000..055d3b15 Binary files /dev/null and b/.venv/Lib/site-packages/pycparser/ply/__pycache__/ygen.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/pycparser/ply/cpp.py b/.venv/Lib/site-packages/pycparser/ply/cpp.py new file mode 100644 index 00000000..86273eac --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/cpp.py @@ -0,0 +1,905 @@ +# ----------------------------------------------------------------------------- +# cpp.py +# +# Author: David Beazley (http://www.dabeaz.com) +# Copyright (C) 2017 +# All rights reserved +# +# This module implements an ANSI-C style lexical preprocessor for PLY. +# ----------------------------------------------------------------------------- +import sys + +# Some Python 3 compatibility shims +if sys.version_info.major < 3: + STRING_TYPES = (str, unicode) +else: + STRING_TYPES = str + xrange = range + +# ----------------------------------------------------------------------------- +# Default preprocessor lexer definitions. These tokens are enough to get +# a basic preprocessor working. Other modules may import these if they want +# ----------------------------------------------------------------------------- + +tokens = ( + 'CPP_ID','CPP_INTEGER', 'CPP_FLOAT', 'CPP_STRING', 'CPP_CHAR', 'CPP_WS', 'CPP_COMMENT1', 'CPP_COMMENT2', 'CPP_POUND','CPP_DPOUND' +) + +literals = "+-*/%|&~^<>=!?()[]{}.,;:\\\'\"" + +# Whitespace +def t_CPP_WS(t): + r'\s+' + t.lexer.lineno += t.value.count("\n") + return t + +t_CPP_POUND = r'\#' +t_CPP_DPOUND = r'\#\#' + +# Identifier +t_CPP_ID = r'[A-Za-z_][\w_]*' + +# Integer literal +def CPP_INTEGER(t): + r'(((((0x)|(0X))[0-9a-fA-F]+)|(\d+))([uU][lL]|[lL][uU]|[uU]|[lL])?)' + return t + +t_CPP_INTEGER = CPP_INTEGER + +# Floating literal +t_CPP_FLOAT = r'((\d+)(\.\d+)(e(\+|-)?(\d+))? | (\d+)e(\+|-)?(\d+))([lL]|[fF])?' + +# String literal +def t_CPP_STRING(t): + r'\"([^\\\n]|(\\(.|\n)))*?\"' + t.lexer.lineno += t.value.count("\n") + return t + +# Character constant 'c' or L'c' +def t_CPP_CHAR(t): + r'(L)?\'([^\\\n]|(\\(.|\n)))*?\'' + t.lexer.lineno += t.value.count("\n") + return t + +# Comment +def t_CPP_COMMENT1(t): + r'(/\*(.|\n)*?\*/)' + ncr = t.value.count("\n") + t.lexer.lineno += ncr + # replace with one space or a number of '\n' + t.type = 'CPP_WS'; t.value = '\n' * ncr if ncr else ' ' + return t + +# Line comment +def t_CPP_COMMENT2(t): + r'(//.*?(\n|$))' + # replace with '/n' + t.type = 'CPP_WS'; t.value = '\n' + return t + +def t_error(t): + t.type = t.value[0] + t.value = t.value[0] + t.lexer.skip(1) + return t + +import re +import copy +import time +import os.path + +# ----------------------------------------------------------------------------- +# trigraph() +# +# Given an input string, this function replaces all trigraph sequences. +# The following mapping is used: +# +# ??= # +# ??/ \ +# ??' ^ +# ??( [ +# ??) ] +# ??! | +# ??< { +# ??> } +# ??- ~ +# ----------------------------------------------------------------------------- + +_trigraph_pat = re.compile(r'''\?\?[=/\'\(\)\!<>\-]''') +_trigraph_rep = { + '=':'#', + '/':'\\', + "'":'^', + '(':'[', + ')':']', + '!':'|', + '<':'{', + '>':'}', + '-':'~' +} + +def trigraph(input): + return _trigraph_pat.sub(lambda g: _trigraph_rep[g.group()[-1]],input) + +# ------------------------------------------------------------------ +# Macro object +# +# This object holds information about preprocessor macros +# +# .name - Macro name (string) +# .value - Macro value (a list of tokens) +# .arglist - List of argument names +# .variadic - Boolean indicating whether or not variadic macro +# .vararg - Name of the variadic parameter +# +# When a macro is created, the macro replacement token sequence is +# pre-scanned and used to create patch lists that are later used +# during macro expansion +# ------------------------------------------------------------------ + +class Macro(object): + def __init__(self,name,value,arglist=None,variadic=False): + self.name = name + self.value = value + self.arglist = arglist + self.variadic = variadic + if variadic: + self.vararg = arglist[-1] + self.source = None + +# ------------------------------------------------------------------ +# Preprocessor object +# +# Object representing a preprocessor. Contains macro definitions, +# include directories, and other information +# ------------------------------------------------------------------ + +class Preprocessor(object): + def __init__(self,lexer=None): + if lexer is None: + lexer = lex.lexer + self.lexer = lexer + self.macros = { } + self.path = [] + self.temp_path = [] + + # Probe the lexer for selected tokens + self.lexprobe() + + tm = time.localtime() + self.define("__DATE__ \"%s\"" % time.strftime("%b %d %Y",tm)) + self.define("__TIME__ \"%s\"" % time.strftime("%H:%M:%S",tm)) + self.parser = None + + # ----------------------------------------------------------------------------- + # tokenize() + # + # Utility function. Given a string of text, tokenize into a list of tokens + # ----------------------------------------------------------------------------- + + def tokenize(self,text): + tokens = [] + self.lexer.input(text) + while True: + tok = self.lexer.token() + if not tok: break + tokens.append(tok) + return tokens + + # --------------------------------------------------------------------- + # error() + # + # Report a preprocessor error/warning of some kind + # ---------------------------------------------------------------------- + + def error(self,file,line,msg): + print("%s:%d %s" % (file,line,msg)) + + # ---------------------------------------------------------------------- + # lexprobe() + # + # This method probes the preprocessor lexer object to discover + # the token types of symbols that are important to the preprocessor. + # If this works right, the preprocessor will simply "work" + # with any suitable lexer regardless of how tokens have been named. + # ---------------------------------------------------------------------- + + def lexprobe(self): + + # Determine the token type for identifiers + self.lexer.input("identifier") + tok = self.lexer.token() + if not tok or tok.value != "identifier": + print("Couldn't determine identifier type") + else: + self.t_ID = tok.type + + # Determine the token type for integers + self.lexer.input("12345") + tok = self.lexer.token() + if not tok or int(tok.value) != 12345: + print("Couldn't determine integer type") + else: + self.t_INTEGER = tok.type + self.t_INTEGER_TYPE = type(tok.value) + + # Determine the token type for strings enclosed in double quotes + self.lexer.input("\"filename\"") + tok = self.lexer.token() + if not tok or tok.value != "\"filename\"": + print("Couldn't determine string type") + else: + self.t_STRING = tok.type + + # Determine the token type for whitespace--if any + self.lexer.input(" ") + tok = self.lexer.token() + if not tok or tok.value != " ": + self.t_SPACE = None + else: + self.t_SPACE = tok.type + + # Determine the token type for newlines + self.lexer.input("\n") + tok = self.lexer.token() + if not tok or tok.value != "\n": + self.t_NEWLINE = None + print("Couldn't determine token for newlines") + else: + self.t_NEWLINE = tok.type + + self.t_WS = (self.t_SPACE, self.t_NEWLINE) + + # Check for other characters used by the preprocessor + chars = [ '<','>','#','##','\\','(',')',',','.'] + for c in chars: + self.lexer.input(c) + tok = self.lexer.token() + if not tok or tok.value != c: + print("Unable to lex '%s' required for preprocessor" % c) + + # ---------------------------------------------------------------------- + # add_path() + # + # Adds a search path to the preprocessor. + # ---------------------------------------------------------------------- + + def add_path(self,path): + self.path.append(path) + + # ---------------------------------------------------------------------- + # group_lines() + # + # Given an input string, this function splits it into lines. Trailing whitespace + # is removed. Any line ending with \ is grouped with the next line. This + # function forms the lowest level of the preprocessor---grouping into text into + # a line-by-line format. + # ---------------------------------------------------------------------- + + def group_lines(self,input): + lex = self.lexer.clone() + lines = [x.rstrip() for x in input.splitlines()] + for i in xrange(len(lines)): + j = i+1 + while lines[i].endswith('\\') and (j < len(lines)): + lines[i] = lines[i][:-1]+lines[j] + lines[j] = "" + j += 1 + + input = "\n".join(lines) + lex.input(input) + lex.lineno = 1 + + current_line = [] + while True: + tok = lex.token() + if not tok: + break + current_line.append(tok) + if tok.type in self.t_WS and '\n' in tok.value: + yield current_line + current_line = [] + + if current_line: + yield current_line + + # ---------------------------------------------------------------------- + # tokenstrip() + # + # Remove leading/trailing whitespace tokens from a token list + # ---------------------------------------------------------------------- + + def tokenstrip(self,tokens): + i = 0 + while i < len(tokens) and tokens[i].type in self.t_WS: + i += 1 + del tokens[:i] + i = len(tokens)-1 + while i >= 0 and tokens[i].type in self.t_WS: + i -= 1 + del tokens[i+1:] + return tokens + + + # ---------------------------------------------------------------------- + # collect_args() + # + # Collects comma separated arguments from a list of tokens. The arguments + # must be enclosed in parenthesis. Returns a tuple (tokencount,args,positions) + # where tokencount is the number of tokens consumed, args is a list of arguments, + # and positions is a list of integers containing the starting index of each + # argument. Each argument is represented by a list of tokens. + # + # When collecting arguments, leading and trailing whitespace is removed + # from each argument. + # + # This function properly handles nested parenthesis and commas---these do not + # define new arguments. + # ---------------------------------------------------------------------- + + def collect_args(self,tokenlist): + args = [] + positions = [] + current_arg = [] + nesting = 1 + tokenlen = len(tokenlist) + + # Search for the opening '('. + i = 0 + while (i < tokenlen) and (tokenlist[i].type in self.t_WS): + i += 1 + + if (i < tokenlen) and (tokenlist[i].value == '('): + positions.append(i+1) + else: + self.error(self.source,tokenlist[0].lineno,"Missing '(' in macro arguments") + return 0, [], [] + + i += 1 + + while i < tokenlen: + t = tokenlist[i] + if t.value == '(': + current_arg.append(t) + nesting += 1 + elif t.value == ')': + nesting -= 1 + if nesting == 0: + if current_arg: + args.append(self.tokenstrip(current_arg)) + positions.append(i) + return i+1,args,positions + current_arg.append(t) + elif t.value == ',' and nesting == 1: + args.append(self.tokenstrip(current_arg)) + positions.append(i+1) + current_arg = [] + else: + current_arg.append(t) + i += 1 + + # Missing end argument + self.error(self.source,tokenlist[-1].lineno,"Missing ')' in macro arguments") + return 0, [],[] + + # ---------------------------------------------------------------------- + # macro_prescan() + # + # Examine the macro value (token sequence) and identify patch points + # This is used to speed up macro expansion later on---we'll know + # right away where to apply patches to the value to form the expansion + # ---------------------------------------------------------------------- + + def macro_prescan(self,macro): + macro.patch = [] # Standard macro arguments + macro.str_patch = [] # String conversion expansion + macro.var_comma_patch = [] # Variadic macro comma patch + i = 0 + while i < len(macro.value): + if macro.value[i].type == self.t_ID and macro.value[i].value in macro.arglist: + argnum = macro.arglist.index(macro.value[i].value) + # Conversion of argument to a string + if i > 0 and macro.value[i-1].value == '#': + macro.value[i] = copy.copy(macro.value[i]) + macro.value[i].type = self.t_STRING + del macro.value[i-1] + macro.str_patch.append((argnum,i-1)) + continue + # Concatenation + elif (i > 0 and macro.value[i-1].value == '##'): + macro.patch.append(('c',argnum,i-1)) + del macro.value[i-1] + continue + elif ((i+1) < len(macro.value) and macro.value[i+1].value == '##'): + macro.patch.append(('c',argnum,i)) + i += 1 + continue + # Standard expansion + else: + macro.patch.append(('e',argnum,i)) + elif macro.value[i].value == '##': + if macro.variadic and (i > 0) and (macro.value[i-1].value == ',') and \ + ((i+1) < len(macro.value)) and (macro.value[i+1].type == self.t_ID) and \ + (macro.value[i+1].value == macro.vararg): + macro.var_comma_patch.append(i-1) + i += 1 + macro.patch.sort(key=lambda x: x[2],reverse=True) + + # ---------------------------------------------------------------------- + # macro_expand_args() + # + # Given a Macro and list of arguments (each a token list), this method + # returns an expanded version of a macro. The return value is a token sequence + # representing the replacement macro tokens + # ---------------------------------------------------------------------- + + def macro_expand_args(self,macro,args): + # Make a copy of the macro token sequence + rep = [copy.copy(_x) for _x in macro.value] + + # Make string expansion patches. These do not alter the length of the replacement sequence + + str_expansion = {} + for argnum, i in macro.str_patch: + if argnum not in str_expansion: + str_expansion[argnum] = ('"%s"' % "".join([x.value for x in args[argnum]])).replace("\\","\\\\") + rep[i] = copy.copy(rep[i]) + rep[i].value = str_expansion[argnum] + + # Make the variadic macro comma patch. If the variadic macro argument is empty, we get rid + comma_patch = False + if macro.variadic and not args[-1]: + for i in macro.var_comma_patch: + rep[i] = None + comma_patch = True + + # Make all other patches. The order of these matters. It is assumed that the patch list + # has been sorted in reverse order of patch location since replacements will cause the + # size of the replacement sequence to expand from the patch point. + + expanded = { } + for ptype, argnum, i in macro.patch: + # Concatenation. Argument is left unexpanded + if ptype == 'c': + rep[i:i+1] = args[argnum] + # Normal expansion. Argument is macro expanded first + elif ptype == 'e': + if argnum not in expanded: + expanded[argnum] = self.expand_macros(args[argnum]) + rep[i:i+1] = expanded[argnum] + + # Get rid of removed comma if necessary + if comma_patch: + rep = [_i for _i in rep if _i] + + return rep + + + # ---------------------------------------------------------------------- + # expand_macros() + # + # Given a list of tokens, this function performs macro expansion. + # The expanded argument is a dictionary that contains macros already + # expanded. This is used to prevent infinite recursion. + # ---------------------------------------------------------------------- + + def expand_macros(self,tokens,expanded=None): + if expanded is None: + expanded = {} + i = 0 + while i < len(tokens): + t = tokens[i] + if t.type == self.t_ID: + if t.value in self.macros and t.value not in expanded: + # Yes, we found a macro match + expanded[t.value] = True + + m = self.macros[t.value] + if not m.arglist: + # A simple macro + ex = self.expand_macros([copy.copy(_x) for _x in m.value],expanded) + for e in ex: + e.lineno = t.lineno + tokens[i:i+1] = ex + i += len(ex) + else: + # A macro with arguments + j = i + 1 + while j < len(tokens) and tokens[j].type in self.t_WS: + j += 1 + if tokens[j].value == '(': + tokcount,args,positions = self.collect_args(tokens[j:]) + if not m.variadic and len(args) != len(m.arglist): + self.error(self.source,t.lineno,"Macro %s requires %d arguments" % (t.value,len(m.arglist))) + i = j + tokcount + elif m.variadic and len(args) < len(m.arglist)-1: + if len(m.arglist) > 2: + self.error(self.source,t.lineno,"Macro %s must have at least %d arguments" % (t.value, len(m.arglist)-1)) + else: + self.error(self.source,t.lineno,"Macro %s must have at least %d argument" % (t.value, len(m.arglist)-1)) + i = j + tokcount + else: + if m.variadic: + if len(args) == len(m.arglist)-1: + args.append([]) + else: + args[len(m.arglist)-1] = tokens[j+positions[len(m.arglist)-1]:j+tokcount-1] + del args[len(m.arglist):] + + # Get macro replacement text + rep = self.macro_expand_args(m,args) + rep = self.expand_macros(rep,expanded) + for r in rep: + r.lineno = t.lineno + tokens[i:j+tokcount] = rep + i += len(rep) + del expanded[t.value] + continue + elif t.value == '__LINE__': + t.type = self.t_INTEGER + t.value = self.t_INTEGER_TYPE(t.lineno) + + i += 1 + return tokens + + # ---------------------------------------------------------------------- + # evalexpr() + # + # Evaluate an expression token sequence for the purposes of evaluating + # integral expressions. + # ---------------------------------------------------------------------- + + def evalexpr(self,tokens): + # tokens = tokenize(line) + # Search for defined macros + i = 0 + while i < len(tokens): + if tokens[i].type == self.t_ID and tokens[i].value == 'defined': + j = i + 1 + needparen = False + result = "0L" + while j < len(tokens): + if tokens[j].type in self.t_WS: + j += 1 + continue + elif tokens[j].type == self.t_ID: + if tokens[j].value in self.macros: + result = "1L" + else: + result = "0L" + if not needparen: break + elif tokens[j].value == '(': + needparen = True + elif tokens[j].value == ')': + break + else: + self.error(self.source,tokens[i].lineno,"Malformed defined()") + j += 1 + tokens[i].type = self.t_INTEGER + tokens[i].value = self.t_INTEGER_TYPE(result) + del tokens[i+1:j+1] + i += 1 + tokens = self.expand_macros(tokens) + for i,t in enumerate(tokens): + if t.type == self.t_ID: + tokens[i] = copy.copy(t) + tokens[i].type = self.t_INTEGER + tokens[i].value = self.t_INTEGER_TYPE("0L") + elif t.type == self.t_INTEGER: + tokens[i] = copy.copy(t) + # Strip off any trailing suffixes + tokens[i].value = str(tokens[i].value) + while tokens[i].value[-1] not in "0123456789abcdefABCDEF": + tokens[i].value = tokens[i].value[:-1] + + expr = "".join([str(x.value) for x in tokens]) + expr = expr.replace("&&"," and ") + expr = expr.replace("||"," or ") + expr = expr.replace("!"," not ") + try: + result = eval(expr) + except Exception: + self.error(self.source,tokens[0].lineno,"Couldn't evaluate expression") + result = 0 + return result + + # ---------------------------------------------------------------------- + # parsegen() + # + # Parse an input string/ + # ---------------------------------------------------------------------- + def parsegen(self,input,source=None): + + # Replace trigraph sequences + t = trigraph(input) + lines = self.group_lines(t) + + if not source: + source = "" + + self.define("__FILE__ \"%s\"" % source) + + self.source = source + chunk = [] + enable = True + iftrigger = False + ifstack = [] + + for x in lines: + for i,tok in enumerate(x): + if tok.type not in self.t_WS: break + if tok.value == '#': + # Preprocessor directive + + # insert necessary whitespace instead of eaten tokens + for tok in x: + if tok.type in self.t_WS and '\n' in tok.value: + chunk.append(tok) + + dirtokens = self.tokenstrip(x[i+1:]) + if dirtokens: + name = dirtokens[0].value + args = self.tokenstrip(dirtokens[1:]) + else: + name = "" + args = [] + + if name == 'define': + if enable: + for tok in self.expand_macros(chunk): + yield tok + chunk = [] + self.define(args) + elif name == 'include': + if enable: + for tok in self.expand_macros(chunk): + yield tok + chunk = [] + oldfile = self.macros['__FILE__'] + for tok in self.include(args): + yield tok + self.macros['__FILE__'] = oldfile + self.source = source + elif name == 'undef': + if enable: + for tok in self.expand_macros(chunk): + yield tok + chunk = [] + self.undef(args) + elif name == 'ifdef': + ifstack.append((enable,iftrigger)) + if enable: + if not args[0].value in self.macros: + enable = False + iftrigger = False + else: + iftrigger = True + elif name == 'ifndef': + ifstack.append((enable,iftrigger)) + if enable: + if args[0].value in self.macros: + enable = False + iftrigger = False + else: + iftrigger = True + elif name == 'if': + ifstack.append((enable,iftrigger)) + if enable: + result = self.evalexpr(args) + if not result: + enable = False + iftrigger = False + else: + iftrigger = True + elif name == 'elif': + if ifstack: + if ifstack[-1][0]: # We only pay attention if outer "if" allows this + if enable: # If already true, we flip enable False + enable = False + elif not iftrigger: # If False, but not triggered yet, we'll check expression + result = self.evalexpr(args) + if result: + enable = True + iftrigger = True + else: + self.error(self.source,dirtokens[0].lineno,"Misplaced #elif") + + elif name == 'else': + if ifstack: + if ifstack[-1][0]: + if enable: + enable = False + elif not iftrigger: + enable = True + iftrigger = True + else: + self.error(self.source,dirtokens[0].lineno,"Misplaced #else") + + elif name == 'endif': + if ifstack: + enable,iftrigger = ifstack.pop() + else: + self.error(self.source,dirtokens[0].lineno,"Misplaced #endif") + else: + # Unknown preprocessor directive + pass + + else: + # Normal text + if enable: + chunk.extend(x) + + for tok in self.expand_macros(chunk): + yield tok + chunk = [] + + # ---------------------------------------------------------------------- + # include() + # + # Implementation of file-inclusion + # ---------------------------------------------------------------------- + + def include(self,tokens): + # Try to extract the filename and then process an include file + if not tokens: + return + if tokens: + if tokens[0].value != '<' and tokens[0].type != self.t_STRING: + tokens = self.expand_macros(tokens) + + if tokens[0].value == '<': + # Include <...> + i = 1 + while i < len(tokens): + if tokens[i].value == '>': + break + i += 1 + else: + print("Malformed #include <...>") + return + filename = "".join([x.value for x in tokens[1:i]]) + path = self.path + [""] + self.temp_path + elif tokens[0].type == self.t_STRING: + filename = tokens[0].value[1:-1] + path = self.temp_path + [""] + self.path + else: + print("Malformed #include statement") + return + for p in path: + iname = os.path.join(p,filename) + try: + data = open(iname,"r").read() + dname = os.path.dirname(iname) + if dname: + self.temp_path.insert(0,dname) + for tok in self.parsegen(data,filename): + yield tok + if dname: + del self.temp_path[0] + break + except IOError: + pass + else: + print("Couldn't find '%s'" % filename) + + # ---------------------------------------------------------------------- + # define() + # + # Define a new macro + # ---------------------------------------------------------------------- + + def define(self,tokens): + if isinstance(tokens,STRING_TYPES): + tokens = self.tokenize(tokens) + + linetok = tokens + try: + name = linetok[0] + if len(linetok) > 1: + mtype = linetok[1] + else: + mtype = None + if not mtype: + m = Macro(name.value,[]) + self.macros[name.value] = m + elif mtype.type in self.t_WS: + # A normal macro + m = Macro(name.value,self.tokenstrip(linetok[2:])) + self.macros[name.value] = m + elif mtype.value == '(': + # A macro with arguments + tokcount, args, positions = self.collect_args(linetok[1:]) + variadic = False + for a in args: + if variadic: + print("No more arguments may follow a variadic argument") + break + astr = "".join([str(_i.value) for _i in a]) + if astr == "...": + variadic = True + a[0].type = self.t_ID + a[0].value = '__VA_ARGS__' + variadic = True + del a[1:] + continue + elif astr[-3:] == "..." and a[0].type == self.t_ID: + variadic = True + del a[1:] + # If, for some reason, "." is part of the identifier, strip off the name for the purposes + # of macro expansion + if a[0].value[-3:] == '...': + a[0].value = a[0].value[:-3] + continue + if len(a) > 1 or a[0].type != self.t_ID: + print("Invalid macro argument") + break + else: + mvalue = self.tokenstrip(linetok[1+tokcount:]) + i = 0 + while i < len(mvalue): + if i+1 < len(mvalue): + if mvalue[i].type in self.t_WS and mvalue[i+1].value == '##': + del mvalue[i] + continue + elif mvalue[i].value == '##' and mvalue[i+1].type in self.t_WS: + del mvalue[i+1] + i += 1 + m = Macro(name.value,mvalue,[x[0].value for x in args],variadic) + self.macro_prescan(m) + self.macros[name.value] = m + else: + print("Bad macro definition") + except LookupError: + print("Bad macro definition") + + # ---------------------------------------------------------------------- + # undef() + # + # Undefine a macro + # ---------------------------------------------------------------------- + + def undef(self,tokens): + id = tokens[0].value + try: + del self.macros[id] + except LookupError: + pass + + # ---------------------------------------------------------------------- + # parse() + # + # Parse input text. + # ---------------------------------------------------------------------- + def parse(self,input,source=None,ignore={}): + self.ignore = ignore + self.parser = self.parsegen(input,source) + + # ---------------------------------------------------------------------- + # token() + # + # Method to return individual tokens + # ---------------------------------------------------------------------- + def token(self): + try: + while True: + tok = next(self.parser) + if tok.type not in self.ignore: return tok + except StopIteration: + self.parser = None + return None + +if __name__ == '__main__': + import ply.lex as lex + lexer = lex.lex() + + # Run a preprocessor + import sys + f = open(sys.argv[1]) + input = f.read() + + p = Preprocessor(lexer) + p.parse(input,sys.argv[1]) + while True: + tok = p.token() + if not tok: break + print(p.source, tok) diff --git a/.venv/Lib/site-packages/pycparser/ply/ctokens.py b/.venv/Lib/site-packages/pycparser/ply/ctokens.py new file mode 100644 index 00000000..f6f6952d --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/ctokens.py @@ -0,0 +1,133 @@ +# ---------------------------------------------------------------------- +# ctokens.py +# +# Token specifications for symbols in ANSI C and C++. This file is +# meant to be used as a library in other tokenizers. +# ---------------------------------------------------------------------- + +# Reserved words + +tokens = [ + # Literals (identifier, integer constant, float constant, string constant, char const) + 'ID', 'TYPEID', 'INTEGER', 'FLOAT', 'STRING', 'CHARACTER', + + # Operators (+,-,*,/,%,|,&,~,^,<<,>>, ||, &&, !, <, <=, >, >=, ==, !=) + 'PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'MODULO', + 'OR', 'AND', 'NOT', 'XOR', 'LSHIFT', 'RSHIFT', + 'LOR', 'LAND', 'LNOT', + 'LT', 'LE', 'GT', 'GE', 'EQ', 'NE', + + # Assignment (=, *=, /=, %=, +=, -=, <<=, >>=, &=, ^=, |=) + 'EQUALS', 'TIMESEQUAL', 'DIVEQUAL', 'MODEQUAL', 'PLUSEQUAL', 'MINUSEQUAL', + 'LSHIFTEQUAL','RSHIFTEQUAL', 'ANDEQUAL', 'XOREQUAL', 'OREQUAL', + + # Increment/decrement (++,--) + 'INCREMENT', 'DECREMENT', + + # Structure dereference (->) + 'ARROW', + + # Ternary operator (?) + 'TERNARY', + + # Delimeters ( ) [ ] { } , . ; : + 'LPAREN', 'RPAREN', + 'LBRACKET', 'RBRACKET', + 'LBRACE', 'RBRACE', + 'COMMA', 'PERIOD', 'SEMI', 'COLON', + + # Ellipsis (...) + 'ELLIPSIS', +] + +# Operators +t_PLUS = r'\+' +t_MINUS = r'-' +t_TIMES = r'\*' +t_DIVIDE = r'/' +t_MODULO = r'%' +t_OR = r'\|' +t_AND = r'&' +t_NOT = r'~' +t_XOR = r'\^' +t_LSHIFT = r'<<' +t_RSHIFT = r'>>' +t_LOR = r'\|\|' +t_LAND = r'&&' +t_LNOT = r'!' +t_LT = r'<' +t_GT = r'>' +t_LE = r'<=' +t_GE = r'>=' +t_EQ = r'==' +t_NE = r'!=' + +# Assignment operators + +t_EQUALS = r'=' +t_TIMESEQUAL = r'\*=' +t_DIVEQUAL = r'/=' +t_MODEQUAL = r'%=' +t_PLUSEQUAL = r'\+=' +t_MINUSEQUAL = r'-=' +t_LSHIFTEQUAL = r'<<=' +t_RSHIFTEQUAL = r'>>=' +t_ANDEQUAL = r'&=' +t_OREQUAL = r'\|=' +t_XOREQUAL = r'\^=' + +# Increment/decrement +t_INCREMENT = r'\+\+' +t_DECREMENT = r'--' + +# -> +t_ARROW = r'->' + +# ? +t_TERNARY = r'\?' + +# Delimeters +t_LPAREN = r'\(' +t_RPAREN = r'\)' +t_LBRACKET = r'\[' +t_RBRACKET = r'\]' +t_LBRACE = r'\{' +t_RBRACE = r'\}' +t_COMMA = r',' +t_PERIOD = r'\.' +t_SEMI = r';' +t_COLON = r':' +t_ELLIPSIS = r'\.\.\.' + +# Identifiers +t_ID = r'[A-Za-z_][A-Za-z0-9_]*' + +# Integer literal +t_INTEGER = r'\d+([uU]|[lL]|[uU][lL]|[lL][uU])?' + +# Floating literal +t_FLOAT = r'((\d+)(\.\d+)(e(\+|-)?(\d+))? | (\d+)e(\+|-)?(\d+))([lL]|[fF])?' + +# String literal +t_STRING = r'\"([^\\\n]|(\\.))*?\"' + +# Character constant 'c' or L'c' +t_CHARACTER = r'(L)?\'([^\\\n]|(\\.))*?\'' + +# Comment (C-Style) +def t_COMMENT(t): + r'/\*(.|\n)*?\*/' + t.lexer.lineno += t.value.count('\n') + return t + +# Comment (C++-Style) +def t_CPPCOMMENT(t): + r'//.*\n' + t.lexer.lineno += 1 + return t + + + + + + diff --git a/.venv/Lib/site-packages/pycparser/ply/lex.py b/.venv/Lib/site-packages/pycparser/ply/lex.py new file mode 100644 index 00000000..4bdd76ca --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/lex.py @@ -0,0 +1,1099 @@ +# ----------------------------------------------------------------------------- +# ply: lex.py +# +# Copyright (C) 2001-2017 +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- + +__version__ = '3.10' +__tabversion__ = '3.10' + +import re +import sys +import types +import copy +import os +import inspect + +# This tuple contains known string types +try: + # Python 2.6 + StringTypes = (types.StringType, types.UnicodeType) +except AttributeError: + # Python 3.0 + StringTypes = (str, bytes) + +# This regular expression is used to match valid token names +_is_identifier = re.compile(r'^[a-zA-Z0-9_]+$') + +# Exception thrown when invalid token encountered and no default error +# handler is defined. +class LexError(Exception): + def __init__(self, message, s): + self.args = (message,) + self.text = s + + +# Token class. This class is used to represent the tokens produced. +class LexToken(object): + def __str__(self): + return 'LexToken(%s,%r,%d,%d)' % (self.type, self.value, self.lineno, self.lexpos) + + def __repr__(self): + return str(self) + + +# This object is a stand-in for a logging object created by the +# logging module. + +class PlyLogger(object): + def __init__(self, f): + self.f = f + + def critical(self, msg, *args, **kwargs): + self.f.write((msg % args) + '\n') + + def warning(self, msg, *args, **kwargs): + self.f.write('WARNING: ' + (msg % args) + '\n') + + def error(self, msg, *args, **kwargs): + self.f.write('ERROR: ' + (msg % args) + '\n') + + info = critical + debug = critical + + +# Null logger is used when no output is generated. Does nothing. +class NullLogger(object): + def __getattribute__(self, name): + return self + + def __call__(self, *args, **kwargs): + return self + + +# ----------------------------------------------------------------------------- +# === Lexing Engine === +# +# The following Lexer class implements the lexer runtime. There are only +# a few public methods and attributes: +# +# input() - Store a new string in the lexer +# token() - Get the next token +# clone() - Clone the lexer +# +# lineno - Current line number +# lexpos - Current position in the input string +# ----------------------------------------------------------------------------- + +class Lexer: + def __init__(self): + self.lexre = None # Master regular expression. This is a list of + # tuples (re, findex) where re is a compiled + # regular expression and findex is a list + # mapping regex group numbers to rules + self.lexretext = None # Current regular expression strings + self.lexstatere = {} # Dictionary mapping lexer states to master regexs + self.lexstateretext = {} # Dictionary mapping lexer states to regex strings + self.lexstaterenames = {} # Dictionary mapping lexer states to symbol names + self.lexstate = 'INITIAL' # Current lexer state + self.lexstatestack = [] # Stack of lexer states + self.lexstateinfo = None # State information + self.lexstateignore = {} # Dictionary of ignored characters for each state + self.lexstateerrorf = {} # Dictionary of error functions for each state + self.lexstateeoff = {} # Dictionary of eof functions for each state + self.lexreflags = 0 # Optional re compile flags + self.lexdata = None # Actual input data (as a string) + self.lexpos = 0 # Current position in input text + self.lexlen = 0 # Length of the input text + self.lexerrorf = None # Error rule (if any) + self.lexeoff = None # EOF rule (if any) + self.lextokens = None # List of valid tokens + self.lexignore = '' # Ignored characters + self.lexliterals = '' # Literal characters that can be passed through + self.lexmodule = None # Module + self.lineno = 1 # Current line number + self.lexoptimize = False # Optimized mode + + def clone(self, object=None): + c = copy.copy(self) + + # If the object parameter has been supplied, it means we are attaching the + # lexer to a new object. In this case, we have to rebind all methods in + # the lexstatere and lexstateerrorf tables. + + if object: + newtab = {} + for key, ritem in self.lexstatere.items(): + newre = [] + for cre, findex in ritem: + newfindex = [] + for f in findex: + if not f or not f[0]: + newfindex.append(f) + continue + newfindex.append((getattr(object, f[0].__name__), f[1])) + newre.append((cre, newfindex)) + newtab[key] = newre + c.lexstatere = newtab + c.lexstateerrorf = {} + for key, ef in self.lexstateerrorf.items(): + c.lexstateerrorf[key] = getattr(object, ef.__name__) + c.lexmodule = object + return c + + # ------------------------------------------------------------ + # writetab() - Write lexer information to a table file + # ------------------------------------------------------------ + def writetab(self, lextab, outputdir=''): + if isinstance(lextab, types.ModuleType): + raise IOError("Won't overwrite existing lextab module") + basetabmodule = lextab.split('.')[-1] + filename = os.path.join(outputdir, basetabmodule) + '.py' + with open(filename, 'w') as tf: + tf.write('# %s.py. This file automatically created by PLY (version %s). Don\'t edit!\n' % (basetabmodule, __version__)) + tf.write('_tabversion = %s\n' % repr(__tabversion__)) + tf.write('_lextokens = set(%s)\n' % repr(tuple(self.lextokens))) + tf.write('_lexreflags = %s\n' % repr(self.lexreflags)) + tf.write('_lexliterals = %s\n' % repr(self.lexliterals)) + tf.write('_lexstateinfo = %s\n' % repr(self.lexstateinfo)) + + # Rewrite the lexstatere table, replacing function objects with function names + tabre = {} + for statename, lre in self.lexstatere.items(): + titem = [] + for (pat, func), retext, renames in zip(lre, self.lexstateretext[statename], self.lexstaterenames[statename]): + titem.append((retext, _funcs_to_names(func, renames))) + tabre[statename] = titem + + tf.write('_lexstatere = %s\n' % repr(tabre)) + tf.write('_lexstateignore = %s\n' % repr(self.lexstateignore)) + + taberr = {} + for statename, ef in self.lexstateerrorf.items(): + taberr[statename] = ef.__name__ if ef else None + tf.write('_lexstateerrorf = %s\n' % repr(taberr)) + + tabeof = {} + for statename, ef in self.lexstateeoff.items(): + tabeof[statename] = ef.__name__ if ef else None + tf.write('_lexstateeoff = %s\n' % repr(tabeof)) + + # ------------------------------------------------------------ + # readtab() - Read lexer information from a tab file + # ------------------------------------------------------------ + def readtab(self, tabfile, fdict): + if isinstance(tabfile, types.ModuleType): + lextab = tabfile + else: + exec('import %s' % tabfile) + lextab = sys.modules[tabfile] + + if getattr(lextab, '_tabversion', '0.0') != __tabversion__: + raise ImportError('Inconsistent PLY version') + + self.lextokens = lextab._lextokens + self.lexreflags = lextab._lexreflags + self.lexliterals = lextab._lexliterals + self.lextokens_all = self.lextokens | set(self.lexliterals) + self.lexstateinfo = lextab._lexstateinfo + self.lexstateignore = lextab._lexstateignore + self.lexstatere = {} + self.lexstateretext = {} + for statename, lre in lextab._lexstatere.items(): + titem = [] + txtitem = [] + for pat, func_name in lre: + titem.append((re.compile(pat, lextab._lexreflags), _names_to_funcs(func_name, fdict))) + + self.lexstatere[statename] = titem + self.lexstateretext[statename] = txtitem + + self.lexstateerrorf = {} + for statename, ef in lextab._lexstateerrorf.items(): + self.lexstateerrorf[statename] = fdict[ef] + + self.lexstateeoff = {} + for statename, ef in lextab._lexstateeoff.items(): + self.lexstateeoff[statename] = fdict[ef] + + self.begin('INITIAL') + + # ------------------------------------------------------------ + # input() - Push a new string into the lexer + # ------------------------------------------------------------ + def input(self, s): + # Pull off the first character to see if s looks like a string + c = s[:1] + if not isinstance(c, StringTypes): + raise ValueError('Expected a string') + self.lexdata = s + self.lexpos = 0 + self.lexlen = len(s) + + # ------------------------------------------------------------ + # begin() - Changes the lexing state + # ------------------------------------------------------------ + def begin(self, state): + if state not in self.lexstatere: + raise ValueError('Undefined state') + self.lexre = self.lexstatere[state] + self.lexretext = self.lexstateretext[state] + self.lexignore = self.lexstateignore.get(state, '') + self.lexerrorf = self.lexstateerrorf.get(state, None) + self.lexeoff = self.lexstateeoff.get(state, None) + self.lexstate = state + + # ------------------------------------------------------------ + # push_state() - Changes the lexing state and saves old on stack + # ------------------------------------------------------------ + def push_state(self, state): + self.lexstatestack.append(self.lexstate) + self.begin(state) + + # ------------------------------------------------------------ + # pop_state() - Restores the previous state + # ------------------------------------------------------------ + def pop_state(self): + self.begin(self.lexstatestack.pop()) + + # ------------------------------------------------------------ + # current_state() - Returns the current lexing state + # ------------------------------------------------------------ + def current_state(self): + return self.lexstate + + # ------------------------------------------------------------ + # skip() - Skip ahead n characters + # ------------------------------------------------------------ + def skip(self, n): + self.lexpos += n + + # ------------------------------------------------------------ + # opttoken() - Return the next token from the Lexer + # + # Note: This function has been carefully implemented to be as fast + # as possible. Don't make changes unless you really know what + # you are doing + # ------------------------------------------------------------ + def token(self): + # Make local copies of frequently referenced attributes + lexpos = self.lexpos + lexlen = self.lexlen + lexignore = self.lexignore + lexdata = self.lexdata + + while lexpos < lexlen: + # This code provides some short-circuit code for whitespace, tabs, and other ignored characters + if lexdata[lexpos] in lexignore: + lexpos += 1 + continue + + # Look for a regular expression match + for lexre, lexindexfunc in self.lexre: + m = lexre.match(lexdata, lexpos) + if not m: + continue + + # Create a token for return + tok = LexToken() + tok.value = m.group() + tok.lineno = self.lineno + tok.lexpos = lexpos + + i = m.lastindex + func, tok.type = lexindexfunc[i] + + if not func: + # If no token type was set, it's an ignored token + if tok.type: + self.lexpos = m.end() + return tok + else: + lexpos = m.end() + break + + lexpos = m.end() + + # If token is processed by a function, call it + + tok.lexer = self # Set additional attributes useful in token rules + self.lexmatch = m + self.lexpos = lexpos + + newtok = func(tok) + + # Every function must return a token, if nothing, we just move to next token + if not newtok: + lexpos = self.lexpos # This is here in case user has updated lexpos. + lexignore = self.lexignore # This is here in case there was a state change + break + + # Verify type of the token. If not in the token map, raise an error + if not self.lexoptimize: + if newtok.type not in self.lextokens_all: + raise LexError("%s:%d: Rule '%s' returned an unknown token type '%s'" % ( + func.__code__.co_filename, func.__code__.co_firstlineno, + func.__name__, newtok.type), lexdata[lexpos:]) + + return newtok + else: + # No match, see if in literals + if lexdata[lexpos] in self.lexliterals: + tok = LexToken() + tok.value = lexdata[lexpos] + tok.lineno = self.lineno + tok.type = tok.value + tok.lexpos = lexpos + self.lexpos = lexpos + 1 + return tok + + # No match. Call t_error() if defined. + if self.lexerrorf: + tok = LexToken() + tok.value = self.lexdata[lexpos:] + tok.lineno = self.lineno + tok.type = 'error' + tok.lexer = self + tok.lexpos = lexpos + self.lexpos = lexpos + newtok = self.lexerrorf(tok) + if lexpos == self.lexpos: + # Error method didn't change text position at all. This is an error. + raise LexError("Scanning error. Illegal character '%s'" % (lexdata[lexpos]), lexdata[lexpos:]) + lexpos = self.lexpos + if not newtok: + continue + return newtok + + self.lexpos = lexpos + raise LexError("Illegal character '%s' at index %d" % (lexdata[lexpos], lexpos), lexdata[lexpos:]) + + if self.lexeoff: + tok = LexToken() + tok.type = 'eof' + tok.value = '' + tok.lineno = self.lineno + tok.lexpos = lexpos + tok.lexer = self + self.lexpos = lexpos + newtok = self.lexeoff(tok) + return newtok + + self.lexpos = lexpos + 1 + if self.lexdata is None: + raise RuntimeError('No input string given with input()') + return None + + # Iterator interface + def __iter__(self): + return self + + def next(self): + t = self.token() + if t is None: + raise StopIteration + return t + + __next__ = next + +# ----------------------------------------------------------------------------- +# ==== Lex Builder === +# +# The functions and classes below are used to collect lexing information +# and build a Lexer object from it. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# _get_regex(func) +# +# Returns the regular expression assigned to a function either as a doc string +# or as a .regex attribute attached by the @TOKEN decorator. +# ----------------------------------------------------------------------------- +def _get_regex(func): + return getattr(func, 'regex', func.__doc__) + +# ----------------------------------------------------------------------------- +# get_caller_module_dict() +# +# This function returns a dictionary containing all of the symbols defined within +# a caller further down the call stack. This is used to get the environment +# associated with the yacc() call if none was provided. +# ----------------------------------------------------------------------------- +def get_caller_module_dict(levels): + f = sys._getframe(levels) + ldict = f.f_globals.copy() + if f.f_globals != f.f_locals: + ldict.update(f.f_locals) + return ldict + +# ----------------------------------------------------------------------------- +# _funcs_to_names() +# +# Given a list of regular expression functions, this converts it to a list +# suitable for output to a table file +# ----------------------------------------------------------------------------- +def _funcs_to_names(funclist, namelist): + result = [] + for f, name in zip(funclist, namelist): + if f and f[0]: + result.append((name, f[1])) + else: + result.append(f) + return result + +# ----------------------------------------------------------------------------- +# _names_to_funcs() +# +# Given a list of regular expression function names, this converts it back to +# functions. +# ----------------------------------------------------------------------------- +def _names_to_funcs(namelist, fdict): + result = [] + for n in namelist: + if n and n[0]: + result.append((fdict[n[0]], n[1])) + else: + result.append(n) + return result + +# ----------------------------------------------------------------------------- +# _form_master_re() +# +# This function takes a list of all of the regex components and attempts to +# form the master regular expression. Given limitations in the Python re +# module, it may be necessary to break the master regex into separate expressions. +# ----------------------------------------------------------------------------- +def _form_master_re(relist, reflags, ldict, toknames): + if not relist: + return [] + regex = '|'.join(relist) + try: + lexre = re.compile(regex, reflags) + + # Build the index to function map for the matching engine + lexindexfunc = [None] * (max(lexre.groupindex.values()) + 1) + lexindexnames = lexindexfunc[:] + + for f, i in lexre.groupindex.items(): + handle = ldict.get(f, None) + if type(handle) in (types.FunctionType, types.MethodType): + lexindexfunc[i] = (handle, toknames[f]) + lexindexnames[i] = f + elif handle is not None: + lexindexnames[i] = f + if f.find('ignore_') > 0: + lexindexfunc[i] = (None, None) + else: + lexindexfunc[i] = (None, toknames[f]) + + return [(lexre, lexindexfunc)], [regex], [lexindexnames] + except Exception: + m = int(len(relist)/2) + if m == 0: + m = 1 + llist, lre, lnames = _form_master_re(relist[:m], reflags, ldict, toknames) + rlist, rre, rnames = _form_master_re(relist[m:], reflags, ldict, toknames) + return (llist+rlist), (lre+rre), (lnames+rnames) + +# ----------------------------------------------------------------------------- +# def _statetoken(s,names) +# +# Given a declaration name s of the form "t_" and a dictionary whose keys are +# state names, this function returns a tuple (states,tokenname) where states +# is a tuple of state names and tokenname is the name of the token. For example, +# calling this with s = "t_foo_bar_SPAM" might return (('foo','bar'),'SPAM') +# ----------------------------------------------------------------------------- +def _statetoken(s, names): + nonstate = 1 + parts = s.split('_') + for i, part in enumerate(parts[1:], 1): + if part not in names and part != 'ANY': + break + + if i > 1: + states = tuple(parts[1:i]) + else: + states = ('INITIAL',) + + if 'ANY' in states: + states = tuple(names) + + tokenname = '_'.join(parts[i:]) + return (states, tokenname) + + +# ----------------------------------------------------------------------------- +# LexerReflect() +# +# This class represents information needed to build a lexer as extracted from a +# user's input file. +# ----------------------------------------------------------------------------- +class LexerReflect(object): + def __init__(self, ldict, log=None, reflags=0): + self.ldict = ldict + self.error_func = None + self.tokens = [] + self.reflags = reflags + self.stateinfo = {'INITIAL': 'inclusive'} + self.modules = set() + self.error = False + self.log = PlyLogger(sys.stderr) if log is None else log + + # Get all of the basic information + def get_all(self): + self.get_tokens() + self.get_literals() + self.get_states() + self.get_rules() + + # Validate all of the information + def validate_all(self): + self.validate_tokens() + self.validate_literals() + self.validate_rules() + return self.error + + # Get the tokens map + def get_tokens(self): + tokens = self.ldict.get('tokens', None) + if not tokens: + self.log.error('No token list is defined') + self.error = True + return + + if not isinstance(tokens, (list, tuple)): + self.log.error('tokens must be a list or tuple') + self.error = True + return + + if not tokens: + self.log.error('tokens is empty') + self.error = True + return + + self.tokens = tokens + + # Validate the tokens + def validate_tokens(self): + terminals = {} + for n in self.tokens: + if not _is_identifier.match(n): + self.log.error("Bad token name '%s'", n) + self.error = True + if n in terminals: + self.log.warning("Token '%s' multiply defined", n) + terminals[n] = 1 + + # Get the literals specifier + def get_literals(self): + self.literals = self.ldict.get('literals', '') + if not self.literals: + self.literals = '' + + # Validate literals + def validate_literals(self): + try: + for c in self.literals: + if not isinstance(c, StringTypes) or len(c) > 1: + self.log.error('Invalid literal %s. Must be a single character', repr(c)) + self.error = True + + except TypeError: + self.log.error('Invalid literals specification. literals must be a sequence of characters') + self.error = True + + def get_states(self): + self.states = self.ldict.get('states', None) + # Build statemap + if self.states: + if not isinstance(self.states, (tuple, list)): + self.log.error('states must be defined as a tuple or list') + self.error = True + else: + for s in self.states: + if not isinstance(s, tuple) or len(s) != 2: + self.log.error("Invalid state specifier %s. Must be a tuple (statename,'exclusive|inclusive')", repr(s)) + self.error = True + continue + name, statetype = s + if not isinstance(name, StringTypes): + self.log.error('State name %s must be a string', repr(name)) + self.error = True + continue + if not (statetype == 'inclusive' or statetype == 'exclusive'): + self.log.error("State type for state %s must be 'inclusive' or 'exclusive'", name) + self.error = True + continue + if name in self.stateinfo: + self.log.error("State '%s' already defined", name) + self.error = True + continue + self.stateinfo[name] = statetype + + # Get all of the symbols with a t_ prefix and sort them into various + # categories (functions, strings, error functions, and ignore characters) + + def get_rules(self): + tsymbols = [f for f in self.ldict if f[:2] == 't_'] + + # Now build up a list of functions and a list of strings + self.toknames = {} # Mapping of symbols to token names + self.funcsym = {} # Symbols defined as functions + self.strsym = {} # Symbols defined as strings + self.ignore = {} # Ignore strings by state + self.errorf = {} # Error functions by state + self.eoff = {} # EOF functions by state + + for s in self.stateinfo: + self.funcsym[s] = [] + self.strsym[s] = [] + + if len(tsymbols) == 0: + self.log.error('No rules of the form t_rulename are defined') + self.error = True + return + + for f in tsymbols: + t = self.ldict[f] + states, tokname = _statetoken(f, self.stateinfo) + self.toknames[f] = tokname + + if hasattr(t, '__call__'): + if tokname == 'error': + for s in states: + self.errorf[s] = t + elif tokname == 'eof': + for s in states: + self.eoff[s] = t + elif tokname == 'ignore': + line = t.__code__.co_firstlineno + file = t.__code__.co_filename + self.log.error("%s:%d: Rule '%s' must be defined as a string", file, line, t.__name__) + self.error = True + else: + for s in states: + self.funcsym[s].append((f, t)) + elif isinstance(t, StringTypes): + if tokname == 'ignore': + for s in states: + self.ignore[s] = t + if '\\' in t: + self.log.warning("%s contains a literal backslash '\\'", f) + + elif tokname == 'error': + self.log.error("Rule '%s' must be defined as a function", f) + self.error = True + else: + for s in states: + self.strsym[s].append((f, t)) + else: + self.log.error('%s not defined as a function or string', f) + self.error = True + + # Sort the functions by line number + for f in self.funcsym.values(): + f.sort(key=lambda x: x[1].__code__.co_firstlineno) + + # Sort the strings by regular expression length + for s in self.strsym.values(): + s.sort(key=lambda x: len(x[1]), reverse=True) + + # Validate all of the t_rules collected + def validate_rules(self): + for state in self.stateinfo: + # Validate all rules defined by functions + + for fname, f in self.funcsym[state]: + line = f.__code__.co_firstlineno + file = f.__code__.co_filename + module = inspect.getmodule(f) + self.modules.add(module) + + tokname = self.toknames[fname] + if isinstance(f, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + nargs = f.__code__.co_argcount + if nargs > reqargs: + self.log.error("%s:%d: Rule '%s' has too many arguments", file, line, f.__name__) + self.error = True + continue + + if nargs < reqargs: + self.log.error("%s:%d: Rule '%s' requires an argument", file, line, f.__name__) + self.error = True + continue + + if not _get_regex(f): + self.log.error("%s:%d: No regular expression defined for rule '%s'", file, line, f.__name__) + self.error = True + continue + + try: + c = re.compile('(?P<%s>%s)' % (fname, _get_regex(f)), self.reflags) + if c.match(''): + self.log.error("%s:%d: Regular expression for rule '%s' matches empty string", file, line, f.__name__) + self.error = True + except re.error as e: + self.log.error("%s:%d: Invalid regular expression for rule '%s'. %s", file, line, f.__name__, e) + if '#' in _get_regex(f): + self.log.error("%s:%d. Make sure '#' in rule '%s' is escaped with '\\#'", file, line, f.__name__) + self.error = True + + # Validate all rules defined by strings + for name, r in self.strsym[state]: + tokname = self.toknames[name] + if tokname == 'error': + self.log.error("Rule '%s' must be defined as a function", name) + self.error = True + continue + + if tokname not in self.tokens and tokname.find('ignore_') < 0: + self.log.error("Rule '%s' defined for an unspecified token %s", name, tokname) + self.error = True + continue + + try: + c = re.compile('(?P<%s>%s)' % (name, r), self.reflags) + if (c.match('')): + self.log.error("Regular expression for rule '%s' matches empty string", name) + self.error = True + except re.error as e: + self.log.error("Invalid regular expression for rule '%s'. %s", name, e) + if '#' in r: + self.log.error("Make sure '#' in rule '%s' is escaped with '\\#'", name) + self.error = True + + if not self.funcsym[state] and not self.strsym[state]: + self.log.error("No rules defined for state '%s'", state) + self.error = True + + # Validate the error function + efunc = self.errorf.get(state, None) + if efunc: + f = efunc + line = f.__code__.co_firstlineno + file = f.__code__.co_filename + module = inspect.getmodule(f) + self.modules.add(module) + + if isinstance(f, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + nargs = f.__code__.co_argcount + if nargs > reqargs: + self.log.error("%s:%d: Rule '%s' has too many arguments", file, line, f.__name__) + self.error = True + + if nargs < reqargs: + self.log.error("%s:%d: Rule '%s' requires an argument", file, line, f.__name__) + self.error = True + + for module in self.modules: + self.validate_module(module) + + # ----------------------------------------------------------------------------- + # validate_module() + # + # This checks to see if there are duplicated t_rulename() functions or strings + # in the parser input file. This is done using a simple regular expression + # match on each line in the source code of the given module. + # ----------------------------------------------------------------------------- + + def validate_module(self, module): + try: + lines, linen = inspect.getsourcelines(module) + except IOError: + return + + fre = re.compile(r'\s*def\s+(t_[a-zA-Z_0-9]*)\(') + sre = re.compile(r'\s*(t_[a-zA-Z_0-9]*)\s*=') + + counthash = {} + linen += 1 + for line in lines: + m = fre.match(line) + if not m: + m = sre.match(line) + if m: + name = m.group(1) + prev = counthash.get(name) + if not prev: + counthash[name] = linen + else: + filename = inspect.getsourcefile(module) + self.log.error('%s:%d: Rule %s redefined. Previously defined on line %d', filename, linen, name, prev) + self.error = True + linen += 1 + +# ----------------------------------------------------------------------------- +# lex(module) +# +# Build all of the regular expression rules from definitions in the supplied module +# ----------------------------------------------------------------------------- +def lex(module=None, object=None, debug=False, optimize=False, lextab='lextab', + reflags=int(re.VERBOSE), nowarn=False, outputdir=None, debuglog=None, errorlog=None): + + if lextab is None: + lextab = 'lextab' + + global lexer + + ldict = None + stateinfo = {'INITIAL': 'inclusive'} + lexobj = Lexer() + lexobj.lexoptimize = optimize + global token, input + + if errorlog is None: + errorlog = PlyLogger(sys.stderr) + + if debug: + if debuglog is None: + debuglog = PlyLogger(sys.stderr) + + # Get the module dictionary used for the lexer + if object: + module = object + + # Get the module dictionary used for the parser + if module: + _items = [(k, getattr(module, k)) for k in dir(module)] + ldict = dict(_items) + # If no __file__ attribute is available, try to obtain it from the __module__ instead + if '__file__' not in ldict: + ldict['__file__'] = sys.modules[ldict['__module__']].__file__ + else: + ldict = get_caller_module_dict(2) + + # Determine if the module is package of a package or not. + # If so, fix the tabmodule setting so that tables load correctly + pkg = ldict.get('__package__') + if pkg and isinstance(lextab, str): + if '.' not in lextab: + lextab = pkg + '.' + lextab + + # Collect parser information from the dictionary + linfo = LexerReflect(ldict, log=errorlog, reflags=reflags) + linfo.get_all() + if not optimize: + if linfo.validate_all(): + raise SyntaxError("Can't build lexer") + + if optimize and lextab: + try: + lexobj.readtab(lextab, ldict) + token = lexobj.token + input = lexobj.input + lexer = lexobj + return lexobj + + except ImportError: + pass + + # Dump some basic debugging information + if debug: + debuglog.info('lex: tokens = %r', linfo.tokens) + debuglog.info('lex: literals = %r', linfo.literals) + debuglog.info('lex: states = %r', linfo.stateinfo) + + # Build a dictionary of valid token names + lexobj.lextokens = set() + for n in linfo.tokens: + lexobj.lextokens.add(n) + + # Get literals specification + if isinstance(linfo.literals, (list, tuple)): + lexobj.lexliterals = type(linfo.literals[0])().join(linfo.literals) + else: + lexobj.lexliterals = linfo.literals + + lexobj.lextokens_all = lexobj.lextokens | set(lexobj.lexliterals) + + # Get the stateinfo dictionary + stateinfo = linfo.stateinfo + + regexs = {} + # Build the master regular expressions + for state in stateinfo: + regex_list = [] + + # Add rules defined by functions first + for fname, f in linfo.funcsym[state]: + line = f.__code__.co_firstlineno + file = f.__code__.co_filename + regex_list.append('(?P<%s>%s)' % (fname, _get_regex(f))) + if debug: + debuglog.info("lex: Adding rule %s -> '%s' (state '%s')", fname, _get_regex(f), state) + + # Now add all of the simple rules + for name, r in linfo.strsym[state]: + regex_list.append('(?P<%s>%s)' % (name, r)) + if debug: + debuglog.info("lex: Adding rule %s -> '%s' (state '%s')", name, r, state) + + regexs[state] = regex_list + + # Build the master regular expressions + + if debug: + debuglog.info('lex: ==== MASTER REGEXS FOLLOW ====') + + for state in regexs: + lexre, re_text, re_names = _form_master_re(regexs[state], reflags, ldict, linfo.toknames) + lexobj.lexstatere[state] = lexre + lexobj.lexstateretext[state] = re_text + lexobj.lexstaterenames[state] = re_names + if debug: + for i, text in enumerate(re_text): + debuglog.info("lex: state '%s' : regex[%d] = '%s'", state, i, text) + + # For inclusive states, we need to add the regular expressions from the INITIAL state + for state, stype in stateinfo.items(): + if state != 'INITIAL' and stype == 'inclusive': + lexobj.lexstatere[state].extend(lexobj.lexstatere['INITIAL']) + lexobj.lexstateretext[state].extend(lexobj.lexstateretext['INITIAL']) + lexobj.lexstaterenames[state].extend(lexobj.lexstaterenames['INITIAL']) + + lexobj.lexstateinfo = stateinfo + lexobj.lexre = lexobj.lexstatere['INITIAL'] + lexobj.lexretext = lexobj.lexstateretext['INITIAL'] + lexobj.lexreflags = reflags + + # Set up ignore variables + lexobj.lexstateignore = linfo.ignore + lexobj.lexignore = lexobj.lexstateignore.get('INITIAL', '') + + # Set up error functions + lexobj.lexstateerrorf = linfo.errorf + lexobj.lexerrorf = linfo.errorf.get('INITIAL', None) + if not lexobj.lexerrorf: + errorlog.warning('No t_error rule is defined') + + # Set up eof functions + lexobj.lexstateeoff = linfo.eoff + lexobj.lexeoff = linfo.eoff.get('INITIAL', None) + + # Check state information for ignore and error rules + for s, stype in stateinfo.items(): + if stype == 'exclusive': + if s not in linfo.errorf: + errorlog.warning("No error rule is defined for exclusive state '%s'", s) + if s not in linfo.ignore and lexobj.lexignore: + errorlog.warning("No ignore rule is defined for exclusive state '%s'", s) + elif stype == 'inclusive': + if s not in linfo.errorf: + linfo.errorf[s] = linfo.errorf.get('INITIAL', None) + if s not in linfo.ignore: + linfo.ignore[s] = linfo.ignore.get('INITIAL', '') + + # Create global versions of the token() and input() functions + token = lexobj.token + input = lexobj.input + lexer = lexobj + + # If in optimize mode, we write the lextab + if lextab and optimize: + if outputdir is None: + # If no output directory is set, the location of the output files + # is determined according to the following rules: + # - If lextab specifies a package, files go into that package directory + # - Otherwise, files go in the same directory as the specifying module + if isinstance(lextab, types.ModuleType): + srcfile = lextab.__file__ + else: + if '.' not in lextab: + srcfile = ldict['__file__'] + else: + parts = lextab.split('.') + pkgname = '.'.join(parts[:-1]) + exec('import %s' % pkgname) + srcfile = getattr(sys.modules[pkgname], '__file__', '') + outputdir = os.path.dirname(srcfile) + try: + lexobj.writetab(lextab, outputdir) + except IOError as e: + errorlog.warning("Couldn't write lextab module %r. %s" % (lextab, e)) + + return lexobj + +# ----------------------------------------------------------------------------- +# runmain() +# +# This runs the lexer as a main program +# ----------------------------------------------------------------------------- + +def runmain(lexer=None, data=None): + if not data: + try: + filename = sys.argv[1] + f = open(filename) + data = f.read() + f.close() + except IndexError: + sys.stdout.write('Reading from standard input (type EOF to end):\n') + data = sys.stdin.read() + + if lexer: + _input = lexer.input + else: + _input = input + _input(data) + if lexer: + _token = lexer.token + else: + _token = token + + while True: + tok = _token() + if not tok: + break + sys.stdout.write('(%s,%r,%d,%d)\n' % (tok.type, tok.value, tok.lineno, tok.lexpos)) + +# ----------------------------------------------------------------------------- +# @TOKEN(regex) +# +# This decorator function can be used to set the regex expression on a function +# when its docstring might need to be set in an alternative way +# ----------------------------------------------------------------------------- + +def TOKEN(r): + def set_regex(f): + if hasattr(r, '__call__'): + f.regex = _get_regex(r) + else: + f.regex = r + return f + return set_regex + +# Alternative spelling of the TOKEN decorator +Token = TOKEN diff --git a/.venv/Lib/site-packages/pycparser/ply/yacc.py b/.venv/Lib/site-packages/pycparser/ply/yacc.py new file mode 100644 index 00000000..20b4f286 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/yacc.py @@ -0,0 +1,3494 @@ +# ----------------------------------------------------------------------------- +# ply: yacc.py +# +# Copyright (C) 2001-2017 +# David M. Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# +# This implements an LR parser that is constructed from grammar rules defined +# as Python functions. The grammer is specified by supplying the BNF inside +# Python documentation strings. The inspiration for this technique was borrowed +# from John Aycock's Spark parsing system. PLY might be viewed as cross between +# Spark and the GNU bison utility. +# +# The current implementation is only somewhat object-oriented. The +# LR parser itself is defined in terms of an object (which allows multiple +# parsers to co-exist). However, most of the variables used during table +# construction are defined in terms of global variables. Users shouldn't +# notice unless they are trying to define multiple parsers at the same +# time using threads (in which case they should have their head examined). +# +# This implementation supports both SLR and LALR(1) parsing. LALR(1) +# support was originally implemented by Elias Ioup (ezioup@alumni.uchicago.edu), +# using the algorithm found in Aho, Sethi, and Ullman "Compilers: Principles, +# Techniques, and Tools" (The Dragon Book). LALR(1) has since been replaced +# by the more efficient DeRemer and Pennello algorithm. +# +# :::::::: WARNING ::::::: +# +# Construction of LR parsing tables is fairly complicated and expensive. +# To make this module run fast, a *LOT* of work has been put into +# optimization---often at the expensive of readability and what might +# consider to be good Python "coding style." Modify the code at your +# own risk! +# ---------------------------------------------------------------------------- + +import re +import types +import sys +import os.path +import inspect +import base64 +import warnings + +__version__ = '3.10' +__tabversion__ = '3.10' + +#----------------------------------------------------------------------------- +# === User configurable parameters === +# +# Change these to modify the default behavior of yacc (if you wish) +#----------------------------------------------------------------------------- + +yaccdebug = True # Debugging mode. If set, yacc generates a + # a 'parser.out' file in the current directory + +debug_file = 'parser.out' # Default name of the debugging file +tab_module = 'parsetab' # Default name of the table module +default_lr = 'LALR' # Default LR table generation method + +error_count = 3 # Number of symbols that must be shifted to leave recovery mode + +yaccdevel = False # Set to True if developing yacc. This turns off optimized + # implementations of certain functions. + +resultlimit = 40 # Size limit of results when running in debug mode. + +pickle_protocol = 0 # Protocol to use when writing pickle files + +# String type-checking compatibility +if sys.version_info[0] < 3: + string_types = basestring +else: + string_types = str + +MAXINT = sys.maxsize + +# This object is a stand-in for a logging object created by the +# logging module. PLY will use this by default to create things +# such as the parser.out file. If a user wants more detailed +# information, they can create their own logging object and pass +# it into PLY. + +class PlyLogger(object): + def __init__(self, f): + self.f = f + + def debug(self, msg, *args, **kwargs): + self.f.write((msg % args) + '\n') + + info = debug + + def warning(self, msg, *args, **kwargs): + self.f.write('WARNING: ' + (msg % args) + '\n') + + def error(self, msg, *args, **kwargs): + self.f.write('ERROR: ' + (msg % args) + '\n') + + critical = debug + +# Null logger is used when no output is generated. Does nothing. +class NullLogger(object): + def __getattribute__(self, name): + return self + + def __call__(self, *args, **kwargs): + return self + +# Exception raised for yacc-related errors +class YaccError(Exception): + pass + +# Format the result message that the parser produces when running in debug mode. +def format_result(r): + repr_str = repr(r) + if '\n' in repr_str: + repr_str = repr(repr_str) + if len(repr_str) > resultlimit: + repr_str = repr_str[:resultlimit] + ' ...' + result = '<%s @ 0x%x> (%s)' % (type(r).__name__, id(r), repr_str) + return result + +# Format stack entries when the parser is running in debug mode +def format_stack_entry(r): + repr_str = repr(r) + if '\n' in repr_str: + repr_str = repr(repr_str) + if len(repr_str) < 16: + return repr_str + else: + return '<%s @ 0x%x>' % (type(r).__name__, id(r)) + +# Panic mode error recovery support. This feature is being reworked--much of the +# code here is to offer a deprecation/backwards compatible transition + +_errok = None +_token = None +_restart = None +_warnmsg = '''PLY: Don't use global functions errok(), token(), and restart() in p_error(). +Instead, invoke the methods on the associated parser instance: + + def p_error(p): + ... + # Use parser.errok(), parser.token(), parser.restart() + ... + + parser = yacc.yacc() +''' + +def errok(): + warnings.warn(_warnmsg) + return _errok() + +def restart(): + warnings.warn(_warnmsg) + return _restart() + +def token(): + warnings.warn(_warnmsg) + return _token() + +# Utility function to call the p_error() function with some deprecation hacks +def call_errorfunc(errorfunc, token, parser): + global _errok, _token, _restart + _errok = parser.errok + _token = parser.token + _restart = parser.restart + r = errorfunc(token) + try: + del _errok, _token, _restart + except NameError: + pass + return r + +#----------------------------------------------------------------------------- +# === LR Parsing Engine === +# +# The following classes are used for the LR parser itself. These are not +# used during table construction and are independent of the actual LR +# table generation algorithm +#----------------------------------------------------------------------------- + +# This class is used to hold non-terminal grammar symbols during parsing. +# It normally has the following attributes set: +# .type = Grammar symbol type +# .value = Symbol value +# .lineno = Starting line number +# .endlineno = Ending line number (optional, set automatically) +# .lexpos = Starting lex position +# .endlexpos = Ending lex position (optional, set automatically) + +class YaccSymbol: + def __str__(self): + return self.type + + def __repr__(self): + return str(self) + +# This class is a wrapper around the objects actually passed to each +# grammar rule. Index lookup and assignment actually assign the +# .value attribute of the underlying YaccSymbol object. +# The lineno() method returns the line number of a given +# item (or 0 if not defined). The linespan() method returns +# a tuple of (startline,endline) representing the range of lines +# for a symbol. The lexspan() method returns a tuple (lexpos,endlexpos) +# representing the range of positional information for a symbol. + +class YaccProduction: + def __init__(self, s, stack=None): + self.slice = s + self.stack = stack + self.lexer = None + self.parser = None + + def __getitem__(self, n): + if isinstance(n, slice): + return [s.value for s in self.slice[n]] + elif n >= 0: + return self.slice[n].value + else: + return self.stack[n].value + + def __setitem__(self, n, v): + self.slice[n].value = v + + def __getslice__(self, i, j): + return [s.value for s in self.slice[i:j]] + + def __len__(self): + return len(self.slice) + + def lineno(self, n): + return getattr(self.slice[n], 'lineno', 0) + + def set_lineno(self, n, lineno): + self.slice[n].lineno = lineno + + def linespan(self, n): + startline = getattr(self.slice[n], 'lineno', 0) + endline = getattr(self.slice[n], 'endlineno', startline) + return startline, endline + + def lexpos(self, n): + return getattr(self.slice[n], 'lexpos', 0) + + def lexspan(self, n): + startpos = getattr(self.slice[n], 'lexpos', 0) + endpos = getattr(self.slice[n], 'endlexpos', startpos) + return startpos, endpos + + def error(self): + raise SyntaxError + +# ----------------------------------------------------------------------------- +# == LRParser == +# +# The LR Parsing engine. +# ----------------------------------------------------------------------------- + +class LRParser: + def __init__(self, lrtab, errorf): + self.productions = lrtab.lr_productions + self.action = lrtab.lr_action + self.goto = lrtab.lr_goto + self.errorfunc = errorf + self.set_defaulted_states() + self.errorok = True + + def errok(self): + self.errorok = True + + def restart(self): + del self.statestack[:] + del self.symstack[:] + sym = YaccSymbol() + sym.type = '$end' + self.symstack.append(sym) + self.statestack.append(0) + + # Defaulted state support. + # This method identifies parser states where there is only one possible reduction action. + # For such states, the parser can make a choose to make a rule reduction without consuming + # the next look-ahead token. This delayed invocation of the tokenizer can be useful in + # certain kinds of advanced parsing situations where the lexer and parser interact with + # each other or change states (i.e., manipulation of scope, lexer states, etc.). + # + # See: https://www.gnu.org/software/bison/manual/html_node/Default-Reductions.html#Default-Reductions + def set_defaulted_states(self): + self.defaulted_states = {} + for state, actions in self.action.items(): + rules = list(actions.values()) + if len(rules) == 1 and rules[0] < 0: + self.defaulted_states[state] = rules[0] + + def disable_defaulted_states(self): + self.defaulted_states = {} + + def parse(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None): + if debug or yaccdevel: + if isinstance(debug, int): + debug = PlyLogger(sys.stderr) + return self.parsedebug(input, lexer, debug, tracking, tokenfunc) + elif tracking: + return self.parseopt(input, lexer, debug, tracking, tokenfunc) + else: + return self.parseopt_notrack(input, lexer, debug, tracking, tokenfunc) + + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parsedebug(). + # + # This is the debugging enabled version of parse(). All changes made to the + # parsing engine should be made here. Optimized versions of this function + # are automatically created by the ply/ygen.py script. This script cuts out + # sections enclosed in markers such as this: + # + # #--! DEBUG + # statements + # #--! DEBUG + # + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + def parsedebug(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None): + #--! parsedebug-start + lookahead = None # Current lookahead symbol + lookaheadstack = [] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + defaulted_states = self.defaulted_states # Local reference to defaulted states + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + #--! DEBUG + debug.info('PLY: PARSE DEBUG START') + #--! DEBUG + + # If no lexer was given, we will try to use the lex module + if not lexer: + from . import lex + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set the parser() token method (sometimes used in error recovery) + self.token = get_token + + # Set up the state and symbol stacks + + statestack = [] # Stack of parsing states + self.statestack = statestack + symstack = [] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = '$end' + symstack.append(sym) + state = 0 + while True: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + #--! DEBUG + debug.debug('') + debug.debug('State : %s', state) + #--! DEBUG + + if state not in defaulted_states: + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + else: + t = defaulted_states[state] + #--! DEBUG + debug.debug('Defaulted state %s: Reduce using %d', state, -t) + #--! DEBUG + + #--! DEBUG + debug.debug('Stack : %s', + ('%s . %s' % (' '.join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip()) + #--! DEBUG + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + #--! DEBUG + debug.debug('Action : Shift and goto state %s', t) + #--! DEBUG + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: + errorcount -= 1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + #--! DEBUG + if plen: + debug.info('Action : Reduce rule [%s] with %s and goto state %d', p.str, + '['+','.join([format_stack_entry(_v.value) for _v in symstack[-plen:]])+']', + goto[statestack[-1-plen]][pname]) + else: + debug.info('Action : Reduce rule [%s] with %s and goto state %d', p.str, [], + goto[statestack[-1]][pname]) + + #--! DEBUG + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + #--! TRACKING + if tracking: + t1 = targ[1] + sym.lineno = t1.lineno + sym.lexpos = t1.lexpos + t1 = targ[-1] + sym.endlineno = getattr(t1, 'endlineno', t1.lineno) + sym.endlexpos = getattr(t1, 'endlexpos', t1.lexpos) + #--! TRACKING + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + self.state = state + p.callable(pslice) + del statestack[-plen:] + #--! DEBUG + debug.info('Result : %s', format_result(pslice[0])) + #--! DEBUG + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + symstack.extend(targ[1:-1]) # Put the production slice back on the stack + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + #--! TRACKING + if tracking: + sym.lineno = lexer.lineno + sym.lexpos = lexer.lexpos + #--! TRACKING + + targ = [sym] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + self.state = state + p.callable(pslice) + #--! DEBUG + debug.info('Result : %s', format_result(pslice[0])) + #--! DEBUG + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + result = getattr(n, 'value', None) + #--! DEBUG + debug.info('Done : Returning %s', format_result(result)) + debug.info('PLY: PARSE DEBUG END') + #--! DEBUG + return result + + if t is None: + + #--! DEBUG + debug.error('Error : %s', + ('%s . %s' % (' '.join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip()) + #--! DEBUG + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = False + errtoken = lookahead + if errtoken.type == '$end': + errtoken = None # End of file! + if self.errorfunc: + if errtoken and not hasattr(errtoken, 'lexer'): + errtoken.lexer = lexer + self.state = state + tok = call_errorfunc(self.errorfunc, errtoken, self) + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken, 'lineno'): + lineno = lookahead.lineno + else: + lineno = 0 + if lineno: + sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type)) + else: + sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type) + else: + sys.stderr.write('yacc: Parse error in input. EOF\n') + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + #--! TRACKING + if tracking: + sym.endlineno = getattr(lookahead, 'lineno', sym.lineno) + sym.endlexpos = getattr(lookahead, 'lexpos', sym.lexpos) + #--! TRACKING + lookahead = None + continue + + # Create the error symbol for the first time and make it the new lookahead symbol + t = YaccSymbol() + t.type = 'error' + + if hasattr(lookahead, 'lineno'): + t.lineno = t.endlineno = lookahead.lineno + if hasattr(lookahead, 'lexpos'): + t.lexpos = t.endlexpos = lookahead.lexpos + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + sym = symstack.pop() + #--! TRACKING + if tracking: + lookahead.lineno = sym.lineno + lookahead.lexpos = sym.lexpos + #--! TRACKING + statestack.pop() + state = statestack[-1] + + continue + + # Call an error function here + raise RuntimeError('yacc: internal parser error!!!\n') + + #--! parsedebug-end + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parseopt(). + # + # Optimized version of parse() method. DO NOT EDIT THIS CODE DIRECTLY! + # This code is automatically generated by the ply/ygen.py script. Make + # changes to the parsedebug() method instead. + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + def parseopt(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None): + #--! parseopt-start + lookahead = None # Current lookahead symbol + lookaheadstack = [] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + defaulted_states = self.defaulted_states # Local reference to defaulted states + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + + # If no lexer was given, we will try to use the lex module + if not lexer: + from . import lex + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set the parser() token method (sometimes used in error recovery) + self.token = get_token + + # Set up the state and symbol stacks + + statestack = [] # Stack of parsing states + self.statestack = statestack + symstack = [] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = '$end' + symstack.append(sym) + state = 0 + while True: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + + if state not in defaulted_states: + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + else: + t = defaulted_states[state] + + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: + errorcount -= 1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + #--! TRACKING + if tracking: + t1 = targ[1] + sym.lineno = t1.lineno + sym.lexpos = t1.lexpos + t1 = targ[-1] + sym.endlineno = getattr(t1, 'endlineno', t1.lineno) + sym.endlexpos = getattr(t1, 'endlexpos', t1.lexpos) + #--! TRACKING + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + self.state = state + p.callable(pslice) + del statestack[-plen:] + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + symstack.extend(targ[1:-1]) # Put the production slice back on the stack + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + #--! TRACKING + if tracking: + sym.lineno = lexer.lineno + sym.lexpos = lexer.lexpos + #--! TRACKING + + targ = [sym] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + self.state = state + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + result = getattr(n, 'value', None) + return result + + if t is None: + + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = False + errtoken = lookahead + if errtoken.type == '$end': + errtoken = None # End of file! + if self.errorfunc: + if errtoken and not hasattr(errtoken, 'lexer'): + errtoken.lexer = lexer + self.state = state + tok = call_errorfunc(self.errorfunc, errtoken, self) + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken, 'lineno'): + lineno = lookahead.lineno + else: + lineno = 0 + if lineno: + sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type)) + else: + sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type) + else: + sys.stderr.write('yacc: Parse error in input. EOF\n') + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + #--! TRACKING + if tracking: + sym.endlineno = getattr(lookahead, 'lineno', sym.lineno) + sym.endlexpos = getattr(lookahead, 'lexpos', sym.lexpos) + #--! TRACKING + lookahead = None + continue + + # Create the error symbol for the first time and make it the new lookahead symbol + t = YaccSymbol() + t.type = 'error' + + if hasattr(lookahead, 'lineno'): + t.lineno = t.endlineno = lookahead.lineno + if hasattr(lookahead, 'lexpos'): + t.lexpos = t.endlexpos = lookahead.lexpos + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + sym = symstack.pop() + #--! TRACKING + if tracking: + lookahead.lineno = sym.lineno + lookahead.lexpos = sym.lexpos + #--! TRACKING + statestack.pop() + state = statestack[-1] + + continue + + # Call an error function here + raise RuntimeError('yacc: internal parser error!!!\n') + + #--! parseopt-end + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # parseopt_notrack(). + # + # Optimized version of parseopt() with line number tracking removed. + # DO NOT EDIT THIS CODE DIRECTLY. This code is automatically generated + # by the ply/ygen.py script. Make changes to the parsedebug() method instead. + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + def parseopt_notrack(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None): + #--! parseopt-notrack-start + lookahead = None # Current lookahead symbol + lookaheadstack = [] # Stack of lookahead symbols + actions = self.action # Local reference to action table (to avoid lookup on self.) + goto = self.goto # Local reference to goto table (to avoid lookup on self.) + prod = self.productions # Local reference to production list (to avoid lookup on self.) + defaulted_states = self.defaulted_states # Local reference to defaulted states + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery + + + # If no lexer was given, we will try to use the lex module + if not lexer: + from . import lex + lexer = lex.lexer + + # Set up the lexer and parser objects on pslice + pslice.lexer = lexer + pslice.parser = self + + # If input was supplied, pass to lexer + if input is not None: + lexer.input(input) + + if tokenfunc is None: + # Tokenize function + get_token = lexer.token + else: + get_token = tokenfunc + + # Set the parser() token method (sometimes used in error recovery) + self.token = get_token + + # Set up the state and symbol stacks + + statestack = [] # Stack of parsing states + self.statestack = statestack + symstack = [] # Stack of grammar symbols + self.symstack = symstack + + pslice.stack = symstack # Put in the production + errtoken = None # Err token + + # The start state is assumed to be (0,$end) + + statestack.append(0) + sym = YaccSymbol() + sym.type = '$end' + symstack.append(sym) + state = 0 + while True: + # Get the next symbol on the input. If a lookahead symbol + # is already set, we just use that. Otherwise, we'll pull + # the next token off of the lookaheadstack or from the lexer + + + if state not in defaulted_states: + if not lookahead: + if not lookaheadstack: + lookahead = get_token() # Get the next token + else: + lookahead = lookaheadstack.pop() + if not lookahead: + lookahead = YaccSymbol() + lookahead.type = '$end' + + # Check the action table + ltype = lookahead.type + t = actions[state].get(ltype) + else: + t = defaulted_states[state] + + + if t is not None: + if t > 0: + # shift a symbol on the stack + statestack.append(t) + state = t + + + symstack.append(lookahead) + lookahead = None + + # Decrease error count on successful shift + if errorcount: + errorcount -= 1 + continue + + if t < 0: + # reduce a symbol on the stack, emit a production + p = prod[-t] + pname = p.name + plen = p.len + + # Get production function + sym = YaccSymbol() + sym.type = pname # Production name + sym.value = None + + + if plen: + targ = symstack[-plen-1:] + targ[0] = sym + + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # below as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + del symstack[-plen:] + self.state = state + p.callable(pslice) + del statestack[-plen:] + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + symstack.extend(targ[1:-1]) # Put the production slice back on the stack + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + else: + + + targ = [sym] + + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # The code enclosed in this section is duplicated + # above as a performance optimization. Make sure + # changes get made in both locations. + + pslice.slice = targ + + try: + # Call the grammar rule with our special slice object + self.state = state + p.callable(pslice) + symstack.append(sym) + state = goto[statestack[-1]][pname] + statestack.append(state) + except SyntaxError: + # If an error was set. Enter error recovery state + lookaheadstack.append(lookahead) # Save the current lookahead token + statestack.pop() # Pop back one state (before the reduce) + state = statestack[-1] + sym.type = 'error' + sym.value = 'error' + lookahead = sym + errorcount = error_count + self.errorok = False + + continue + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + if t == 0: + n = symstack[-1] + result = getattr(n, 'value', None) + return result + + if t is None: + + + # We have some kind of parsing error here. To handle + # this, we are going to push the current token onto + # the tokenstack and replace it with an 'error' token. + # If there are any synchronization rules, they may + # catch it. + # + # In addition to pushing the error token, we call call + # the user defined p_error() function if this is the + # first syntax error. This function is only called if + # errorcount == 0. + if errorcount == 0 or self.errorok: + errorcount = error_count + self.errorok = False + errtoken = lookahead + if errtoken.type == '$end': + errtoken = None # End of file! + if self.errorfunc: + if errtoken and not hasattr(errtoken, 'lexer'): + errtoken.lexer = lexer + self.state = state + tok = call_errorfunc(self.errorfunc, errtoken, self) + if self.errorok: + # User must have done some kind of panic + # mode recovery on their own. The + # returned token is the next lookahead + lookahead = tok + errtoken = None + continue + else: + if errtoken: + if hasattr(errtoken, 'lineno'): + lineno = lookahead.lineno + else: + lineno = 0 + if lineno: + sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type)) + else: + sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type) + else: + sys.stderr.write('yacc: Parse error in input. EOF\n') + return + + else: + errorcount = error_count + + # case 1: the statestack only has 1 entry on it. If we're in this state, the + # entire parse has been rolled back and we're completely hosed. The token is + # discarded and we just keep going. + + if len(statestack) <= 1 and lookahead.type != '$end': + lookahead = None + errtoken = None + state = 0 + # Nuke the pushback stack + del lookaheadstack[:] + continue + + # case 2: the statestack has a couple of entries on it, but we're + # at the end of the file. nuke the top entry and generate an error token + + # Start nuking entries on the stack + if lookahead.type == '$end': + # Whoa. We're really hosed here. Bail out + return + + if lookahead.type != 'error': + sym = symstack[-1] + if sym.type == 'error': + # Hmmm. Error is on top of stack, we'll just nuke input + # symbol and continue + lookahead = None + continue + + # Create the error symbol for the first time and make it the new lookahead symbol + t = YaccSymbol() + t.type = 'error' + + if hasattr(lookahead, 'lineno'): + t.lineno = t.endlineno = lookahead.lineno + if hasattr(lookahead, 'lexpos'): + t.lexpos = t.endlexpos = lookahead.lexpos + t.value = lookahead + lookaheadstack.append(lookahead) + lookahead = t + else: + sym = symstack.pop() + statestack.pop() + state = statestack[-1] + + continue + + # Call an error function here + raise RuntimeError('yacc: internal parser error!!!\n') + + #--! parseopt-notrack-end + +# ----------------------------------------------------------------------------- +# === Grammar Representation === +# +# The following functions, classes, and variables are used to represent and +# manipulate the rules that make up a grammar. +# ----------------------------------------------------------------------------- + +# regex matching identifiers +_is_identifier = re.compile(r'^[a-zA-Z0-9_-]+$') + +# ----------------------------------------------------------------------------- +# class Production: +# +# This class stores the raw information about a single production or grammar rule. +# A grammar rule refers to a specification such as this: +# +# expr : expr PLUS term +# +# Here are the basic attributes defined on all productions +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','PLUS','term'] +# prec - Production precedence level +# number - Production number. +# func - Function that executes on reduce +# file - File where production function is defined +# lineno - Line number where production function is defined +# +# The following attributes are defined or optional. +# +# len - Length of the production (number of symbols on right hand side) +# usyms - Set of unique symbols found in the production +# ----------------------------------------------------------------------------- + +class Production(object): + reduced = 0 + def __init__(self, number, name, prod, precedence=('right', 0), func=None, file='', line=0): + self.name = name + self.prod = tuple(prod) + self.number = number + self.func = func + self.callable = None + self.file = file + self.line = line + self.prec = precedence + + # Internal settings used during table construction + + self.len = len(self.prod) # Length of the production + + # Create a list of unique production symbols used in the production + self.usyms = [] + for s in self.prod: + if s not in self.usyms: + self.usyms.append(s) + + # List of all LR items for the production + self.lr_items = [] + self.lr_next = None + + # Create a string representation + if self.prod: + self.str = '%s -> %s' % (self.name, ' '.join(self.prod)) + else: + self.str = '%s -> ' % self.name + + def __str__(self): + return self.str + + def __repr__(self): + return 'Production(' + str(self) + ')' + + def __len__(self): + return len(self.prod) + + def __nonzero__(self): + return 1 + + def __getitem__(self, index): + return self.prod[index] + + # Return the nth lr_item from the production (or None if at the end) + def lr_item(self, n): + if n > len(self.prod): + return None + p = LRItem(self, n) + # Precompute the list of productions immediately following. + try: + p.lr_after = Prodnames[p.prod[n+1]] + except (IndexError, KeyError): + p.lr_after = [] + try: + p.lr_before = p.prod[n-1] + except IndexError: + p.lr_before = None + return p + + # Bind the production function name to a callable + def bind(self, pdict): + if self.func: + self.callable = pdict[self.func] + +# This class serves as a minimal standin for Production objects when +# reading table data from files. It only contains information +# actually used by the LR parsing engine, plus some additional +# debugging information. +class MiniProduction(object): + def __init__(self, str, name, len, func, file, line): + self.name = name + self.len = len + self.func = func + self.callable = None + self.file = file + self.line = line + self.str = str + + def __str__(self): + return self.str + + def __repr__(self): + return 'MiniProduction(%s)' % self.str + + # Bind the production function name to a callable + def bind(self, pdict): + if self.func: + self.callable = pdict[self.func] + + +# ----------------------------------------------------------------------------- +# class LRItem +# +# This class represents a specific stage of parsing a production rule. For +# example: +# +# expr : expr . PLUS term +# +# In the above, the "." represents the current location of the parse. Here +# basic attributes: +# +# name - Name of the production. For example 'expr' +# prod - A list of symbols on the right side ['expr','.', 'PLUS','term'] +# number - Production number. +# +# lr_next Next LR item. Example, if we are ' expr -> expr . PLUS term' +# then lr_next refers to 'expr -> expr PLUS . term' +# lr_index - LR item index (location of the ".") in the prod list. +# lookaheads - LALR lookahead symbols for this item +# len - Length of the production (number of symbols on right hand side) +# lr_after - List of all productions that immediately follow +# lr_before - Grammar symbol immediately before +# ----------------------------------------------------------------------------- + +class LRItem(object): + def __init__(self, p, n): + self.name = p.name + self.prod = list(p.prod) + self.number = p.number + self.lr_index = n + self.lookaheads = {} + self.prod.insert(n, '.') + self.prod = tuple(self.prod) + self.len = len(self.prod) + self.usyms = p.usyms + + def __str__(self): + if self.prod: + s = '%s -> %s' % (self.name, ' '.join(self.prod)) + else: + s = '%s -> ' % self.name + return s + + def __repr__(self): + return 'LRItem(' + str(self) + ')' + +# ----------------------------------------------------------------------------- +# rightmost_terminal() +# +# Return the rightmost terminal from a list of symbols. Used in add_production() +# ----------------------------------------------------------------------------- +def rightmost_terminal(symbols, terminals): + i = len(symbols) - 1 + while i >= 0: + if symbols[i] in terminals: + return symbols[i] + i -= 1 + return None + +# ----------------------------------------------------------------------------- +# === GRAMMAR CLASS === +# +# The following class represents the contents of the specified grammar along +# with various computed properties such as first sets, follow sets, LR items, etc. +# This data is used for critical parts of the table generation process later. +# ----------------------------------------------------------------------------- + +class GrammarError(YaccError): + pass + +class Grammar(object): + def __init__(self, terminals): + self.Productions = [None] # A list of all of the productions. The first + # entry is always reserved for the purpose of + # building an augmented grammar + + self.Prodnames = {} # A dictionary mapping the names of nonterminals to a list of all + # productions of that nonterminal. + + self.Prodmap = {} # A dictionary that is only used to detect duplicate + # productions. + + self.Terminals = {} # A dictionary mapping the names of terminal symbols to a + # list of the rules where they are used. + + for term in terminals: + self.Terminals[term] = [] + + self.Terminals['error'] = [] + + self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list + # of rule numbers where they are used. + + self.First = {} # A dictionary of precomputed FIRST(x) symbols + + self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols + + self.Precedence = {} # Precedence rules for each terminal. Contains tuples of the + # form ('right',level) or ('nonassoc', level) or ('left',level) + + self.UsedPrecedence = set() # Precedence rules that were actually used by the grammer. + # This is only used to provide error checking and to generate + # a warning about unused precedence rules. + + self.Start = None # Starting symbol for the grammar + + + def __len__(self): + return len(self.Productions) + + def __getitem__(self, index): + return self.Productions[index] + + # ----------------------------------------------------------------------------- + # set_precedence() + # + # Sets the precedence for a given terminal. assoc is the associativity such as + # 'left','right', or 'nonassoc'. level is a numeric level. + # + # ----------------------------------------------------------------------------- + + def set_precedence(self, term, assoc, level): + assert self.Productions == [None], 'Must call set_precedence() before add_production()' + if term in self.Precedence: + raise GrammarError('Precedence already specified for terminal %r' % term) + if assoc not in ['left', 'right', 'nonassoc']: + raise GrammarError("Associativity must be one of 'left','right', or 'nonassoc'") + self.Precedence[term] = (assoc, level) + + # ----------------------------------------------------------------------------- + # add_production() + # + # Given an action function, this function assembles a production rule and + # computes its precedence level. + # + # The production rule is supplied as a list of symbols. For example, + # a rule such as 'expr : expr PLUS term' has a production name of 'expr' and + # symbols ['expr','PLUS','term']. + # + # Precedence is determined by the precedence of the right-most non-terminal + # or the precedence of a terminal specified by %prec. + # + # A variety of error checks are performed to make sure production symbols + # are valid and that %prec is used correctly. + # ----------------------------------------------------------------------------- + + def add_production(self, prodname, syms, func=None, file='', line=0): + + if prodname in self.Terminals: + raise GrammarError('%s:%d: Illegal rule name %r. Already defined as a token' % (file, line, prodname)) + if prodname == 'error': + raise GrammarError('%s:%d: Illegal rule name %r. error is a reserved word' % (file, line, prodname)) + if not _is_identifier.match(prodname): + raise GrammarError('%s:%d: Illegal rule name %r' % (file, line, prodname)) + + # Look for literal tokens + for n, s in enumerate(syms): + if s[0] in "'\"": + try: + c = eval(s) + if (len(c) > 1): + raise GrammarError('%s:%d: Literal token %s in rule %r may only be a single character' % + (file, line, s, prodname)) + if c not in self.Terminals: + self.Terminals[c] = [] + syms[n] = c + continue + except SyntaxError: + pass + if not _is_identifier.match(s) and s != '%prec': + raise GrammarError('%s:%d: Illegal name %r in rule %r' % (file, line, s, prodname)) + + # Determine the precedence level + if '%prec' in syms: + if syms[-1] == '%prec': + raise GrammarError('%s:%d: Syntax error. Nothing follows %%prec' % (file, line)) + if syms[-2] != '%prec': + raise GrammarError('%s:%d: Syntax error. %%prec can only appear at the end of a grammar rule' % + (file, line)) + precname = syms[-1] + prodprec = self.Precedence.get(precname) + if not prodprec: + raise GrammarError('%s:%d: Nothing known about the precedence of %r' % (file, line, precname)) + else: + self.UsedPrecedence.add(precname) + del syms[-2:] # Drop %prec from the rule + else: + # If no %prec, precedence is determined by the rightmost terminal symbol + precname = rightmost_terminal(syms, self.Terminals) + prodprec = self.Precedence.get(precname, ('right', 0)) + + # See if the rule is already in the rulemap + map = '%s -> %s' % (prodname, syms) + if map in self.Prodmap: + m = self.Prodmap[map] + raise GrammarError('%s:%d: Duplicate rule %s. ' % (file, line, m) + + 'Previous definition at %s:%d' % (m.file, m.line)) + + # From this point on, everything is valid. Create a new Production instance + pnumber = len(self.Productions) + if prodname not in self.Nonterminals: + self.Nonterminals[prodname] = [] + + # Add the production number to Terminals and Nonterminals + for t in syms: + if t in self.Terminals: + self.Terminals[t].append(pnumber) + else: + if t not in self.Nonterminals: + self.Nonterminals[t] = [] + self.Nonterminals[t].append(pnumber) + + # Create a production and add it to the list of productions + p = Production(pnumber, prodname, syms, prodprec, func, file, line) + self.Productions.append(p) + self.Prodmap[map] = p + + # Add to the global productions list + try: + self.Prodnames[prodname].append(p) + except KeyError: + self.Prodnames[prodname] = [p] + + # ----------------------------------------------------------------------------- + # set_start() + # + # Sets the starting symbol and creates the augmented grammar. Production + # rule 0 is S' -> start where start is the start symbol. + # ----------------------------------------------------------------------------- + + def set_start(self, start=None): + if not start: + start = self.Productions[1].name + if start not in self.Nonterminals: + raise GrammarError('start symbol %s undefined' % start) + self.Productions[0] = Production(0, "S'", [start]) + self.Nonterminals[start].append(0) + self.Start = start + + # ----------------------------------------------------------------------------- + # find_unreachable() + # + # Find all of the nonterminal symbols that can't be reached from the starting + # symbol. Returns a list of nonterminals that can't be reached. + # ----------------------------------------------------------------------------- + + def find_unreachable(self): + + # Mark all symbols that are reachable from a symbol s + def mark_reachable_from(s): + if s in reachable: + return + reachable.add(s) + for p in self.Prodnames.get(s, []): + for r in p.prod: + mark_reachable_from(r) + + reachable = set() + mark_reachable_from(self.Productions[0].prod[0]) + return [s for s in self.Nonterminals if s not in reachable] + + # ----------------------------------------------------------------------------- + # infinite_cycles() + # + # This function looks at the various parsing rules and tries to detect + # infinite recursion cycles (grammar rules where there is no possible way + # to derive a string of only terminals). + # ----------------------------------------------------------------------------- + + def infinite_cycles(self): + terminates = {} + + # Terminals: + for t in self.Terminals: + terminates[t] = True + + terminates['$end'] = True + + # Nonterminals: + + # Initialize to false: + for n in self.Nonterminals: + terminates[n] = False + + # Then propagate termination until no change: + while True: + some_change = False + for (n, pl) in self.Prodnames.items(): + # Nonterminal n terminates iff any of its productions terminates. + for p in pl: + # Production p terminates iff all of its rhs symbols terminate. + for s in p.prod: + if not terminates[s]: + # The symbol s does not terminate, + # so production p does not terminate. + p_terminates = False + break + else: + # didn't break from the loop, + # so every symbol s terminates + # so production p terminates. + p_terminates = True + + if p_terminates: + # symbol n terminates! + if not terminates[n]: + terminates[n] = True + some_change = True + # Don't need to consider any more productions for this n. + break + + if not some_change: + break + + infinite = [] + for (s, term) in terminates.items(): + if not term: + if s not in self.Prodnames and s not in self.Terminals and s != 'error': + # s is used-but-not-defined, and we've already warned of that, + # so it would be overkill to say that it's also non-terminating. + pass + else: + infinite.append(s) + + return infinite + + # ----------------------------------------------------------------------------- + # undefined_symbols() + # + # Find all symbols that were used the grammar, but not defined as tokens or + # grammar rules. Returns a list of tuples (sym, prod) where sym in the symbol + # and prod is the production where the symbol was used. + # ----------------------------------------------------------------------------- + def undefined_symbols(self): + result = [] + for p in self.Productions: + if not p: + continue + + for s in p.prod: + if s not in self.Prodnames and s not in self.Terminals and s != 'error': + result.append((s, p)) + return result + + # ----------------------------------------------------------------------------- + # unused_terminals() + # + # Find all terminals that were defined, but not used by the grammar. Returns + # a list of all symbols. + # ----------------------------------------------------------------------------- + def unused_terminals(self): + unused_tok = [] + for s, v in self.Terminals.items(): + if s != 'error' and not v: + unused_tok.append(s) + + return unused_tok + + # ------------------------------------------------------------------------------ + # unused_rules() + # + # Find all grammar rules that were defined, but not used (maybe not reachable) + # Returns a list of productions. + # ------------------------------------------------------------------------------ + + def unused_rules(self): + unused_prod = [] + for s, v in self.Nonterminals.items(): + if not v: + p = self.Prodnames[s][0] + unused_prod.append(p) + return unused_prod + + # ----------------------------------------------------------------------------- + # unused_precedence() + # + # Returns a list of tuples (term,precedence) corresponding to precedence + # rules that were never used by the grammar. term is the name of the terminal + # on which precedence was applied and precedence is a string such as 'left' or + # 'right' corresponding to the type of precedence. + # ----------------------------------------------------------------------------- + + def unused_precedence(self): + unused = [] + for termname in self.Precedence: + if not (termname in self.Terminals or termname in self.UsedPrecedence): + unused.append((termname, self.Precedence[termname][0])) + + return unused + + # ------------------------------------------------------------------------- + # _first() + # + # Compute the value of FIRST1(beta) where beta is a tuple of symbols. + # + # During execution of compute_first1, the result may be incomplete. + # Afterward (e.g., when called from compute_follow()), it will be complete. + # ------------------------------------------------------------------------- + def _first(self, beta): + + # We are computing First(x1,x2,x3,...,xn) + result = [] + for x in beta: + x_produces_empty = False + + # Add all the non- symbols of First[x] to the result. + for f in self.First[x]: + if f == '': + x_produces_empty = True + else: + if f not in result: + result.append(f) + + if x_produces_empty: + # We have to consider the next x in beta, + # i.e. stay in the loop. + pass + else: + # We don't have to consider any further symbols in beta. + break + else: + # There was no 'break' from the loop, + # so x_produces_empty was true for all x in beta, + # so beta produces empty as well. + result.append('') + + return result + + # ------------------------------------------------------------------------- + # compute_first() + # + # Compute the value of FIRST1(X) for all symbols + # ------------------------------------------------------------------------- + def compute_first(self): + if self.First: + return self.First + + # Terminals: + for t in self.Terminals: + self.First[t] = [t] + + self.First['$end'] = ['$end'] + + # Nonterminals: + + # Initialize to the empty set: + for n in self.Nonterminals: + self.First[n] = [] + + # Then propagate symbols until no change: + while True: + some_change = False + for n in self.Nonterminals: + for p in self.Prodnames[n]: + for f in self._first(p.prod): + if f not in self.First[n]: + self.First[n].append(f) + some_change = True + if not some_change: + break + + return self.First + + # --------------------------------------------------------------------- + # compute_follow() + # + # Computes all of the follow sets for every non-terminal symbol. The + # follow set is the set of all symbols that might follow a given + # non-terminal. See the Dragon book, 2nd Ed. p. 189. + # --------------------------------------------------------------------- + def compute_follow(self, start=None): + # If already computed, return the result + if self.Follow: + return self.Follow + + # If first sets not computed yet, do that first. + if not self.First: + self.compute_first() + + # Add '$end' to the follow list of the start symbol + for k in self.Nonterminals: + self.Follow[k] = [] + + if not start: + start = self.Productions[1].name + + self.Follow[start] = ['$end'] + + while True: + didadd = False + for p in self.Productions[1:]: + # Here is the production set + for i, B in enumerate(p.prod): + if B in self.Nonterminals: + # Okay. We got a non-terminal in a production + fst = self._first(p.prod[i+1:]) + hasempty = False + for f in fst: + if f != '' and f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = True + if f == '': + hasempty = True + if hasempty or i == (len(p.prod)-1): + # Add elements of follow(a) to follow(b) + for f in self.Follow[p.name]: + if f not in self.Follow[B]: + self.Follow[B].append(f) + didadd = True + if not didadd: + break + return self.Follow + + + # ----------------------------------------------------------------------------- + # build_lritems() + # + # This function walks the list of productions and builds a complete set of the + # LR items. The LR items are stored in two ways: First, they are uniquely + # numbered and placed in the list _lritems. Second, a linked list of LR items + # is built for each production. For example: + # + # E -> E PLUS E + # + # Creates the list + # + # [E -> . E PLUS E, E -> E . PLUS E, E -> E PLUS . E, E -> E PLUS E . ] + # ----------------------------------------------------------------------------- + + def build_lritems(self): + for p in self.Productions: + lastlri = p + i = 0 + lr_items = [] + while True: + if i > len(p): + lri = None + else: + lri = LRItem(p, i) + # Precompute the list of productions immediately following + try: + lri.lr_after = self.Prodnames[lri.prod[i+1]] + except (IndexError, KeyError): + lri.lr_after = [] + try: + lri.lr_before = lri.prod[i-1] + except IndexError: + lri.lr_before = None + + lastlri.lr_next = lri + if not lri: + break + lr_items.append(lri) + lastlri = lri + i += 1 + p.lr_items = lr_items + +# ----------------------------------------------------------------------------- +# == Class LRTable == +# +# This basic class represents a basic table of LR parsing information. +# Methods for generating the tables are not defined here. They are defined +# in the derived class LRGeneratedTable. +# ----------------------------------------------------------------------------- + +class VersionError(YaccError): + pass + +class LRTable(object): + def __init__(self): + self.lr_action = None + self.lr_goto = None + self.lr_productions = None + self.lr_method = None + + def read_table(self, module): + if isinstance(module, types.ModuleType): + parsetab = module + else: + exec('import %s' % module) + parsetab = sys.modules[module] + + if parsetab._tabversion != __tabversion__: + raise VersionError('yacc table file version is out of date') + + self.lr_action = parsetab._lr_action + self.lr_goto = parsetab._lr_goto + + self.lr_productions = [] + for p in parsetab._lr_productions: + self.lr_productions.append(MiniProduction(*p)) + + self.lr_method = parsetab._lr_method + return parsetab._lr_signature + + def read_pickle(self, filename): + try: + import cPickle as pickle + except ImportError: + import pickle + + if not os.path.exists(filename): + raise ImportError + + in_f = open(filename, 'rb') + + tabversion = pickle.load(in_f) + if tabversion != __tabversion__: + raise VersionError('yacc table file version is out of date') + self.lr_method = pickle.load(in_f) + signature = pickle.load(in_f) + self.lr_action = pickle.load(in_f) + self.lr_goto = pickle.load(in_f) + productions = pickle.load(in_f) + + self.lr_productions = [] + for p in productions: + self.lr_productions.append(MiniProduction(*p)) + + in_f.close() + return signature + + # Bind all production function names to callable objects in pdict + def bind_callables(self, pdict): + for p in self.lr_productions: + p.bind(pdict) + + +# ----------------------------------------------------------------------------- +# === LR Generator === +# +# The following classes and functions are used to generate LR parsing tables on +# a grammar. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# digraph() +# traverse() +# +# The following two functions are used to compute set valued functions +# of the form: +# +# F(x) = F'(x) U U{F(y) | x R y} +# +# This is used to compute the values of Read() sets as well as FOLLOW sets +# in LALR(1) generation. +# +# Inputs: X - An input set +# R - A relation +# FP - Set-valued function +# ------------------------------------------------------------------------------ + +def digraph(X, R, FP): + N = {} + for x in X: + N[x] = 0 + stack = [] + F = {} + for x in X: + if N[x] == 0: + traverse(x, N, stack, F, X, R, FP) + return F + +def traverse(x, N, stack, F, X, R, FP): + stack.append(x) + d = len(stack) + N[x] = d + F[x] = FP(x) # F(X) <- F'(x) + + rel = R(x) # Get y's related to x + for y in rel: + if N[y] == 0: + traverse(y, N, stack, F, X, R, FP) + N[x] = min(N[x], N[y]) + for a in F.get(y, []): + if a not in F[x]: + F[x].append(a) + if N[x] == d: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + while element != x: + N[stack[-1]] = MAXINT + F[stack[-1]] = F[x] + element = stack.pop() + +class LALRError(YaccError): + pass + +# ----------------------------------------------------------------------------- +# == LRGeneratedTable == +# +# This class implements the LR table generation algorithm. There are no +# public methods except for write() +# ----------------------------------------------------------------------------- + +class LRGeneratedTable(LRTable): + def __init__(self, grammar, method='LALR', log=None): + if method not in ['SLR', 'LALR']: + raise LALRError('Unsupported method %s' % method) + + self.grammar = grammar + self.lr_method = method + + # Set up the logger + if not log: + log = NullLogger() + self.log = log + + # Internal attributes + self.lr_action = {} # Action table + self.lr_goto = {} # Goto table + self.lr_productions = grammar.Productions # Copy of grammar Production array + self.lr_goto_cache = {} # Cache of computed gotos + self.lr0_cidhash = {} # Cache of closures + + self._add_count = 0 # Internal counter used to detect cycles + + # Diagonistic information filled in by the table generator + self.sr_conflict = 0 + self.rr_conflict = 0 + self.conflicts = [] # List of conflicts + + self.sr_conflicts = [] + self.rr_conflicts = [] + + # Build the tables + self.grammar.build_lritems() + self.grammar.compute_first() + self.grammar.compute_follow() + self.lr_parse_table() + + # Compute the LR(0) closure operation on I, where I is a set of LR(0) items. + + def lr0_closure(self, I): + self._add_count += 1 + + # Add everything in I to J + J = I[:] + didadd = True + while didadd: + didadd = False + for j in J: + for x in j.lr_after: + if getattr(x, 'lr0_added', 0) == self._add_count: + continue + # Add B --> .G to J + J.append(x.lr_next) + x.lr0_added = self._add_count + didadd = True + + return J + + # Compute the LR(0) goto function goto(I,X) where I is a set + # of LR(0) items and X is a grammar symbol. This function is written + # in a way that guarantees uniqueness of the generated goto sets + # (i.e. the same goto set will never be returned as two different Python + # objects). With uniqueness, we can later do fast set comparisons using + # id(obj) instead of element-wise comparison. + + def lr0_goto(self, I, x): + # First we look for a previously cached entry + g = self.lr_goto_cache.get((id(I), x)) + if g: + return g + + # Now we generate the goto set in a way that guarantees uniqueness + # of the result + + s = self.lr_goto_cache.get(x) + if not s: + s = {} + self.lr_goto_cache[x] = s + + gs = [] + for p in I: + n = p.lr_next + if n and n.lr_before == x: + s1 = s.get(id(n)) + if not s1: + s1 = {} + s[id(n)] = s1 + gs.append(n) + s = s1 + g = s.get('$end') + if not g: + if gs: + g = self.lr0_closure(gs) + s['$end'] = g + else: + s['$end'] = gs + self.lr_goto_cache[(id(I), x)] = g + return g + + # Compute the LR(0) sets of item function + def lr0_items(self): + C = [self.lr0_closure([self.grammar.Productions[0].lr_next])] + i = 0 + for I in C: + self.lr0_cidhash[id(I)] = i + i += 1 + + # Loop over the items in C and each grammar symbols + i = 0 + while i < len(C): + I = C[i] + i += 1 + + # Collect all of the symbols that could possibly be in the goto(I,X) sets + asyms = {} + for ii in I: + for s in ii.usyms: + asyms[s] = None + + for x in asyms: + g = self.lr0_goto(I, x) + if not g or id(g) in self.lr0_cidhash: + continue + self.lr0_cidhash[id(g)] = len(C) + C.append(g) + + return C + + # ----------------------------------------------------------------------------- + # ==== LALR(1) Parsing ==== + # + # LALR(1) parsing is almost exactly the same as SLR except that instead of + # relying upon Follow() sets when performing reductions, a more selective + # lookahead set that incorporates the state of the LR(0) machine is utilized. + # Thus, we mainly just have to focus on calculating the lookahead sets. + # + # The method used here is due to DeRemer and Pennelo (1982). + # + # DeRemer, F. L., and T. J. Pennelo: "Efficient Computation of LALR(1) + # Lookahead Sets", ACM Transactions on Programming Languages and Systems, + # Vol. 4, No. 4, Oct. 1982, pp. 615-649 + # + # Further details can also be found in: + # + # J. Tremblay and P. Sorenson, "The Theory and Practice of Compiler Writing", + # McGraw-Hill Book Company, (1985). + # + # ----------------------------------------------------------------------------- + + # ----------------------------------------------------------------------------- + # compute_nullable_nonterminals() + # + # Creates a dictionary containing all of the non-terminals that might produce + # an empty production. + # ----------------------------------------------------------------------------- + + def compute_nullable_nonterminals(self): + nullable = set() + num_nullable = 0 + while True: + for p in self.grammar.Productions[1:]: + if p.len == 0: + nullable.add(p.name) + continue + for t in p.prod: + if t not in nullable: + break + else: + nullable.add(p.name) + if len(nullable) == num_nullable: + break + num_nullable = len(nullable) + return nullable + + # ----------------------------------------------------------------------------- + # find_nonterminal_trans(C) + # + # Given a set of LR(0) items, this functions finds all of the non-terminal + # transitions. These are transitions in which a dot appears immediately before + # a non-terminal. Returns a list of tuples of the form (state,N) where state + # is the state number and N is the nonterminal symbol. + # + # The input C is the set of LR(0) items. + # ----------------------------------------------------------------------------- + + def find_nonterminal_transitions(self, C): + trans = [] + for stateno, state in enumerate(C): + for p in state: + if p.lr_index < p.len - 1: + t = (stateno, p.prod[p.lr_index+1]) + if t[1] in self.grammar.Nonterminals: + if t not in trans: + trans.append(t) + return trans + + # ----------------------------------------------------------------------------- + # dr_relation() + # + # Computes the DR(p,A) relationships for non-terminal transitions. The input + # is a tuple (state,N) where state is a number and N is a nonterminal symbol. + # + # Returns a list of terminals. + # ----------------------------------------------------------------------------- + + def dr_relation(self, C, trans, nullable): + dr_set = {} + state, N = trans + terms = [] + + g = self.lr0_goto(C[state], N) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index+1] + if a in self.grammar.Terminals: + if a not in terms: + terms.append(a) + + # This extra bit is to handle the start state + if state == 0 and N == self.grammar.Productions[0].prod[0]: + terms.append('$end') + + return terms + + # ----------------------------------------------------------------------------- + # reads_relation() + # + # Computes the READS() relation (p,A) READS (t,C). + # ----------------------------------------------------------------------------- + + def reads_relation(self, C, trans, empty): + # Look for empty transitions + rel = [] + state, N = trans + + g = self.lr0_goto(C[state], N) + j = self.lr0_cidhash.get(id(g), -1) + for p in g: + if p.lr_index < p.len - 1: + a = p.prod[p.lr_index + 1] + if a in empty: + rel.append((j, a)) + + return rel + + # ----------------------------------------------------------------------------- + # compute_lookback_includes() + # + # Determines the lookback and includes relations + # + # LOOKBACK: + # + # This relation is determined by running the LR(0) state machine forward. + # For example, starting with a production "N : . A B C", we run it forward + # to obtain "N : A B C ." We then build a relationship between this final + # state and the starting state. These relationships are stored in a dictionary + # lookdict. + # + # INCLUDES: + # + # Computes the INCLUDE() relation (p,A) INCLUDES (p',B). + # + # This relation is used to determine non-terminal transitions that occur + # inside of other non-terminal transition states. (p,A) INCLUDES (p', B) + # if the following holds: + # + # B -> LAT, where T -> epsilon and p' -L-> p + # + # L is essentially a prefix (which may be empty), T is a suffix that must be + # able to derive an empty string. State p' must lead to state p with the string L. + # + # ----------------------------------------------------------------------------- + + def compute_lookback_includes(self, C, trans, nullable): + lookdict = {} # Dictionary of lookback relations + includedict = {} # Dictionary of include relations + + # Make a dictionary of non-terminal transitions + dtrans = {} + for t in trans: + dtrans[t] = 1 + + # Loop over all transitions and compute lookbacks and includes + for state, N in trans: + lookb = [] + includes = [] + for p in C[state]: + if p.name != N: + continue + + # Okay, we have a name match. We now follow the production all the way + # through the state machine until we get the . on the right hand side + + lr_index = p.lr_index + j = state + while lr_index < p.len - 1: + lr_index = lr_index + 1 + t = p.prod[lr_index] + + # Check to see if this symbol and state are a non-terminal transition + if (j, t) in dtrans: + # Yes. Okay, there is some chance that this is an includes relation + # the only way to know for certain is whether the rest of the + # production derives empty + + li = lr_index + 1 + while li < p.len: + if p.prod[li] in self.grammar.Terminals: + break # No forget it + if p.prod[li] not in nullable: + break + li = li + 1 + else: + # Appears to be a relation between (j,t) and (state,N) + includes.append((j, t)) + + g = self.lr0_goto(C[j], t) # Go to next set + j = self.lr0_cidhash.get(id(g), -1) # Go to next state + + # When we get here, j is the final state, now we have to locate the production + for r in C[j]: + if r.name != p.name: + continue + if r.len != p.len: + continue + i = 0 + # This look is comparing a production ". A B C" with "A B C ." + while i < r.lr_index: + if r.prod[i] != p.prod[i+1]: + break + i = i + 1 + else: + lookb.append((j, r)) + for i in includes: + if i not in includedict: + includedict[i] = [] + includedict[i].append((state, N)) + lookdict[(state, N)] = lookb + + return lookdict, includedict + + # ----------------------------------------------------------------------------- + # compute_read_sets() + # + # Given a set of LR(0) items, this function computes the read sets. + # + # Inputs: C = Set of LR(0) items + # ntrans = Set of nonterminal transitions + # nullable = Set of empty transitions + # + # Returns a set containing the read sets + # ----------------------------------------------------------------------------- + + def compute_read_sets(self, C, ntrans, nullable): + FP = lambda x: self.dr_relation(C, x, nullable) + R = lambda x: self.reads_relation(C, x, nullable) + F = digraph(ntrans, R, FP) + return F + + # ----------------------------------------------------------------------------- + # compute_follow_sets() + # + # Given a set of LR(0) items, a set of non-terminal transitions, a readset, + # and an include set, this function computes the follow sets + # + # Follow(p,A) = Read(p,A) U U {Follow(p',B) | (p,A) INCLUDES (p',B)} + # + # Inputs: + # ntrans = Set of nonterminal transitions + # readsets = Readset (previously computed) + # inclsets = Include sets (previously computed) + # + # Returns a set containing the follow sets + # ----------------------------------------------------------------------------- + + def compute_follow_sets(self, ntrans, readsets, inclsets): + FP = lambda x: readsets[x] + R = lambda x: inclsets.get(x, []) + F = digraph(ntrans, R, FP) + return F + + # ----------------------------------------------------------------------------- + # add_lookaheads() + # + # Attaches the lookahead symbols to grammar rules. + # + # Inputs: lookbacks - Set of lookback relations + # followset - Computed follow set + # + # This function directly attaches the lookaheads to productions contained + # in the lookbacks set + # ----------------------------------------------------------------------------- + + def add_lookaheads(self, lookbacks, followset): + for trans, lb in lookbacks.items(): + # Loop over productions in lookback + for state, p in lb: + if state not in p.lookaheads: + p.lookaheads[state] = [] + f = followset.get(trans, []) + for a in f: + if a not in p.lookaheads[state]: + p.lookaheads[state].append(a) + + # ----------------------------------------------------------------------------- + # add_lalr_lookaheads() + # + # This function does all of the work of adding lookahead information for use + # with LALR parsing + # ----------------------------------------------------------------------------- + + def add_lalr_lookaheads(self, C): + # Determine all of the nullable nonterminals + nullable = self.compute_nullable_nonterminals() + + # Find all non-terminal transitions + trans = self.find_nonterminal_transitions(C) + + # Compute read sets + readsets = self.compute_read_sets(C, trans, nullable) + + # Compute lookback/includes relations + lookd, included = self.compute_lookback_includes(C, trans, nullable) + + # Compute LALR FOLLOW sets + followsets = self.compute_follow_sets(trans, readsets, included) + + # Add all of the lookaheads + self.add_lookaheads(lookd, followsets) + + # ----------------------------------------------------------------------------- + # lr_parse_table() + # + # This function constructs the parse tables for SLR or LALR + # ----------------------------------------------------------------------------- + def lr_parse_table(self): + Productions = self.grammar.Productions + Precedence = self.grammar.Precedence + goto = self.lr_goto # Goto array + action = self.lr_action # Action array + log = self.log # Logger for output + + actionp = {} # Action production array (temporary) + + log.info('Parsing method: %s', self.lr_method) + + # Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items + # This determines the number of states + + C = self.lr0_items() + + if self.lr_method == 'LALR': + self.add_lalr_lookaheads(C) + + # Build the parser table, state by state + st = 0 + for I in C: + # Loop over each production in I + actlist = [] # List of actions + st_action = {} + st_actionp = {} + st_goto = {} + log.info('') + log.info('state %d', st) + log.info('') + for p in I: + log.info(' (%d) %s', p.number, p) + log.info('') + + for p in I: + if p.len == p.lr_index + 1: + if p.name == "S'": + # Start symbol. Accept! + st_action['$end'] = 0 + st_actionp['$end'] = p + else: + # We are at the end of a production. Reduce! + if self.lr_method == 'LALR': + laheads = p.lookaheads[st] + else: + laheads = self.grammar.Follow[p.name] + for a in laheads: + actlist.append((a, p, 'reduce using rule %d (%s)' % (p.number, p))) + r = st_action.get(a) + if r is not None: + # Whoa. Have a shift/reduce or reduce/reduce conflict + if r > 0: + # Need to decide on shift or reduce here + # By default we favor shifting. Need to add + # some precedence rules here. + + # Shift precedence comes from the token + sprec, slevel = Precedence.get(a, ('right', 0)) + + # Reduce precedence comes from rule being reduced (p) + rprec, rlevel = Productions[p.number].prec + + if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')): + # We really need to reduce here. + st_action[a] = -p.number + st_actionp[a] = p + if not slevel and not rlevel: + log.info(' ! shift/reduce conflict for %s resolved as reduce', a) + self.sr_conflicts.append((st, a, 'reduce')) + Productions[p.number].reduced += 1 + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the shift + if not rlevel: + log.info(' ! shift/reduce conflict for %s resolved as shift', a) + self.sr_conflicts.append((st, a, 'shift')) + elif r < 0: + # Reduce/reduce conflict. In this case, we favor the rule + # that was defined first in the grammar file + oldp = Productions[-r] + pp = Productions[p.number] + if oldp.line > pp.line: + st_action[a] = -p.number + st_actionp[a] = p + chosenp, rejectp = pp, oldp + Productions[p.number].reduced += 1 + Productions[oldp.number].reduced -= 1 + else: + chosenp, rejectp = oldp, pp + self.rr_conflicts.append((st, chosenp, rejectp)) + log.info(' ! reduce/reduce conflict for %s resolved using rule %d (%s)', + a, st_actionp[a].number, st_actionp[a]) + else: + raise LALRError('Unknown conflict in state %d' % st) + else: + st_action[a] = -p.number + st_actionp[a] = p + Productions[p.number].reduced += 1 + else: + i = p.lr_index + a = p.prod[i+1] # Get symbol right after the "." + if a in self.grammar.Terminals: + g = self.lr0_goto(I, a) + j = self.lr0_cidhash.get(id(g), -1) + if j >= 0: + # We are in a shift state + actlist.append((a, p, 'shift and go to state %d' % j)) + r = st_action.get(a) + if r is not None: + # Whoa have a shift/reduce or shift/shift conflict + if r > 0: + if r != j: + raise LALRError('Shift/shift conflict in state %d' % st) + elif r < 0: + # Do a precedence check. + # - if precedence of reduce rule is higher, we reduce. + # - if precedence of reduce is same and left assoc, we reduce. + # - otherwise we shift + + # Shift precedence comes from the token + sprec, slevel = Precedence.get(a, ('right', 0)) + + # Reduce precedence comes from the rule that could have been reduced + rprec, rlevel = Productions[st_actionp[a].number].prec + + if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')): + # We decide to shift here... highest precedence to shift + Productions[st_actionp[a].number].reduced -= 1 + st_action[a] = j + st_actionp[a] = p + if not rlevel: + log.info(' ! shift/reduce conflict for %s resolved as shift', a) + self.sr_conflicts.append((st, a, 'shift')) + elif (slevel == rlevel) and (rprec == 'nonassoc'): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the reduce + if not slevel and not rlevel: + log.info(' ! shift/reduce conflict for %s resolved as reduce', a) + self.sr_conflicts.append((st, a, 'reduce')) + + else: + raise LALRError('Unknown conflict in state %d' % st) + else: + st_action[a] = j + st_actionp[a] = p + + # Print the actions associated with each terminal + _actprint = {} + for a, p, m in actlist: + if a in st_action: + if p is st_actionp[a]: + log.info(' %-15s %s', a, m) + _actprint[(a, m)] = 1 + log.info('') + # Print the actions that were not used. (debugging) + not_used = 0 + for a, p, m in actlist: + if a in st_action: + if p is not st_actionp[a]: + if not (a, m) in _actprint: + log.debug(' ! %-15s [ %s ]', a, m) + not_used = 1 + _actprint[(a, m)] = 1 + if not_used: + log.debug('') + + # Construct the goto table for this state + + nkeys = {} + for ii in I: + for s in ii.usyms: + if s in self.grammar.Nonterminals: + nkeys[s] = None + for n in nkeys: + g = self.lr0_goto(I, n) + j = self.lr0_cidhash.get(id(g), -1) + if j >= 0: + st_goto[n] = j + log.info(' %-30s shift and go to state %d', n, j) + + action[st] = st_action + actionp[st] = st_actionp + goto[st] = st_goto + st += 1 + + # ----------------------------------------------------------------------------- + # write() + # + # This function writes the LR parsing tables to a file + # ----------------------------------------------------------------------------- + + def write_table(self, tabmodule, outputdir='', signature=''): + if isinstance(tabmodule, types.ModuleType): + raise IOError("Won't overwrite existing tabmodule") + + basemodulename = tabmodule.split('.')[-1] + filename = os.path.join(outputdir, basemodulename) + '.py' + try: + f = open(filename, 'w') + + f.write(''' +# %s +# This file is automatically generated. Do not edit. +_tabversion = %r + +_lr_method = %r + +_lr_signature = %r + ''' % (os.path.basename(filename), __tabversion__, self.lr_method, signature)) + + # Change smaller to 0 to go back to original tables + smaller = 1 + + # Factor out names to try and make smaller + if smaller: + items = {} + + for s, nd in self.lr_action.items(): + for name, v in nd.items(): + i = items.get(name) + if not i: + i = ([], []) + items[name] = i + i[0].append(s) + i[1].append(v) + + f.write('\n_lr_action_items = {') + for k, v in items.items(): + f.write('%r:([' % k) + for i in v[0]: + f.write('%r,' % i) + f.write('],[') + for i in v[1]: + f.write('%r,' % i) + + f.write(']),') + f.write('}\n') + + f.write(''' +_lr_action = {} +for _k, _v in _lr_action_items.items(): + for _x,_y in zip(_v[0],_v[1]): + if not _x in _lr_action: _lr_action[_x] = {} + _lr_action[_x][_k] = _y +del _lr_action_items +''') + + else: + f.write('\n_lr_action = { ') + for k, v in self.lr_action.items(): + f.write('(%r,%r):%r,' % (k[0], k[1], v)) + f.write('}\n') + + if smaller: + # Factor out names to try and make smaller + items = {} + + for s, nd in self.lr_goto.items(): + for name, v in nd.items(): + i = items.get(name) + if not i: + i = ([], []) + items[name] = i + i[0].append(s) + i[1].append(v) + + f.write('\n_lr_goto_items = {') + for k, v in items.items(): + f.write('%r:([' % k) + for i in v[0]: + f.write('%r,' % i) + f.write('],[') + for i in v[1]: + f.write('%r,' % i) + + f.write(']),') + f.write('}\n') + + f.write(''' +_lr_goto = {} +for _k, _v in _lr_goto_items.items(): + for _x, _y in zip(_v[0], _v[1]): + if not _x in _lr_goto: _lr_goto[_x] = {} + _lr_goto[_x][_k] = _y +del _lr_goto_items +''') + else: + f.write('\n_lr_goto = { ') + for k, v in self.lr_goto.items(): + f.write('(%r,%r):%r,' % (k[0], k[1], v)) + f.write('}\n') + + # Write production table + f.write('_lr_productions = [\n') + for p in self.lr_productions: + if p.func: + f.write(' (%r,%r,%d,%r,%r,%d),\n' % (p.str, p.name, p.len, + p.func, os.path.basename(p.file), p.line)) + else: + f.write(' (%r,%r,%d,None,None,None),\n' % (str(p), p.name, p.len)) + f.write(']\n') + f.close() + + except IOError as e: + raise + + + # ----------------------------------------------------------------------------- + # pickle_table() + # + # This function pickles the LR parsing tables to a supplied file object + # ----------------------------------------------------------------------------- + + def pickle_table(self, filename, signature=''): + try: + import cPickle as pickle + except ImportError: + import pickle + with open(filename, 'wb') as outf: + pickle.dump(__tabversion__, outf, pickle_protocol) + pickle.dump(self.lr_method, outf, pickle_protocol) + pickle.dump(signature, outf, pickle_protocol) + pickle.dump(self.lr_action, outf, pickle_protocol) + pickle.dump(self.lr_goto, outf, pickle_protocol) + + outp = [] + for p in self.lr_productions: + if p.func: + outp.append((p.str, p.name, p.len, p.func, os.path.basename(p.file), p.line)) + else: + outp.append((str(p), p.name, p.len, None, None, None)) + pickle.dump(outp, outf, pickle_protocol) + +# ----------------------------------------------------------------------------- +# === INTROSPECTION === +# +# The following functions and classes are used to implement the PLY +# introspection features followed by the yacc() function itself. +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# get_caller_module_dict() +# +# This function returns a dictionary containing all of the symbols defined within +# a caller further down the call stack. This is used to get the environment +# associated with the yacc() call if none was provided. +# ----------------------------------------------------------------------------- + +def get_caller_module_dict(levels): + f = sys._getframe(levels) + ldict = f.f_globals.copy() + if f.f_globals != f.f_locals: + ldict.update(f.f_locals) + return ldict + +# ----------------------------------------------------------------------------- +# parse_grammar() +# +# This takes a raw grammar rule string and parses it into production data +# ----------------------------------------------------------------------------- +def parse_grammar(doc, file, line): + grammar = [] + # Split the doc string into lines + pstrings = doc.splitlines() + lastp = None + dline = line + for ps in pstrings: + dline += 1 + p = ps.split() + if not p: + continue + try: + if p[0] == '|': + # This is a continuation of a previous rule + if not lastp: + raise SyntaxError("%s:%d: Misplaced '|'" % (file, dline)) + prodname = lastp + syms = p[1:] + else: + prodname = p[0] + lastp = prodname + syms = p[2:] + assign = p[1] + if assign != ':' and assign != '::=': + raise SyntaxError("%s:%d: Syntax error. Expected ':'" % (file, dline)) + + grammar.append((file, dline, prodname, syms)) + except SyntaxError: + raise + except Exception: + raise SyntaxError('%s:%d: Syntax error in rule %r' % (file, dline, ps.strip())) + + return grammar + +# ----------------------------------------------------------------------------- +# ParserReflect() +# +# This class represents information extracted for building a parser including +# start symbol, error function, tokens, precedence list, action functions, +# etc. +# ----------------------------------------------------------------------------- +class ParserReflect(object): + def __init__(self, pdict, log=None): + self.pdict = pdict + self.start = None + self.error_func = None + self.tokens = None + self.modules = set() + self.grammar = [] + self.error = False + + if log is None: + self.log = PlyLogger(sys.stderr) + else: + self.log = log + + # Get all of the basic information + def get_all(self): + self.get_start() + self.get_error_func() + self.get_tokens() + self.get_precedence() + self.get_pfunctions() + + # Validate all of the information + def validate_all(self): + self.validate_start() + self.validate_error_func() + self.validate_tokens() + self.validate_precedence() + self.validate_pfunctions() + self.validate_modules() + return self.error + + # Compute a signature over the grammar + def signature(self): + parts = [] + try: + if self.start: + parts.append(self.start) + if self.prec: + parts.append(''.join([''.join(p) for p in self.prec])) + if self.tokens: + parts.append(' '.join(self.tokens)) + for f in self.pfuncs: + if f[3]: + parts.append(f[3]) + except (TypeError, ValueError): + pass + return ''.join(parts) + + # ----------------------------------------------------------------------------- + # validate_modules() + # + # This method checks to see if there are duplicated p_rulename() functions + # in the parser module file. Without this function, it is really easy for + # users to make mistakes by cutting and pasting code fragments (and it's a real + # bugger to try and figure out why the resulting parser doesn't work). Therefore, + # we just do a little regular expression pattern matching of def statements + # to try and detect duplicates. + # ----------------------------------------------------------------------------- + + def validate_modules(self): + # Match def p_funcname( + fre = re.compile(r'\s*def\s+(p_[a-zA-Z_0-9]*)\(') + + for module in self.modules: + try: + lines, linen = inspect.getsourcelines(module) + except IOError: + continue + + counthash = {} + for linen, line in enumerate(lines): + linen += 1 + m = fre.match(line) + if m: + name = m.group(1) + prev = counthash.get(name) + if not prev: + counthash[name] = linen + else: + filename = inspect.getsourcefile(module) + self.log.warning('%s:%d: Function %s redefined. Previously defined on line %d', + filename, linen, name, prev) + + # Get the start symbol + def get_start(self): + self.start = self.pdict.get('start') + + # Validate the start symbol + def validate_start(self): + if self.start is not None: + if not isinstance(self.start, string_types): + self.log.error("'start' must be a string") + + # Look for error handler + def get_error_func(self): + self.error_func = self.pdict.get('p_error') + + # Validate the error function + def validate_error_func(self): + if self.error_func: + if isinstance(self.error_func, types.FunctionType): + ismethod = 0 + elif isinstance(self.error_func, types.MethodType): + ismethod = 1 + else: + self.log.error("'p_error' defined, but is not a function or method") + self.error = True + return + + eline = self.error_func.__code__.co_firstlineno + efile = self.error_func.__code__.co_filename + module = inspect.getmodule(self.error_func) + self.modules.add(module) + + argcount = self.error_func.__code__.co_argcount - ismethod + if argcount != 1: + self.log.error('%s:%d: p_error() requires 1 argument', efile, eline) + self.error = True + + # Get the tokens map + def get_tokens(self): + tokens = self.pdict.get('tokens') + if not tokens: + self.log.error('No token list is defined') + self.error = True + return + + if not isinstance(tokens, (list, tuple)): + self.log.error('tokens must be a list or tuple') + self.error = True + return + + if not tokens: + self.log.error('tokens is empty') + self.error = True + return + + self.tokens = tokens + + # Validate the tokens + def validate_tokens(self): + # Validate the tokens. + if 'error' in self.tokens: + self.log.error("Illegal token name 'error'. Is a reserved word") + self.error = True + return + + terminals = set() + for n in self.tokens: + if n in terminals: + self.log.warning('Token %r multiply defined', n) + terminals.add(n) + + # Get the precedence map (if any) + def get_precedence(self): + self.prec = self.pdict.get('precedence') + + # Validate and parse the precedence map + def validate_precedence(self): + preclist = [] + if self.prec: + if not isinstance(self.prec, (list, tuple)): + self.log.error('precedence must be a list or tuple') + self.error = True + return + for level, p in enumerate(self.prec): + if not isinstance(p, (list, tuple)): + self.log.error('Bad precedence table') + self.error = True + return + + if len(p) < 2: + self.log.error('Malformed precedence entry %s. Must be (assoc, term, ..., term)', p) + self.error = True + return + assoc = p[0] + if not isinstance(assoc, string_types): + self.log.error('precedence associativity must be a string') + self.error = True + return + for term in p[1:]: + if not isinstance(term, string_types): + self.log.error('precedence items must be strings') + self.error = True + return + preclist.append((term, assoc, level+1)) + self.preclist = preclist + + # Get all p_functions from the grammar + def get_pfunctions(self): + p_functions = [] + for name, item in self.pdict.items(): + if not name.startswith('p_') or name == 'p_error': + continue + if isinstance(item, (types.FunctionType, types.MethodType)): + line = getattr(item, 'co_firstlineno', item.__code__.co_firstlineno) + module = inspect.getmodule(item) + p_functions.append((line, module, name, item.__doc__)) + + # Sort all of the actions by line number; make sure to stringify + # modules to make them sortable, since `line` may not uniquely sort all + # p functions + p_functions.sort(key=lambda p_function: ( + p_function[0], + str(p_function[1]), + p_function[2], + p_function[3])) + self.pfuncs = p_functions + + # Validate all of the p_functions + def validate_pfunctions(self): + grammar = [] + # Check for non-empty symbols + if len(self.pfuncs) == 0: + self.log.error('no rules of the form p_rulename are defined') + self.error = True + return + + for line, module, name, doc in self.pfuncs: + file = inspect.getsourcefile(module) + func = self.pdict[name] + if isinstance(func, types.MethodType): + reqargs = 2 + else: + reqargs = 1 + if func.__code__.co_argcount > reqargs: + self.log.error('%s:%d: Rule %r has too many arguments', file, line, func.__name__) + self.error = True + elif func.__code__.co_argcount < reqargs: + self.log.error('%s:%d: Rule %r requires an argument', file, line, func.__name__) + self.error = True + elif not func.__doc__: + self.log.warning('%s:%d: No documentation string specified in function %r (ignored)', + file, line, func.__name__) + else: + try: + parsed_g = parse_grammar(doc, file, line) + for g in parsed_g: + grammar.append((name, g)) + except SyntaxError as e: + self.log.error(str(e)) + self.error = True + + # Looks like a valid grammar rule + # Mark the file in which defined. + self.modules.add(module) + + # Secondary validation step that looks for p_ definitions that are not functions + # or functions that look like they might be grammar rules. + + for n, v in self.pdict.items(): + if n.startswith('p_') and isinstance(v, (types.FunctionType, types.MethodType)): + continue + if n.startswith('t_'): + continue + if n.startswith('p_') and n != 'p_error': + self.log.warning('%r not defined as a function', n) + if ((isinstance(v, types.FunctionType) and v.__code__.co_argcount == 1) or + (isinstance(v, types.MethodType) and v.__func__.__code__.co_argcount == 2)): + if v.__doc__: + try: + doc = v.__doc__.split(' ') + if doc[1] == ':': + self.log.warning('%s:%d: Possible grammar rule %r defined without p_ prefix', + v.__code__.co_filename, v.__code__.co_firstlineno, n) + except IndexError: + pass + + self.grammar = grammar + +# ----------------------------------------------------------------------------- +# yacc(module) +# +# Build a parser +# ----------------------------------------------------------------------------- + +def yacc(method='LALR', debug=yaccdebug, module=None, tabmodule=tab_module, start=None, + check_recursion=True, optimize=False, write_tables=True, debugfile=debug_file, + outputdir=None, debuglog=None, errorlog=None, picklefile=None): + + if tabmodule is None: + tabmodule = tab_module + + # Reference to the parsing method of the last built parser + global parse + + # If pickling is enabled, table files are not created + if picklefile: + write_tables = 0 + + if errorlog is None: + errorlog = PlyLogger(sys.stderr) + + # Get the module dictionary used for the parser + if module: + _items = [(k, getattr(module, k)) for k in dir(module)] + pdict = dict(_items) + # If no __file__ attribute is available, try to obtain it from the __module__ instead + if '__file__' not in pdict: + pdict['__file__'] = sys.modules[pdict['__module__']].__file__ + else: + pdict = get_caller_module_dict(2) + + if outputdir is None: + # If no output directory is set, the location of the output files + # is determined according to the following rules: + # - If tabmodule specifies a package, files go into that package directory + # - Otherwise, files go in the same directory as the specifying module + if isinstance(tabmodule, types.ModuleType): + srcfile = tabmodule.__file__ + else: + if '.' not in tabmodule: + srcfile = pdict['__file__'] + else: + parts = tabmodule.split('.') + pkgname = '.'.join(parts[:-1]) + exec('import %s' % pkgname) + srcfile = getattr(sys.modules[pkgname], '__file__', '') + outputdir = os.path.dirname(srcfile) + + # Determine if the module is package of a package or not. + # If so, fix the tabmodule setting so that tables load correctly + pkg = pdict.get('__package__') + if pkg and isinstance(tabmodule, str): + if '.' not in tabmodule: + tabmodule = pkg + '.' + tabmodule + + + + # Set start symbol if it's specified directly using an argument + if start is not None: + pdict['start'] = start + + # Collect parser information from the dictionary + pinfo = ParserReflect(pdict, log=errorlog) + pinfo.get_all() + + if pinfo.error: + raise YaccError('Unable to build parser') + + # Check signature against table files (if any) + signature = pinfo.signature() + + # Read the tables + try: + lr = LRTable() + if picklefile: + read_signature = lr.read_pickle(picklefile) + else: + read_signature = lr.read_table(tabmodule) + if optimize or (read_signature == signature): + try: + lr.bind_callables(pinfo.pdict) + parser = LRParser(lr, pinfo.error_func) + parse = parser.parse + return parser + except Exception as e: + errorlog.warning('There was a problem loading the table file: %r', e) + except VersionError as e: + errorlog.warning(str(e)) + except ImportError: + pass + + if debuglog is None: + if debug: + try: + debuglog = PlyLogger(open(os.path.join(outputdir, debugfile), 'w')) + except IOError as e: + errorlog.warning("Couldn't open %r. %s" % (debugfile, e)) + debuglog = NullLogger() + else: + debuglog = NullLogger() + + debuglog.info('Created by PLY version %s (http://www.dabeaz.com/ply)', __version__) + + errors = False + + # Validate the parser information + if pinfo.validate_all(): + raise YaccError('Unable to build parser') + + if not pinfo.error_func: + errorlog.warning('no p_error() function is defined') + + # Create a grammar object + grammar = Grammar(pinfo.tokens) + + # Set precedence level for terminals + for term, assoc, level in pinfo.preclist: + try: + grammar.set_precedence(term, assoc, level) + except GrammarError as e: + errorlog.warning('%s', e) + + # Add productions to the grammar + for funcname, gram in pinfo.grammar: + file, line, prodname, syms = gram + try: + grammar.add_production(prodname, syms, funcname, file, line) + except GrammarError as e: + errorlog.error('%s', e) + errors = True + + # Set the grammar start symbols + try: + if start is None: + grammar.set_start(pinfo.start) + else: + grammar.set_start(start) + except GrammarError as e: + errorlog.error(str(e)) + errors = True + + if errors: + raise YaccError('Unable to build parser') + + # Verify the grammar structure + undefined_symbols = grammar.undefined_symbols() + for sym, prod in undefined_symbols: + errorlog.error('%s:%d: Symbol %r used, but not defined as a token or a rule', prod.file, prod.line, sym) + errors = True + + unused_terminals = grammar.unused_terminals() + if unused_terminals: + debuglog.info('') + debuglog.info('Unused terminals:') + debuglog.info('') + for term in unused_terminals: + errorlog.warning('Token %r defined, but not used', term) + debuglog.info(' %s', term) + + # Print out all productions to the debug log + if debug: + debuglog.info('') + debuglog.info('Grammar') + debuglog.info('') + for n, p in enumerate(grammar.Productions): + debuglog.info('Rule %-5d %s', n, p) + + # Find unused non-terminals + unused_rules = grammar.unused_rules() + for prod in unused_rules: + errorlog.warning('%s:%d: Rule %r defined, but not used', prod.file, prod.line, prod.name) + + if len(unused_terminals) == 1: + errorlog.warning('There is 1 unused token') + if len(unused_terminals) > 1: + errorlog.warning('There are %d unused tokens', len(unused_terminals)) + + if len(unused_rules) == 1: + errorlog.warning('There is 1 unused rule') + if len(unused_rules) > 1: + errorlog.warning('There are %d unused rules', len(unused_rules)) + + if debug: + debuglog.info('') + debuglog.info('Terminals, with rules where they appear') + debuglog.info('') + terms = list(grammar.Terminals) + terms.sort() + for term in terms: + debuglog.info('%-20s : %s', term, ' '.join([str(s) for s in grammar.Terminals[term]])) + + debuglog.info('') + debuglog.info('Nonterminals, with rules where they appear') + debuglog.info('') + nonterms = list(grammar.Nonterminals) + nonterms.sort() + for nonterm in nonterms: + debuglog.info('%-20s : %s', nonterm, ' '.join([str(s) for s in grammar.Nonterminals[nonterm]])) + debuglog.info('') + + if check_recursion: + unreachable = grammar.find_unreachable() + for u in unreachable: + errorlog.warning('Symbol %r is unreachable', u) + + infinite = grammar.infinite_cycles() + for inf in infinite: + errorlog.error('Infinite recursion detected for symbol %r', inf) + errors = True + + unused_prec = grammar.unused_precedence() + for term, assoc in unused_prec: + errorlog.error('Precedence rule %r defined for unknown symbol %r', assoc, term) + errors = True + + if errors: + raise YaccError('Unable to build parser') + + # Run the LRGeneratedTable on the grammar + if debug: + errorlog.debug('Generating %s tables', method) + + lr = LRGeneratedTable(grammar, method, debuglog) + + if debug: + num_sr = len(lr.sr_conflicts) + + # Report shift/reduce and reduce/reduce conflicts + if num_sr == 1: + errorlog.warning('1 shift/reduce conflict') + elif num_sr > 1: + errorlog.warning('%d shift/reduce conflicts', num_sr) + + num_rr = len(lr.rr_conflicts) + if num_rr == 1: + errorlog.warning('1 reduce/reduce conflict') + elif num_rr > 1: + errorlog.warning('%d reduce/reduce conflicts', num_rr) + + # Write out conflicts to the output file + if debug and (lr.sr_conflicts or lr.rr_conflicts): + debuglog.warning('') + debuglog.warning('Conflicts:') + debuglog.warning('') + + for state, tok, resolution in lr.sr_conflicts: + debuglog.warning('shift/reduce conflict for %s in state %d resolved as %s', tok, state, resolution) + + already_reported = set() + for state, rule, rejected in lr.rr_conflicts: + if (state, id(rule), id(rejected)) in already_reported: + continue + debuglog.warning('reduce/reduce conflict in state %d resolved using rule (%s)', state, rule) + debuglog.warning('rejected rule (%s) in state %d', rejected, state) + errorlog.warning('reduce/reduce conflict in state %d resolved using rule (%s)', state, rule) + errorlog.warning('rejected rule (%s) in state %d', rejected, state) + already_reported.add((state, id(rule), id(rejected))) + + warned_never = [] + for state, rule, rejected in lr.rr_conflicts: + if not rejected.reduced and (rejected not in warned_never): + debuglog.warning('Rule (%s) is never reduced', rejected) + errorlog.warning('Rule (%s) is never reduced', rejected) + warned_never.append(rejected) + + # Write the table file if requested + if write_tables: + try: + lr.write_table(tabmodule, outputdir, signature) + except IOError as e: + errorlog.warning("Couldn't create %r. %s" % (tabmodule, e)) + + # Write a pickled version of the tables + if picklefile: + try: + lr.pickle_table(picklefile, signature) + except IOError as e: + errorlog.warning("Couldn't create %r. %s" % (picklefile, e)) + + # Build the parser + lr.bind_callables(pinfo.pdict) + parser = LRParser(lr, pinfo.error_func) + + parse = parser.parse + return parser diff --git a/.venv/Lib/site-packages/pycparser/ply/ygen.py b/.venv/Lib/site-packages/pycparser/ply/ygen.py new file mode 100644 index 00000000..acf5ca1a --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/ply/ygen.py @@ -0,0 +1,74 @@ +# ply: ygen.py +# +# This is a support program that auto-generates different versions of the YACC parsing +# function with different features removed for the purposes of performance. +# +# Users should edit the method LParser.parsedebug() in yacc.py. The source code +# for that method is then used to create the other methods. See the comments in +# yacc.py for further details. + +import os.path +import shutil + +def get_source_range(lines, tag): + srclines = enumerate(lines) + start_tag = '#--! %s-start' % tag + end_tag = '#--! %s-end' % tag + + for start_index, line in srclines: + if line.strip().startswith(start_tag): + break + + for end_index, line in srclines: + if line.strip().endswith(end_tag): + break + + return (start_index + 1, end_index) + +def filter_section(lines, tag): + filtered_lines = [] + include = True + tag_text = '#--! %s' % tag + for line in lines: + if line.strip().startswith(tag_text): + include = not include + elif include: + filtered_lines.append(line) + return filtered_lines + +def main(): + dirname = os.path.dirname(__file__) + shutil.copy2(os.path.join(dirname, 'yacc.py'), os.path.join(dirname, 'yacc.py.bak')) + with open(os.path.join(dirname, 'yacc.py'), 'r') as f: + lines = f.readlines() + + parse_start, parse_end = get_source_range(lines, 'parsedebug') + parseopt_start, parseopt_end = get_source_range(lines, 'parseopt') + parseopt_notrack_start, parseopt_notrack_end = get_source_range(lines, 'parseopt-notrack') + + # Get the original source + orig_lines = lines[parse_start:parse_end] + + # Filter the DEBUG sections out + parseopt_lines = filter_section(orig_lines, 'DEBUG') + + # Filter the TRACKING sections out + parseopt_notrack_lines = filter_section(parseopt_lines, 'TRACKING') + + # Replace the parser source sections with updated versions + lines[parseopt_notrack_start:parseopt_notrack_end] = parseopt_notrack_lines + lines[parseopt_start:parseopt_end] = parseopt_lines + + lines = [line.rstrip()+'\n' for line in lines] + with open(os.path.join(dirname, 'yacc.py'), 'w') as f: + f.writelines(lines) + + print('Updated yacc.py') + +if __name__ == '__main__': + main() + + + + + diff --git a/.venv/Lib/site-packages/pycparser/plyparser.py b/.venv/Lib/site-packages/pycparser/plyparser.py new file mode 100644 index 00000000..b8f4c439 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/plyparser.py @@ -0,0 +1,133 @@ +#----------------------------------------------------------------- +# plyparser.py +# +# PLYParser class and other utilities for simplifying programming +# parsers with PLY +# +# Eli Bendersky [https://eli.thegreenplace.net/] +# License: BSD +#----------------------------------------------------------------- + +import warnings + +class Coord(object): + """ Coordinates of a syntactic element. Consists of: + - File name + - Line number + - (optional) column number, for the Lexer + """ + __slots__ = ('file', 'line', 'column', '__weakref__') + def __init__(self, file, line, column=None): + self.file = file + self.line = line + self.column = column + + def __str__(self): + str = "%s:%s" % (self.file, self.line) + if self.column: str += ":%s" % self.column + return str + + +class ParseError(Exception): pass + + +class PLYParser(object): + def _create_opt_rule(self, rulename): + """ Given a rule name, creates an optional ply.yacc rule + for it. The name of the optional rule is + _opt + """ + optname = rulename + '_opt' + + def optrule(self, p): + p[0] = p[1] + + optrule.__doc__ = '%s : empty\n| %s' % (optname, rulename) + optrule.__name__ = 'p_%s' % optname + setattr(self.__class__, optrule.__name__, optrule) + + def _coord(self, lineno, column=None): + return Coord( + file=self.clex.filename, + line=lineno, + column=column) + + def _token_coord(self, p, token_idx): + """ Returns the coordinates for the YaccProduction object 'p' indexed + with 'token_idx'. The coordinate includes the 'lineno' and + 'column'. Both follow the lex semantic, starting from 1. + """ + last_cr = p.lexer.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx)) + if last_cr < 0: + last_cr = -1 + column = (p.lexpos(token_idx) - (last_cr)) + return self._coord(p.lineno(token_idx), column) + + def _parse_error(self, msg, coord): + raise ParseError("%s: %s" % (coord, msg)) + + +def parameterized(*params): + """ Decorator to create parameterized rules. + + Parameterized rule methods must be named starting with 'p_' and contain + 'xxx', and their docstrings may contain 'xxx' and 'yyy'. These will be + replaced by the given parameter tuples. For example, ``p_xxx_rule()`` with + docstring 'xxx_rule : yyy' when decorated with + ``@parameterized(('id', 'ID'))`` produces ``p_id_rule()`` with the docstring + 'id_rule : ID'. Using multiple tuples produces multiple rules. + """ + def decorate(rule_func): + rule_func._params = params + return rule_func + return decorate + + +def template(cls): + """ Class decorator to generate rules from parameterized rule templates. + + See `parameterized` for more information on parameterized rules. + """ + issued_nodoc_warning = False + for attr_name in dir(cls): + if attr_name.startswith('p_'): + method = getattr(cls, attr_name) + if hasattr(method, '_params'): + # Remove the template method + delattr(cls, attr_name) + # Create parameterized rules from this method; only run this if + # the method has a docstring. This is to address an issue when + # pycparser's users are installed in -OO mode which strips + # docstrings away. + # See: https://github.com/eliben/pycparser/pull/198/ and + # https://github.com/eliben/pycparser/issues/197 + # for discussion. + if method.__doc__ is not None: + _create_param_rules(cls, method) + elif not issued_nodoc_warning: + warnings.warn( + 'parsing methods must have __doc__ for pycparser to work properly', + RuntimeWarning, + stacklevel=2) + issued_nodoc_warning = True + return cls + + +def _create_param_rules(cls, func): + """ Create ply.yacc rules based on a parameterized rule function + + Generates new methods (one per each pair of parameters) based on the + template rule function `func`, and attaches them to `cls`. The rule + function's parameters must be accessible via its `_params` attribute. + """ + for xxx, yyy in func._params: + # Use the template method's body for each new method + def param_rule(self, p): + func(self, p) + + # Substitute in the params for the grammar rule and function name + param_rule.__doc__ = func.__doc__.replace('xxx', xxx).replace('yyy', yyy) + param_rule.__name__ = func.__name__.replace('xxx', xxx) + + # Attach the new method to the class + setattr(cls, param_rule.__name__, param_rule) diff --git a/.venv/Lib/site-packages/pycparser/yacctab.py b/.venv/Lib/site-packages/pycparser/yacctab.py new file mode 100644 index 00000000..0622c366 --- /dev/null +++ b/.venv/Lib/site-packages/pycparser/yacctab.py @@ -0,0 +1,366 @@ + +# yacctab.py +# This file is automatically generated. Do not edit. +_tabversion = '3.10' + +_lr_method = 'LALR' + +_lr_signature = 'translation_unit_or_emptyleftLORleftLANDleftORleftXORleftANDleftEQNEleftGTGELTLEleftRSHIFTLSHIFTleftPLUSMINUSleftTIMESDIVIDEMODAUTO BREAK CASE CHAR CONST CONTINUE DEFAULT DO DOUBLE ELSE ENUM EXTERN FLOAT FOR GOTO IF INLINE INT LONG REGISTER OFFSETOF RESTRICT RETURN SHORT SIGNED SIZEOF STATIC STRUCT SWITCH TYPEDEF UNION UNSIGNED VOID VOLATILE WHILE __INT128 _BOOL _COMPLEX _NORETURN _THREAD_LOCAL _STATIC_ASSERT _ATOMIC _ALIGNOF _ALIGNAS ID TYPEID INT_CONST_DEC INT_CONST_OCT INT_CONST_HEX INT_CONST_BIN INT_CONST_CHAR FLOAT_CONST HEX_FLOAT_CONST CHAR_CONST WCHAR_CONST U8CHAR_CONST U16CHAR_CONST U32CHAR_CONST STRING_LITERAL WSTRING_LITERAL U8STRING_LITERAL U16STRING_LITERAL U32STRING_LITERAL PLUS MINUS TIMES DIVIDE MOD OR AND NOT XOR LSHIFT RSHIFT LOR LAND LNOT LT LE GT GE EQ NE EQUALS TIMESEQUAL DIVEQUAL MODEQUAL PLUSEQUAL MINUSEQUAL LSHIFTEQUAL RSHIFTEQUAL ANDEQUAL XOREQUAL OREQUAL PLUSPLUS MINUSMINUS ARROW CONDOP LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE COMMA PERIOD SEMI COLON ELLIPSIS PPHASH PPPRAGMA PPPRAGMASTRabstract_declarator_opt : empty\n| abstract_declaratorassignment_expression_opt : empty\n| assignment_expressionblock_item_list_opt : empty\n| block_item_listdeclaration_list_opt : empty\n| declaration_listdeclaration_specifiers_no_type_opt : empty\n| declaration_specifiers_no_typedesignation_opt : empty\n| designationexpression_opt : empty\n| expressionid_init_declarator_list_opt : empty\n| id_init_declarator_listidentifier_list_opt : empty\n| identifier_listinit_declarator_list_opt : empty\n| init_declarator_listinitializer_list_opt : empty\n| initializer_listparameter_type_list_opt : empty\n| parameter_type_liststruct_declarator_list_opt : empty\n| struct_declarator_listtype_qualifier_list_opt : empty\n| type_qualifier_list direct_id_declarator : ID\n direct_id_declarator : LPAREN id_declarator RPAREN\n direct_id_declarator : direct_id_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_id_declarator : direct_id_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_id_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_id_declarator : direct_id_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_id_declarator : direct_id_declarator LPAREN parameter_type_list RPAREN\n | direct_id_declarator LPAREN identifier_list_opt RPAREN\n direct_typeid_declarator : TYPEID\n direct_typeid_declarator : LPAREN typeid_declarator RPAREN\n direct_typeid_declarator : direct_typeid_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_typeid_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LPAREN parameter_type_list RPAREN\n | direct_typeid_declarator LPAREN identifier_list_opt RPAREN\n direct_typeid_noparen_declarator : TYPEID\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_typeid_noparen_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LPAREN parameter_type_list RPAREN\n | direct_typeid_noparen_declarator LPAREN identifier_list_opt RPAREN\n id_declarator : direct_id_declarator\n id_declarator : pointer direct_id_declarator\n typeid_declarator : direct_typeid_declarator\n typeid_declarator : pointer direct_typeid_declarator\n typeid_noparen_declarator : direct_typeid_noparen_declarator\n typeid_noparen_declarator : pointer direct_typeid_noparen_declarator\n translation_unit_or_empty : translation_unit\n | empty\n translation_unit : external_declaration\n translation_unit : translation_unit external_declaration\n external_declaration : function_definition\n external_declaration : declaration\n external_declaration : pp_directive\n | pppragma_directive\n external_declaration : SEMI\n external_declaration : static_assert\n static_assert : _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN\n | _STATIC_ASSERT LPAREN constant_expression RPAREN\n pp_directive : PPHASH\n pppragma_directive : PPPRAGMA\n | PPPRAGMA PPPRAGMASTR\n function_definition : id_declarator declaration_list_opt compound_statement\n function_definition : declaration_specifiers id_declarator declaration_list_opt compound_statement\n statement : labeled_statement\n | expression_statement\n | compound_statement\n | selection_statement\n | iteration_statement\n | jump_statement\n | pppragma_directive\n | static_assert\n pragmacomp_or_statement : pppragma_directive statement\n | statement\n decl_body : declaration_specifiers init_declarator_list_opt\n | declaration_specifiers_no_type id_init_declarator_list_opt\n declaration : decl_body SEMI\n declaration_list : declaration\n | declaration_list declaration\n declaration_specifiers_no_type : type_qualifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : storage_class_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : function_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : atomic_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : alignment_specifier declaration_specifiers_no_type_opt\n declaration_specifiers : declaration_specifiers type_qualifier\n declaration_specifiers : declaration_specifiers storage_class_specifier\n declaration_specifiers : declaration_specifiers function_specifier\n declaration_specifiers : declaration_specifiers type_specifier_no_typeid\n declaration_specifiers : type_specifier\n declaration_specifiers : declaration_specifiers_no_type type_specifier\n declaration_specifiers : declaration_specifiers alignment_specifier\n storage_class_specifier : AUTO\n | REGISTER\n | STATIC\n | EXTERN\n | TYPEDEF\n | _THREAD_LOCAL\n function_specifier : INLINE\n | _NORETURN\n type_specifier_no_typeid : VOID\n | _BOOL\n | CHAR\n | SHORT\n | INT\n | LONG\n | FLOAT\n | DOUBLE\n | _COMPLEX\n | SIGNED\n | UNSIGNED\n | __INT128\n type_specifier : typedef_name\n | enum_specifier\n | struct_or_union_specifier\n | type_specifier_no_typeid\n | atomic_specifier\n atomic_specifier : _ATOMIC LPAREN type_name RPAREN\n type_qualifier : CONST\n | RESTRICT\n | VOLATILE\n | _ATOMIC\n init_declarator_list : init_declarator\n | init_declarator_list COMMA init_declarator\n init_declarator : declarator\n | declarator EQUALS initializer\n id_init_declarator_list : id_init_declarator\n | id_init_declarator_list COMMA init_declarator\n id_init_declarator : id_declarator\n | id_declarator EQUALS initializer\n specifier_qualifier_list : specifier_qualifier_list type_specifier_no_typeid\n specifier_qualifier_list : specifier_qualifier_list type_qualifier\n specifier_qualifier_list : type_specifier\n specifier_qualifier_list : type_qualifier_list type_specifier\n specifier_qualifier_list : alignment_specifier\n specifier_qualifier_list : specifier_qualifier_list alignment_specifier\n struct_or_union_specifier : struct_or_union ID\n | struct_or_union TYPEID\n struct_or_union_specifier : struct_or_union brace_open struct_declaration_list brace_close\n | struct_or_union brace_open brace_close\n struct_or_union_specifier : struct_or_union ID brace_open struct_declaration_list brace_close\n | struct_or_union ID brace_open brace_close\n | struct_or_union TYPEID brace_open struct_declaration_list brace_close\n | struct_or_union TYPEID brace_open brace_close\n struct_or_union : STRUCT\n | UNION\n struct_declaration_list : struct_declaration\n | struct_declaration_list struct_declaration\n struct_declaration : specifier_qualifier_list struct_declarator_list_opt SEMI\n struct_declaration : SEMI\n struct_declaration : pppragma_directive\n struct_declarator_list : struct_declarator\n | struct_declarator_list COMMA struct_declarator\n struct_declarator : declarator\n struct_declarator : declarator COLON constant_expression\n | COLON constant_expression\n enum_specifier : ENUM ID\n | ENUM TYPEID\n enum_specifier : ENUM brace_open enumerator_list brace_close\n enum_specifier : ENUM ID brace_open enumerator_list brace_close\n | ENUM TYPEID brace_open enumerator_list brace_close\n enumerator_list : enumerator\n | enumerator_list COMMA\n | enumerator_list COMMA enumerator\n alignment_specifier : _ALIGNAS LPAREN type_name RPAREN\n | _ALIGNAS LPAREN constant_expression RPAREN\n enumerator : ID\n | ID EQUALS constant_expression\n declarator : id_declarator\n | typeid_declarator\n pointer : TIMES type_qualifier_list_opt\n | TIMES type_qualifier_list_opt pointer\n type_qualifier_list : type_qualifier\n | type_qualifier_list type_qualifier\n parameter_type_list : parameter_list\n | parameter_list COMMA ELLIPSIS\n parameter_list : parameter_declaration\n | parameter_list COMMA parameter_declaration\n parameter_declaration : declaration_specifiers id_declarator\n | declaration_specifiers typeid_noparen_declarator\n parameter_declaration : declaration_specifiers abstract_declarator_opt\n identifier_list : identifier\n | identifier_list COMMA identifier\n initializer : assignment_expression\n initializer : brace_open initializer_list_opt brace_close\n | brace_open initializer_list COMMA brace_close\n initializer_list : designation_opt initializer\n | initializer_list COMMA designation_opt initializer\n designation : designator_list EQUALS\n designator_list : designator\n | designator_list designator\n designator : LBRACKET constant_expression RBRACKET\n | PERIOD identifier\n type_name : specifier_qualifier_list abstract_declarator_opt\n abstract_declarator : pointer\n abstract_declarator : pointer direct_abstract_declarator\n abstract_declarator : direct_abstract_declarator\n direct_abstract_declarator : LPAREN abstract_declarator RPAREN direct_abstract_declarator : direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET\n direct_abstract_declarator : LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_abstract_declarator : direct_abstract_declarator LBRACKET TIMES RBRACKET\n direct_abstract_declarator : LBRACKET TIMES RBRACKET\n direct_abstract_declarator : direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN\n direct_abstract_declarator : LPAREN parameter_type_list_opt RPAREN\n block_item : declaration\n | statement\n block_item_list : block_item\n | block_item_list block_item\n compound_statement : brace_open block_item_list_opt brace_close labeled_statement : ID COLON pragmacomp_or_statement labeled_statement : CASE constant_expression COLON pragmacomp_or_statement labeled_statement : DEFAULT COLON pragmacomp_or_statement selection_statement : IF LPAREN expression RPAREN pragmacomp_or_statement selection_statement : IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement selection_statement : SWITCH LPAREN expression RPAREN pragmacomp_or_statement iteration_statement : WHILE LPAREN expression RPAREN pragmacomp_or_statement iteration_statement : DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI iteration_statement : FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement iteration_statement : FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement jump_statement : GOTO ID SEMI jump_statement : BREAK SEMI jump_statement : CONTINUE SEMI jump_statement : RETURN expression SEMI\n | RETURN SEMI\n expression_statement : expression_opt SEMI expression : assignment_expression\n | expression COMMA assignment_expression\n assignment_expression : LPAREN compound_statement RPAREN typedef_name : TYPEID assignment_expression : conditional_expression\n | unary_expression assignment_operator assignment_expression\n assignment_operator : EQUALS\n | XOREQUAL\n | TIMESEQUAL\n | DIVEQUAL\n | MODEQUAL\n | PLUSEQUAL\n | MINUSEQUAL\n | LSHIFTEQUAL\n | RSHIFTEQUAL\n | ANDEQUAL\n | OREQUAL\n constant_expression : conditional_expression conditional_expression : binary_expression\n | binary_expression CONDOP expression COLON conditional_expression\n binary_expression : cast_expression\n | binary_expression TIMES binary_expression\n | binary_expression DIVIDE binary_expression\n | binary_expression MOD binary_expression\n | binary_expression PLUS binary_expression\n | binary_expression MINUS binary_expression\n | binary_expression RSHIFT binary_expression\n | binary_expression LSHIFT binary_expression\n | binary_expression LT binary_expression\n | binary_expression LE binary_expression\n | binary_expression GE binary_expression\n | binary_expression GT binary_expression\n | binary_expression EQ binary_expression\n | binary_expression NE binary_expression\n | binary_expression AND binary_expression\n | binary_expression OR binary_expression\n | binary_expression XOR binary_expression\n | binary_expression LAND binary_expression\n | binary_expression LOR binary_expression\n cast_expression : unary_expression cast_expression : LPAREN type_name RPAREN cast_expression unary_expression : postfix_expression unary_expression : PLUSPLUS unary_expression\n | MINUSMINUS unary_expression\n | unary_operator cast_expression\n unary_expression : SIZEOF unary_expression\n | SIZEOF LPAREN type_name RPAREN\n | _ALIGNOF LPAREN type_name RPAREN\n unary_operator : AND\n | TIMES\n | PLUS\n | MINUS\n | NOT\n | LNOT\n postfix_expression : primary_expression postfix_expression : postfix_expression LBRACKET expression RBRACKET postfix_expression : postfix_expression LPAREN argument_expression_list RPAREN\n | postfix_expression LPAREN RPAREN\n postfix_expression : postfix_expression PERIOD ID\n | postfix_expression PERIOD TYPEID\n | postfix_expression ARROW ID\n | postfix_expression ARROW TYPEID\n postfix_expression : postfix_expression PLUSPLUS\n | postfix_expression MINUSMINUS\n postfix_expression : LPAREN type_name RPAREN brace_open initializer_list brace_close\n | LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close\n primary_expression : identifier primary_expression : constant primary_expression : unified_string_literal\n | unified_wstring_literal\n primary_expression : LPAREN expression RPAREN primary_expression : OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN\n offsetof_member_designator : identifier\n | offsetof_member_designator PERIOD identifier\n | offsetof_member_designator LBRACKET expression RBRACKET\n argument_expression_list : assignment_expression\n | argument_expression_list COMMA assignment_expression\n identifier : ID constant : INT_CONST_DEC\n | INT_CONST_OCT\n | INT_CONST_HEX\n | INT_CONST_BIN\n | INT_CONST_CHAR\n constant : FLOAT_CONST\n | HEX_FLOAT_CONST\n constant : CHAR_CONST\n | WCHAR_CONST\n | U8CHAR_CONST\n | U16CHAR_CONST\n | U32CHAR_CONST\n unified_string_literal : STRING_LITERAL\n | unified_string_literal STRING_LITERAL\n unified_wstring_literal : WSTRING_LITERAL\n | U8STRING_LITERAL\n | U16STRING_LITERAL\n | U32STRING_LITERAL\n | unified_wstring_literal WSTRING_LITERAL\n | unified_wstring_literal U8STRING_LITERAL\n | unified_wstring_literal U16STRING_LITERAL\n | unified_wstring_literal U32STRING_LITERAL\n brace_open : LBRACE\n brace_close : RBRACE\n empty : ' + +_lr_action_items = {'INT_CONST_CHAR':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,132,-335,-28,-182,-27,132,-337,-87,-72,-337,132,-286,-285,132,132,-283,-287,-288,132,-284,132,132,132,-336,-183,132,132,-28,-337,132,-28,-337,-337,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,-337,-76,-79,-82,-75,132,-77,132,132,-81,-215,-214,-80,-216,132,-78,132,132,-69,-284,132,132,-284,132,132,-244,-247,-245,-241,-242,-246,-248,132,-250,-251,-243,-249,-12,132,132,-11,132,132,132,132,-234,-233,132,-231,132,132,-217,132,-230,132,-84,-218,132,132,132,-337,-337,-198,132,132,132,-337,-284,-229,-232,132,-221,132,-83,-219,-68,132,-28,-337,132,-11,132,132,-220,132,132,132,-284,132,132,132,-337,132,-225,-224,-222,-84,132,132,132,-226,-223,132,-228,-227,]),'VOID':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[6,-337,-113,-128,6,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,6,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,6,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,6,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,6,-131,-95,-101,-97,6,-53,-126,6,-88,6,6,-93,6,-147,-335,-146,6,-167,-166,-182,-100,-126,6,-87,-90,-94,-92,-61,-72,6,-144,-142,6,6,6,-73,6,-89,6,6,6,-149,-159,-160,-156,-336,6,-183,-30,6,6,-74,6,6,6,6,-174,-175,6,-143,-140,6,-141,-145,-76,-79,-82,-75,-77,6,-81,-215,-214,-80,-216,-78,-127,6,-153,6,-151,-148,-157,-168,-69,-36,-35,6,6,6,-234,-233,6,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,6,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LBRACKET':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,71,73,76,77,80,81,82,86,96,97,98,100,101,103,104,105,106,109,111,127,132,133,134,136,138,139,140,141,142,143,145,147,148,152,153,154,156,160,161,163,164,166,167,168,169,176,177,187,191,198,199,200,211,216,227,230,235,236,237,238,240,241,261,263,269,275,276,278,279,280,283,310,312,314,316,317,328,340,341,342,344,345,347,355,356,371,376,402,403,404,405,407,411,414,442,443,448,449,453,454,457,458,464,465,470,472,474,482,483,488,489,490,492,511,512,518,519,520,526,527,529,530,531,532,544,545,547,550,551,559,560,563,565,570,571,572,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,117,-123,-117,-119,-114,-130,-118,-103,-96,-98,128,-131,-37,-95,-101,-97,117,-147,-335,-146,-167,-166,-28,-180,-182,-27,-100,-126,128,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-314,-142,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,265,-323,-312,282,-149,-336,-183,-181,-30,282,-38,373,-326,-334,-332,-331,-333,-174,-175,-298,-297,-143,-140,282,282,-141,-145,421,-312,-127,-153,-151,-148,-168,-36,-35,282,282,459,-45,-44,-43,-199,373,-296,-295,-294,-293,-292,-305,421,-152,-150,-170,-169,-31,-34,282,459,-39,-42,-202,373,-200,-290,-291,373,-213,-207,-211,-33,-32,-41,-40,-201,549,-307,-209,-208,-210,-212,-51,-50,-306,373,-299,-46,-49,-308,-300,-48,-47,-309,]),'WCHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,133,-335,-28,-182,-27,133,-337,-87,-72,-337,133,-286,-285,133,133,-283,-287,-288,133,-284,133,133,133,-336,-183,133,133,-28,-337,133,-28,-337,-337,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,-337,-76,-79,-82,-75,133,-77,133,133,-81,-215,-214,-80,-216,133,-78,133,133,-69,-284,133,133,-284,133,133,-244,-247,-245,-241,-242,-246,-248,133,-250,-251,-243,-249,-12,133,133,-11,133,133,133,133,-234,-233,133,-231,133,133,-217,133,-230,133,-84,-218,133,133,133,-337,-337,-198,133,133,133,-337,-284,-229,-232,133,-221,133,-83,-219,-68,133,-28,-337,133,-11,133,133,-220,133,133,133,-284,133,133,133,-337,133,-225,-224,-222,-84,133,133,133,-226,-223,133,-228,-227,]),'FLOAT_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,134,-335,-28,-182,-27,134,-337,-87,-72,-337,134,-286,-285,134,134,-283,-287,-288,134,-284,134,134,134,-336,-183,134,134,-28,-337,134,-28,-337,-337,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,-337,-76,-79,-82,-75,134,-77,134,134,-81,-215,-214,-80,-216,134,-78,134,134,-69,-284,134,134,-284,134,134,-244,-247,-245,-241,-242,-246,-248,134,-250,-251,-243,-249,-12,134,134,-11,134,134,134,134,-234,-233,134,-231,134,134,-217,134,-230,134,-84,-218,134,134,134,-337,-337,-198,134,134,134,-337,-284,-229,-232,134,-221,134,-83,-219,-68,134,-28,-337,134,-11,134,134,-220,134,134,134,-284,134,134,134,-337,134,-225,-224,-222,-84,134,134,134,-226,-223,134,-228,-227,]),'MINUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,135,-335,-28,-182,-27,135,-337,-87,-72,-337,135,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,135,-327,135,-283,-287,-325,-304,-322,-302,-255,-315,-289,245,-328,-316,-288,-329,-320,-276,-323,135,-284,135,135,-312,135,-336,-183,135,135,-28,-337,135,-28,-337,-274,-337,135,-326,135,-280,135,-277,-334,-332,-331,-333,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,-298,-297,135,135,-279,-278,-337,-76,-79,-82,-75,135,-77,135,135,-81,-215,-214,-80,-216,135,-78,-312,135,135,-69,-284,135,135,-284,135,135,-244,-247,-245,-241,-242,-246,-248,135,-250,-251,-243,-249,-12,135,135,-11,245,245,245,-260,245,245,245,-259,245,245,-257,-256,245,245,245,245,245,-258,-296,-295,-294,-293,-292,-305,135,135,135,135,-234,-233,135,-231,135,135,-217,135,-230,135,-84,-218,135,135,135,-337,-337,-198,135,-281,-282,135,-290,-291,135,-275,-337,-284,-229,-232,135,-221,135,-83,-219,-68,135,-28,-337,135,-11,135,135,-220,135,135,135,-284,135,135,-306,135,-337,-299,135,-225,-224,-222,-84,-300,135,135,135,-226,-223,135,-228,-227,]),'RPAREN':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,71,73,76,77,80,81,82,86,96,98,100,101,103,104,105,106,107,109,111,118,125,127,129,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,157,158,159,160,161,162,163,164,166,167,168,169,176,177,178,183,187,191,198,199,200,203,207,208,209,210,211,212,213,215,216,221,222,224,225,230,232,234,235,236,237,238,240,241,261,263,266,268,269,270,271,272,273,274,275,276,277,278,279,280,281,283,294,312,314,316,317,328,340,341,342,343,344,345,346,347,348,355,356,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,408,409,411,414,415,416,417,418,422,433,439,442,443,448,449,452,453,454,457,458,460,461,462,463,464,465,468,476,478,480,482,483,486,487,489,490,492,495,501,503,507,511,512,516,517,518,519,524,525,526,527,529,530,531,532,544,545,547,551,553,556,559,560,563,565,566,567,570,571,572,573,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-103,-96,-98,-54,-131,-37,-95,-101,-97,-53,-147,-146,-167,-166,-28,-180,-182,-27,200,-100,-126,-337,216,-55,-337,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,240,-255,241,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-337,-252,312,-149,-336,-183,-181,-30,332,340,-17,341,-186,-337,-18,-184,-191,-38,355,356,-274,-239,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,407,-279,-143,411,413,-235,-278,-203,-140,-204,-1,-337,-141,-145,-2,-206,-14,-127,-153,-151,-148,-168,-36,-35,-337,-190,-204,-56,-188,-45,-189,-44,-43,476,477,478,479,480,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-310,483,-305,-205,-23,-24,489,490,-337,-13,-218,-152,-150,-170,-169,510,-31,-34,-204,-57,-337,-192,-185,-187,-39,-42,-240,-237,-281,-282,-290,-291,-236,-275,-213,-207,-211,532,535,537,539,-33,-32,544,545,-41,-40,-254,-311,547,-307,-209,-208,-210,-212,-51,-50,-306,-299,-337,568,-46,-49,-308,-300,-337,574,-48,-47,-309,577,]),'STRUCT':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[24,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,24,-65,-102,-337,-131,-108,-63,-129,24,-29,-62,-70,-52,-337,-337,-337,-130,24,-71,-103,-337,-9,-131,-91,-10,24,24,-53,-337,-88,24,24,-93,24,-335,24,-182,24,-87,-90,-94,-92,-61,-72,24,24,24,-73,24,-89,24,24,24,-159,-160,-156,-336,-183,-30,24,-74,24,24,24,24,-174,-175,24,24,-76,-79,-82,-75,-77,24,-81,-215,-214,-80,-216,-78,-127,24,24,-157,-69,-36,-35,24,24,24,-234,-233,24,-231,-217,-230,-81,-84,-218,-158,-31,-34,24,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LONG':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[23,-337,-113,-128,23,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,23,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,23,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,23,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,23,-131,-95,-101,-97,23,-53,-126,23,-88,23,23,-93,23,-147,-335,-146,23,-167,-166,-182,-100,-126,23,-87,-90,-94,-92,-61,-72,23,-144,-142,23,23,23,-73,23,-89,23,23,23,-149,-159,-160,-156,-336,23,-183,-30,23,23,-74,23,23,23,23,-174,-175,23,-143,-140,23,-141,-145,-76,-79,-82,-75,-77,23,-81,-215,-214,-80,-216,-78,-127,23,-153,23,-151,-148,-157,-168,-69,-36,-35,23,23,23,-234,-233,23,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,23,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PLUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,137,-335,-28,-182,-27,137,-337,-87,-72,-337,137,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,137,-327,137,-283,-287,-325,-304,-322,-302,-255,-315,-289,249,-328,-316,-288,-329,-320,-276,-323,137,-284,137,137,-312,137,-336,-183,137,137,-28,-337,137,-28,-337,-274,-337,137,-326,137,-280,137,-277,-334,-332,-331,-333,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,-298,-297,137,137,-279,-278,-337,-76,-79,-82,-75,137,-77,137,137,-81,-215,-214,-80,-216,137,-78,-312,137,137,-69,-284,137,137,-284,137,137,-244,-247,-245,-241,-242,-246,-248,137,-250,-251,-243,-249,-12,137,137,-11,249,249,249,-260,249,249,249,-259,249,249,-257,-256,249,249,249,249,249,-258,-296,-295,-294,-293,-292,-305,137,137,137,137,-234,-233,137,-231,137,137,-217,137,-230,137,-84,-218,137,137,137,-337,-337,-198,137,-281,-282,137,-290,-291,137,-275,-337,-284,-229,-232,137,-221,137,-83,-219,-68,137,-28,-337,137,-11,137,137,-220,137,137,137,-284,137,137,-306,137,-337,-299,137,-225,-224,-222,-84,-300,137,137,137,-226,-223,137,-228,-227,]),'ELLIPSIS':([350,],[462,]),'U32STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,139,-335,-28,-182,-27,139,-337,-87,-72,-337,139,-286,-285,-330,139,-327,139,-283,-287,235,-328,-288,-329,139,-284,139,139,139,-336,-183,139,139,-28,-337,139,-28,-337,-337,139,139,139,-334,-332,-331,-333,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,-337,-76,-79,-82,-75,139,-77,139,139,-81,-215,-214,-80,-216,139,-78,139,139,-69,-284,139,139,-284,139,139,-244,-247,-245,-241,-242,-246,-248,139,-250,-251,-243,-249,-12,139,139,-11,139,139,139,139,-234,-233,139,-231,139,139,-217,139,-230,139,-84,-218,139,139,139,-337,-337,-198,139,139,139,-337,-284,-229,-232,139,-221,139,-83,-219,-68,139,-28,-337,139,-11,139,139,-220,139,139,139,-284,139,139,139,-337,139,-225,-224,-222,-84,139,139,139,-226,-223,139,-228,-227,]),'GT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,250,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,250,-262,-260,-264,250,-263,-259,-266,250,-257,-256,-265,250,250,250,250,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'GOTO':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,287,-336,-76,-79,-82,-75,-77,287,-81,-215,-214,-80,-216,287,-78,-69,-234,-233,-231,287,-217,-230,287,-84,-218,287,-229,-232,-221,287,-83,-219,-68,287,-220,287,287,-225,-224,-222,-84,287,287,-226,-223,287,-228,-227,]),'ENUM':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[32,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,32,-65,-102,-337,-131,-108,-63,-129,32,-29,-62,-70,-52,-337,-337,-337,-130,32,-71,-103,-337,-9,-131,-91,-10,32,32,-53,-337,-88,32,32,-93,32,-335,32,-182,32,-87,-90,-94,-92,-61,-72,32,32,32,-73,32,-89,32,32,32,-159,-160,-156,-336,-183,-30,32,-74,32,32,32,32,-174,-175,32,32,-76,-79,-82,-75,-77,32,-81,-215,-214,-80,-216,-78,-127,32,32,-157,-69,-36,-35,32,32,32,-234,-233,32,-231,-217,-230,-81,-84,-218,-158,-31,-34,32,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PERIOD':([97,132,133,134,136,138,139,140,141,143,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,227,230,235,236,237,238,261,263,310,371,376,402,403,404,405,407,411,470,472,474,482,483,488,520,526,527,547,550,551,563,565,572,],[-335,-317,-321,-318,-303,-324,-330,-313,-319,-301,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,264,-323,-312,-336,372,-326,-334,-332,-331,-333,-298,-297,-312,-199,372,-296,-295,-294,-293,-292,-305,-202,372,-200,-290,-291,372,-201,548,-307,-306,372,-299,-308,-300,-309,]),'GE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,254,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,254,-262,-260,-264,254,-263,-259,-266,254,-257,-256,-265,254,254,254,254,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'INT_CONST_DEC':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,140,-335,-28,-182,-27,140,-337,-87,-72,-337,140,-286,-285,140,140,-283,-287,-288,140,-284,140,140,140,-336,-183,140,140,-28,-337,140,-28,-337,-337,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,-337,-76,-79,-82,-75,140,-77,140,140,-81,-215,-214,-80,-216,140,-78,140,140,-69,-284,140,140,-284,140,140,-244,-247,-245,-241,-242,-246,-248,140,-250,-251,-243,-249,-12,140,140,-11,140,140,140,140,-234,-233,140,-231,140,140,-217,140,-230,140,-84,-218,140,140,140,-337,-337,-198,140,140,140,-337,-284,-229,-232,140,-221,140,-83,-219,-68,140,-28,-337,140,-11,140,140,-220,140,140,140,-284,140,140,140,-337,140,-225,-224,-222,-84,140,140,140,-226,-223,140,-228,-227,]),'ARROW':([132,133,134,136,138,139,140,141,143,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,230,235,236,237,238,261,263,310,402,403,404,405,407,411,482,483,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,262,-323,-312,-336,-326,-334,-332,-331,-333,-298,-297,-312,-296,-295,-294,-293,-292,-305,-290,-291,-306,-299,-300,]),'_STATIC_ASSERT':([0,14,16,17,19,25,38,45,47,59,61,97,119,123,124,180,181,191,223,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[41,-64,-60,-67,-66,-65,-63,-62,-70,41,-71,-335,-87,-61,-72,-73,41,-336,-74,-76,-79,-82,-75,-77,41,-81,-215,-214,-80,-216,41,-78,-69,-234,-233,-231,41,-217,-230,41,-84,-218,41,-229,-232,-221,41,-83,-219,-68,41,-220,41,41,-225,-224,-222,-84,41,41,-226,-223,41,-228,-227,]),'CHAR':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[46,-337,-113,-128,46,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,46,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,46,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,46,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,46,-131,-95,-101,-97,46,-53,-126,46,-88,46,46,-93,46,-147,-335,-146,46,-167,-166,-182,-100,-126,46,-87,-90,-94,-92,-61,-72,46,-144,-142,46,46,46,-73,46,-89,46,46,46,-149,-159,-160,-156,-336,46,-183,-30,46,46,-74,46,46,46,46,-174,-175,46,-143,-140,46,-141,-145,-76,-79,-82,-75,-77,46,-81,-215,-214,-80,-216,-78,-127,46,-153,46,-151,-148,-157,-168,-69,-36,-35,46,46,46,-234,-233,46,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,46,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'HEX_FLOAT_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,141,-335,-28,-182,-27,141,-337,-87,-72,-337,141,-286,-285,141,141,-283,-287,-288,141,-284,141,141,141,-336,-183,141,141,-28,-337,141,-28,-337,-337,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,-337,-76,-79,-82,-75,141,-77,141,141,-81,-215,-214,-80,-216,141,-78,141,141,-69,-284,141,141,-284,141,141,-244,-247,-245,-241,-242,-246,-248,141,-250,-251,-243,-249,-12,141,141,-11,141,141,141,141,-234,-233,141,-231,141,141,-217,141,-230,141,-84,-218,141,141,141,-337,-337,-198,141,141,141,-337,-284,-229,-232,141,-221,141,-83,-219,-68,141,-28,-337,141,-11,141,141,-220,141,141,141,-284,141,141,141,-337,141,-225,-224,-222,-84,141,141,141,-226,-223,141,-228,-227,]),'DOUBLE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[50,-337,-113,-128,50,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,50,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,50,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,50,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,50,-131,-95,-101,-97,50,-53,-126,50,-88,50,50,-93,50,-147,-335,-146,50,-167,-166,-182,-100,-126,50,-87,-90,-94,-92,-61,-72,50,-144,-142,50,50,50,-73,50,-89,50,50,50,-149,-159,-160,-156,-336,50,-183,-30,50,50,-74,50,50,50,50,-174,-175,50,-143,-140,50,-141,-145,-76,-79,-82,-75,-77,50,-81,-215,-214,-80,-216,-78,-127,50,-153,50,-151,-148,-157,-168,-69,-36,-35,50,50,50,-234,-233,50,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,50,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'MINUSEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,358,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'INT_CONST_OCT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,145,-335,-28,-182,-27,145,-337,-87,-72,-337,145,-286,-285,145,145,-283,-287,-288,145,-284,145,145,145,-336,-183,145,145,-28,-337,145,-28,-337,-337,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,-337,-76,-79,-82,-75,145,-77,145,145,-81,-215,-214,-80,-216,145,-78,145,145,-69,-284,145,145,-284,145,145,-244,-247,-245,-241,-242,-246,-248,145,-250,-251,-243,-249,-12,145,145,-11,145,145,145,145,-234,-233,145,-231,145,145,-217,145,-230,145,-84,-218,145,145,145,-337,-337,-198,145,145,145,-337,-284,-229,-232,145,-221,145,-83,-219,-68,145,-28,-337,145,-11,145,145,-220,145,145,145,-284,145,145,145,-337,145,-225,-224,-222,-84,145,145,145,-226,-223,145,-228,-227,]),'TIMESEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,367,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'OR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,259,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,259,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,259,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SHORT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[2,-337,-113,-128,2,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,2,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,2,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,2,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,2,-131,-95,-101,-97,2,-53,-126,2,-88,2,2,-93,2,-147,-335,-146,2,-167,-166,-182,-100,-126,2,-87,-90,-94,-92,-61,-72,2,-144,-142,2,2,2,-73,2,-89,2,2,2,-149,-159,-160,-156,-336,2,-183,-30,2,2,-74,2,2,2,2,-174,-175,2,-143,-140,2,-141,-145,-76,-79,-82,-75,-77,2,-81,-215,-214,-80,-216,-78,-127,2,-153,2,-151,-148,-157,-168,-69,-36,-35,2,2,2,-234,-233,2,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,2,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'RETURN':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,290,-336,-76,-79,-82,-75,-77,290,-81,-215,-214,-80,-216,290,-78,-69,-234,-233,-231,290,-217,-230,290,-84,-218,290,-229,-232,-221,290,-83,-219,-68,290,-220,290,290,-225,-224,-222,-84,290,290,-226,-223,290,-228,-227,]),'RSHIFTEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,368,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'_ALIGNAS':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,109,111,118,119,123,124,129,142,147,174,177,180,181,182,184,185,186,187,188,189,190,191,192,200,211,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[8,8,-113,-128,8,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,8,-120,-115,-65,-102,8,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,8,8,-119,8,-114,-130,8,-118,-71,-103,8,-131,-96,-98,8,-131,-95,-101,-97,8,-53,8,8,-88,8,8,-147,-335,-146,8,-167,-166,-100,-126,8,-87,-61,-72,8,-144,-142,8,8,-73,8,-89,8,8,8,-149,-159,-160,-156,-336,8,-30,8,-74,8,8,8,8,-174,-175,8,-143,-140,8,-141,-145,-76,-79,-82,-75,-77,8,-81,-215,-214,-80,-216,-78,-127,8,-153,8,-151,-148,-157,-168,-69,-36,-35,8,8,8,-234,-233,8,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,8,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'RESTRICT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[39,39,-113,-128,39,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,39,-120,-115,-65,-102,39,-131,-108,-238,-111,39,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,39,39,-119,39,-114,-130,39,-118,-71,-103,39,-131,-96,-98,39,-131,-95,-101,-97,39,-53,39,39,-88,39,39,-147,-335,-146,39,-167,-166,39,-182,-100,-126,39,39,-87,-61,-72,39,39,-144,-142,39,39,39,-73,39,-89,39,39,39,-149,-159,-160,-156,-336,39,-183,-30,39,39,39,39,39,-74,39,39,39,39,-174,-175,39,-143,-140,39,-141,-145,39,-76,-79,-82,-75,-77,39,-81,-215,-214,-80,-216,-78,-127,39,-153,39,-151,-148,-157,-168,-69,-36,-35,39,39,39,-234,-233,39,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,39,39,-229,-232,-221,-83,-219,-68,-33,-32,39,39,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'STATIC':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,105,109,111,117,118,119,123,124,128,129,180,181,182,187,191,198,200,205,211,219,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,536,554,555,557,558,575,576,578,579,],[10,10,-113,-128,10,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,10,-120,-115,-65,-102,10,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,10,10,-119,10,-114,-130,10,-118,-71,-103,10,-131,-96,-98,10,-131,-95,-101,-97,-53,10,10,-88,10,-147,-335,-146,-167,-166,-182,-100,-126,206,10,-87,-61,-72,220,10,-73,10,-89,-149,-336,-183,-30,338,10,353,-74,-174,-175,10,-76,-79,-82,-75,-77,10,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,10,10,10,-234,-233,10,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,515,10,-229,-232,-221,-83,-219,-68,-33,-32,542,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'SIZEOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,146,-335,-28,-182,-27,146,-337,-87,-72,-337,146,-286,-285,146,146,-283,-287,-288,146,-284,146,146,146,-336,-183,146,146,-28,-337,146,-28,-337,-337,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,-337,-76,-79,-82,-75,146,-77,146,146,-81,-215,-214,-80,-216,146,-78,146,146,-69,-284,146,146,-284,146,146,-244,-247,-245,-241,-242,-246,-248,146,-250,-251,-243,-249,-12,146,146,-11,146,146,146,146,-234,-233,146,-231,146,146,-217,146,-230,146,-84,-218,146,146,146,-337,-337,-198,146,146,146,-337,-284,-229,-232,146,-221,146,-83,-219,-68,146,-28,-337,146,-11,146,146,-220,146,146,146,-284,146,146,146,-337,146,-225,-224,-222,-84,146,146,146,-226,-223,146,-228,-227,]),'UNSIGNED':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[22,-337,-113,-128,22,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,22,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,22,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,22,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,22,-131,-95,-101,-97,22,-53,-126,22,-88,22,22,-93,22,-147,-335,-146,22,-167,-166,-182,-100,-126,22,-87,-90,-94,-92,-61,-72,22,-144,-142,22,22,22,-73,22,-89,22,22,22,-149,-159,-160,-156,-336,22,-183,-30,22,22,-74,22,22,22,22,-174,-175,22,-143,-140,22,-141,-145,-76,-79,-82,-75,-77,22,-81,-215,-214,-80,-216,-78,-127,22,-153,22,-151,-148,-157,-168,-69,-36,-35,22,22,22,-234,-233,22,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,22,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'UNION':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[28,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,28,-65,-102,-337,-131,-108,-63,-129,28,-29,-62,-70,-52,-337,-337,-337,-130,28,-71,-103,-337,-9,-131,-91,-10,28,28,-53,-337,-88,28,28,-93,28,-335,28,-182,28,-87,-90,-94,-92,-61,-72,28,28,28,-73,28,-89,28,28,28,-159,-160,-156,-336,-183,-30,28,-74,28,28,28,28,-174,-175,28,28,-76,-79,-82,-75,-77,28,-81,-215,-214,-80,-216,-78,-127,28,28,-157,-69,-36,-35,28,28,28,-234,-233,28,-231,-217,-230,-81,-84,-218,-158,-31,-34,28,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'COLON':([2,3,5,6,12,22,23,33,34,36,39,42,43,44,46,48,49,50,54,56,58,60,73,74,76,77,86,96,98,100,101,111,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,179,187,191,192,200,216,224,225,230,232,234,235,236,237,238,240,241,261,263,268,269,272,273,275,279,280,295,310,312,314,316,317,324,328,340,341,355,356,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,407,411,431,442,443,445,448,449,453,454,464,465,468,476,478,480,482,483,486,487,511,512,518,519,524,547,551,565,],[-113,-128,-124,-110,-125,-120,-115,-238,-111,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-54,-179,-131,-37,-53,-147,-146,-167,-166,-126,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-178,-149,-336,319,-30,-38,-274,-239,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,-235,-278,-140,-141,-145,429,440,-127,-153,-151,-148,447,-168,-36,-35,-44,-43,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,481,-270,-258,-296,-295,-294,-293,-292,-305,502,-152,-150,319,-170,-169,-31,-34,-39,-42,-240,-237,-281,-282,-290,-291,-236,-275,-33,-32,-41,-40,-254,-306,-299,-300,]),'$end':([0,9,14,16,17,19,25,38,45,47,57,59,61,119,123,124,180,191,223,332,439,510,],[-337,0,-64,-60,-67,-66,-65,-63,-62,-70,-59,-58,-71,-87,-61,-72,-73,-336,-74,-69,-218,-68,]),'WSTRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,148,-335,-28,-182,-27,148,-337,-87,-72,-337,148,-286,-285,-330,148,-327,148,-283,-287,237,-328,-288,-329,148,-284,148,148,148,-336,-183,148,148,-28,-337,148,-28,-337,-337,148,148,148,-334,-332,-331,-333,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,-337,-76,-79,-82,-75,148,-77,148,148,-81,-215,-214,-80,-216,148,-78,148,148,-69,-284,148,148,-284,148,148,-244,-247,-245,-241,-242,-246,-248,148,-250,-251,-243,-249,-12,148,148,-11,148,148,148,148,-234,-233,148,-231,148,148,-217,148,-230,148,-84,-218,148,148,148,-337,-337,-198,148,148,148,-337,-284,-229,-232,148,-221,148,-83,-219,-68,148,-28,-337,148,-11,148,148,-220,148,148,148,-284,148,148,148,-337,148,-225,-224,-222,-84,148,148,148,-226,-223,148,-228,-227,]),'DIVIDE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,252,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,252,252,252,252,252,252,252,252,252,252,-257,-256,252,252,252,252,252,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'FOR':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,292,-336,-76,-79,-82,-75,-77,292,-81,-215,-214,-80,-216,292,-78,-69,-234,-233,-231,292,-217,-230,292,-84,-218,292,-229,-232,-221,292,-83,-219,-68,292,-220,292,292,-225,-224,-222,-84,292,292,-226,-223,292,-228,-227,]),'PLUSPLUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,145,146,148,149,150,151,152,153,154,156,160,161,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,482,483,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,149,-335,-28,-182,-27,149,-337,-87,-72,-337,149,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-314,149,-327,149,-283,-287,-325,-304,-322,-302,-315,-289,-328,-316,-288,-329,-320,263,-323,149,-284,149,149,-312,149,-336,-183,149,149,-28,-337,149,-28,-337,-337,149,-326,149,149,-334,-332,-331,-333,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,-298,-297,149,149,-337,-76,-79,-82,-75,149,-77,149,149,-81,-215,-214,-80,-216,149,-78,-312,149,149,-69,-284,149,149,-284,149,149,-244,-247,-245,-241,-242,-246,-248,149,-250,-251,-243,-249,-12,149,149,-11,-296,-295,-294,-293,-292,-305,149,149,149,149,-234,-233,149,-231,149,149,-217,149,-230,149,-84,-218,149,149,149,-337,-337,-198,149,149,-290,-291,149,-337,-284,-229,-232,149,-221,149,-83,-219,-68,149,-28,-337,149,-11,149,149,-220,149,149,149,-284,149,149,-306,149,-337,-299,149,-225,-224,-222,-84,-300,149,149,149,-226,-223,149,-228,-227,]),'EQUALS':([42,48,73,74,75,77,78,86,110,127,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,179,191,197,200,216,224,230,232,234,235,236,237,238,261,263,268,273,310,340,341,355,356,371,376,402,403,404,405,407,411,453,454,464,465,470,474,478,480,482,483,487,511,512,518,519,520,547,551,565,],[-29,-52,-54,-179,-178,-37,131,-53,201,-55,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-178,-336,329,-30,-38,360,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-36,-35,-44,-43,-199,475,-296,-295,-294,-293,-292,-305,-31,-34,-39,-42,-202,-200,-281,-282,-290,-291,-275,-33,-32,-41,-40,-201,-306,-299,-300,]),'ELSE':([61,124,191,284,285,286,289,291,300,303,308,332,424,425,428,435,437,438,439,496,497,500,505,506,510,536,554,555,557,558,575,576,578,579,],[-71,-72,-336,-76,-79,-82,-75,-77,-81,-80,-78,-69,-234,-233,-231,-230,-81,-84,-218,-229,-232,-221,-83,-219,-68,-220,-225,-224,-222,569,-226,-223,-228,-227,]),'ANDEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,365,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'EQ':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,256,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,256,-262,-260,-264,-268,-263,-259,-266,256,-257,-256,-265,256,-267,256,256,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'AND':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,150,-335,-28,-182,-27,150,-337,-87,-72,-337,150,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,150,-327,150,-283,-287,-325,-304,-322,-302,-255,-315,-289,257,-328,-316,-288,-329,-320,-276,-323,150,-284,150,150,-312,150,-336,-183,150,150,-28,-337,150,-28,-337,-274,-337,150,-326,150,-280,150,-277,-334,-332,-331,-333,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,-298,-297,150,150,-279,-278,-337,-76,-79,-82,-75,150,-77,150,150,-81,-215,-214,-80,-216,150,-78,-312,150,150,-69,-284,150,150,-284,150,150,-244,-247,-245,-241,-242,-246,-248,150,-250,-251,-243,-249,-12,150,150,-11,-261,257,-262,-260,-264,-268,-263,-259,-266,257,-257,-256,-265,257,-267,-269,257,-258,-296,-295,-294,-293,-292,-305,150,150,150,150,-234,-233,150,-231,150,150,-217,150,-230,150,-84,-218,150,150,150,-337,-337,-198,150,-281,-282,150,-290,-291,150,-275,-337,-284,-229,-232,150,-221,150,-83,-219,-68,150,-28,-337,150,-11,150,150,-220,150,150,150,-284,150,150,-306,150,-337,-299,150,-225,-224,-222,-84,-300,150,150,150,-226,-223,150,-228,-227,]),'TYPEID':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,103,104,105,106,109,111,118,119,120,121,122,123,124,126,129,142,147,172,174,180,181,182,184,185,186,187,188,189,190,191,192,198,199,200,202,211,214,223,229,231,233,239,240,241,262,264,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,344,350,422,424,425,427,428,432,435,437,438,439,442,443,445,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[33,-337,-113,-128,77,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,33,-120,-115,-154,-65,-102,-126,-155,-131,-108,96,100,-238,-111,-337,-122,-63,-129,33,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,33,-118,-71,-103,-337,-9,-131,-91,-10,-96,77,-98,77,33,-131,-95,-101,-97,33,-53,-126,77,-88,33,33,-93,33,-147,-335,-146,33,-167,-166,-28,-180,-182,-27,-100,-126,33,-87,-90,-94,-92,-61,-72,77,33,-144,-142,33,33,-73,33,-89,33,33,33,-149,-159,-160,-156,-336,77,-183,-181,-30,77,347,33,-74,33,33,33,33,-174,-175,402,404,33,-143,-140,33,-141,-145,-76,-79,-82,-75,-77,33,-81,-215,-214,-80,-216,-78,-127,33,-153,33,-151,-148,-157,-168,-69,-36,-35,33,347,33,33,-234,-233,33,-231,-217,-230,-81,-84,-218,-152,-150,77,-158,-170,-169,-31,-34,33,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LBRACE':([21,24,28,31,32,42,48,61,75,86,88,90,92,93,96,97,98,100,101,119,124,130,131,181,182,191,200,201,227,229,284,285,286,289,291,298,300,301,302,303,305,307,308,332,340,341,369,375,377,413,424,425,428,429,432,435,437,438,439,440,453,454,472,475,477,478,479,488,496,497,500,502,505,506,510,511,512,521,522,535,536,537,539,550,554,555,557,558,569,574,575,576,577,578,579,],[-337,-154,-155,97,97,-29,-52,-71,-337,-53,-7,-88,97,-8,97,-335,97,97,97,-87,-72,97,97,97,-89,-336,-30,97,-337,97,-76,-79,-82,-75,-77,97,-81,-215,-214,-80,-216,97,-78,-69,-36,-35,-12,97,-11,97,-234,-233,-231,97,-217,-230,97,-84,-218,97,-31,-34,-337,-198,97,97,97,-337,-229,-232,-221,97,-83,-219,-68,-33,-32,97,-11,97,-220,97,97,-337,-225,-224,-222,-84,97,97,-226,-223,97,-228,-227,]),'PPHASH':([0,14,16,17,19,25,38,45,47,59,61,119,123,124,180,191,223,332,439,510,],[47,-64,-60,-67,-66,-65,-63,-62,-70,47,-71,-87,-61,-72,-73,-336,-74,-69,-218,-68,]),'INT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[56,-337,-113,-128,56,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,56,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,56,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,56,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,56,-131,-95,-101,-97,56,-53,-126,56,-88,56,56,-93,56,-147,-335,-146,56,-167,-166,-182,-100,-126,56,-87,-90,-94,-92,-61,-72,56,-144,-142,56,56,56,-73,56,-89,56,56,56,-149,-159,-160,-156,-336,56,-183,-30,56,56,-74,56,56,56,56,-174,-175,56,-143,-140,56,-141,-145,-76,-79,-82,-75,-77,56,-81,-215,-214,-80,-216,-78,-127,56,-153,56,-151,-148,-157,-168,-69,-36,-35,56,56,56,-234,-233,56,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,56,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'SIGNED':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[54,-337,-113,-128,54,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,54,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,54,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,54,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,54,-131,-95,-101,-97,54,-53,-126,54,-88,54,54,-93,54,-147,-335,-146,54,-167,-166,-182,-100,-126,54,-87,-90,-94,-92,-61,-72,54,-144,-142,54,54,54,-73,54,-89,54,54,54,-149,-159,-160,-156,-336,54,-183,-30,54,54,-74,54,54,54,54,-174,-175,54,-143,-140,54,-141,-145,-76,-79,-82,-75,-77,54,-81,-215,-214,-80,-216,-78,-127,54,-153,54,-151,-148,-157,-168,-69,-36,-35,54,54,54,-234,-233,54,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,54,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'CONTINUE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,293,-336,-76,-79,-82,-75,-77,293,-81,-215,-214,-80,-216,293,-78,-69,-234,-233,-231,293,-217,-230,293,-84,-218,293,-229,-232,-221,293,-83,-219,-68,293,-220,293,293,-225,-224,-222,-84,293,293,-226,-223,293,-228,-227,]),'NOT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,151,-335,-28,-182,-27,151,-337,-87,-72,-337,151,-286,-285,151,151,-283,-287,-288,151,-284,151,151,151,-336,-183,151,151,-28,-337,151,-28,-337,-337,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,-337,-76,-79,-82,-75,151,-77,151,151,-81,-215,-214,-80,-216,151,-78,151,151,-69,-284,151,151,-284,151,151,-244,-247,-245,-241,-242,-246,-248,151,-250,-251,-243,-249,-12,151,151,-11,151,151,151,151,-234,-233,151,-231,151,151,-217,151,-230,151,-84,-218,151,151,151,-337,-337,-198,151,151,151,-337,-284,-229,-232,151,-221,151,-83,-219,-68,151,-28,-337,151,-11,151,151,-220,151,151,151,-284,151,151,151,-337,151,-225,-224,-222,-84,151,151,151,-226,-223,151,-228,-227,]),'OREQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,366,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'MOD':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,260,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,260,260,260,260,260,260,260,260,260,260,-257,-256,260,260,260,260,260,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'RSHIFT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,242,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,242,-262,-260,242,242,242,-259,242,242,-257,-256,242,242,242,242,242,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'DEFAULT':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,295,-336,-76,-79,-82,-75,-77,295,-81,-215,-214,-80,-216,295,-78,-69,-234,-233,-231,295,-217,-230,295,-84,-218,295,-229,-232,-221,295,-83,-219,-68,295,-220,295,295,-225,-224,-222,-84,295,295,-226,-223,295,-228,-227,]),'_NORETURN':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[20,20,-113,-128,20,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,20,-120,-115,-65,-102,20,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,20,20,-119,20,-114,-130,20,-118,-71,-103,20,-131,-96,-98,20,-131,-95,-101,-97,-53,20,20,-88,20,-147,-335,-146,-167,-166,-100,-126,20,-87,-61,-72,20,-73,20,-89,-149,-336,-30,20,-74,-174,-175,20,-76,-79,-82,-75,-77,20,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,20,20,20,-234,-233,20,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,20,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'__INT128':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[43,-337,-113,-128,43,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,43,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,43,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,43,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,43,-131,-95,-101,-97,43,-53,-126,43,-88,43,43,-93,43,-147,-335,-146,43,-167,-166,-182,-100,-126,43,-87,-90,-94,-92,-61,-72,43,-144,-142,43,43,43,-73,43,-89,43,43,43,-149,-159,-160,-156,-336,43,-183,-30,43,43,-74,43,43,43,43,-174,-175,43,-143,-140,43,-141,-145,-76,-79,-82,-75,-77,43,-81,-215,-214,-80,-216,-78,-127,43,-153,43,-151,-148,-157,-168,-69,-36,-35,43,43,43,-234,-233,43,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,43,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'WHILE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,436,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,296,-336,-76,-79,-82,-75,-77,296,-81,-215,-214,-80,-216,296,-78,-69,-234,-233,-231,296,-217,-230,504,296,-84,-218,296,-229,-232,-221,296,-83,-219,-68,296,-220,296,296,-225,-224,-222,-84,296,296,-226,-223,296,-228,-227,]),'U8CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,154,-335,-28,-182,-27,154,-337,-87,-72,-337,154,-286,-285,154,154,-283,-287,-288,154,-284,154,154,154,-336,-183,154,154,-28,-337,154,-28,-337,-337,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,-337,-76,-79,-82,-75,154,-77,154,154,-81,-215,-214,-80,-216,154,-78,154,154,-69,-284,154,154,-284,154,154,-244,-247,-245,-241,-242,-246,-248,154,-250,-251,-243,-249,-12,154,154,-11,154,154,154,154,-234,-233,154,-231,154,154,-217,154,-230,154,-84,-218,154,154,154,-337,-337,-198,154,154,154,-337,-284,-229,-232,154,-221,154,-83,-219,-68,154,-28,-337,154,-11,154,154,-220,154,154,154,-284,154,154,154,-337,154,-225,-224,-222,-84,154,154,154,-226,-223,154,-228,-227,]),'_ALIGNOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,155,-335,-28,-182,-27,155,-337,-87,-72,-337,155,-286,-285,155,155,-283,-287,-288,155,-284,155,155,155,-336,-183,155,155,-28,-337,155,-28,-337,-337,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,-337,-76,-79,-82,-75,155,-77,155,155,-81,-215,-214,-80,-216,155,-78,155,155,-69,-284,155,155,-284,155,155,-244,-247,-245,-241,-242,-246,-248,155,-250,-251,-243,-249,-12,155,155,-11,155,155,155,155,-234,-233,155,-231,155,155,-217,155,-230,155,-84,-218,155,155,155,-337,-337,-198,155,155,155,-337,-284,-229,-232,155,-221,155,-83,-219,-68,155,-28,-337,155,-11,155,155,-220,155,155,155,-284,155,155,155,-337,155,-225,-224,-222,-84,155,155,155,-226,-223,155,-228,-227,]),'EXTERN':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[13,13,-113,-128,13,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,13,-120,-115,-65,-102,13,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,13,13,-119,13,-114,-130,13,-118,-71,-103,13,-131,-96,-98,13,-131,-95,-101,-97,-53,13,13,-88,13,-147,-335,-146,-167,-166,-100,-126,13,-87,-61,-72,13,-73,13,-89,-149,-336,-30,13,-74,-174,-175,13,-76,-79,-82,-75,-77,13,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,13,13,13,-234,-233,13,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,13,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'CASE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,297,-336,-76,-79,-82,-75,-77,297,-81,-215,-214,-80,-216,297,-78,-69,-234,-233,-231,297,-217,-230,297,-84,-218,297,-229,-232,-221,297,-83,-219,-68,297,-220,297,297,-225,-224,-222,-84,297,297,-226,-223,297,-228,-227,]),'LAND':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,255,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,255,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'REGISTER':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[62,62,-113,-128,62,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,62,-120,-115,-65,-102,62,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,62,62,-119,62,-114,-130,62,-118,-71,-103,62,-131,-96,-98,62,-131,-95,-101,-97,-53,62,62,-88,62,-147,-335,-146,-167,-166,-100,-126,62,-87,-61,-72,62,-73,62,-89,-149,-336,-30,62,-74,-174,-175,62,-76,-79,-82,-75,-77,62,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,62,62,62,-234,-233,62,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,62,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'MODEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,359,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'NE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,247,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,247,-262,-260,-264,-268,-263,-259,-266,247,-257,-256,-265,247,-267,247,247,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SWITCH':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,299,-336,-76,-79,-82,-75,-77,299,-81,-215,-214,-80,-216,299,-78,-69,-234,-233,-231,299,-217,-230,299,-84,-218,299,-229,-232,-221,299,-83,-219,-68,299,-220,299,299,-225,-224,-222,-84,299,299,-226,-223,299,-228,-227,]),'INT_CONST_HEX':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,160,-335,-28,-182,-27,160,-337,-87,-72,-337,160,-286,-285,160,160,-283,-287,-288,160,-284,160,160,160,-336,-183,160,160,-28,-337,160,-28,-337,-337,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,-337,-76,-79,-82,-75,160,-77,160,160,-81,-215,-214,-80,-216,160,-78,160,160,-69,-284,160,160,-284,160,160,-244,-247,-245,-241,-242,-246,-248,160,-250,-251,-243,-249,-12,160,160,-11,160,160,160,160,-234,-233,160,-231,160,160,-217,160,-230,160,-84,-218,160,160,160,-337,-337,-198,160,160,160,-337,-284,-229,-232,160,-221,160,-83,-219,-68,160,-28,-337,160,-11,160,160,-220,160,160,160,-284,160,160,160,-337,160,-225,-224,-222,-84,160,160,160,-226,-223,160,-228,-227,]),'_COMPLEX':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[60,-337,-113,-128,60,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,60,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,60,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,60,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,60,-131,-95,-101,-97,60,-53,-126,60,-88,60,60,-93,60,-147,-335,-146,60,-167,-166,-182,-100,-126,60,-87,-90,-94,-92,-61,-72,60,-144,-142,60,60,60,-73,60,-89,60,60,60,-149,-159,-160,-156,-336,60,-183,-30,60,60,-74,60,60,60,60,-174,-175,60,-143,-140,60,-141,-145,-76,-79,-82,-75,-77,60,-81,-215,-214,-80,-216,-78,-127,60,-153,60,-151,-148,-157,-168,-69,-36,-35,60,60,60,-234,-233,60,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,60,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PPPRAGMASTR':([61,],[124,]),'PLUSEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,362,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'U32CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,138,-335,-28,-182,-27,138,-337,-87,-72,-337,138,-286,-285,138,138,-283,-287,-288,138,-284,138,138,138,-336,-183,138,138,-28,-337,138,-28,-337,-337,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,-337,-76,-79,-82,-75,138,-77,138,138,-81,-215,-214,-80,-216,138,-78,138,138,-69,-284,138,138,-284,138,138,-244,-247,-245,-241,-242,-246,-248,138,-250,-251,-243,-249,-12,138,138,-11,138,138,138,138,-234,-233,138,-231,138,138,-217,138,-230,138,-84,-218,138,138,138,-337,-337,-198,138,138,138,-337,-284,-229,-232,138,-221,138,-83,-219,-68,138,-28,-337,138,-11,138,138,-220,138,138,138,-284,138,138,138,-337,138,-225,-224,-222,-84,138,138,138,-226,-223,138,-228,-227,]),'CONDOP':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,258,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'U8STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,163,-335,-28,-182,-27,163,-337,-87,-72,-337,163,-286,-285,-330,163,-327,163,-283,-287,236,-328,-288,-329,163,-284,163,163,163,-336,-183,163,163,-28,-337,163,-28,-337,-337,163,163,163,-334,-332,-331,-333,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,-337,-76,-79,-82,-75,163,-77,163,163,-81,-215,-214,-80,-216,163,-78,163,163,-69,-284,163,163,-284,163,163,-244,-247,-245,-241,-242,-246,-248,163,-250,-251,-243,-249,-12,163,163,-11,163,163,163,163,-234,-233,163,-231,163,163,-217,163,-230,163,-84,-218,163,163,163,-337,-337,-198,163,163,163,-337,-284,-229,-232,163,-221,163,-83,-219,-68,163,-28,-337,163,-11,163,163,-220,163,163,163,-284,163,163,163,-337,163,-225,-224,-222,-84,163,163,163,-226,-223,163,-228,-227,]),'BREAK':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,304,-336,-76,-79,-82,-75,-77,304,-81,-215,-214,-80,-216,304,-78,-69,-234,-233,-231,304,-217,-230,304,-84,-218,304,-229,-232,-221,304,-83,-219,-68,304,-220,304,304,-225,-224,-222,-84,304,304,-226,-223,304,-228,-227,]),'VOLATILE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[58,58,-113,-128,58,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,58,-120,-115,-65,-102,58,-131,-108,-238,-111,58,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,58,58,-119,58,-114,-130,58,-118,-71,-103,58,-131,-96,-98,58,-131,-95,-101,-97,58,-53,58,58,-88,58,58,-147,-335,-146,58,-167,-166,58,-182,-100,-126,58,58,-87,-61,-72,58,58,-144,-142,58,58,58,-73,58,-89,58,58,58,-149,-159,-160,-156,-336,58,-183,-30,58,58,58,58,58,-74,58,58,58,58,-174,-175,58,-143,-140,58,-141,-145,58,-76,-79,-82,-75,-77,58,-81,-215,-214,-80,-216,-78,-127,58,-153,58,-151,-148,-157,-168,-69,-36,-35,58,58,58,-234,-233,58,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,58,58,-229,-232,-221,-83,-219,-68,-33,-32,58,58,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PPPRAGMA':([0,14,16,17,19,25,38,45,47,59,61,97,99,119,123,124,180,181,184,185,186,188,189,190,191,223,284,285,286,289,291,298,300,301,302,303,305,307,308,313,315,318,332,424,425,428,429,432,435,437,438,439,440,446,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[61,-64,-60,-67,-66,-65,-63,-62,-70,61,-71,-335,61,-87,-61,-72,-73,61,61,61,61,-159,-160,-156,-336,-74,-76,-79,-82,-75,-77,61,-81,-215,-214,-80,-216,61,-78,61,61,-157,-69,-234,-233,-231,61,-217,-230,61,-84,-218,61,-158,-229,-232,-221,61,-83,-219,-68,61,-220,61,61,-225,-224,-222,-84,61,61,-226,-223,61,-228,-227,]),'INLINE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[30,30,-113,-128,30,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,30,-120,-115,-65,-102,30,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,30,30,-119,30,-114,-130,30,-118,-71,-103,30,-131,-96,-98,30,-131,-95,-101,-97,-53,30,30,-88,30,-147,-335,-146,-167,-166,-100,-126,30,-87,-61,-72,30,-73,30,-89,-149,-336,-30,30,-74,-174,-175,30,-76,-79,-82,-75,-77,30,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,30,30,30,-234,-233,30,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,30,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'INT_CONST_BIN':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,164,-335,-28,-182,-27,164,-337,-87,-72,-337,164,-286,-285,164,164,-283,-287,-288,164,-284,164,164,164,-336,-183,164,164,-28,-337,164,-28,-337,-337,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,-337,-76,-79,-82,-75,164,-77,164,164,-81,-215,-214,-80,-216,164,-78,164,164,-69,-284,164,164,-284,164,164,-244,-247,-245,-241,-242,-246,-248,164,-250,-251,-243,-249,-12,164,164,-11,164,164,164,164,-234,-233,164,-231,164,164,-217,164,-230,164,-84,-218,164,164,164,-337,-337,-198,164,164,164,-337,-284,-229,-232,164,-221,164,-83,-219,-68,164,-28,-337,164,-11,164,164,-220,164,164,164,-284,164,164,164,-337,164,-225,-224,-222,-84,164,164,164,-226,-223,164,-228,-227,]),'DO':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,307,-336,-76,-79,-82,-75,-77,307,-81,-215,-214,-80,-216,307,-78,-69,-234,-233,-231,307,-217,-230,307,-84,-218,307,-229,-232,-221,307,-83,-219,-68,307,-220,307,307,-225,-224,-222,-84,307,307,-226,-223,307,-228,-227,]),'LNOT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,165,-335,-28,-182,-27,165,-337,-87,-72,-337,165,-286,-285,165,165,-283,-287,-288,165,-284,165,165,165,-336,-183,165,165,-28,-337,165,-28,-337,-337,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,-337,-76,-79,-82,-75,165,-77,165,165,-81,-215,-214,-80,-216,165,-78,165,165,-69,-284,165,165,-284,165,165,-244,-247,-245,-241,-242,-246,-248,165,-250,-251,-243,-249,-12,165,165,-11,165,165,165,165,-234,-233,165,-231,165,165,-217,165,-230,165,-84,-218,165,165,165,-337,-337,-198,165,165,165,-337,-284,-229,-232,165,-221,165,-83,-219,-68,165,-28,-337,165,-11,165,165,-220,165,165,165,-284,165,165,165,-337,165,-225,-224,-222,-84,165,165,165,-226,-223,165,-228,-227,]),'CONST':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[3,3,-113,-128,3,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,3,-120,-115,-65,-102,3,-131,-108,-238,-111,3,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,3,3,-119,3,-114,-130,3,-118,-71,-103,3,-131,-96,-98,3,-131,-95,-101,-97,3,-53,3,3,-88,3,3,-147,-335,-146,3,-167,-166,3,-182,-100,-126,3,3,-87,-61,-72,3,3,-144,-142,3,3,3,-73,3,-89,3,3,3,-149,-159,-160,-156,-336,3,-183,-30,3,3,3,3,3,-74,3,3,3,3,-174,-175,3,-143,-140,3,-141,-145,3,-76,-79,-82,-75,-77,3,-81,-215,-214,-80,-216,-78,-127,3,-153,3,-151,-148,-157,-168,-69,-36,-35,3,3,3,-234,-233,3,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,3,3,-229,-232,-221,-83,-219,-68,-33,-32,3,3,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LSHIFT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,244,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,244,-262,-260,244,244,244,-259,244,244,-257,-256,244,244,244,244,244,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'LOR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,243,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,167,-335,-28,-182,-27,167,-337,-87,-72,-337,167,-286,-285,167,167,-283,-287,-288,167,-284,167,167,167,-336,-183,167,167,-28,-337,167,-28,-337,-337,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,-337,-76,-79,-82,-75,167,-77,167,167,-81,-215,-214,-80,-216,167,-78,167,167,-69,-284,167,167,-284,167,167,-244,-247,-245,-241,-242,-246,-248,167,-250,-251,-243,-249,-12,167,167,-11,167,167,167,167,-234,-233,167,-231,167,167,-217,167,-230,167,-84,-218,167,167,167,-337,-337,-198,167,167,167,-337,-284,-229,-232,167,-221,167,-83,-219,-68,167,-28,-337,167,-11,167,167,-220,167,167,167,-284,167,167,167,-337,167,-225,-224,-222,-84,167,167,167,-226,-223,167,-228,-227,]),'U16STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,166,-335,-28,-182,-27,166,-337,-87,-72,-337,166,-286,-285,-330,166,-327,166,-283,-287,238,-328,-288,-329,166,-284,166,166,166,-336,-183,166,166,-28,-337,166,-28,-337,-337,166,166,166,-334,-332,-331,-333,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,-337,-76,-79,-82,-75,166,-77,166,166,-81,-215,-214,-80,-216,166,-78,166,166,-69,-284,166,166,-284,166,166,-244,-247,-245,-241,-242,-246,-248,166,-250,-251,-243,-249,-12,166,166,-11,166,166,166,166,-234,-233,166,-231,166,166,-217,166,-230,166,-84,-218,166,166,166,-337,-337,-198,166,166,166,-337,-284,-229,-232,166,-221,166,-83,-219,-68,166,-28,-337,166,-11,166,166,-220,166,166,166,-284,166,166,166,-337,166,-225,-224,-222,-84,166,166,166,-226,-223,166,-228,-227,]),'RBRACE':([61,97,99,119,124,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,181,184,185,186,188,189,190,191,195,196,197,224,225,227,228,230,232,234,235,236,237,238,261,263,268,273,284,285,286,289,291,298,300,301,302,303,305,306,308,309,313,315,318,325,326,327,332,370,374,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,424,425,428,432,435,437,438,439,446,450,451,468,469,472,473,476,478,480,482,483,487,496,497,500,505,506,510,523,524,528,536,546,547,550,551,554,555,557,558,565,575,576,578,579,],[-71,-335,191,-87,-72,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-337,191,191,191,-159,-160,-156,-336,-171,191,-176,-274,-239,-337,-193,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-76,-79,-82,-75,-77,-6,-81,-215,-214,-80,-216,-5,-78,191,191,191,-157,191,191,-172,-69,191,-22,-21,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-234,-233,-231,-217,-230,-81,-84,-218,-158,-173,-177,-240,-194,191,-196,-237,-281,-282,-290,-291,-275,-229,-232,-221,-83,-219,-68,-195,-254,191,-220,-197,-306,191,-299,-225,-224,-222,-84,-300,-226,-223,-228,-227,]),'_BOOL':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[34,-337,-113,-128,34,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,34,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,34,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,34,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,34,-131,-95,-101,-97,34,-53,-126,34,-88,34,34,-93,34,-147,-335,-146,34,-167,-166,-182,-100,-126,34,-87,-90,-94,-92,-61,-72,34,-144,-142,34,34,34,-73,34,-89,34,34,34,-149,-159,-160,-156,-336,34,-183,-30,34,34,-74,34,34,34,34,-174,-175,34,-143,-140,34,-141,-145,-76,-79,-82,-75,-77,34,-81,-215,-214,-80,-216,-78,-127,34,-153,34,-151,-148,-157,-168,-69,-36,-35,34,34,34,-234,-233,34,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,34,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,246,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,246,-262,-260,-264,246,-263,-259,-266,246,-257,-256,-265,246,246,246,246,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SEMI':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,70,71,73,74,75,76,77,78,79,80,81,82,83,84,86,87,89,91,94,96,97,98,99,100,101,108,109,110,111,113,114,115,119,120,121,122,123,124,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,179,180,181,184,185,186,187,188,189,190,191,192,200,216,217,223,224,225,226,228,230,232,234,235,236,237,238,240,241,261,263,268,269,272,273,275,279,280,284,285,286,288,289,290,291,293,294,298,300,301,302,303,304,305,306,307,308,310,312,313,314,315,316,317,318,320,321,322,323,324,328,330,331,332,340,341,355,356,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,423,424,425,426,427,428,429,432,433,435,437,438,439,440,442,443,444,446,448,449,453,454,464,465,468,469,476,478,480,482,483,486,487,496,497,498,499,500,502,505,506,508,509,510,511,512,518,519,523,524,533,534,535,536,537,539,547,551,552,554,555,557,558,565,568,569,574,575,576,577,578,579,],[19,-337,-113,-128,-337,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,-337,-29,-121,-116,-62,-112,-70,-52,-123,-117,119,-337,-337,-119,-337,-114,-130,19,-118,-71,-103,-337,-9,-131,-91,-10,-96,-20,-98,-54,-179,-178,-131,-37,-134,-85,-95,-101,-97,-19,-132,-53,-126,-337,-337,-93,-147,-335,-146,188,-167,-166,-136,-100,-138,-126,-16,-86,-15,-87,-90,-94,-92,-61,-72,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-178,-73,-337,188,188,188,-149,-159,-160,-156,-336,-337,-30,-38,-133,-74,-274,-239,-135,-193,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,-235,-278,-140,-141,-145,-76,-79,-82,424,-75,425,-77,428,-14,-337,-81,-215,-214,-80,435,-216,-13,-337,-78,-312,-127,188,-153,188,-151,-148,-157,-26,-25,446,-161,-163,-168,-139,-137,-69,-36,-35,-44,-43,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,496,-234,-233,497,-337,-231,-337,-217,-13,-230,-81,-84,-218,-337,-152,-150,-165,-158,-170,-169,-31,-34,-39,-42,-240,-194,-237,-281,-282,-290,-291,-236,-275,-229,-232,533,-337,-221,-337,-83,-219,-162,-164,-68,-33,-32,-41,-40,-195,-254,-337,553,-337,-220,-337,-337,-306,-299,566,-225,-224,-222,-84,-300,575,-337,-337,-226,-223,-337,-228,-227,]),'_THREAD_LOCAL':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[11,11,-113,-128,11,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,11,-120,-115,-65,-102,11,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,11,11,-119,11,-114,-130,11,-118,-71,-103,11,-131,-96,-98,11,-131,-95,-101,-97,-53,11,11,-88,11,-147,-335,-146,-167,-166,-100,-126,11,-87,-61,-72,11,-73,11,-89,-149,-336,-30,11,-74,-174,-175,11,-76,-79,-82,-75,-77,11,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,11,11,11,-234,-233,11,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,11,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,248,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,248,-262,-260,-264,248,-263,-259,-266,248,-257,-256,-265,248,248,248,248,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'COMMA':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,70,71,73,74,75,76,77,78,80,81,82,84,86,96,98,100,101,103,104,105,106,108,109,110,111,113,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,177,178,179,187,191,195,196,197,198,199,200,203,210,211,212,213,215,216,217,224,225,226,228,230,232,234,235,236,237,238,240,241,261,263,268,269,270,272,273,274,275,276,277,279,280,281,283,294,310,312,314,316,317,320,323,324,325,326,327,328,330,331,340,341,343,344,345,346,347,348,355,356,374,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,414,426,442,443,444,448,449,450,451,453,454,458,461,463,464,465,468,469,473,476,478,480,482,483,486,487,489,490,492,501,503,507,508,509,511,512,518,519,523,524,525,528,529,530,531,532,544,545,546,547,551,556,559,560,564,565,570,571,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-103,-96,126,-98,-54,-179,-178,-131,-37,-134,-95,-101,-97,-132,-53,-147,-146,-167,-166,-28,-180,-182,-27,-136,-100,-138,-126,202,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-337,-252,-178,-149,-336,-171,327,-176,-183,-181,-30,333,-186,-337,349,350,-191,-38,-133,-274,-239,-135,-193,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,412,-235,-278,-203,-140,-204,-1,-141,-145,-2,-206,412,-312,-127,-153,-151,-148,445,-161,-163,327,327,-172,-168,-139,-137,-36,-35,-190,-204,-56,-188,-45,-189,-44,-43,472,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,412,-270,-258,-296,-295,-294,-293,412,-292,-310,484,485,-305,-205,412,-152,-150,-165,-170,-169,-173,-177,-31,-34,-57,-192,-187,-39,-42,-240,-194,-196,-237,-281,-282,-290,-291,-236,-275,-213,-207,-211,412,412,412,-162,-164,-33,-32,-41,-40,-195,-254,-311,550,-209,-208,-210,-212,-51,-50,-197,-306,-299,412,-46,-49,412,-300,-48,-47,]),'U16CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,169,-335,-28,-182,-27,169,-337,-87,-72,-337,169,-286,-285,169,169,-283,-287,-288,169,-284,169,169,169,-336,-183,169,169,-28,-337,169,-28,-337,-337,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,-337,-76,-79,-82,-75,169,-77,169,169,-81,-215,-214,-80,-216,169,-78,169,169,-69,-284,169,169,-284,169,169,-244,-247,-245,-241,-242,-246,-248,169,-250,-251,-243,-249,-12,169,169,-11,169,169,169,169,-234,-233,169,-231,169,169,-217,169,-230,169,-84,-218,169,169,169,-337,-337,-198,169,169,169,-337,-284,-229,-232,169,-221,169,-83,-219,-68,169,-28,-337,169,-11,169,169,-220,169,169,169,-284,169,169,169,-337,169,-225,-224,-222,-84,169,169,169,-226,-223,169,-228,-227,]),'OFFSETOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,170,-335,-28,-182,-27,170,-337,-87,-72,-337,170,-286,-285,170,170,-283,-287,-288,170,-284,170,170,170,-336,-183,170,170,-28,-337,170,-28,-337,-337,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,-337,-76,-79,-82,-75,170,-77,170,170,-81,-215,-214,-80,-216,170,-78,170,170,-69,-284,170,170,-284,170,170,-244,-247,-245,-241,-242,-246,-248,170,-250,-251,-243,-249,-12,170,170,-11,170,170,170,170,-234,-233,170,-231,170,170,-217,170,-230,170,-84,-218,170,170,170,-337,-337,-198,170,170,170,-337,-284,-229,-232,170,-221,170,-83,-219,-68,170,-28,-337,170,-11,170,170,-220,170,170,170,-284,170,170,170,-337,170,-225,-224,-222,-84,170,170,170,-226,-223,170,-228,-227,]),'_ATOMIC':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,103,105,109,111,117,118,119,120,121,122,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,214,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[29,65,-113,-128,76,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,65,-120,-115,-65,-102,65,-131,-108,-238,-111,76,-122,-63,-129,112,-29,-121,-116,-62,-112,-70,-52,-123,-117,65,65,-119,65,-114,-130,29,-118,-71,-103,65,-9,-131,-91,-10,-96,-98,65,-131,-95,-101,-97,29,-53,65,76,-88,112,65,-93,29,-147,-335,-146,29,-167,-166,76,-182,-100,-126,76,29,-87,-90,-94,-92,-61,-72,76,29,-144,-142,65,29,76,-73,65,-89,29,29,29,-149,-159,-160,-156,-336,76,-183,-30,76,76,76,112,76,76,-74,29,29,29,29,-174,-175,29,-143,-140,29,-141,-145,76,-76,-79,-82,-75,-77,65,-81,-215,-214,-80,-216,-78,-127,29,-153,29,-151,-148,-157,-168,-69,-36,-35,29,29,29,-234,-233,65,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,76,29,-229,-232,-221,-83,-219,-68,-33,-32,76,76,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'TYPEDEF':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[7,7,-113,-128,7,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,7,-120,-115,-65,-102,7,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,7,7,-119,7,-114,-130,7,-118,-71,-103,7,-131,-96,-98,7,-131,-95,-101,-97,-53,7,7,-88,7,-147,-335,-146,-167,-166,-100,-126,7,-87,-61,-72,7,-73,7,-89,-149,-336,-30,7,-74,-174,-175,7,-76,-79,-82,-75,-77,7,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,7,7,7,-234,-233,7,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,7,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'XOR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,251,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,251,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,251,-267,-269,251,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'AUTO':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[26,26,-113,-128,26,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,26,-120,-115,-65,-102,26,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,26,26,-119,26,-114,-130,26,-118,-71,-103,26,-131,-96,-98,26,-131,-95,-101,-97,-53,26,26,-88,26,-147,-335,-146,-167,-166,-100,-126,26,-87,-61,-72,26,-73,26,-89,-149,-336,-30,26,-74,-174,-175,26,-76,-79,-82,-75,-77,26,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,26,26,26,-234,-233,26,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,26,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'DIVEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,357,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'TIMES':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,22,23,25,26,27,29,30,33,34,35,36,37,38,39,40,43,44,45,46,47,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,76,80,81,82,85,87,89,91,94,96,97,98,100,101,103,104,105,106,109,111,116,117,119,120,121,122,123,124,126,128,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,177,180,181,187,191,192,198,201,202,204,205,206,211,218,219,220,223,224,227,229,230,231,232,233,234,235,236,237,238,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,269,273,275,278,279,280,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,312,314,316,317,319,328,329,332,336,338,339,342,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[35,-337,-113,-128,35,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,-131,-108,-238,-111,-337,-122,35,-63,-129,35,-121,-116,-62,-112,-70,-123,-117,-337,-337,-119,-337,-114,-130,35,-118,-71,-103,-337,-9,-131,-91,-10,-96,35,-98,-131,-95,-101,-97,173,-126,35,35,-93,-147,-335,-146,-167,-166,-28,35,-182,-27,-100,-126,173,-337,-87,-90,-94,-92,-61,-72,35,-337,173,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-144,-301,-274,-314,173,-142,-327,173,-283,-287,-325,-304,-322,-302,-255,-315,-289,253,-328,-316,-288,-329,-320,-276,-323,173,-284,173,173,-312,35,-73,173,-149,-336,35,-183,173,35,336,-28,-337,35,352,-28,-337,-74,-274,-337,173,-326,173,-280,173,-277,-334,-332,-331,-333,-174,-175,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,-298,-297,173,173,-279,-143,-278,-140,35,-141,-145,420,-76,-79,-82,-75,173,-77,173,173,-81,-215,-214,-80,-216,173,-78,-312,-127,-153,-151,-148,173,-168,173,-69,-284,173,173,35,-284,173,173,-244,-247,-245,-241,-242,-246,-248,173,-250,-251,-243,-249,-12,173,173,-11,253,253,253,253,253,253,253,253,253,253,-257,-256,253,253,253,253,253,-258,-296,-295,-294,-293,-292,-305,173,173,173,494,-234,-233,173,-231,173,173,-217,173,-230,173,-84,-218,173,173,-152,-150,35,173,-170,-169,-337,-337,-198,173,-281,-282,173,-290,-291,173,-275,-337,-284,-229,-232,173,-221,173,-83,-219,-68,541,-28,-337,173,-11,173,173,-220,173,173,173,-284,173,173,-306,173,-337,-299,173,-225,-224,-222,-84,-300,173,173,173,-226,-223,173,-228,-227,]),'LPAREN':([0,1,2,3,4,5,6,7,8,10,11,12,13,14,15,16,17,18,19,20,22,23,25,26,27,29,30,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,73,76,77,80,81,82,85,86,87,89,91,94,96,97,98,100,101,103,104,105,106,109,111,112,116,117,119,120,121,122,123,124,126,127,128,131,132,133,134,135,136,137,138,139,140,141,142,143,145,146,147,148,149,150,151,152,153,154,155,156,160,161,163,164,165,166,167,168,169,170,171,173,174,175,176,177,180,181,187,191,192,198,199,200,201,202,204,205,206,211,216,218,219,220,223,227,229,230,231,233,235,236,237,238,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,269,275,276,278,279,280,282,283,284,285,286,289,290,291,292,296,297,298,299,300,301,302,303,305,307,308,310,311,312,314,316,317,319,328,329,332,336,338,339,340,341,342,344,345,347,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,414,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,453,454,457,458,459,464,465,472,475,477,481,482,483,484,488,489,490,492,494,496,497,499,500,502,504,505,506,510,511,512,513,514,515,518,519,521,522,529,530,531,532,533,535,536,537,538,539,541,542,543,544,545,547,549,550,551,553,554,555,557,558,559,560,565,566,569,570,571,574,575,576,577,578,579,],[37,-337,-113,-128,69,-124,-110,-106,85,-104,-107,-125,-105,-64,37,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,95,-108,-238,-111,-337,-122,37,-63,-129,37,116,-29,-121,-116,-62,-112,-70,118,-123,-117,-337,-337,-119,-337,-114,-130,37,-118,-71,-103,-337,-9,95,-91,-10,-96,69,-98,69,129,-131,-37,-95,-101,-97,174,118,-126,69,37,-93,-147,-335,-146,-167,-166,-28,-180,-182,-27,-100,-126,95,174,-337,-87,-90,-94,-92,-61,-72,69,129,-337,229,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-144,-301,-314,231,-142,-327,233,-283,-287,-325,-304,-322,239,-302,-315,-289,-328,-316,-288,-329,-320,266,-323,267,174,-284,229,233,-312,278,-73,229,-149,-336,69,-183,-181,-30,229,69,229,-28,-337,342,-38,229,-28,-337,-74,-337,229,-326,229,229,-334,-332,-331,-333,-174,-175,174,174,174,174,174,174,174,174,174,174,174,174,174,174,174,174,229,174,174,-298,-297,229,229,-143,-140,278,278,-141,-145,-337,422,-76,-79,-82,-75,229,-77,427,430,174,229,434,-81,-215,-214,-80,-216,229,-78,-312,441,-127,-153,-151,-148,174,-168,174,-69,-284,229,229,-36,-35,342,342,460,-45,-284,229,229,-44,-43,-244,-247,-245,-241,-242,-246,-248,229,-250,-251,-243,-249,-12,174,229,-11,-296,-295,-294,-293,-292,-305,229,174,422,229,229,-234,-233,229,-231,229,229,-217,229,-230,229,-84,-218,229,229,-152,-150,69,174,-170,-169,-31,-34,342,460,-337,-39,-42,-337,-198,174,174,-290,-291,229,-337,-213,-207,-211,-284,-229,-232,229,-221,229,538,-83,-219,-68,-33,-32,229,-28,-337,-41,-40,229,-11,-209,-208,-210,-212,229,229,-220,229,229,229,-284,229,229,-51,-50,-306,229,-337,-299,229,-225,-224,-222,-84,-46,-49,-300,229,229,-48,-47,229,-226,-223,229,-228,-227,]),'MINUSMINUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,145,146,148,149,150,151,152,153,154,156,160,161,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,482,483,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,175,-335,-28,-182,-27,175,-337,-87,-72,-337,175,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-314,175,-327,175,-283,-287,-325,-304,-322,-302,-315,-289,-328,-316,-288,-329,-320,261,-323,175,-284,175,175,-312,175,-336,-183,175,175,-28,-337,175,-28,-337,-337,175,-326,175,175,-334,-332,-331,-333,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,-298,-297,175,175,-337,-76,-79,-82,-75,175,-77,175,175,-81,-215,-214,-80,-216,175,-78,-312,175,175,-69,-284,175,175,-284,175,175,-244,-247,-245,-241,-242,-246,-248,175,-250,-251,-243,-249,-12,175,175,-11,-296,-295,-294,-293,-292,-305,175,175,175,175,-234,-233,175,-231,175,175,-217,175,-230,175,-84,-218,175,175,175,-337,-337,-198,175,175,-290,-291,175,-337,-284,-229,-232,175,-221,175,-83,-219,-68,175,-28,-337,175,-11,175,175,-220,175,175,175,-284,175,175,-306,175,-337,-299,175,-225,-224,-222,-84,-300,175,175,175,-226,-223,175,-228,-227,]),'ID':([0,1,2,3,4,5,6,7,10,11,12,13,14,15,16,17,18,19,20,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,43,44,45,46,47,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,76,80,81,82,85,87,89,91,94,96,97,98,100,101,102,103,104,105,106,109,111,116,117,118,119,120,121,122,123,124,126,128,129,131,135,137,142,146,147,149,150,151,165,171,173,174,175,180,181,187,191,192,193,194,198,199,201,202,204,205,206,211,218,219,220,223,227,229,231,233,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,262,264,265,266,269,275,279,280,282,284,285,286,287,289,290,291,297,298,300,301,302,303,305,307,308,312,314,316,317,319,327,328,329,332,336,338,339,342,344,349,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,372,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,457,459,460,472,475,477,481,484,485,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,548,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[42,-337,-113,-128,42,-124,-110,-106,-104,-107,-125,-105,-64,42,-60,-67,-99,-66,-109,-120,-115,-154,-65,-102,-126,-155,-131,-108,98,101,-238,-111,-337,-122,42,-63,-129,42,-121,-116,-62,-112,-70,-123,-117,-337,-337,-119,-337,-114,-130,42,-118,-71,-103,-337,-9,-131,-91,-10,-96,42,-98,42,-131,-95,-101,-97,176,-126,42,42,-93,-147,-335,-146,-167,-166,197,-28,-180,-182,-27,-100,-126,176,-337,176,-87,-90,-94,-92,-61,-72,42,-337,176,176,-286,-285,-144,176,-142,176,-283,-287,-288,176,-284,176,176,-73,310,-149,-336,42,197,197,-183,-181,176,42,176,-28,-337,42,176,-28,-337,-74,-337,176,176,176,-174,-175,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,403,405,176,176,-143,-140,-141,-145,-337,-76,-79,-82,423,-75,176,-77,176,310,-81,-215,-214,-80,-216,310,-78,-127,-153,-151,-148,176,197,-168,176,-69,-284,176,176,42,42,176,-284,176,176,-244,-247,-245,-241,-242,-246,-248,176,-250,-251,-243,-249,-12,176,176,176,-11,176,176,176,176,-234,-233,176,-231,310,176,-217,176,-230,310,-84,-218,310,176,-152,-150,42,176,-170,-169,42,-337,176,-337,-198,176,176,176,176,-337,-284,-229,-232,176,-221,310,-83,-219,-68,176,-28,-337,176,-11,176,310,-220,310,176,310,-284,176,176,176,176,-337,176,-225,-224,-222,-84,176,310,310,-226,-223,310,-228,-227,]),'IF':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,311,-336,-76,-79,-82,-75,-77,311,-81,-215,-214,-80,-216,311,-78,-69,-234,-233,-231,311,-217,-230,311,-84,-218,311,-229,-232,-221,311,-83,-219,-68,311,-220,311,311,-225,-224,-222,-84,311,311,-226,-223,311,-228,-227,]),'STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,136,137,146,149,150,151,152,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,333,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,452,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,152,-335,-28,-182,-27,152,-337,-87,-72,-337,152,-286,230,-285,152,152,-283,-287,-325,-288,152,-284,152,152,152,-336,-183,152,152,-28,-337,152,-28,-337,-337,152,-326,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,-337,-76,-79,-82,-75,152,-77,152,152,-81,-215,-214,-80,-216,152,-78,152,152,-69,152,-284,152,152,-284,152,152,-244,-247,-245,-241,-242,-246,-248,152,-250,-251,-243,-249,-12,152,152,-11,152,152,152,152,-234,-233,152,-231,152,152,-217,152,-230,152,-84,-218,152,152,152,230,-337,-337,-198,152,152,152,-337,-284,-229,-232,152,-221,152,-83,-219,-68,152,-28,-337,152,-11,152,152,-220,152,152,152,-284,152,152,152,-337,152,-225,-224,-222,-84,152,152,152,-226,-223,152,-228,-227,]),'FLOAT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[44,-337,-113,-128,44,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,44,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,44,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,44,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,44,-131,-95,-101,-97,44,-53,-126,44,-88,44,44,-93,44,-147,-335,-146,44,-167,-166,-182,-100,-126,44,-87,-90,-94,-92,-61,-72,44,-144,-142,44,44,44,-73,44,-89,44,44,44,-149,-159,-160,-156,-336,44,-183,-30,44,44,-74,44,44,44,44,-174,-175,44,-143,-140,44,-141,-145,-76,-79,-82,-75,-77,44,-81,-215,-214,-80,-216,-78,-127,44,-153,44,-151,-148,-157,-168,-69,-36,-35,44,44,44,-234,-233,44,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,44,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'XOREQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,361,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'LSHIFTEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,363,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'RBRACKET':([3,39,58,76,103,105,106,117,128,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,191,198,204,205,218,219,224,225,230,232,234,235,236,237,238,261,263,268,272,273,282,334,335,336,337,351,352,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,406,407,411,419,420,421,455,456,459,466,467,468,471,476,478,480,482,483,486,487,491,493,494,513,514,524,540,541,547,551,561,562,564,565,],[-128,-129,-130,-131,-28,-182,-27,-337,-337,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-336,-183,-337,-28,-337,-28,-274,-239,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-235,-278,-337,453,-4,454,-3,464,465,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,482,-292,-305,-337,492,-337,511,512,-337,518,519,-240,520,-237,-281,-282,-290,-291,-236,-275,529,530,531,-337,-28,-254,559,560,-306,-299,570,571,572,-300,]),} + +_lr_action = {} +for _k, _v in _lr_action_items.items(): + for _x,_y in zip(_v[0],_v[1]): + if not _x in _lr_action: _lr_action[_x] = {} + _lr_action[_x][_k] = _y +del _lr_action_items + +_lr_goto_items = {'expression_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[284,284,284,284,284,284,284,284,284,284,284,284,284,]),'struct_or_union_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,]),'init_declarator_list':([4,89,],[70,70,]),'init_declarator_list_opt':([4,89,],[79,79,]),'iteration_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[285,285,285,285,285,285,285,285,285,285,285,285,285,]),'static_assert':([0,59,181,298,307,429,437,440,502,535,537,539,569,574,577,],[17,17,286,286,286,286,286,286,286,286,286,286,286,286,286,]),'unified_string_literal':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,333,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,452,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,]),'assignment_expression_opt':([204,218,419,421,513,],[334,351,491,493,540,]),'brace_open':([31,32,92,96,98,100,101,130,131,181,201,229,298,307,375,413,429,437,440,477,478,479,502,521,535,537,539,569,574,577,],[99,102,181,184,185,193,194,181,227,181,227,181,181,181,227,488,181,181,181,488,488,488,181,227,181,181,181,181,181,181,]),'enumerator':([102,193,194,327,],[195,195,195,450,]),'typeid_noparen_declarator':([211,],[348,]),'type_qualifier_list_opt':([35,117,128,206,220,282,459,515,],[104,204,218,339,354,419,513,543,]),'declaration_specifiers_no_type_opt':([1,27,52,53,55,63,87,],[66,94,120,121,122,94,94,]),'expression_opt':([181,298,307,427,429,437,440,499,502,533,535,537,539,553,566,569,574,577,],[288,288,288,498,288,288,288,534,288,552,288,288,288,567,573,288,288,288,]),'designation':([227,472,488,550,],[369,369,369,369,]),'parameter_list':([118,129,278,342,422,460,],[213,213,213,213,213,213,]),'alignment_specifier':([0,1,4,21,27,52,53,55,59,63,75,85,87,89,93,95,99,118,129,174,177,181,184,185,186,192,211,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[53,53,81,53,53,53,53,53,53,53,53,142,53,81,53,142,142,53,53,142,280,53,142,142,142,280,81,142,142,142,142,142,53,53,142,142,53,53,53,53,53,]),'labeled_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[289,289,289,289,289,289,289,289,289,289,289,289,289,]),'abstract_declarator':([177,211,278,342,],[281,281,418,418,]),'translation_unit':([0,],[59,]),'init_declarator':([4,89,126,202,],[84,84,217,331,]),'direct_abstract_declarator':([177,211,276,278,342,344,457,],[283,283,414,283,283,414,414,]),'designator_list':([227,472,488,550,],[376,376,376,376,]),'identifier':([85,116,118,129,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,349,353,354,364,372,373,375,412,413,419,421,427,429,430,434,437,440,441,447,460,477,481,484,485,499,502,513,521,533,535,537,538,539,542,543,548,549,553,566,569,574,577,],[143,143,215,215,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,461,143,143,143,470,143,143,143,143,143,143,143,143,143,143,143,143,143,143,215,143,143,143,527,143,143,143,143,143,143,143,143,143,143,143,563,143,143,143,143,143,143,]),'offsetof_member_designator':([485,],[526,]),'unary_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[144,144,224,232,234,144,224,273,224,224,224,224,224,224,224,144,144,144,144,144,144,144,144,144,144,144,144,144,144,144,144,224,144,144,224,224,224,144,224,224,144,144,224,224,224,224,224,144,224,224,144,224,224,224,224,224,224,224,224,224,144,144,144,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,]),'abstract_declarator_opt':([177,211,],[274,343,]),'initializer':([131,201,375,521,],[226,330,473,546,]),'direct_id_declarator':([0,4,15,37,40,59,69,72,89,91,126,192,202,211,342,344,445,457,],[48,48,86,48,48,48,48,86,48,48,48,48,48,48,48,86,48,86,]),'struct_declaration_list':([99,184,185,],[186,313,315,]),'pp_directive':([0,59,],[14,14,]),'declaration_list':([21,75,],[93,93,]),'id_init_declarator':([40,91,],[108,108,]),'type_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[18,18,109,18,18,147,109,18,147,147,18,18,269,147,18,147,147,147,109,147,147,147,147,147,18,18,147,147,18,18,18,18,18,]),'compound_statement':([92,130,181,229,298,307,429,437,440,502,535,537,539,569,574,577,],[180,223,291,378,291,291,291,291,291,291,291,291,291,291,291,291,]),'pointer':([0,4,37,40,59,69,89,91,104,126,177,192,202,211,278,342,445,],[15,72,15,15,15,72,72,15,199,72,276,72,72,344,276,457,72,]),'typeid_declarator':([4,69,89,126,192,202,445,],[74,125,74,74,74,74,74,]),'id_init_declarator_list':([40,91,],[113,113,]),'declarator':([4,89,126,192,202,445,],[78,78,78,324,78,324,]),'argument_expression_list':([266,],[409,]),'struct_declarator_list_opt':([192,],[322,]),'block_item_list':([181,],[298,]),'parameter_type_list_opt':([278,342,422,],[417,417,495,]),'struct_declarator':([192,445,],[323,508,]),'type_qualifier':([0,1,4,21,27,35,52,53,55,59,63,75,85,87,89,93,95,99,103,117,118,128,129,172,174,177,181,184,185,186,192,205,206,211,219,220,229,231,233,239,267,278,282,298,313,315,342,350,422,427,459,460,514,515,],[52,52,80,52,52,105,52,52,52,52,52,52,105,52,80,52,105,105,198,105,52,105,52,198,105,279,52,105,105,105,279,198,105,80,198,105,105,105,105,105,105,52,105,52,105,105,52,52,52,52,105,52,198,105,]),'assignment_operator':([224,],[364,]),'expression':([174,181,229,231,233,258,265,290,298,307,427,429,430,434,437,440,441,499,502,533,535,537,538,539,549,553,566,569,574,577,],[270,294,270,270,270,399,406,426,294,294,294,294,501,503,294,294,507,294,294,294,294,294,556,294,564,294,294,294,294,294,]),'storage_class_specifier':([0,1,4,21,27,52,53,55,59,63,75,87,89,93,118,129,181,211,278,298,342,350,422,427,460,],[1,1,68,1,1,1,1,1,1,1,1,1,68,1,1,1,1,68,1,1,1,1,1,1,1,]),'unified_wstring_literal':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,]),'translation_unit_or_empty':([0,],[9,]),'initializer_list_opt':([227,],[370,]),'brace_close':([99,184,185,186,196,309,313,315,325,326,370,472,528,550,],[187,314,316,317,328,439,442,443,448,449,469,523,551,565,]),'direct_typeid_declarator':([4,69,72,89,126,192,202,445,],[73,73,127,73,73,73,73,73,]),'external_declaration':([0,59,],[16,123,]),'pragmacomp_or_statement':([307,429,440,502,535,537,539,569,574,577,],[436,500,506,536,554,555,557,576,578,579,]),'type_name':([85,95,174,229,231,233,239,267,],[157,183,271,379,380,381,382,410,]),'typedef_name':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,]),'pppragma_directive':([0,59,99,181,184,185,186,298,307,313,315,429,437,440,502,535,537,539,569,574,577,],[25,25,189,300,189,189,189,300,437,189,189,437,300,437,437,437,437,437,437,437,437,]),'statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[301,301,438,438,505,438,438,438,438,558,438,438,438,]),'cast_expression':([85,116,131,171,174,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[158,158,158,268,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,487,158,158,158,158,158,158,158,158,158,158,487,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,]),'atomic_specifier':([0,1,21,27,40,52,53,55,59,63,75,85,87,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[27,63,87,63,111,63,63,63,27,63,87,111,63,111,87,111,111,27,27,111,111,87,111,111,111,111,111,111,111,111,111,27,87,111,111,27,27,27,87,27,]),'struct_declarator_list':([192,],[320,]),'empty':([0,1,4,21,27,35,40,52,53,55,63,75,87,89,91,117,118,128,129,177,181,192,204,206,211,218,220,227,278,282,298,307,342,419,421,422,427,429,437,440,459,460,472,488,499,502,513,515,533,535,537,539,550,553,566,569,574,577,],[57,64,83,88,64,106,115,64,64,64,64,88,64,83,115,106,208,106,208,277,306,321,337,106,277,337,106,377,415,106,433,433,415,337,337,415,433,433,433,433,106,208,522,522,433,433,337,106,433,433,433,433,522,433,433,433,433,433,]),'parameter_declaration':([118,129,278,342,350,422,460,],[210,210,210,210,463,210,210,]),'primary_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,]),'declaration':([0,21,59,75,93,181,298,427,],[38,90,38,90,182,302,302,499,]),'declaration_specifiers_no_type':([0,1,21,27,52,53,55,59,63,75,87,93,118,129,181,278,298,342,350,422,427,460,],[40,67,91,67,67,67,67,40,67,91,67,91,214,214,91,214,91,214,214,214,91,214,]),'jump_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[303,303,303,303,303,303,303,303,303,303,303,303,303,]),'enumerator_list':([102,193,194,],[196,325,326,]),'block_item':([181,298,],[305,432,]),'constant_expression':([85,116,297,319,329,373,447,],[159,203,431,444,451,471,509,]),'identifier_list_opt':([118,129,460,],[207,221,516,]),'constant':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,]),'type_specifier_no_typeid':([0,4,21,40,59,75,85,89,91,93,95,99,118,129,172,174,177,181,184,185,186,192,211,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[12,71,12,12,12,12,12,71,12,12,12,12,12,12,12,12,275,12,12,12,12,275,71,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,]),'struct_declaration':([99,184,185,186,313,315,],[190,190,190,318,318,318,]),'direct_typeid_noparen_declarator':([211,344,],[345,458,]),'id_declarator':([0,4,37,40,59,69,89,91,126,192,202,211,342,445,],[21,75,107,110,21,107,179,110,179,179,179,346,107,179,]),'selection_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[308,308,308,308,308,308,308,308,308,308,308,308,308,]),'postfix_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,]),'initializer_list':([227,488,],[374,528,]),'unary_operator':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,]),'struct_or_union':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,]),'block_item_list_opt':([181,],[309,]),'assignment_expression':([131,174,181,201,204,218,229,231,233,258,265,266,290,298,307,338,339,353,354,364,375,412,419,421,427,429,430,434,437,440,441,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[228,272,272,228,335,335,272,272,272,272,272,408,272,272,272,455,456,466,467,468,228,486,335,335,272,272,272,272,272,272,272,525,272,272,335,228,272,272,272,272,272,561,562,272,272,272,272,272,272,]),'designation_opt':([227,472,488,550,],[375,521,375,521,]),'parameter_type_list':([118,129,278,342,422,460,],[209,222,416,416,416,517,]),'type_qualifier_list':([35,85,95,99,117,128,174,184,185,186,206,220,229,231,233,239,267,282,313,315,459,515,],[103,172,172,172,205,219,172,172,172,172,103,103,172,172,172,172,172,103,172,172,514,103,]),'designator':([227,376,472,488,550,],[371,474,371,371,371,]),'id_init_declarator_list_opt':([40,91,],[114,114,]),'declaration_specifiers':([0,21,59,75,93,118,129,181,278,298,342,350,422,427,460,],[4,89,4,89,89,211,211,89,211,89,211,211,211,89,211,]),'identifier_list':([118,129,460,],[212,212,212,]),'declaration_list_opt':([21,75,],[92,130,]),'function_definition':([0,59,],[45,45,]),'binary_expression':([85,116,131,174,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,419,421,427,429,430,434,437,440,441,447,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[162,162,162,162,162,162,162,162,162,162,162,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,162,400,401,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,]),'enum_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,]),'decl_body':([0,21,59,75,93,181,298,427,],[51,51,51,51,51,51,51,51,]),'function_specifier':([0,1,4,21,27,52,53,55,59,63,75,87,89,93,118,129,181,211,278,298,342,350,422,427,460,],[55,55,82,55,55,55,55,55,55,55,55,55,82,55,55,55,55,82,55,55,55,55,55,55,55,]),'specifier_qualifier_list':([85,95,99,174,184,185,186,229,231,233,239,267,313,315,],[177,177,192,177,192,192,192,177,177,177,177,177,192,192,]),'conditional_expression':([85,116,131,174,181,201,204,218,229,231,233,258,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,419,421,427,429,430,434,437,440,441,447,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[178,178,225,225,225,225,225,225,225,225,225,225,225,225,225,178,225,225,178,178,225,225,225,225,225,178,225,225,225,225,225,225,225,225,225,225,225,178,524,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,]),} + +_lr_goto = {} +for _k, _v in _lr_goto_items.items(): + for _x, _y in zip(_v[0], _v[1]): + if not _x in _lr_goto: _lr_goto[_x] = {} + _lr_goto[_x][_k] = _y +del _lr_goto_items +_lr_productions = [ + ("S' -> translation_unit_or_empty","S'",1,None,None,None), + ('abstract_declarator_opt -> empty','abstract_declarator_opt',1,'p_abstract_declarator_opt','plyparser.py',43), + ('abstract_declarator_opt -> abstract_declarator','abstract_declarator_opt',1,'p_abstract_declarator_opt','plyparser.py',44), + ('assignment_expression_opt -> empty','assignment_expression_opt',1,'p_assignment_expression_opt','plyparser.py',43), + ('assignment_expression_opt -> assignment_expression','assignment_expression_opt',1,'p_assignment_expression_opt','plyparser.py',44), + ('block_item_list_opt -> empty','block_item_list_opt',1,'p_block_item_list_opt','plyparser.py',43), + ('block_item_list_opt -> block_item_list','block_item_list_opt',1,'p_block_item_list_opt','plyparser.py',44), + ('declaration_list_opt -> empty','declaration_list_opt',1,'p_declaration_list_opt','plyparser.py',43), + ('declaration_list_opt -> declaration_list','declaration_list_opt',1,'p_declaration_list_opt','plyparser.py',44), + ('declaration_specifiers_no_type_opt -> empty','declaration_specifiers_no_type_opt',1,'p_declaration_specifiers_no_type_opt','plyparser.py',43), + ('declaration_specifiers_no_type_opt -> declaration_specifiers_no_type','declaration_specifiers_no_type_opt',1,'p_declaration_specifiers_no_type_opt','plyparser.py',44), + ('designation_opt -> empty','designation_opt',1,'p_designation_opt','plyparser.py',43), + ('designation_opt -> designation','designation_opt',1,'p_designation_opt','plyparser.py',44), + ('expression_opt -> empty','expression_opt',1,'p_expression_opt','plyparser.py',43), + ('expression_opt -> expression','expression_opt',1,'p_expression_opt','plyparser.py',44), + ('id_init_declarator_list_opt -> empty','id_init_declarator_list_opt',1,'p_id_init_declarator_list_opt','plyparser.py',43), + ('id_init_declarator_list_opt -> id_init_declarator_list','id_init_declarator_list_opt',1,'p_id_init_declarator_list_opt','plyparser.py',44), + ('identifier_list_opt -> empty','identifier_list_opt',1,'p_identifier_list_opt','plyparser.py',43), + ('identifier_list_opt -> identifier_list','identifier_list_opt',1,'p_identifier_list_opt','plyparser.py',44), + ('init_declarator_list_opt -> empty','init_declarator_list_opt',1,'p_init_declarator_list_opt','plyparser.py',43), + ('init_declarator_list_opt -> init_declarator_list','init_declarator_list_opt',1,'p_init_declarator_list_opt','plyparser.py',44), + ('initializer_list_opt -> empty','initializer_list_opt',1,'p_initializer_list_opt','plyparser.py',43), + ('initializer_list_opt -> initializer_list','initializer_list_opt',1,'p_initializer_list_opt','plyparser.py',44), + ('parameter_type_list_opt -> empty','parameter_type_list_opt',1,'p_parameter_type_list_opt','plyparser.py',43), + ('parameter_type_list_opt -> parameter_type_list','parameter_type_list_opt',1,'p_parameter_type_list_opt','plyparser.py',44), + ('struct_declarator_list_opt -> empty','struct_declarator_list_opt',1,'p_struct_declarator_list_opt','plyparser.py',43), + ('struct_declarator_list_opt -> struct_declarator_list','struct_declarator_list_opt',1,'p_struct_declarator_list_opt','plyparser.py',44), + ('type_qualifier_list_opt -> empty','type_qualifier_list_opt',1,'p_type_qualifier_list_opt','plyparser.py',43), + ('type_qualifier_list_opt -> type_qualifier_list','type_qualifier_list_opt',1,'p_type_qualifier_list_opt','plyparser.py',44), + ('direct_id_declarator -> ID','direct_id_declarator',1,'p_direct_id_declarator_1','plyparser.py',126), + ('direct_id_declarator -> LPAREN id_declarator RPAREN','direct_id_declarator',3,'p_direct_id_declarator_2','plyparser.py',126), + ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_id_declarator',5,'p_direct_id_declarator_3','plyparser.py',126), + ('direct_id_declarator -> direct_id_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_id_declarator',6,'p_direct_id_declarator_4','plyparser.py',126), + ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_id_declarator',6,'p_direct_id_declarator_4','plyparser.py',127), + ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_id_declarator',5,'p_direct_id_declarator_5','plyparser.py',126), + ('direct_id_declarator -> direct_id_declarator LPAREN parameter_type_list RPAREN','direct_id_declarator',4,'p_direct_id_declarator_6','plyparser.py',126), + ('direct_id_declarator -> direct_id_declarator LPAREN identifier_list_opt RPAREN','direct_id_declarator',4,'p_direct_id_declarator_6','plyparser.py',127), + ('direct_typeid_declarator -> TYPEID','direct_typeid_declarator',1,'p_direct_typeid_declarator_1','plyparser.py',126), + ('direct_typeid_declarator -> LPAREN typeid_declarator RPAREN','direct_typeid_declarator',3,'p_direct_typeid_declarator_2','plyparser.py',126), + ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_typeid_declarator',5,'p_direct_typeid_declarator_3','plyparser.py',126), + ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_typeid_declarator',6,'p_direct_typeid_declarator_4','plyparser.py',126), + ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_typeid_declarator',6,'p_direct_typeid_declarator_4','plyparser.py',127), + ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_typeid_declarator',5,'p_direct_typeid_declarator_5','plyparser.py',126), + ('direct_typeid_declarator -> direct_typeid_declarator LPAREN parameter_type_list RPAREN','direct_typeid_declarator',4,'p_direct_typeid_declarator_6','plyparser.py',126), + ('direct_typeid_declarator -> direct_typeid_declarator LPAREN identifier_list_opt RPAREN','direct_typeid_declarator',4,'p_direct_typeid_declarator_6','plyparser.py',127), + ('direct_typeid_noparen_declarator -> TYPEID','direct_typeid_noparen_declarator',1,'p_direct_typeid_noparen_declarator_1','plyparser.py',126), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_typeid_noparen_declarator',5,'p_direct_typeid_noparen_declarator_3','plyparser.py',126), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_typeid_noparen_declarator',6,'p_direct_typeid_noparen_declarator_4','plyparser.py',126), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_typeid_noparen_declarator',6,'p_direct_typeid_noparen_declarator_4','plyparser.py',127), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_typeid_noparen_declarator',5,'p_direct_typeid_noparen_declarator_5','plyparser.py',126), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LPAREN parameter_type_list RPAREN','direct_typeid_noparen_declarator',4,'p_direct_typeid_noparen_declarator_6','plyparser.py',126), + ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LPAREN identifier_list_opt RPAREN','direct_typeid_noparen_declarator',4,'p_direct_typeid_noparen_declarator_6','plyparser.py',127), + ('id_declarator -> direct_id_declarator','id_declarator',1,'p_id_declarator_1','plyparser.py',126), + ('id_declarator -> pointer direct_id_declarator','id_declarator',2,'p_id_declarator_2','plyparser.py',126), + ('typeid_declarator -> direct_typeid_declarator','typeid_declarator',1,'p_typeid_declarator_1','plyparser.py',126), + ('typeid_declarator -> pointer direct_typeid_declarator','typeid_declarator',2,'p_typeid_declarator_2','plyparser.py',126), + ('typeid_noparen_declarator -> direct_typeid_noparen_declarator','typeid_noparen_declarator',1,'p_typeid_noparen_declarator_1','plyparser.py',126), + ('typeid_noparen_declarator -> pointer direct_typeid_noparen_declarator','typeid_noparen_declarator',2,'p_typeid_noparen_declarator_2','plyparser.py',126), + ('translation_unit_or_empty -> translation_unit','translation_unit_or_empty',1,'p_translation_unit_or_empty','c_parser.py',509), + ('translation_unit_or_empty -> empty','translation_unit_or_empty',1,'p_translation_unit_or_empty','c_parser.py',510), + ('translation_unit -> external_declaration','translation_unit',1,'p_translation_unit_1','c_parser.py',518), + ('translation_unit -> translation_unit external_declaration','translation_unit',2,'p_translation_unit_2','c_parser.py',524), + ('external_declaration -> function_definition','external_declaration',1,'p_external_declaration_1','c_parser.py',534), + ('external_declaration -> declaration','external_declaration',1,'p_external_declaration_2','c_parser.py',539), + ('external_declaration -> pp_directive','external_declaration',1,'p_external_declaration_3','c_parser.py',544), + ('external_declaration -> pppragma_directive','external_declaration',1,'p_external_declaration_3','c_parser.py',545), + ('external_declaration -> SEMI','external_declaration',1,'p_external_declaration_4','c_parser.py',550), + ('external_declaration -> static_assert','external_declaration',1,'p_external_declaration_5','c_parser.py',555), + ('static_assert -> _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN','static_assert',6,'p_static_assert_declaration','c_parser.py',560), + ('static_assert -> _STATIC_ASSERT LPAREN constant_expression RPAREN','static_assert',4,'p_static_assert_declaration','c_parser.py',561), + ('pp_directive -> PPHASH','pp_directive',1,'p_pp_directive','c_parser.py',569), + ('pppragma_directive -> PPPRAGMA','pppragma_directive',1,'p_pppragma_directive','c_parser.py',575), + ('pppragma_directive -> PPPRAGMA PPPRAGMASTR','pppragma_directive',2,'p_pppragma_directive','c_parser.py',576), + ('function_definition -> id_declarator declaration_list_opt compound_statement','function_definition',3,'p_function_definition_1','c_parser.py',586), + ('function_definition -> declaration_specifiers id_declarator declaration_list_opt compound_statement','function_definition',4,'p_function_definition_2','c_parser.py',604), + ('statement -> labeled_statement','statement',1,'p_statement','c_parser.py',619), + ('statement -> expression_statement','statement',1,'p_statement','c_parser.py',620), + ('statement -> compound_statement','statement',1,'p_statement','c_parser.py',621), + ('statement -> selection_statement','statement',1,'p_statement','c_parser.py',622), + ('statement -> iteration_statement','statement',1,'p_statement','c_parser.py',623), + ('statement -> jump_statement','statement',1,'p_statement','c_parser.py',624), + ('statement -> pppragma_directive','statement',1,'p_statement','c_parser.py',625), + ('statement -> static_assert','statement',1,'p_statement','c_parser.py',626), + ('pragmacomp_or_statement -> pppragma_directive statement','pragmacomp_or_statement',2,'p_pragmacomp_or_statement','c_parser.py',674), + ('pragmacomp_or_statement -> statement','pragmacomp_or_statement',1,'p_pragmacomp_or_statement','c_parser.py',675), + ('decl_body -> declaration_specifiers init_declarator_list_opt','decl_body',2,'p_decl_body','c_parser.py',694), + ('decl_body -> declaration_specifiers_no_type id_init_declarator_list_opt','decl_body',2,'p_decl_body','c_parser.py',695), + ('declaration -> decl_body SEMI','declaration',2,'p_declaration','c_parser.py',755), + ('declaration_list -> declaration','declaration_list',1,'p_declaration_list','c_parser.py',764), + ('declaration_list -> declaration_list declaration','declaration_list',2,'p_declaration_list','c_parser.py',765), + ('declaration_specifiers_no_type -> type_qualifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_1','c_parser.py',775), + ('declaration_specifiers_no_type -> storage_class_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_2','c_parser.py',780), + ('declaration_specifiers_no_type -> function_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_3','c_parser.py',785), + ('declaration_specifiers_no_type -> atomic_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_4','c_parser.py',792), + ('declaration_specifiers_no_type -> alignment_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_5','c_parser.py',797), + ('declaration_specifiers -> declaration_specifiers type_qualifier','declaration_specifiers',2,'p_declaration_specifiers_1','c_parser.py',802), + ('declaration_specifiers -> declaration_specifiers storage_class_specifier','declaration_specifiers',2,'p_declaration_specifiers_2','c_parser.py',807), + ('declaration_specifiers -> declaration_specifiers function_specifier','declaration_specifiers',2,'p_declaration_specifiers_3','c_parser.py',812), + ('declaration_specifiers -> declaration_specifiers type_specifier_no_typeid','declaration_specifiers',2,'p_declaration_specifiers_4','c_parser.py',817), + ('declaration_specifiers -> type_specifier','declaration_specifiers',1,'p_declaration_specifiers_5','c_parser.py',822), + ('declaration_specifiers -> declaration_specifiers_no_type type_specifier','declaration_specifiers',2,'p_declaration_specifiers_6','c_parser.py',827), + ('declaration_specifiers -> declaration_specifiers alignment_specifier','declaration_specifiers',2,'p_declaration_specifiers_7','c_parser.py',832), + ('storage_class_specifier -> AUTO','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',837), + ('storage_class_specifier -> REGISTER','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',838), + ('storage_class_specifier -> STATIC','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',839), + ('storage_class_specifier -> EXTERN','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',840), + ('storage_class_specifier -> TYPEDEF','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',841), + ('storage_class_specifier -> _THREAD_LOCAL','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',842), + ('function_specifier -> INLINE','function_specifier',1,'p_function_specifier','c_parser.py',847), + ('function_specifier -> _NORETURN','function_specifier',1,'p_function_specifier','c_parser.py',848), + ('type_specifier_no_typeid -> VOID','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',853), + ('type_specifier_no_typeid -> _BOOL','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',854), + ('type_specifier_no_typeid -> CHAR','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',855), + ('type_specifier_no_typeid -> SHORT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',856), + ('type_specifier_no_typeid -> INT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',857), + ('type_specifier_no_typeid -> LONG','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',858), + ('type_specifier_no_typeid -> FLOAT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',859), + ('type_specifier_no_typeid -> DOUBLE','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',860), + ('type_specifier_no_typeid -> _COMPLEX','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',861), + ('type_specifier_no_typeid -> SIGNED','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',862), + ('type_specifier_no_typeid -> UNSIGNED','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',863), + ('type_specifier_no_typeid -> __INT128','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',864), + ('type_specifier -> typedef_name','type_specifier',1,'p_type_specifier','c_parser.py',869), + ('type_specifier -> enum_specifier','type_specifier',1,'p_type_specifier','c_parser.py',870), + ('type_specifier -> struct_or_union_specifier','type_specifier',1,'p_type_specifier','c_parser.py',871), + ('type_specifier -> type_specifier_no_typeid','type_specifier',1,'p_type_specifier','c_parser.py',872), + ('type_specifier -> atomic_specifier','type_specifier',1,'p_type_specifier','c_parser.py',873), + ('atomic_specifier -> _ATOMIC LPAREN type_name RPAREN','atomic_specifier',4,'p_atomic_specifier','c_parser.py',879), + ('type_qualifier -> CONST','type_qualifier',1,'p_type_qualifier','c_parser.py',886), + ('type_qualifier -> RESTRICT','type_qualifier',1,'p_type_qualifier','c_parser.py',887), + ('type_qualifier -> VOLATILE','type_qualifier',1,'p_type_qualifier','c_parser.py',888), + ('type_qualifier -> _ATOMIC','type_qualifier',1,'p_type_qualifier','c_parser.py',889), + ('init_declarator_list -> init_declarator','init_declarator_list',1,'p_init_declarator_list','c_parser.py',894), + ('init_declarator_list -> init_declarator_list COMMA init_declarator','init_declarator_list',3,'p_init_declarator_list','c_parser.py',895), + ('init_declarator -> declarator','init_declarator',1,'p_init_declarator','c_parser.py',903), + ('init_declarator -> declarator EQUALS initializer','init_declarator',3,'p_init_declarator','c_parser.py',904), + ('id_init_declarator_list -> id_init_declarator','id_init_declarator_list',1,'p_id_init_declarator_list','c_parser.py',909), + ('id_init_declarator_list -> id_init_declarator_list COMMA init_declarator','id_init_declarator_list',3,'p_id_init_declarator_list','c_parser.py',910), + ('id_init_declarator -> id_declarator','id_init_declarator',1,'p_id_init_declarator','c_parser.py',915), + ('id_init_declarator -> id_declarator EQUALS initializer','id_init_declarator',3,'p_id_init_declarator','c_parser.py',916), + ('specifier_qualifier_list -> specifier_qualifier_list type_specifier_no_typeid','specifier_qualifier_list',2,'p_specifier_qualifier_list_1','c_parser.py',923), + ('specifier_qualifier_list -> specifier_qualifier_list type_qualifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_2','c_parser.py',928), + ('specifier_qualifier_list -> type_specifier','specifier_qualifier_list',1,'p_specifier_qualifier_list_3','c_parser.py',933), + ('specifier_qualifier_list -> type_qualifier_list type_specifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_4','c_parser.py',938), + ('specifier_qualifier_list -> alignment_specifier','specifier_qualifier_list',1,'p_specifier_qualifier_list_5','c_parser.py',943), + ('specifier_qualifier_list -> specifier_qualifier_list alignment_specifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_6','c_parser.py',948), + ('struct_or_union_specifier -> struct_or_union ID','struct_or_union_specifier',2,'p_struct_or_union_specifier_1','c_parser.py',956), + ('struct_or_union_specifier -> struct_or_union TYPEID','struct_or_union_specifier',2,'p_struct_or_union_specifier_1','c_parser.py',957), + ('struct_or_union_specifier -> struct_or_union brace_open struct_declaration_list brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_2','c_parser.py',967), + ('struct_or_union_specifier -> struct_or_union brace_open brace_close','struct_or_union_specifier',3,'p_struct_or_union_specifier_2','c_parser.py',968), + ('struct_or_union_specifier -> struct_or_union ID brace_open struct_declaration_list brace_close','struct_or_union_specifier',5,'p_struct_or_union_specifier_3','c_parser.py',985), + ('struct_or_union_specifier -> struct_or_union ID brace_open brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_3','c_parser.py',986), + ('struct_or_union_specifier -> struct_or_union TYPEID brace_open struct_declaration_list brace_close','struct_or_union_specifier',5,'p_struct_or_union_specifier_3','c_parser.py',987), + ('struct_or_union_specifier -> struct_or_union TYPEID brace_open brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_3','c_parser.py',988), + ('struct_or_union -> STRUCT','struct_or_union',1,'p_struct_or_union','c_parser.py',1004), + ('struct_or_union -> UNION','struct_or_union',1,'p_struct_or_union','c_parser.py',1005), + ('struct_declaration_list -> struct_declaration','struct_declaration_list',1,'p_struct_declaration_list','c_parser.py',1012), + ('struct_declaration_list -> struct_declaration_list struct_declaration','struct_declaration_list',2,'p_struct_declaration_list','c_parser.py',1013), + ('struct_declaration -> specifier_qualifier_list struct_declarator_list_opt SEMI','struct_declaration',3,'p_struct_declaration_1','c_parser.py',1021), + ('struct_declaration -> SEMI','struct_declaration',1,'p_struct_declaration_2','c_parser.py',1059), + ('struct_declaration -> pppragma_directive','struct_declaration',1,'p_struct_declaration_3','c_parser.py',1064), + ('struct_declarator_list -> struct_declarator','struct_declarator_list',1,'p_struct_declarator_list','c_parser.py',1069), + ('struct_declarator_list -> struct_declarator_list COMMA struct_declarator','struct_declarator_list',3,'p_struct_declarator_list','c_parser.py',1070), + ('struct_declarator -> declarator','struct_declarator',1,'p_struct_declarator_1','c_parser.py',1078), + ('struct_declarator -> declarator COLON constant_expression','struct_declarator',3,'p_struct_declarator_2','c_parser.py',1083), + ('struct_declarator -> COLON constant_expression','struct_declarator',2,'p_struct_declarator_2','c_parser.py',1084), + ('enum_specifier -> ENUM ID','enum_specifier',2,'p_enum_specifier_1','c_parser.py',1092), + ('enum_specifier -> ENUM TYPEID','enum_specifier',2,'p_enum_specifier_1','c_parser.py',1093), + ('enum_specifier -> ENUM brace_open enumerator_list brace_close','enum_specifier',4,'p_enum_specifier_2','c_parser.py',1098), + ('enum_specifier -> ENUM ID brace_open enumerator_list brace_close','enum_specifier',5,'p_enum_specifier_3','c_parser.py',1103), + ('enum_specifier -> ENUM TYPEID brace_open enumerator_list brace_close','enum_specifier',5,'p_enum_specifier_3','c_parser.py',1104), + ('enumerator_list -> enumerator','enumerator_list',1,'p_enumerator_list','c_parser.py',1109), + ('enumerator_list -> enumerator_list COMMA','enumerator_list',2,'p_enumerator_list','c_parser.py',1110), + ('enumerator_list -> enumerator_list COMMA enumerator','enumerator_list',3,'p_enumerator_list','c_parser.py',1111), + ('alignment_specifier -> _ALIGNAS LPAREN type_name RPAREN','alignment_specifier',4,'p_alignment_specifier','c_parser.py',1122), + ('alignment_specifier -> _ALIGNAS LPAREN constant_expression RPAREN','alignment_specifier',4,'p_alignment_specifier','c_parser.py',1123), + ('enumerator -> ID','enumerator',1,'p_enumerator','c_parser.py',1128), + ('enumerator -> ID EQUALS constant_expression','enumerator',3,'p_enumerator','c_parser.py',1129), + ('declarator -> id_declarator','declarator',1,'p_declarator','c_parser.py',1144), + ('declarator -> typeid_declarator','declarator',1,'p_declarator','c_parser.py',1145), + ('pointer -> TIMES type_qualifier_list_opt','pointer',2,'p_pointer','c_parser.py',1257), + ('pointer -> TIMES type_qualifier_list_opt pointer','pointer',3,'p_pointer','c_parser.py',1258), + ('type_qualifier_list -> type_qualifier','type_qualifier_list',1,'p_type_qualifier_list','c_parser.py',1287), + ('type_qualifier_list -> type_qualifier_list type_qualifier','type_qualifier_list',2,'p_type_qualifier_list','c_parser.py',1288), + ('parameter_type_list -> parameter_list','parameter_type_list',1,'p_parameter_type_list','c_parser.py',1293), + ('parameter_type_list -> parameter_list COMMA ELLIPSIS','parameter_type_list',3,'p_parameter_type_list','c_parser.py',1294), + ('parameter_list -> parameter_declaration','parameter_list',1,'p_parameter_list','c_parser.py',1302), + ('parameter_list -> parameter_list COMMA parameter_declaration','parameter_list',3,'p_parameter_list','c_parser.py',1303), + ('parameter_declaration -> declaration_specifiers id_declarator','parameter_declaration',2,'p_parameter_declaration_1','c_parser.py',1322), + ('parameter_declaration -> declaration_specifiers typeid_noparen_declarator','parameter_declaration',2,'p_parameter_declaration_1','c_parser.py',1323), + ('parameter_declaration -> declaration_specifiers abstract_declarator_opt','parameter_declaration',2,'p_parameter_declaration_2','c_parser.py',1334), + ('identifier_list -> identifier','identifier_list',1,'p_identifier_list','c_parser.py',1366), + ('identifier_list -> identifier_list COMMA identifier','identifier_list',3,'p_identifier_list','c_parser.py',1367), + ('initializer -> assignment_expression','initializer',1,'p_initializer_1','c_parser.py',1376), + ('initializer -> brace_open initializer_list_opt brace_close','initializer',3,'p_initializer_2','c_parser.py',1381), + ('initializer -> brace_open initializer_list COMMA brace_close','initializer',4,'p_initializer_2','c_parser.py',1382), + ('initializer_list -> designation_opt initializer','initializer_list',2,'p_initializer_list','c_parser.py',1390), + ('initializer_list -> initializer_list COMMA designation_opt initializer','initializer_list',4,'p_initializer_list','c_parser.py',1391), + ('designation -> designator_list EQUALS','designation',2,'p_designation','c_parser.py',1402), + ('designator_list -> designator','designator_list',1,'p_designator_list','c_parser.py',1410), + ('designator_list -> designator_list designator','designator_list',2,'p_designator_list','c_parser.py',1411), + ('designator -> LBRACKET constant_expression RBRACKET','designator',3,'p_designator','c_parser.py',1416), + ('designator -> PERIOD identifier','designator',2,'p_designator','c_parser.py',1417), + ('type_name -> specifier_qualifier_list abstract_declarator_opt','type_name',2,'p_type_name','c_parser.py',1422), + ('abstract_declarator -> pointer','abstract_declarator',1,'p_abstract_declarator_1','c_parser.py',1434), + ('abstract_declarator -> pointer direct_abstract_declarator','abstract_declarator',2,'p_abstract_declarator_2','c_parser.py',1442), + ('abstract_declarator -> direct_abstract_declarator','abstract_declarator',1,'p_abstract_declarator_3','c_parser.py',1447), + ('direct_abstract_declarator -> LPAREN abstract_declarator RPAREN','direct_abstract_declarator',3,'p_direct_abstract_declarator_1','c_parser.py',1457), + ('direct_abstract_declarator -> direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_2','c_parser.py',1461), + ('direct_abstract_declarator -> LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_3','c_parser.py',1472), + ('direct_abstract_declarator -> direct_abstract_declarator LBRACKET TIMES RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_4','c_parser.py',1482), + ('direct_abstract_declarator -> LBRACKET TIMES RBRACKET','direct_abstract_declarator',3,'p_direct_abstract_declarator_5','c_parser.py',1493), + ('direct_abstract_declarator -> direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN','direct_abstract_declarator',4,'p_direct_abstract_declarator_6','c_parser.py',1502), + ('direct_abstract_declarator -> LPAREN parameter_type_list_opt RPAREN','direct_abstract_declarator',3,'p_direct_abstract_declarator_7','c_parser.py',1512), + ('block_item -> declaration','block_item',1,'p_block_item','c_parser.py',1523), + ('block_item -> statement','block_item',1,'p_block_item','c_parser.py',1524), + ('block_item_list -> block_item','block_item_list',1,'p_block_item_list','c_parser.py',1531), + ('block_item_list -> block_item_list block_item','block_item_list',2,'p_block_item_list','c_parser.py',1532), + ('compound_statement -> brace_open block_item_list_opt brace_close','compound_statement',3,'p_compound_statement_1','c_parser.py',1538), + ('labeled_statement -> ID COLON pragmacomp_or_statement','labeled_statement',3,'p_labeled_statement_1','c_parser.py',1544), + ('labeled_statement -> CASE constant_expression COLON pragmacomp_or_statement','labeled_statement',4,'p_labeled_statement_2','c_parser.py',1548), + ('labeled_statement -> DEFAULT COLON pragmacomp_or_statement','labeled_statement',3,'p_labeled_statement_3','c_parser.py',1552), + ('selection_statement -> IF LPAREN expression RPAREN pragmacomp_or_statement','selection_statement',5,'p_selection_statement_1','c_parser.py',1556), + ('selection_statement -> IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement','selection_statement',7,'p_selection_statement_2','c_parser.py',1560), + ('selection_statement -> SWITCH LPAREN expression RPAREN pragmacomp_or_statement','selection_statement',5,'p_selection_statement_3','c_parser.py',1564), + ('iteration_statement -> WHILE LPAREN expression RPAREN pragmacomp_or_statement','iteration_statement',5,'p_iteration_statement_1','c_parser.py',1569), + ('iteration_statement -> DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI','iteration_statement',7,'p_iteration_statement_2','c_parser.py',1573), + ('iteration_statement -> FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement','iteration_statement',9,'p_iteration_statement_3','c_parser.py',1577), + ('iteration_statement -> FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement','iteration_statement',8,'p_iteration_statement_4','c_parser.py',1581), + ('jump_statement -> GOTO ID SEMI','jump_statement',3,'p_jump_statement_1','c_parser.py',1586), + ('jump_statement -> BREAK SEMI','jump_statement',2,'p_jump_statement_2','c_parser.py',1590), + ('jump_statement -> CONTINUE SEMI','jump_statement',2,'p_jump_statement_3','c_parser.py',1594), + ('jump_statement -> RETURN expression SEMI','jump_statement',3,'p_jump_statement_4','c_parser.py',1598), + ('jump_statement -> RETURN SEMI','jump_statement',2,'p_jump_statement_4','c_parser.py',1599), + ('expression_statement -> expression_opt SEMI','expression_statement',2,'p_expression_statement','c_parser.py',1604), + ('expression -> assignment_expression','expression',1,'p_expression','c_parser.py',1611), + ('expression -> expression COMMA assignment_expression','expression',3,'p_expression','c_parser.py',1612), + ('assignment_expression -> LPAREN compound_statement RPAREN','assignment_expression',3,'p_parenthesized_compound_expression','c_parser.py',1624), + ('typedef_name -> TYPEID','typedef_name',1,'p_typedef_name','c_parser.py',1628), + ('assignment_expression -> conditional_expression','assignment_expression',1,'p_assignment_expression','c_parser.py',1632), + ('assignment_expression -> unary_expression assignment_operator assignment_expression','assignment_expression',3,'p_assignment_expression','c_parser.py',1633), + ('assignment_operator -> EQUALS','assignment_operator',1,'p_assignment_operator','c_parser.py',1646), + ('assignment_operator -> XOREQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1647), + ('assignment_operator -> TIMESEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1648), + ('assignment_operator -> DIVEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1649), + ('assignment_operator -> MODEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1650), + ('assignment_operator -> PLUSEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1651), + ('assignment_operator -> MINUSEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1652), + ('assignment_operator -> LSHIFTEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1653), + ('assignment_operator -> RSHIFTEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1654), + ('assignment_operator -> ANDEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1655), + ('assignment_operator -> OREQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1656), + ('constant_expression -> conditional_expression','constant_expression',1,'p_constant_expression','c_parser.py',1661), + ('conditional_expression -> binary_expression','conditional_expression',1,'p_conditional_expression','c_parser.py',1665), + ('conditional_expression -> binary_expression CONDOP expression COLON conditional_expression','conditional_expression',5,'p_conditional_expression','c_parser.py',1666), + ('binary_expression -> cast_expression','binary_expression',1,'p_binary_expression','c_parser.py',1674), + ('binary_expression -> binary_expression TIMES binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1675), + ('binary_expression -> binary_expression DIVIDE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1676), + ('binary_expression -> binary_expression MOD binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1677), + ('binary_expression -> binary_expression PLUS binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1678), + ('binary_expression -> binary_expression MINUS binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1679), + ('binary_expression -> binary_expression RSHIFT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1680), + ('binary_expression -> binary_expression LSHIFT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1681), + ('binary_expression -> binary_expression LT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1682), + ('binary_expression -> binary_expression LE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1683), + ('binary_expression -> binary_expression GE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1684), + ('binary_expression -> binary_expression GT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1685), + ('binary_expression -> binary_expression EQ binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1686), + ('binary_expression -> binary_expression NE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1687), + ('binary_expression -> binary_expression AND binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1688), + ('binary_expression -> binary_expression OR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1689), + ('binary_expression -> binary_expression XOR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1690), + ('binary_expression -> binary_expression LAND binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1691), + ('binary_expression -> binary_expression LOR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1692), + ('cast_expression -> unary_expression','cast_expression',1,'p_cast_expression_1','c_parser.py',1700), + ('cast_expression -> LPAREN type_name RPAREN cast_expression','cast_expression',4,'p_cast_expression_2','c_parser.py',1704), + ('unary_expression -> postfix_expression','unary_expression',1,'p_unary_expression_1','c_parser.py',1708), + ('unary_expression -> PLUSPLUS unary_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1712), + ('unary_expression -> MINUSMINUS unary_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1713), + ('unary_expression -> unary_operator cast_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1714), + ('unary_expression -> SIZEOF unary_expression','unary_expression',2,'p_unary_expression_3','c_parser.py',1719), + ('unary_expression -> SIZEOF LPAREN type_name RPAREN','unary_expression',4,'p_unary_expression_3','c_parser.py',1720), + ('unary_expression -> _ALIGNOF LPAREN type_name RPAREN','unary_expression',4,'p_unary_expression_3','c_parser.py',1721), + ('unary_operator -> AND','unary_operator',1,'p_unary_operator','c_parser.py',1729), + ('unary_operator -> TIMES','unary_operator',1,'p_unary_operator','c_parser.py',1730), + ('unary_operator -> PLUS','unary_operator',1,'p_unary_operator','c_parser.py',1731), + ('unary_operator -> MINUS','unary_operator',1,'p_unary_operator','c_parser.py',1732), + ('unary_operator -> NOT','unary_operator',1,'p_unary_operator','c_parser.py',1733), + ('unary_operator -> LNOT','unary_operator',1,'p_unary_operator','c_parser.py',1734), + ('postfix_expression -> primary_expression','postfix_expression',1,'p_postfix_expression_1','c_parser.py',1739), + ('postfix_expression -> postfix_expression LBRACKET expression RBRACKET','postfix_expression',4,'p_postfix_expression_2','c_parser.py',1743), + ('postfix_expression -> postfix_expression LPAREN argument_expression_list RPAREN','postfix_expression',4,'p_postfix_expression_3','c_parser.py',1747), + ('postfix_expression -> postfix_expression LPAREN RPAREN','postfix_expression',3,'p_postfix_expression_3','c_parser.py',1748), + ('postfix_expression -> postfix_expression PERIOD ID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1753), + ('postfix_expression -> postfix_expression PERIOD TYPEID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1754), + ('postfix_expression -> postfix_expression ARROW ID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1755), + ('postfix_expression -> postfix_expression ARROW TYPEID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1756), + ('postfix_expression -> postfix_expression PLUSPLUS','postfix_expression',2,'p_postfix_expression_5','c_parser.py',1762), + ('postfix_expression -> postfix_expression MINUSMINUS','postfix_expression',2,'p_postfix_expression_5','c_parser.py',1763), + ('postfix_expression -> LPAREN type_name RPAREN brace_open initializer_list brace_close','postfix_expression',6,'p_postfix_expression_6','c_parser.py',1768), + ('postfix_expression -> LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close','postfix_expression',7,'p_postfix_expression_6','c_parser.py',1769), + ('primary_expression -> identifier','primary_expression',1,'p_primary_expression_1','c_parser.py',1774), + ('primary_expression -> constant','primary_expression',1,'p_primary_expression_2','c_parser.py',1778), + ('primary_expression -> unified_string_literal','primary_expression',1,'p_primary_expression_3','c_parser.py',1782), + ('primary_expression -> unified_wstring_literal','primary_expression',1,'p_primary_expression_3','c_parser.py',1783), + ('primary_expression -> LPAREN expression RPAREN','primary_expression',3,'p_primary_expression_4','c_parser.py',1788), + ('primary_expression -> OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN','primary_expression',6,'p_primary_expression_5','c_parser.py',1792), + ('offsetof_member_designator -> identifier','offsetof_member_designator',1,'p_offsetof_member_designator','c_parser.py',1800), + ('offsetof_member_designator -> offsetof_member_designator PERIOD identifier','offsetof_member_designator',3,'p_offsetof_member_designator','c_parser.py',1801), + ('offsetof_member_designator -> offsetof_member_designator LBRACKET expression RBRACKET','offsetof_member_designator',4,'p_offsetof_member_designator','c_parser.py',1802), + ('argument_expression_list -> assignment_expression','argument_expression_list',1,'p_argument_expression_list','c_parser.py',1814), + ('argument_expression_list -> argument_expression_list COMMA assignment_expression','argument_expression_list',3,'p_argument_expression_list','c_parser.py',1815), + ('identifier -> ID','identifier',1,'p_identifier','c_parser.py',1824), + ('constant -> INT_CONST_DEC','constant',1,'p_constant_1','c_parser.py',1828), + ('constant -> INT_CONST_OCT','constant',1,'p_constant_1','c_parser.py',1829), + ('constant -> INT_CONST_HEX','constant',1,'p_constant_1','c_parser.py',1830), + ('constant -> INT_CONST_BIN','constant',1,'p_constant_1','c_parser.py',1831), + ('constant -> INT_CONST_CHAR','constant',1,'p_constant_1','c_parser.py',1832), + ('constant -> FLOAT_CONST','constant',1,'p_constant_2','c_parser.py',1851), + ('constant -> HEX_FLOAT_CONST','constant',1,'p_constant_2','c_parser.py',1852), + ('constant -> CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1868), + ('constant -> WCHAR_CONST','constant',1,'p_constant_3','c_parser.py',1869), + ('constant -> U8CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1870), + ('constant -> U16CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1871), + ('constant -> U32CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1872), + ('unified_string_literal -> STRING_LITERAL','unified_string_literal',1,'p_unified_string_literal','c_parser.py',1883), + ('unified_string_literal -> unified_string_literal STRING_LITERAL','unified_string_literal',2,'p_unified_string_literal','c_parser.py',1884), + ('unified_wstring_literal -> WSTRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1894), + ('unified_wstring_literal -> U8STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1895), + ('unified_wstring_literal -> U16STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1896), + ('unified_wstring_literal -> U32STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1897), + ('unified_wstring_literal -> unified_wstring_literal WSTRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1898), + ('unified_wstring_literal -> unified_wstring_literal U8STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1899), + ('unified_wstring_literal -> unified_wstring_literal U16STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1900), + ('unified_wstring_literal -> unified_wstring_literal U32STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1901), + ('brace_open -> LBRACE','brace_open',1,'p_brace_open','c_parser.py',1911), + ('brace_close -> RBRACE','brace_close',1,'p_brace_close','c_parser.py',1917), + ('empty -> ','empty',0,'p_empty','c_parser.py',1923), +] diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/INSTALLER b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/LICENSE b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/LICENSE new file mode 100644 index 00000000..3a971190 --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2014, Saurabh Kumar (python-dotenv), 2013, Ted Tieken (django-dotenv-rw), 2013, Jacob Kaplan-Moss (django-dotenv) + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +- Neither the name of django-dotenv nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/METADATA b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/METADATA new file mode 100644 index 00000000..b9af7fe6 --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/METADATA @@ -0,0 +1,692 @@ +Metadata-Version: 2.1 +Name: python-dotenv +Version: 1.0.1 +Summary: Read key-value pairs from a .env file and set them as environment variables +Home-page: https://github.com/theskumar/python-dotenv +Author: Saurabh Kumar +Author-email: me+github@saurabh-kumar.com +License: BSD-3-Clause +Keywords: environment variables,deployments,settings,env,dotenv,configurations,python +Classifier: Development Status :: 5 - Production/Stable +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: System Administrators +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Topic :: System :: Systems Administration +Classifier: Topic :: Utilities +Classifier: Environment :: Web Environment +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE +Provides-Extra: cli +Requires-Dist: click >=5.0 ; extra == 'cli' + +# python-dotenv + +[![Build Status][build_status_badge]][build_status_link] +[![PyPI version][pypi_badge]][pypi_link] + +Python-dotenv reads key-value pairs from a `.env` file and can set them as environment +variables. It helps in the development of applications following the +[12-factor](https://12factor.net/) principles. + +- [Getting Started](#getting-started) +- [Other Use Cases](#other-use-cases) + * [Load configuration without altering the environment](#load-configuration-without-altering-the-environment) + * [Parse configuration as a stream](#parse-configuration-as-a-stream) + * [Load .env files in IPython](#load-env-files-in-ipython) +- [Command-line Interface](#command-line-interface) +- [File format](#file-format) + * [Multiline values](#multiline-values) + * [Variable expansion](#variable-expansion) +- [Related Projects](#related-projects) +- [Acknowledgements](#acknowledgements) + +## Getting Started + +```shell +pip install python-dotenv +``` + +If your application takes its configuration from environment variables, like a 12-factor +application, launching it in development is not very practical because you have to set +those environment variables yourself. + +To help you with that, you can add Python-dotenv to your application to make it load the +configuration from a `.env` file when it is present (e.g. in development) while remaining +configurable via the environment: + +```python +from dotenv import load_dotenv + +load_dotenv() # take environment variables from .env. + +# Code of your application, which uses environment variables (e.g. from `os.environ` or +# `os.getenv`) as if they came from the actual environment. +``` + +By default, `load_dotenv` doesn't override existing environment variables. + +To configure the development environment, add a `.env` in the root directory of your +project: + +``` +. +├── .env +└── foo.py +``` + +The syntax of `.env` files supported by python-dotenv is similar to that of Bash: + +```bash +# Development settings +DOMAIN=example.org +ADMIN_EMAIL=admin@${DOMAIN} +ROOT_URL=${DOMAIN}/app +``` + +If you use variables in values, ensure they are surrounded with `{` and `}`, like +`${DOMAIN}`, as bare variables such as `$DOMAIN` are not expanded. + +You will probably want to add `.env` to your `.gitignore`, especially if it contains +secrets like a password. + +See the section "File format" below for more information about what you can write in a +`.env` file. + +## Other Use Cases + +### Load configuration without altering the environment + +The function `dotenv_values` works more or less the same way as `load_dotenv`, except it +doesn't touch the environment, it just returns a `dict` with the values parsed from the +`.env` file. + +```python +from dotenv import dotenv_values + +config = dotenv_values(".env") # config = {"USER": "foo", "EMAIL": "foo@example.org"} +``` + +This notably enables advanced configuration management: + +```python +import os +from dotenv import dotenv_values + +config = { + **dotenv_values(".env.shared"), # load shared development variables + **dotenv_values(".env.secret"), # load sensitive variables + **os.environ, # override loaded values with environment variables +} +``` + +### Parse configuration as a stream + +`load_dotenv` and `dotenv_values` accept [streams][python_streams] via their `stream` +argument. It is thus possible to load the variables from sources other than the +filesystem (e.g. the network). + +```python +from io import StringIO + +from dotenv import load_dotenv + +config = StringIO("USER=foo\nEMAIL=foo@example.org") +load_dotenv(stream=config) +``` + +### Load .env files in IPython + +You can use dotenv in IPython. By default, it will use `find_dotenv` to search for a +`.env` file: + +```python +%load_ext dotenv +%dotenv +``` + +You can also specify a path: + +```python +%dotenv relative/or/absolute/path/to/.env +``` + +Optional flags: + +- `-o` to override existing variables. +- `-v` for increased verbosity. + +## Command-line Interface + +A CLI interface `dotenv` is also included, which helps you manipulate the `.env` file +without manually opening it. + +```shell +$ pip install "python-dotenv[cli]" +$ dotenv set USER foo +$ dotenv set EMAIL foo@example.org +$ dotenv list +USER=foo +EMAIL=foo@example.org +$ dotenv list --format=json +{ + "USER": "foo", + "EMAIL": "foo@example.org" +} +$ dotenv run -- python foo.py +``` + +Run `dotenv --help` for more information about the options and subcommands. + +## File format + +The format is not formally specified and still improves over time. That being said, +`.env` files should mostly look like Bash files. + +Keys can be unquoted or single-quoted. Values can be unquoted, single- or double-quoted. +Spaces before and after keys, equal signs, and values are ignored. Values can be followed +by a comment. Lines can start with the `export` directive, which does not affect their +interpretation. + +Allowed escape sequences: + +- in single-quoted values: `\\`, `\'` +- in double-quoted values: `\\`, `\'`, `\"`, `\a`, `\b`, `\f`, `\n`, `\r`, `\t`, `\v` + +### Multiline values + +It is possible for single- or double-quoted values to span multiple lines. The following +examples are equivalent: + +```bash +FOO="first line +second line" +``` + +```bash +FOO="first line\nsecond line" +``` + +### Variable without a value + +A variable can have no value: + +```bash +FOO +``` + +It results in `dotenv_values` associating that variable name with the value `None` (e.g. +`{"FOO": None}`. `load_dotenv`, on the other hand, simply ignores such variables. + +This shouldn't be confused with `FOO=`, in which case the variable is associated with the +empty string. + +### Variable expansion + +Python-dotenv can interpolate variables using POSIX variable expansion. + +With `load_dotenv(override=True)` or `dotenv_values()`, the value of a variable is the +first of the values defined in the following list: + +- Value of that variable in the `.env` file. +- Value of that variable in the environment. +- Default value, if provided. +- Empty string. + +With `load_dotenv(override=False)`, the value of a variable is the first of the values +defined in the following list: + +- Value of that variable in the environment. +- Value of that variable in the `.env` file. +- Default value, if provided. +- Empty string. + +## Related Projects + +- [Honcho](https://github.com/nickstenning/honcho) - For managing + Procfile-based applications. +- [django-dotenv](https://github.com/jpadilla/django-dotenv) +- [django-environ](https://github.com/joke2k/django-environ) +- [django-environ-2](https://github.com/sergeyklay/django-environ-2) +- [django-configuration](https://github.com/jezdez/django-configurations) +- [dump-env](https://github.com/sobolevn/dump-env) +- [environs](https://github.com/sloria/environs) +- [dynaconf](https://github.com/rochacbruno/dynaconf) +- [parse_it](https://github.com/naorlivne/parse_it) +- [python-decouple](https://github.com/HBNetwork/python-decouple) + +## Acknowledgements + +This project is currently maintained by [Saurabh Kumar](https://saurabh-kumar.com) and +[Bertrand Bonnefoy-Claudet](https://github.com/bbc2) and would not have been possible +without the support of these [awesome +people](https://github.com/theskumar/python-dotenv/graphs/contributors). + +[build_status_badge]: https://github.com/theskumar/python-dotenv/actions/workflows/test.yml/badge.svg +[build_status_link]: https://github.com/theskumar/python-dotenv/actions/workflows/test.yml +[pypi_badge]: https://badge.fury.io/py/python-dotenv.svg +[pypi_link]: https://badge.fury.io/py/python-dotenv +[python_streams]: https://docs.python.org/3/library/io.html + +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this +project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.1] - 2024-01-23 + +**Fixed** + +* Gracefully handle code which has been imported from a zipfile ([#456] by [@samwyma]) +* Allow modules using load_dotenv to be reloaded when launched in a separate thread ([#497] by [@freddyaboulton]) +* Fix file not closed after deletion, handle error in the rewrite function ([#469] by [@Qwerty-133]) + +**Misc** +* Use pathlib.Path in tests ([#466] by [@eumiro]) +* Fix year in release date in changelog.md ([#454] by [@jankislinger]) +* Use https in README links ([#474] by [@Nicals]) + +## [1.0.0] - 2023-02-24 + +**Fixed** + +* Drop support for python 3.7, add python 3.12-dev (#449 by [@theskumar]) +* Handle situations where the cwd does not exist. (#446 by [@jctanner]) + +## [0.21.1] - 2023-01-21 + +**Added** + +* Use Python 3.11 non-beta in CI (#438 by [@bbc2]) +* Modernize variables code (#434 by [@Nougat-Waffle]) +* Modernize main.py and parser.py code (#435 by [@Nougat-Waffle]) +* Improve conciseness of cli.py and __init__.py (#439 by [@Nougat-Waffle]) +* Improve error message for `get` and `list` commands when env file can't be opened (#441 by [@bbc2]) +* Updated License to align with BSD OSI template (#433 by [@lsmith77]) + + +**Fixed** + +* Fix Out-of-scope error when "dest" variable is undefined (#413 by [@theGOTOguy]) +* Fix IPython test warning about deprecated `magic` (#440 by [@bbc2]) +* Fix type hint for dotenv_path var, add StrPath alias (#432 by [@eaf]) + +## [0.21.0] - 2022-09-03 + +**Added** + +* CLI: add support for invocations via 'python -m'. (#395 by [@theskumar]) +* `load_dotenv` function now returns `False`. (#388 by [@larsks]) +* CLI: add --format= option to list command. (#407 by [@sammck]) + +**Fixed** + +* Drop Python 3.5 and 3.6 and upgrade GA (#393 by [@eggplants]) +* Use `open` instead of `io.open`. (#389 by [@rabinadk1]) +* Improve documentation for variables without a value (#390 by [@bbc2]) +* Add `parse_it` to Related Projects (#410 by [@naorlivne]) +* Update README.md (#415 by [@harveer07]) +* Improve documentation with direct use of MkDocs (#398 by [@bbc2]) + +## [0.20.0] - 2022-03-24 + +**Added** + +- Add `encoding` (`Optional[str]`) parameter to `get_key`, `set_key` and `unset_key`. + (#379 by [@bbc2]) + +**Fixed** + +- Use dict to specify the `entry_points` parameter of `setuptools.setup` (#376 by + [@mgorny]). +- Don't build universal wheels (#387 by [@bbc2]). + +## [0.19.2] - 2021-11-11 + +**Fixed** + +- In `set_key`, add missing newline character before new entry if necessary. (#361 by + [@bbc2]) + +## [0.19.1] - 2021-08-09 + +**Added** + +- Add support for Python 3.10. (#359 by [@theskumar]) + +## [0.19.0] - 2021-07-24 + +**Changed** + +- Require Python 3.5 or a later version. Python 2 and 3.4 are no longer supported. (#341 + by [@bbc2]). + +**Added** + +- The `dotenv_path` argument of `set_key` and `unset_key` now has a type of `Union[str, + os.PathLike]` instead of just `os.PathLike` (#347 by [@bbc2]). +- The `stream` argument of `load_dotenv` and `dotenv_values` can now be a text stream + (`IO[str]`), which includes values like `io.StringIO("foo")` and `open("file.env", + "r")` (#348 by [@bbc2]). + +## [0.18.0] - 2021-06-20 + +**Changed** + +- Raise `ValueError` if `quote_mode` isn't one of `always`, `auto` or `never` in + `set_key` (#330 by [@bbc2]). +- When writing a value to a .env file with `set_key` or `dotenv set ` (#330 + by [@bbc2]): + - Use single quotes instead of double quotes. + - Don't strip surrounding quotes. + - In `auto` mode, don't add quotes if the value is only made of alphanumeric characters + (as determined by `string.isalnum`). + +## [0.17.1] - 2021-04-29 + +**Fixed** + +- Fixed tests for build environments relying on `PYTHONPATH` (#318 by [@befeleme]). + +## [0.17.0] - 2021-04-02 + +**Changed** + +- Make `dotenv get ` only show the value, not `key=value` (#313 by [@bbc2]). + +**Added** + +- Add `--override`/`--no-override` option to `dotenv run` (#312 by [@zueve] and [@bbc2]). + +## [0.16.0] - 2021-03-27 + +**Changed** + +- The default value of the `encoding` parameter for `load_dotenv` and `dotenv_values` is + now `"utf-8"` instead of `None` (#306 by [@bbc2]). +- Fix resolution order in variable expansion with `override=False` (#287 by [@bbc2]). + +## [0.15.0] - 2020-10-28 + +**Added** + +- Add `--export` option to `set` to make it prepend the binding with `export` (#270 by + [@jadutter]). + +**Changed** + +- Make `set` command create the `.env` file in the current directory if no `.env` file was + found (#270 by [@jadutter]). + +**Fixed** + +- Fix potentially empty expanded value for duplicate key (#260 by [@bbc2]). +- Fix import error on Python 3.5.0 and 3.5.1 (#267 by [@gongqingkui]). +- Fix parsing of unquoted values containing several adjacent space or tab characters + (#277 by [@bbc2], review by [@x-yuri]). + +## [0.14.0] - 2020-07-03 + +**Changed** + +- Privilege definition in file over the environment in variable expansion (#256 by + [@elbehery95]). + +**Fixed** + +- Improve error message for when file isn't found (#245 by [@snobu]). +- Use HTTPS URL in package meta data (#251 by [@ekohl]). + +## [0.13.0] - 2020-04-16 + +**Added** + +- Add support for a Bash-like default value in variable expansion (#248 by [@bbc2]). + +## [0.12.0] - 2020-02-28 + +**Changed** + +- Use current working directory to find `.env` when bundled by PyInstaller (#213 by + [@gergelyk]). + +**Fixed** + +- Fix escaping of quoted values written by `set_key` (#236 by [@bbc2]). +- Fix `dotenv run` crashing on environment variables without values (#237 by [@yannham]). +- Remove warning when last line is empty (#238 by [@bbc2]). + +## [0.11.0] - 2020-02-07 + +**Added** + +- Add `interpolate` argument to `load_dotenv` and `dotenv_values` to disable interpolation + (#232 by [@ulyssessouza]). + +**Changed** + +- Use logging instead of warnings (#231 by [@bbc2]). + +**Fixed** + +- Fix installation in non-UTF-8 environments (#225 by [@altendky]). +- Fix PyPI classifiers (#228 by [@bbc2]). + +## [0.10.5] - 2020-01-19 + +**Fixed** + +- Fix handling of malformed lines and lines without a value (#222 by [@bbc2]): + - Don't print warning when key has no value. + - Reject more malformed lines (e.g. "A: B", "a='b',c"). +- Fix handling of lines with just a comment (#224 by [@bbc2]). + +## [0.10.4] - 2020-01-17 + +**Added** + +- Make typing optional (#179 by [@techalchemy]). +- Print a warning on malformed line (#211 by [@bbc2]). +- Support keys without a value (#220 by [@ulyssessouza]). + +## 0.10.3 + +- Improve interactive mode detection ([@andrewsmith])([#183]). +- Refactor parser to fix parsing inconsistencies ([@bbc2])([#170]). + - Interpret escapes as control characters only in double-quoted strings. + - Interpret `#` as start of comment only if preceded by whitespace. + +## 0.10.2 + +- Add type hints and expose them to users ([@qnighy])([#172]) +- `load_dotenv` and `dotenv_values` now accept an `encoding` parameter, defaults to `None` + ([@theskumar])([@earlbread])([#161]) +- Fix `str`/`unicode` inconsistency in Python 2: values are always `str` now. ([@bbc2])([#121]) +- Fix Unicode error in Python 2, introduced in 0.10.0. ([@bbc2])([#176]) + +## 0.10.1 +- Fix parsing of variable without a value ([@asyncee])([@bbc2])([#158]) + +## 0.10.0 + +- Add support for UTF-8 in unquoted values ([@bbc2])([#148]) +- Add support for trailing comments ([@bbc2])([#148]) +- Add backslashes support in values ([@bbc2])([#148]) +- Add support for newlines in values ([@bbc2])([#148]) +- Force environment variables to str with Python2 on Windows ([@greyli]) +- Drop Python 3.3 support ([@greyli]) +- Fix stderr/-out/-in redirection ([@venthur]) + + +## 0.9.0 + +- Add `--version` parameter to cli ([@venthur]) +- Enable loading from current directory ([@cjauvin]) +- Add 'dotenv run' command for calling arbitrary shell script with .env ([@venthur]) + +## 0.8.1 + +- Add tests for docs ([@Flimm]) +- Make 'cli' support optional. Use `pip install python-dotenv[cli]`. ([@theskumar]) + +## 0.8.0 + +- `set_key` and `unset_key` only modified the affected file instead of + parsing and re-writing file, this causes comments and other file + entact as it is. +- Add support for `export` prefix in the line. +- Internal refractoring ([@theskumar]) +- Allow `load_dotenv` and `dotenv_values` to work with `StringIO())` ([@alanjds])([@theskumar])([#78]) + +## 0.7.1 + +- Remove hard dependency on iPython ([@theskumar]) + +## 0.7.0 + +- Add support to override system environment variable via .env. + ([@milonimrod](https://github.com/milonimrod)) + ([\#63](https://github.com/theskumar/python-dotenv/issues/63)) +- Disable ".env not found" warning by default + ([@maxkoryukov](https://github.com/maxkoryukov)) + ([\#57](https://github.com/theskumar/python-dotenv/issues/57)) + +## 0.6.5 + +- Add support for special characters `\`. + ([@pjona](https://github.com/pjona)) + ([\#60](https://github.com/theskumar/python-dotenv/issues/60)) + +## 0.6.4 + +- Fix issue with single quotes ([@Flimm]) + ([\#52](https://github.com/theskumar/python-dotenv/issues/52)) + +## 0.6.3 + +- Handle unicode exception in setup.py + ([\#46](https://github.com/theskumar/python-dotenv/issues/46)) + +## 0.6.2 + +- Fix dotenv list command ([@ticosax](https://github.com/ticosax)) +- Add iPython Support + ([@tillahoffmann](https://github.com/tillahoffmann)) + +## 0.6.0 + +- Drop support for Python 2.6 +- Handle escaped characters and newlines in quoted values. (Thanks + [@iameugenejo](https://github.com/iameugenejo)) +- Remove any spaces around unquoted key/value. (Thanks + [@paulochf](https://github.com/paulochf)) +- Added POSIX variable expansion. (Thanks + [@hugochinchilla](https://github.com/hugochinchilla)) + +## 0.5.1 + +- Fix find\_dotenv - it now start search from the file where this + function is called from. + +## 0.5.0 + +- Add `find_dotenv` method that will try to find a `.env` file. + (Thanks [@isms](https://github.com/isms)) + +## 0.4.0 + +- cli: Added `-q/--quote` option to control the behaviour of quotes + around values in `.env`. (Thanks + [@hugochinchilla](https://github.com/hugochinchilla)). +- Improved test coverage. + +[#78]: https://github.com/theskumar/python-dotenv/issues/78 +[#121]: https://github.com/theskumar/python-dotenv/issues/121 +[#148]: https://github.com/theskumar/python-dotenv/issues/148 +[#158]: https://github.com/theskumar/python-dotenv/issues/158 +[#170]: https://github.com/theskumar/python-dotenv/issues/170 +[#172]: https://github.com/theskumar/python-dotenv/issues/172 +[#176]: https://github.com/theskumar/python-dotenv/issues/176 +[#183]: https://github.com/theskumar/python-dotenv/issues/183 +[#359]: https://github.com/theskumar/python-dotenv/issues/359 +[#469]: https://github.com/theskumar/python-dotenv/issues/469 +[#456]: https://github.com/theskumar/python-dotenv/issues/456 +[#466]: https://github.com/theskumar/python-dotenv/issues/466 +[#454]: https://github.com/theskumar/python-dotenv/issues/454 +[#474]: https://github.com/theskumar/python-dotenv/issues/474 + +[@alanjds]: https://github.com/alanjds +[@altendky]: https://github.com/altendky +[@andrewsmith]: https://github.com/andrewsmith +[@asyncee]: https://github.com/asyncee +[@bbc2]: https://github.com/bbc2 +[@befeleme]: https://github.com/befeleme +[@cjauvin]: https://github.com/cjauvin +[@eaf]: https://github.com/eaf +[@earlbread]: https://github.com/earlbread +[@eggplants]: https://github.com/@eggplants +[@ekohl]: https://github.com/ekohl +[@elbehery95]: https://github.com/elbehery95 +[@eumiro]: https://github.com/eumiro +[@Flimm]: https://github.com/Flimm +[@freddyaboulton]: https://github.com/freddyaboulton +[@gergelyk]: https://github.com/gergelyk +[@gongqingkui]: https://github.com/gongqingkui +[@greyli]: https://github.com/greyli +[@harveer07]: https://github.com/@harveer07 +[@jadutter]: https://github.com/jadutter +[@jankislinger]: https://github.com/jankislinger +[@jctanner]: https://github.com/jctanner +[@larsks]: https://github.com/@larsks +[@lsmith77]: https://github.com/lsmith77 +[@mgorny]: https://github.com/mgorny +[@naorlivne]: https://github.com/@naorlivne +[@Nicals]: https://github.com/Nicals +[@Nougat-Waffle]: https://github.com/Nougat-Waffle +[@qnighy]: https://github.com/qnighy +[@Qwerty-133]: https://github.com/Qwerty-133 +[@rabinadk1]: https://github.com/@rabinadk1 +[@sammck]: https://github.com/@sammck +[@samwyma]: https://github.com/samwyma +[@snobu]: https://github.com/snobu +[@techalchemy]: https://github.com/techalchemy +[@theGOTOguy]: https://github.com/theGOTOguy +[@theskumar]: https://github.com/theskumar +[@ulyssessouza]: https://github.com/ulyssessouza +[@venthur]: https://github.com/venthur +[@x-yuri]: https://github.com/x-yuri +[@yannham]: https://github.com/yannham +[@zueve]: https://github.com/zueve + + +[Unreleased]: https://github.com/theskumar/python-dotenv/compare/v1.0.1...HEAD +[1.0.1]: https://github.com/theskumar/python-dotenv/compare/v1.0.0...v1.0.1 +[1.0.0]: https://github.com/theskumar/python-dotenv/compare/v0.21.0...v1.0.0 +[0.21.1]: https://github.com/theskumar/python-dotenv/compare/v0.21.0...v0.21.1 +[0.21.0]: https://github.com/theskumar/python-dotenv/compare/v0.20.0...v0.21.0 +[0.20.0]: https://github.com/theskumar/python-dotenv/compare/v0.19.2...v0.20.0 +[0.19.2]: https://github.com/theskumar/python-dotenv/compare/v0.19.1...v0.19.2 +[0.19.1]: https://github.com/theskumar/python-dotenv/compare/v0.19.0...v0.19.1 +[0.19.0]: https://github.com/theskumar/python-dotenv/compare/v0.18.0...v0.19.0 +[0.18.0]: https://github.com/theskumar/python-dotenv/compare/v0.17.1...v0.18.0 +[0.17.1]: https://github.com/theskumar/python-dotenv/compare/v0.17.0...v0.17.1 +[0.17.0]: https://github.com/theskumar/python-dotenv/compare/v0.16.0...v0.17.0 +[0.16.0]: https://github.com/theskumar/python-dotenv/compare/v0.15.0...v0.16.0 +[0.15.0]: https://github.com/theskumar/python-dotenv/compare/v0.14.0...v0.15.0 +[0.14.0]: https://github.com/theskumar/python-dotenv/compare/v0.13.0...v0.14.0 +[0.13.0]: https://github.com/theskumar/python-dotenv/compare/v0.12.0...v0.13.0 +[0.12.0]: https://github.com/theskumar/python-dotenv/compare/v0.11.0...v0.12.0 +[0.11.0]: https://github.com/theskumar/python-dotenv/compare/v0.10.5...v0.11.0 +[0.10.5]: https://github.com/theskumar/python-dotenv/compare/v0.10.4...v0.10.5 +[0.10.4]: https://github.com/theskumar/python-dotenv/compare/v0.10.3...v0.10.4 diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/RECORD b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/RECORD new file mode 100644 index 00000000..a63bcb9d --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/RECORD @@ -0,0 +1,26 @@ +../../Scripts/dotenv.exe,sha256=YRCcm6xHhO6_Iiv70DTSA7aqz5sdqV3_qn_5sRFlj30,108410 +dotenv/__init__.py,sha256=WBU5SfSiKAhS3hzu17ykNuuwbuwyDCX91Szv4vUeOuM,1292 +dotenv/__main__.py,sha256=N0RhLG7nHIqtlJHwwepIo-zbJPNx9sewCCRGY528h_4,129 +dotenv/__pycache__/__init__.cpython-311.pyc,, +dotenv/__pycache__/__main__.cpython-311.pyc,, +dotenv/__pycache__/cli.cpython-311.pyc,, +dotenv/__pycache__/ipython.cpython-311.pyc,, +dotenv/__pycache__/main.cpython-311.pyc,, +dotenv/__pycache__/parser.cpython-311.pyc,, +dotenv/__pycache__/variables.cpython-311.pyc,, +dotenv/__pycache__/version.cpython-311.pyc,, +dotenv/cli.py,sha256=_ttQuR9Yl4k1PT53ByISkDjJ3kO_N_LzIDZzZ95uXEk,5809 +dotenv/ipython.py,sha256=avI6aez_RxnBptYgchIquF2TSgKI-GOhY3ppiu3VuWE,1303 +dotenv/main.py,sha256=GV7Ki6JYPDa-xy2ZXHKqER-bRvKa7qqh0G0OwffYJr8,12098 +dotenv/parser.py,sha256=QgU5HwMwM2wMqt0vz6dHTJ4nzPmwqRqvi4MSyeVifgU,5186 +dotenv/py.typed,sha256=8PjyZ1aVoQpRVvt71muvuq5qE-jTFZkK-GLHkhdebmc,26 +dotenv/variables.py,sha256=CD0qXOvvpB3q5RpBQMD9qX6vHX7SyW-SuiwGMFSlt08,2348 +dotenv/version.py,sha256=d4QHYmS_30j0hPN8NmNPnQ_Z0TphDRbu4MtQj9cT9e8,22 +python_dotenv-1.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +python_dotenv-1.0.1.dist-info/LICENSE,sha256=gGGbcEnwjIFoOtDgHwjyV6hAZS3XHugxRtNmWMfSwrk,1556 +python_dotenv-1.0.1.dist-info/METADATA,sha256=fCkcTEUG3zknbuN1BK8e0PPCIgvPBLk-LneK0mRDM_s,23170 +python_dotenv-1.0.1.dist-info/RECORD,, +python_dotenv-1.0.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +python_dotenv-1.0.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92 +python_dotenv-1.0.1.dist-info/entry_points.txt,sha256=yRl1rCbswb1nQTQ_gZRlCw5QfabztUGnfGWLhlXFNdI,47 +python_dotenv-1.0.1.dist-info/top_level.txt,sha256=eyqUH4SHJNr6ahOYlxIunTr4XinE8Z5ajWLdrK3r0D8,7 diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/REQUESTED b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/WHEEL b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/WHEEL new file mode 100644 index 00000000..98c0d20b --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/entry_points.txt b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/entry_points.txt new file mode 100644 index 00000000..0a868232 --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +dotenv = dotenv.__main__:cli diff --git a/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/top_level.txt b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/top_level.txt new file mode 100644 index 00000000..fe7c01aa --- /dev/null +++ b/.venv/Lib/site-packages/python_dotenv-1.0.1.dist-info/top_level.txt @@ -0,0 +1 @@ +dotenv diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/INSTALLER b/.venv/Lib/site-packages/redis-5.0.1.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/LICENSE b/.venv/Lib/site-packages/redis-5.0.1.dist-info/LICENSE new file mode 100644 index 00000000..8509ccd6 --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022-2023, Redis, inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/METADATA b/.venv/Lib/site-packages/redis-5.0.1.dist-info/METADATA new file mode 100644 index 00000000..98242bc0 --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/METADATA @@ -0,0 +1,203 @@ +Metadata-Version: 2.1 +Name: redis +Version: 5.0.1 +Summary: Python client for Redis database and key-value store +Home-page: https://github.com/redis/redis-py +Author: Redis Inc. +Author-email: oss@redis.com +License: MIT +Project-URL: Documentation, https://redis.readthedocs.io/en/latest/ +Project-URL: Changes, https://github.com/redis/redis-py/releases +Project-URL: Code, https://github.com/redis/redis-py +Project-URL: Issue tracker, https://github.com/redis/redis-py/issues +Keywords: Redis,key-value store,database +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: async-timeout >=4.0.2 ; python_full_version <= "3.11.2" +Requires-Dist: importlib-metadata >=1.0 ; python_version < "3.8" +Requires-Dist: typing-extensions ; python_version < "3.8" +Provides-Extra: hiredis +Requires-Dist: hiredis >=1.0.0 ; extra == 'hiredis' +Provides-Extra: ocsp +Requires-Dist: cryptography >=36.0.1 ; extra == 'ocsp' +Requires-Dist: pyopenssl ==20.0.1 ; extra == 'ocsp' +Requires-Dist: requests >=2.26.0 ; extra == 'ocsp' + +# redis-py + +The Python interface to the Redis key-value store. + +[![CI](https://github.com/redis/redis-py/workflows/CI/badge.svg?branch=master)](https://github.com/redis/redis-py/actions?query=workflow%3ACI+branch%3Amaster) +[![docs](https://readthedocs.org/projects/redis/badge/?version=stable&style=flat)](https://redis-py.readthedocs.io/en/stable/) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) +[![pypi](https://badge.fury.io/py/redis.svg)](https://pypi.org/project/redis/) +[![pre-release](https://img.shields.io/github/v/release/redis/redis-py?include_prereleases&label=latest-prerelease)](https://github.com/redis/redis-py/releases) +[![codecov](https://codecov.io/gh/redis/redis-py/branch/master/graph/badge.svg?token=yenl5fzxxr)](https://codecov.io/gh/redis/redis-py) + +[Installation](#installation) | [Usage](#usage) | [Advanced Topics](#advanced-topics) | [Contributing](https://github.com/redis/redis-py/blob/master/CONTRIBUTING.md) + +--------------------------------------------- + +**Note: ** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. + +--------------------------------------------- + +## Installation + +Start a redis via docker: + +``` bash +docker run -p 6379:6379 -it redis/redis-stack:latest +``` + +To install redis-py, simply: + +``` bash +$ pip install redis +``` + +For faster performance, install redis with hiredis support, this provides a compiled response parser, and *for most cases* requires zero code changes. +By default, if hiredis >= 1.0 is available, redis-py will attempt to use it for response parsing. + +``` bash +$ pip install "redis[hiredis]" +``` + +Looking for a high-level library to handle object mapping? See [redis-om-python](https://github.com/redis/redis-om-python)! + +## Supported Redis Versions + +The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES) and [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES). + +The table below highlights version compatibility of the most-recent library versions and redis versions. + +| Library version | Supported redis versions | +|-----------------|-------------------| +| 3.5.3 | <= 6.2 Family of releases | +| >= 4.5.0 | Version 5.0 to 7.0 | +| >= 5.0.0 | Version 5.0 to current | + + +## Usage + +### Basic Example + +``` python +>>> import redis +>>> r = redis.Redis(host='localhost', port=6379, db=0) +>>> r.set('foo', 'bar') +True +>>> r.get('foo') +b'bar' +``` + +The above code connects to localhost on port 6379, sets a value in Redis, and retrieves it. All responses are returned as bytes in Python, to receive decoded strings, set *decode_responses=True*. For this, and more connection options, see [these examples](https://redis.readthedocs.io/en/stable/examples.html). + + +#### RESP3 Support +To enable support for RESP3, ensure you have at least version 5.0 of the client, and change your connection object to include *protocol=3* + +``` python +>>> import redis +>>> r = redis.Redis(host='localhost', port=6379, db=0, protocol=3) +``` + +### Connection Pools + +By default, redis-py uses a connection pool to manage connections. Each instance of a Redis class receives its own connection pool. You can however define your own [redis.ConnectionPool](https://redis.readthedocs.io/en/stable/connections.html#connection-pools). + +``` python +>>> pool = redis.ConnectionPool(host='localhost', port=6379, db=0) +>>> r = redis.Redis(connection_pool=pool) +``` + +Alternatively, you might want to look at [Async connections](https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html), or [Cluster connections](https://redis.readthedocs.io/en/stable/connections.html#cluster-client), or even [Async Cluster connections](https://redis.readthedocs.io/en/stable/connections.html#async-cluster-client). + +### Redis Commands + +There is built-in support for all of the [out-of-the-box Redis commands](https://redis.io/commands). They are exposed using the raw Redis command names (`HSET`, `HGETALL`, etc.) except where a word (i.e. del) is reserved by the language. The complete set of commands can be found [here](https://github.com/redis/redis-py/tree/master/redis/commands), or [the documentation](https://redis.readthedocs.io/en/stable/commands.html). + +## Advanced Topics + +The [official Redis command documentation](https://redis.io/commands) +does a great job of explaining each command in detail. redis-py attempts +to adhere to the official command syntax. There are a few exceptions: + +- **MULTI/EXEC**: These are implemented as part of the Pipeline class. + The pipeline is wrapped with the MULTI and EXEC statements by + default when it is executed, which can be disabled by specifying + transaction=False. See more about Pipelines below. + +- **SUBSCRIBE/LISTEN**: Similar to pipelines, PubSub is implemented as + a separate class as it places the underlying connection in a state + where it can\'t execute non-pubsub commands. Calling the pubsub + method from the Redis client will return a PubSub instance where you + can subscribe to channels and listen for messages. You can only call + PUBLISH from the Redis client (see [this comment on issue + #151](https://github.com/redis/redis-py/issues/151#issuecomment-1545015) + for details). + +For more details, please see the documentation on [advanced topics page](https://redis.readthedocs.io/en/stable/advanced_features.html). + +### Pipelines + +The following is a basic example of a [Redis pipeline](https://redis.io/docs/manual/pipelining/), a method to optimize round-trip calls, by batching Redis commands, and receiving their results as a list. + + +``` python +>>> pipe = r.pipeline() +>>> pipe.set('foo', 5) +>>> pipe.set('bar', 18.5) +>>> pipe.set('blee', "hello world!") +>>> pipe.execute() +[True, True, True] +``` + +### PubSub + +The following example shows how to utilize [Redis Pub/Sub](https://redis.io/docs/manual/pubsub/) to subscribe to specific channels. + +``` python +>>> r = redis.Redis(...) +>>> p = r.pubsub() +>>> p.subscribe('my-first-channel', 'my-second-channel', ...) +>>> p.get_message() +{'pattern': None, 'type': 'subscribe', 'channel': b'my-second-channel', 'data': 1} +``` + + +-------------------------- + +### Author + +redis-py is developed and maintained by [Redis Inc](https://redis.com). It can be found [here]( +https://github.com/redis/redis-py), or downloaded from [pypi](https://pypi.org/project/redis/). + +Special thanks to: + +- Andy McCurdy () the original author of redis-py. +- Ludovico Magnocavallo, author of the original Python Redis client, + from which some of the socket code is still used. +- Alexander Solovyov for ideas on the generic response callback + system. +- Paul Hubbard for initial packaging support. + +[![Redis](./docs/logo-redis.png)](https://www.redis.com) + diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/RECORD b/.venv/Lib/site-packages/redis-5.0.1.dist-info/RECORD new file mode 100644 index 00000000..f6d8c78e --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/RECORD @@ -0,0 +1,147 @@ +redis-5.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +redis-5.0.1.dist-info/LICENSE,sha256=pXslClvwPXr-VbdAYzE_Ktt7ANVGwKsUmok5gzP-PMg,1074 +redis-5.0.1.dist-info/METADATA,sha256=xLwWid1Pns_mCEX6qn3qtFxtf7pphgPFPWOwEg5LWrQ,8910 +redis-5.0.1.dist-info/RECORD,, +redis-5.0.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92 +redis-5.0.1.dist-info/top_level.txt,sha256=OMAefszlde6ZoOtlM35AWzpRIrwtcqAMHGlRit-w2-4,6 +redis/__init__.py,sha256=PthSOEfXKlYV9xBgroOnO2tJD7uu0BWwvztgsKUvK48,2110 +redis/__pycache__/__init__.cpython-311.pyc,, +redis/__pycache__/backoff.cpython-311.pyc,, +redis/__pycache__/client.cpython-311.pyc,, +redis/__pycache__/cluster.cpython-311.pyc,, +redis/__pycache__/compat.cpython-311.pyc,, +redis/__pycache__/connection.cpython-311.pyc,, +redis/__pycache__/crc.cpython-311.pyc,, +redis/__pycache__/credentials.cpython-311.pyc,, +redis/__pycache__/exceptions.cpython-311.pyc,, +redis/__pycache__/lock.cpython-311.pyc,, +redis/__pycache__/ocsp.cpython-311.pyc,, +redis/__pycache__/retry.cpython-311.pyc,, +redis/__pycache__/sentinel.cpython-311.pyc,, +redis/__pycache__/typing.cpython-311.pyc,, +redis/__pycache__/utils.cpython-311.pyc,, +redis/_parsers/__init__.py,sha256=qkfgV2X9iyvQAvbLdSelwgz0dCk9SGAosCvuZC9-qDc,550 +redis/_parsers/__pycache__/__init__.cpython-311.pyc,, +redis/_parsers/__pycache__/base.cpython-311.pyc,, +redis/_parsers/__pycache__/commands.cpython-311.pyc,, +redis/_parsers/__pycache__/encoders.cpython-311.pyc,, +redis/_parsers/__pycache__/helpers.cpython-311.pyc,, +redis/_parsers/__pycache__/hiredis.cpython-311.pyc,, +redis/_parsers/__pycache__/resp2.cpython-311.pyc,, +redis/_parsers/__pycache__/resp3.cpython-311.pyc,, +redis/_parsers/__pycache__/socket.cpython-311.pyc,, +redis/_parsers/base.py,sha256=95SoPNwt4xJQB-ONIjxsR46n4EHnxnmkv9f0ReZSIR0,7480 +redis/_parsers/commands.py,sha256=pmR4hl4u93UvCmeDgePHFc6pWDr4slrKEvCsdMmtj_M,11052 +redis/_parsers/encoders.py,sha256=X0jvTp-E4TZUlZxV5LJJ88TuVrF1vly5tuC0xjxGaSc,1734 +redis/_parsers/helpers.py,sha256=xcRjjns6uQPb2pp0AOlOK9LhMJL4ofyEMFqVA7CwzsE,27947 +redis/_parsers/hiredis.py,sha256=X8yk0ElEEjHlhUgjs9fdHSOijlxYtunTrTJSLzkGrvQ,7581 +redis/_parsers/resp2.py,sha256=f22kH-_ZP2iNtOn6xOe65MSy_fJpu8OEn1u_hgeeojI,4813 +redis/_parsers/resp3.py,sha256=rXDA0R-wjCj2vyGaaWEf50NXN7UFBzefRnK3NGzWz2E,9657 +redis/_parsers/socket.py,sha256=CKD8QW_wFSNlIZzxlbNduaGpiv0I8wBcsGuAIojDfJg,5403 +redis/asyncio/__init__.py,sha256=uoDD8XYVi0Kj6mcufYwLDUTQXmBRx7a0bhKF9stZr7I,1489 +redis/asyncio/__pycache__/__init__.cpython-311.pyc,, +redis/asyncio/__pycache__/client.cpython-311.pyc,, +redis/asyncio/__pycache__/cluster.cpython-311.pyc,, +redis/asyncio/__pycache__/connection.cpython-311.pyc,, +redis/asyncio/__pycache__/lock.cpython-311.pyc,, +redis/asyncio/__pycache__/retry.cpython-311.pyc,, +redis/asyncio/__pycache__/sentinel.cpython-311.pyc,, +redis/asyncio/__pycache__/utils.cpython-311.pyc,, +redis/asyncio/client.py,sha256=BYurDT13lsw0N3a8sLqQFl00tFFolpET7_EujLw2Nbc,58826 +redis/asyncio/cluster.py,sha256=a0Za2icr03ytjF_WVohDMvEZejixUdVMhpsKWeMxYHY,63076 +redis/asyncio/connection.py,sha256=ZwClasZ2x0SQY90gDZvraFIx2lhGPnDm-xUUPPsb424,43426 +redis/asyncio/lock.py,sha256=lLasXEO2E1CskhX5ZZoaSGpmwZP1Q782R3HAUNG3wD4,11967 +redis/asyncio/retry.py,sha256=SnPPOlo5gcyIFtkC4DY7HFvmDgUaILsJ3DeHioogdB8,2219 +redis/asyncio/sentinel.py,sha256=sTVJCbi1KtIbHJc3fkHRZb_LGav_UtCAq-ipxltkGsE,14198 +redis/asyncio/utils.py,sha256=Yxc5YQumhLjtDDwCS4mgxI6yy2Z21AzLlFxVbxCohic,704 +redis/backoff.py,sha256=x-sAjV7u4MmdOjFZSZ8RnUnCaQtPhCBbGNBgICvCW3I,2966 +redis/client.py,sha256=IkqYEPg2WA35jBjPCpEgcKcVW3Hx8lm89j_IQ2dnoOw,57514 +redis/cluster.py,sha256=HcH2YM057xpWMQhGYBLWv5l9yrb7hzcSuPXXbqJl_DY,92754 +redis/commands/__init__.py,sha256=cTUH-MGvaLYS0WuoytyqtN1wniw2A1KbkUXcpvOSY3I,576 +redis/commands/__pycache__/__init__.cpython-311.pyc,, +redis/commands/__pycache__/cluster.cpython-311.pyc,, +redis/commands/__pycache__/core.cpython-311.pyc,, +redis/commands/__pycache__/helpers.cpython-311.pyc,, +redis/commands/__pycache__/redismodules.cpython-311.pyc,, +redis/commands/__pycache__/sentinel.cpython-311.pyc,, +redis/commands/bf/__init__.py,sha256=ESmQXH4p9Dp37tNCwQGDiF_BHDEaKnXSF7ZfASEqkFY,8027 +redis/commands/bf/__pycache__/__init__.cpython-311.pyc,, +redis/commands/bf/__pycache__/commands.cpython-311.pyc,, +redis/commands/bf/__pycache__/info.cpython-311.pyc,, +redis/commands/bf/commands.py,sha256=kVWUatdS0zLcu8-fVIqLLQBU5u8fJWIOCVUD3fqYVp0,21462 +redis/commands/bf/info.py,sha256=tpE4hv1zApxoOgyV9_8BEDZcl4Wf6tS1dSvtlxV7uTE,3395 +redis/commands/cluster.py,sha256=5BDwdeUnWVWOalF5fHD12HPQeDq_rc2vhuCI3sChrYE,31562 +redis/commands/core.py,sha256=2WM9nZ3f0Xqny8o5yucORe0fLRItJO4SWU68W5Wr1mw,223552 +redis/commands/graph/__init__.py,sha256=NmklyOuzIa20yEWrhnKQxgQlaXKYkcwBkGHpvQyo5J8,7237 +redis/commands/graph/__pycache__/__init__.cpython-311.pyc,, +redis/commands/graph/__pycache__/commands.cpython-311.pyc,, +redis/commands/graph/__pycache__/edge.cpython-311.pyc,, +redis/commands/graph/__pycache__/exceptions.cpython-311.pyc,, +redis/commands/graph/__pycache__/execution_plan.cpython-311.pyc,, +redis/commands/graph/__pycache__/node.cpython-311.pyc,, +redis/commands/graph/__pycache__/path.cpython-311.pyc,, +redis/commands/graph/__pycache__/query_result.cpython-311.pyc,, +redis/commands/graph/commands.py,sha256=rLGV58ZJKEf6yxzk1oD3IwiS03lP6bpbo0249pFI0OY,10379 +redis/commands/graph/edge.py,sha256=_TljVB4a1pPS9pb8_Cvw8rclbBOOI__-fY9fybU4djQ,2460 +redis/commands/graph/exceptions.py,sha256=kRDBsYLgwIaM4vqioO_Bp_ugWvjfqCH7DIv4Gpc9HCM,107 +redis/commands/graph/execution_plan.py,sha256=Pxr8_zhPWT_EdZSgGrbiWw8wFL6q5JF7O-Z6Xzm55iw,6742 +redis/commands/graph/node.py,sha256=Pasfsl5dF6WqT9KCNFAKKwGubyK_2ORCoAQE4VtnXkQ,2400 +redis/commands/graph/path.py,sha256=m6Gz4DYfMIQ8VReDLHlnQw_KI2rVdepWYk_AU0_x_GM,2080 +redis/commands/graph/query_result.py,sha256=GTEnBE0rAiUk4JquaxcVKdL1kzSMDWW5ky-iFTvRN84,17040 +redis/commands/helpers.py,sha256=WgfhdH3NCBW2Vqg-9PcP2EIKwzBkzb5CeqfdnPm2tTQ,4531 +redis/commands/json/__init__.py,sha256=llpDQz2kBNnJyfQfuh0-2oY-knMb6gAS0ADtPmaTKsM,4854 +redis/commands/json/__pycache__/__init__.cpython-311.pyc,, +redis/commands/json/__pycache__/_util.cpython-311.pyc,, +redis/commands/json/__pycache__/commands.cpython-311.pyc,, +redis/commands/json/__pycache__/decoders.cpython-311.pyc,, +redis/commands/json/__pycache__/path.cpython-311.pyc,, +redis/commands/json/_util.py,sha256=b_VQTh10FyLl8BtREfJfDagOJCyd6wTQQs8g63pi5GI,116 +redis/commands/json/commands.py,sha256=9P3NBFyWuRxWer5i__NtJx7oJZNnTOisfrHGhwaRfoA,15603 +redis/commands/json/decoders.py,sha256=a_IoMV_wgeJyUifD4P6HTcM9s6FhricwmzQcZRmc-Gw,1411 +redis/commands/json/path.py,sha256=0zaO6_q_FVMk1Bkhkb7Wcr8AF2Tfr69VhkKy1IBVhpA,393 +redis/commands/redismodules.py,sha256=7TfVzLj319mhsA6WEybsOdIPk4pC-1hScJg3H5hv3T4,2454 +redis/commands/search/__init__.py,sha256=happQFVF0j7P87p7LQsUK5AK0kuem9cA-xvVRdQWpos,5744 +redis/commands/search/__pycache__/__init__.cpython-311.pyc,, +redis/commands/search/__pycache__/_util.cpython-311.pyc,, +redis/commands/search/__pycache__/aggregation.cpython-311.pyc,, +redis/commands/search/__pycache__/commands.cpython-311.pyc,, +redis/commands/search/__pycache__/document.cpython-311.pyc,, +redis/commands/search/__pycache__/field.cpython-311.pyc,, +redis/commands/search/__pycache__/indexDefinition.cpython-311.pyc,, +redis/commands/search/__pycache__/query.cpython-311.pyc,, +redis/commands/search/__pycache__/querystring.cpython-311.pyc,, +redis/commands/search/__pycache__/reducers.cpython-311.pyc,, +redis/commands/search/__pycache__/result.cpython-311.pyc,, +redis/commands/search/__pycache__/suggestion.cpython-311.pyc,, +redis/commands/search/_util.py,sha256=VAguSwh_3dNtJwNU6Vle2CNdPE10_NUkPffD7GWFX48,193 +redis/commands/search/aggregation.py,sha256=8yQ1P31Qiy29xehlmN2ToCh73e-MHmOg_y0_UXfQDS8,10772 +redis/commands/search/commands.py,sha256=dpSMZ7hXjbAlrUL4h5GX6BtP4WibQZCO6Ylfo8qkAF0,36751 +redis/commands/search/document.py,sha256=g2R-PRgq-jN33_GLXzavvse4cpIHBMfjPfPK7tnE9Gc,413 +redis/commands/search/field.py,sha256=WxtOHgtm9S82_C0nzeT7fHRrWPkGflJnSXQRIiaVJmU,4518 +redis/commands/search/indexDefinition.py,sha256=VL2CMzjxN0HEIaTn88evnHX1fCEmytbik4vAmiiYSC8,2489 +redis/commands/search/query.py,sha256=blBcgFnurT9rkg4gI6j14EekWU_J9e_aDlryVCCWDjM,11564 +redis/commands/search/querystring.py,sha256=dE577kOqkCErNgO-IXI4xFVHI8kQE-JiH5ZRI_CKjHE,7597 +redis/commands/search/reducers.py,sha256=Scceylx8BjyqS-TJOdhNW63n6tecL9ojt4U5Sqho5UY,4220 +redis/commands/search/result.py,sha256=4H7LnOVWScti7WO2XYxjhiTu3QNIt2pZHO1eptXZDBk,2149 +redis/commands/search/suggestion.py,sha256=V_re6suDCoNc0ETn_P1t51FeK4pCamPwxZRxCY8jscE,1612 +redis/commands/sentinel.py,sha256=hRcIQ9x9nEkdcCsJzo6Ves6vk-3tsfQqfJTT_v3oLY0,4110 +redis/commands/timeseries/__init__.py,sha256=gkz6wshEzzQQryBOnrAqqQzttS-AHfXmuN_H1J38EbM,3459 +redis/commands/timeseries/__pycache__/__init__.cpython-311.pyc,, +redis/commands/timeseries/__pycache__/commands.cpython-311.pyc,, +redis/commands/timeseries/__pycache__/info.cpython-311.pyc,, +redis/commands/timeseries/__pycache__/utils.cpython-311.pyc,, +redis/commands/timeseries/commands.py,sha256=bFdk-609CnL-dTqMU5yQEiY-UCjVpLknHGDENQ2t-1U,33438 +redis/commands/timeseries/info.py,sha256=5deBInBtLPb3ZrVoSB4EhWkRPkSIW5Qd_98rMDnutnk,3207 +redis/commands/timeseries/utils.py,sha256=o7q7Fe1wgpdTLKyGY8Qi2VV6XKEBprhzmPdrFz3OIvo,1309 +redis/compat.py,sha256=tr-t9oHdeosrK3TvZySaLvP3ZlGqTZQaXtlTqiqp_8I,242 +redis/connection.py,sha256=fxHl5icHS3Mk2AhHeSGxcpMcY5aeHmq5589g2XyI_xg,50524 +redis/crc.py,sha256=Z3kXFtkY2LdgefnQMud1xr4vG5UYvA9LCMqNMX1ywu4,729 +redis/credentials.py,sha256=6VvFeReFp6vernGIWlIVOm8OmbNgoFYdd1wgsjZTnlk,738 +redis/exceptions.py,sha256=AzWeYEpVR1koUddMgvz0WZxmPX_jyksagoRf8FSSWKA,5103 +redis/lock.py,sha256=CwB_qo7ADDGSt_JqjQKSL1nKDCwdb-ASJsAlv0JO6mA,11564 +redis/ocsp.py,sha256=WwiGby6yZYR0D3lgnnQYmPKy-UAgYqGXi6A4jDBZGL4,11450 +redis/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +redis/retry.py,sha256=Ssp9s2hhDfyRs0rCRCaTgRtLR7NAYO5QMw4QflourGo,1817 +redis/sentinel.py,sha256=CErsD-c3mYFnXDttCY1OvpyUdfKcyD5F9Jv9Fd3iHuU,14175 +redis/typing.py,sha256=wjyihEjyGiJrigcs0-zhy7K-MzVy7uLidjszNdPHMug,2212 +redis/utils.py,sha256=87p7ImnihyIhiaqalVYh9Qq9JeaVwi_Y4GBzNaHAXJg,3381 diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/WHEEL b/.venv/Lib/site-packages/redis-5.0.1.dist-info/WHEEL new file mode 100644 index 00000000..7e688737 --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.2) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/redis-5.0.1.dist-info/top_level.txt b/.venv/Lib/site-packages/redis-5.0.1.dist-info/top_level.txt new file mode 100644 index 00000000..7800f0fa --- /dev/null +++ b/.venv/Lib/site-packages/redis-5.0.1.dist-info/top_level.txt @@ -0,0 +1 @@ +redis diff --git a/.venv/Lib/site-packages/redis/__init__.py b/.venv/Lib/site-packages/redis/__init__.py new file mode 100644 index 00000000..495d2d99 --- /dev/null +++ b/.venv/Lib/site-packages/redis/__init__.py @@ -0,0 +1,94 @@ +import sys + +from redis import asyncio # noqa +from redis.backoff import default_backoff +from redis.client import Redis, StrictRedis +from redis.cluster import RedisCluster +from redis.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + OutOfMemoryError, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, +) +from redis.utils import from_url + +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + + +def int_or_str(value): + try: + return int(value) + except ValueError: + return value + + +try: + __version__ = metadata.version("redis") +except metadata.PackageNotFoundError: + __version__ = "99.99.99" + + +try: + VERSION = tuple(map(int_or_str, __version__.split("."))) +except AttributeError: + VERSION = tuple([99, 99, 99]) + +__all__ = [ + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "CredentialProvider", + "DataError", + "from_url", + "default_backoff", + "InvalidResponse", + "OutOfMemoryError", + "PubSubError", + "ReadOnlyError", + "Redis", + "RedisCluster", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", + "SSLConnection", + "UsernamePasswordCredentialProvider", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", +] diff --git a/.venv/Lib/site-packages/redis/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..78155189 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/backoff.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/backoff.cpython-311.pyc new file mode 100644 index 00000000..14997d30 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/backoff.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..5129a5ac Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/cluster.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/cluster.cpython-311.pyc new file mode 100644 index 00000000..4c26f835 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/cluster.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/compat.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/compat.cpython-311.pyc new file mode 100644 index 00000000..ff20bc50 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/compat.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/connection.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/connection.cpython-311.pyc new file mode 100644 index 00000000..eaad95c1 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/connection.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/crc.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/crc.cpython-311.pyc new file mode 100644 index 00000000..c5791b85 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/crc.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/credentials.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/credentials.cpython-311.pyc new file mode 100644 index 00000000..82c6063d Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/credentials.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/exceptions.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 00000000..ae8f83a0 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/lock.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/lock.cpython-311.pyc new file mode 100644 index 00000000..896adde5 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/lock.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/ocsp.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/ocsp.cpython-311.pyc new file mode 100644 index 00000000..8367d045 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/ocsp.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/retry.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/retry.cpython-311.pyc new file mode 100644 index 00000000..526add44 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/retry.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/sentinel.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/sentinel.cpython-311.pyc new file mode 100644 index 00000000..a2170c61 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/sentinel.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/typing.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/typing.cpython-311.pyc new file mode 100644 index 00000000..0a019156 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/typing.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/redis/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..864b62c3 Binary files /dev/null and b/.venv/Lib/site-packages/redis/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__init__.py b/.venv/Lib/site-packages/redis/_parsers/__init__.py new file mode 100644 index 00000000..6cc32e3c --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/__init__.py @@ -0,0 +1,20 @@ +from .base import BaseParser, _AsyncRESPBase +from .commands import AsyncCommandsParser, CommandsParser +from .encoders import Encoder +from .hiredis import _AsyncHiredisParser, _HiredisParser +from .resp2 import _AsyncRESP2Parser, _RESP2Parser +from .resp3 import _AsyncRESP3Parser, _RESP3Parser + +__all__ = [ + "AsyncCommandsParser", + "_AsyncHiredisParser", + "_AsyncRESPBase", + "_AsyncRESP2Parser", + "_AsyncRESP3Parser", + "CommandsParser", + "Encoder", + "BaseParser", + "_HiredisParser", + "_RESP2Parser", + "_RESP3Parser", +] diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..e79e8aba Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/base.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/base.cpython-311.pyc new file mode 100644 index 00000000..a75cac1b Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/base.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..d5f50bfb Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/encoders.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/encoders.cpython-311.pyc new file mode 100644 index 00000000..80d82a03 Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/encoders.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/helpers.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 00000000..2fa674d7 Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/helpers.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/hiredis.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/hiredis.cpython-311.pyc new file mode 100644 index 00000000..0b8dbd8d Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/hiredis.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp2.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp2.cpython-311.pyc new file mode 100644 index 00000000..b49e96ff Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp2.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp3.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp3.cpython-311.pyc new file mode 100644 index 00000000..16274bd6 Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/resp3.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/__pycache__/socket.cpython-311.pyc b/.venv/Lib/site-packages/redis/_parsers/__pycache__/socket.cpython-311.pyc new file mode 100644 index 00000000..d84a315d Binary files /dev/null and b/.venv/Lib/site-packages/redis/_parsers/__pycache__/socket.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/_parsers/base.py b/.venv/Lib/site-packages/redis/_parsers/base.py new file mode 100644 index 00000000..8e59249b --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/base.py @@ -0,0 +1,225 @@ +import sys +from abc import ABC +from asyncio import IncompleteReadError, StreamReader, TimeoutError +from typing import List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from ..exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ConnectionError, + ExecAbortError, + ModuleError, + NoPermissionError, + NoScriptError, + OutOfMemoryError, + ReadOnlyError, + RedisError, + ResponseError, +) +from ..typing import EncodableT +from .encoders import Encoder +from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer + +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) +# user send an AUTH cmd to a server without authorization configured +NO_AUTH_SET_ERROR = { + # Redis >= 6.0 + "AUTH called without any password " + "configured for the default user. Are you sure " + "your configuration is correct?": AuthenticationError, + # Redis < 6.0 + "Client sent AUTH, but no password is set": AuthenticationError, +} + + +class BaseParser(ABC): + EXCEPTION_CLASSES = { + "ERR": { + "max number of clients reached": ConnectionError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + **NO_AUTH_SET_ERROR, + }, + "OOM": OutOfMemoryError, + "WRONGPASS": AuthenticationError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + @classmethod + def parse_error(cls, response): + "Parse an error response" + error_code = response.split(" ")[0] + if error_code in cls.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = cls.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection): + raise NotImplementedError() + + +class _RESPBase(BaseParser): + """Base class for sync-based resp parsing""" + + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + self.encoder = None + self._sock = None + self._buffer = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._sock = connection._sock + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + "Called when the socket disconnects" + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + def can_read(self, timeout): + return self._buffer and self._buffer.can_read(timeout) + + +class AsyncBaseParser(BaseParser): + """Base parsing class for the python-backed async parser""" + + __slots__ = "_stream", "_read_size" + + def __init__(self, socket_read_size: int): + self._stream: Optional[StreamReader] = None + self._read_size = socket_read_size + + async def can_read_destructive(self) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class _AsyncRESPBase(AsyncBaseParser): + """Base class for async resp parsing""" + + __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() + + def on_connect(self, connection): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + self.encoder = connection.encoder + self._clear() + self._connected = True + + def on_disconnect(self): + """Called when the stream disconnects""" + self._connected = False + + async def can_read_destructive(self) -> bool: + if not self._connected: + raise RedisError("Buffer is closed.") + if self._buffer: + return True + try: + async with async_timeout(0): + return await self._stream.read(1) + except TimeoutError: + return False + + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result diff --git a/.venv/Lib/site-packages/redis/_parsers/commands.py b/.venv/Lib/site-packages/redis/_parsers/commands.py new file mode 100644 index 00000000..b5109252 --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/commands.py @@ -0,0 +1,281 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +from redis.exceptions import RedisError, ResponseError +from redis.utils import str_if_bytes + +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + + +class AbstractCommandsParser: + def _get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + keys = None + if command == "PUBSUB": + # the second argument is a part of the command name, e.g. + # ['PUBSUB', 'NUMSUB', 'foo']. + pubsub_type = args[1].upper() + if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]: + keys = args[2:] + elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: + # format example: + # SUBSCRIBE channel [channel ...] + keys = list(args[1:]) + elif command in ["PUBLISH", "SPUBLISH"]: + # format example: + # PUBLISH channel message + keys = [args[1]] + return keys + + def parse_subcommand(self, command, **options): + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + return cmd_dict + + +class CommandsParser(AbstractCommandsParser): + """ + Parses Redis commands to get command keys. + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with + 'movablekeys', and these commands' keys are determined by the command + 'COMMAND GETKEYS'. + """ + + def __init__(self, redis_connection): + self.commands = {} + self.initialize(redis_connection) + + def initialize(self, r): + commands = r.command() + uppercase_commands = [] + for cmd in commands: + if any(x.isupper() for x in cmd): + uppercase_commands.append(cmd) + for cmd in uppercase_commands: + commands[cmd.lower()] = commands.pop(cmd) + self.commands = commands + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + def get_keys(self, redis_conn, *args): + """ + Get the keys from the passed command. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + if cmd_name not in self.commands: + # try to split the command name and to take only the main command, + # e.g. 'memory' for 'memory usage' + cmd_name_split = cmd_name.split() + cmd_name = cmd_name_split[0] + if cmd_name in self.commands: + # save the splitted command to args + args = cmd_name_split + list(args[1:]) + else: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands.get(cmd_name) + if "movablekeys" in command["flags"]: + keys = self._get_moveable_keys(redis_conn, *args) + elif "pubsub" in command["flags"] or command["name"] == "pubsub": + keys = self._get_pubsub_keys(*args) + else: + if ( + command["step_count"] == 0 + and command["first_key_pos"] == 0 + and command["last_key_pos"] == 0 + ): + is_subcmd = False + if "subcommands" in command: + subcmd_name = f"{cmd_name}|{args[1].lower()}" + for subcmd in command["subcommands"]: + if str_if_bytes(subcmd[0]) == subcmd_name: + command = self.parse_subcommand(subcmd) + is_subcmd = True + + # The command doesn't have keys in it + if not is_subcmd: + return None + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) - abs(last_key_pos) + keys_pos = list( + range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) + ) + keys = [args[pos] for pos in keys_pos] + + return keys + + def _get_moveable_keys(self, redis_conn, *args): + """ + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces = args[0].split() + list(args[1:]) + try: + keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys + + +class AsyncCommandsParser(AbstractCommandsParser): + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands", "node") + + def __init__(self) -> None: + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} + + async def initialize(self, node: Optional["ClusterNode"] = None) -> None: + if node: + self.node = node + + commands = await self.node.execute_command("COMMAND") + self.commands = {cmd.lower(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + """ + Get the keys from the passed command. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + if cmd_name not in self.commands: + # try to split the command name and to take only the main command, + # e.g. 'memory' for 'memory usage' + cmd_name_split = cmd_name.split() + cmd_name = cmd_name_split[0] + if cmd_name in self.commands: + # save the splitted command to args + args = cmd_name_split + list(args[1:]) + else: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize() + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands.get(cmd_name) + if "movablekeys" in command["flags"]: + keys = await self._get_moveable_keys(*args) + elif "pubsub" in command["flags"] or command["name"] == "pubsub": + keys = self._get_pubsub_keys(*args) + else: + if ( + command["step_count"] == 0 + and command["first_key_pos"] == 0 + and command["last_key_pos"] == 0 + ): + is_subcmd = False + if "subcommands" in command: + subcmd_name = f"{cmd_name}|{args[1].lower()}" + for subcmd in command["subcommands"]: + if str_if_bytes(subcmd[0]) == subcmd_name: + command = self.parse_subcommand(subcmd) + is_subcmd = True + + # The command doesn't have keys in it + if not is_subcmd: + return None + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) - abs(last_key_pos) + keys_pos = list( + range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) + ) + keys = [args[pos] for pos in keys_pos] + + return keys + + async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + try: + keys = await self.node.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys diff --git a/.venv/Lib/site-packages/redis/_parsers/encoders.py b/.venv/Lib/site-packages/redis/_parsers/encoders.py new file mode 100644 index 00000000..6fdf0ad8 --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/encoders.py @@ -0,0 +1,44 @@ +from ..exceptions import DataError + + +class Encoder: + "Encode strings to bytes-like and decode bytes-like to strings" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) + elif isinstance(value, (int, float)): + value = repr(value).encode() + elif not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value diff --git a/.venv/Lib/site-packages/redis/_parsers/helpers.py b/.venv/Lib/site-packages/redis/_parsers/helpers.py new file mode 100644 index 00000000..fb5da831 --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/helpers.py @@ -0,0 +1,852 @@ +import datetime + +from redis.utils import str_if_bytes + + +def timestamp_to_datetime(response): + "Converts a unix timestamp to a Python datetime object" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def parse_debug_object(response): + "Parse the results of Redis's DEBUG OBJECT command into a Python dict" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_state_resp3(response): + result = {} + for key in response: + try: + value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key])) + result[str_if_bytes(key)] = value + except Exception: + result[str_if_bytes(key)] = response[str_if_bytes(key)] + flags = set(result["flags"].split(",")) + result["flags"] = flags + return result + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_masters_resp3(response): + return [parse_sentinel_state(master) for master in response] + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_slaves_and_sentinels_resp3(response): + return [parse_sentinel_state_resp3(item) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*[response[i::n] for i in range(n)])) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xautoclaim(response, **options): + if options.get("parse_justid", False): + return response[1] + response[1] = parse_stream_list(response[1]) + return response + + +def parse_xinfo_stream(response, **options): + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} + if not options.get("full", False): + first = data.get("first-entry") + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + else: + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_zmscore(response, **options): + # zmscore: list of scores (double precision floating point number) or nil + return [float(score) if score is not None else None for score in response] + + +def parse_slowlog_get(response, **options): + space = " " if options.get("decode_responses", False) else b" " + + def parse_item(item): + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} + # Redis Enterprise injects another entry at index [3], which has + # the complexity info (i.e. the value N in case the command has + # an O(N) complexity) instead of the command. + if isinstance(item[3], list): + result["command"] = space.join(item[3]) + result["client_address"] = item[4] + result["client_name"] = item[5] + else: + result["complexity"] = item[3] + result["command"] = space.join(item[4]) + result["client_address"] = item[5] + result["client_name"] = item[6] + return result + + return [parse_item(item) for item in response] + + +def parse_stralgo(response, **options): + """ + Parse the response from `STRALGO` command. + Without modifiers the returned value is string. + When LEN is given the command returns the length of the result + (i.e integer). + When IDX is given the command returns a dictionary with the LCS + length and all the ranges in both the strings, start and end + offset for each string, where there are matches. + When WITHMATCHLEN is given, each array representing a match will + also have the length of the match at the beginning of the array. + """ + if options.get("len", False): + return int(response) + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] + else: + matches = [list(map(tuple, match)) for match in response[1]] + return { + str_if_bytes(response[0]): matches, + str_if_bytes(response[2]): int(response[3]), + } + return str_if_bytes(response) + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": [], + "migrations": [], + "connected": True if connected == "connected" else False, + } + if len(line_items) >= 9: + slots, migrations = _parse_slots(line_items[8:]) + node_dict["slots"], node_dict["migrations"] = slots, migrations + return addr, node_dict + + +def _parse_slots(slot_ranges): + slots, migrations = [], [] + for s_range in slot_ranges: + if "->-" in s_range: + slot_id, dst_node_id = s_range[1:-1].split("->-", 1) + migrations.append( + {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} + ) + elif "-<-" in s_range: + slot_id, src_node_id = s_range[1:-1].split("-<-", 1) + migrations.append( + {"slot": slot_id, "node_id": src_node_id, "state": "importing"} + ) + else: + s_range = [sl for sl in s_range.split("-")] + slots.append(s_range) + + return slots, migrations + + +def parse_cluster_nodes(response, **options): + """ + @see: https://redis.io/commands/cluster-nodes # string / bytes + @see: https://redis.io/commands/cluster-replicas # list of string / bytes + """ + if isinstance(response, (str, bytes)): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) + + +def parse_geosearch_generic(response, **options): + """ + Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' + commands according to 'withdist', 'withhash' and 'withcoord' labels. + """ + try: + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + except KeyError: # it means the command was sent via execute_command + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting function to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + commands[cmd_name] = cmd_dict + return commands + + +def parse_command_resp3(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = command[1] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + cmd_dict["acl_categories"] = command[6] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + + commands[cmd_name] = cmd_dict + return commands + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): + data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) + if data["keys"] == [""]: + data["keys"] = [] + if "channels" in data: + if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): + data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) + if data["channels"] == [""]: + data["channels"] = [] + if "selectors" in data: + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + categories.append(command) if "@" in command else commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + for info in str_if_bytes(value).strip().split(): + key, value = info.split("=") + client_info[key] = value + + # Those fields are defined as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_set_result(response, **options): + """ + Handle SET result since GET argument is available since Redis 6.2. + Parsing SET result into: + - BOOL + - String when GET argument is used + """ + if options.get("get"): + # Redis will return a getCommand result. + # See `setGenericCommand` in t_string.c + return response + return response and str_if_bytes(response) == "OK" + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +_RedisCallbacks = { + **string_keys_to_dict( + "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + bool, + ), + **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), + **string_keys_to_dict( + "ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH", + bool_ok, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict( + "GEORADIUS GEORADIUSBYMEMBER GEOSEARCH", + parse_geosearch_generic, + ), + **string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list), + "ACL GETUSER": parse_acl_getuser, + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SETUSER": bool_ok, + "ACL SAVE": bool_ok, + "CLIENT INFO": parse_client_info, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT PAUSE": bool_ok, + "CLIENT SETINFO": bool_ok, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": bool, + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER ADDSLOTSRANGE": bool_ok, + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER DELSLOTSRANGE": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "COMMAND": parse_command, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "FUNCTION DELETE": bool_ok, + "FUNCTION FLUSH": bool_ok, + "FUNCTION RESTORE": bool_ok, + "GEODIST": float_or_none, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MODULE LOAD": bool, + "MODULE UNLOAD": bool, + "PING": lambda r: str_if_bytes(r) == "PONG", + "PUBSUB NUMSUB": parse_pubsub_numsub, + "PUBSUB SHARDNUMSUB": parse_pubsub_numsub, + "QUIT": bool_ok, + "SET": parse_set_result, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SET": bool_ok, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG RESET": bool_ok, + "SORT": sort_return_tuples, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XAUTOCLAIM": parse_xautoclaim, + "XCLAIM": parse_xclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZSCAN": parse_zscan, +} + + +_RedisCallbacksRESP2 = { + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE " + "ZREVRANGEBYSCORE ZREVRANK ZUNION", + zset_score_pairs, + ), + **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL GENPASS": str_if_bytes, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "MEMORY STATS": parse_memory_stats, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "RESET": str_if_bytes, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "STRALGO": parse_stralgo, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "ZADD": parse_zadd, + "ZMSCORE": parse_zmscore, +} + + +_RedisCallbacksRESP3 = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r + ] + if isinstance(r, list) + else bool_ok(r), + "COMMAND": parse_command_resp3, + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, + "SENTINEL MASTER": parse_sentinel_state_resp3, + "SENTINEL MASTERS": parse_sentinel_masters_resp3, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} for d in r + ], +} diff --git a/.venv/Lib/site-packages/redis/_parsers/hiredis.py b/.venv/Lib/site-packages/redis/_parsers/hiredis.py new file mode 100644 index 00000000..b3247b71 --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/hiredis.py @@ -0,0 +1,217 @@ +import asyncio +import socket +import sys +from typing import Callable, List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from redis.compat import TypedDict + +from ..exceptions import ConnectionError, InvalidResponse, RedisError +from ..typing import EncodableT +from ..utils import HIREDIS_AVAILABLE +from .base import AsyncBaseParser, BaseParser +from .socket import ( + NONBLOCKING_EXCEPTION_ERROR_NUMBERS, + NONBLOCKING_EXCEPTIONS, + SENTINEL, + SERVER_CLOSED_CONNECTION_ERROR, +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class _HiredisParser(BaseParser): + "Parser class for connections using Hiredis" + + def __init__(self, socket_read_size): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + self.socket_read_size = socket_read_size + self._buffer = bytearray(socket_read_size) + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection, **kwargs): + import hiredis + + self._sock = connection._sock + self._socket_timeout = connection.socket_timeout + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + "errors": connection.encoder.encoding_errors, + } + + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + + def on_disconnect(self): + self._sock = None + self._reader = None + self._next_response = False + + def can_read(self, timeout): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + sock = self._sock + custom_timeout = timeout is not SENTINEL + try: + if custom_timeout: + sock.settimeout(timeout) + bufflen = self._sock.recv_into(self._buffer) + if bufflen == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + self._reader.feed(self._buffer, 0, bufflen) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + if custom_timeout: + sock.settimeout(self._socket_timeout) + + def read_response(self, disable_decoding=False): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + + while response is False: + self.read_from_socket() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +class _AsyncHiredisParser(AsyncBaseParser): + """Async implementation of parser class for connections using Hiredis""" + + __slots__ = ("_reader",) + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader = None + + def on_connect(self, connection): + import hiredis + + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._connected = True + + def on_disconnect(self): + self._connected = False + + async def can_read_destructive(self): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._reader.gets(): + return True + try: + async with async_timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: + return False + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response diff --git a/.venv/Lib/site-packages/redis/_parsers/resp2.py b/.venv/Lib/site-packages/redis/_parsers/resp2.py new file mode 100644 index 00000000..d5adc1a8 --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/resp2.py @@ -0,0 +1,132 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP2Parser(_RESPBase): + """RESP2 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = self._buffer.read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP2Parser(_AsyncRESPBase): + """Async class for the RESP2 protocol""" + + async def read_response(self, disable_decoding: bool = False): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = await self._read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/.venv/Lib/site-packages/redis/_parsers/resp3.py b/.venv/Lib/site-packages/redis/_parsers/resp3.py new file mode 100644 index 00000000..ad766a8f --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/resp3.py @@ -0,0 +1,259 @@ +from logging import getLogger +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP3Parser(_RESPBase): + """RESP3 protocol implementation""" + + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + def read_response(self, disable_decoding=False, push_request=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False, push_request=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = self._buffer.read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = self._buffer.read(int(response)) + # verbatim string response + elif byte == b"=": + response = self._buffer.read(int(response))[4:] + # array response + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + # we use this approach and not dict comprehension here + # because this dict comprehension fails in python 3.7 + resp_dict = {} + for _ in range(int(response)): + key = self._read_response(disable_decoding=disable_decoding) + resp_dict[key] = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + response = resp_dict + # push response + elif byte == b">": + response = [ + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func + + +class _AsyncRESP3Parser(_AsyncRESPBase): + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + async def read_response( + self, disable_decoding: bool = False, push_request: bool = False + ): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False, push_request: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # if byte not in (b"-", b"+", b":", b"$", b"*"): + # raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = await self._read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = await self._read(int(response)) + # verbatim string response + elif byte == b"=": + response = (await self._read(int(response)))[4:] + # array response + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + response = { + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) + ) + for _ in range(int(response)) + } + # push response + elif byte == b">": + response = [ + ( + await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func diff --git a/.venv/Lib/site-packages/redis/_parsers/socket.py b/.venv/Lib/site-packages/redis/_parsers/socket.py new file mode 100644 index 00000000..8147243b --- /dev/null +++ b/.venv/Lib/site-packages/redis/_parsers/socket.py @@ -0,0 +1,162 @@ +import errno +import io +import socket +from io import SEEK_END +from typing import Optional, Union + +from ..exceptions import ConnectionError, TimeoutError +from ..utils import SSL_AVAILABLE + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} + +if SSL_AVAILABLE: + import ssl + + if hasattr(ssl, "SSLWantReadError"): + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 + else: + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +SENTINEL = object() + +SYM_CRLF = b"\r\n" + + +class SocketBuffer: + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): + self._sock = socket + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos + + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: + sock = self._sock + socket_read_size = self.socket_read_size + marker = 0 + custom_timeout = timeout is not SENTINEL + + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) + try: + while True: + data = self._sock.recv(socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + marker += data_length + + if length is not None and length > marker: + continue + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + buf.seek(current_pos) + if custom_timeout: + sock.settimeout(self.socket_timeout) + + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # BufferIO will return less than requested if buffer is short + data = self._buffer.read(length) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) + return data[:-2] + + def readline(self) -> bytes: + buf = self._buffer + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + data += buf.readline() + + return data[:-2] + + def get_pos(self) -> int: + """ + Get current read position + """ + return self._buffer.tell() + + def rewind(self, pos: int) -> None: + """ + Rewind the buffer to a specific position, to re-start reading + """ + self._buffer.seek(pos) + + def purge(self) -> None: + """ + After a successful read, purge the read part of buffer + """ + unread = self.unread_bytes() + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self._buffer.seek(0) + + def close(self) -> None: + try: + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._sock = None diff --git a/.venv/Lib/site-packages/redis/asyncio/__init__.py b/.venv/Lib/site-packages/redis/asyncio/__init__.py new file mode 100644 index 00000000..3545ab44 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/__init__.py @@ -0,0 +1,64 @@ +from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.cluster import RedisCluster +from redis.asyncio.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.asyncio.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, +) +from redis.asyncio.utils import from_url +from redis.backoff import default_backoff +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + OutOfMemoryError, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) + +__all__ = [ + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "DataError", + "from_url", + "default_backoff", + "InvalidResponse", + "PubSubError", + "OutOfMemoryError", + "ReadOnlyError", + "Redis", + "RedisCluster", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", + "SSLConnection", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", +] diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..53424d16 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..f5cea005 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/cluster.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/cluster.cpython-311.pyc new file mode 100644 index 00000000..1b61c3bd Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/cluster.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/connection.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/connection.cpython-311.pyc new file mode 100644 index 00000000..2f343005 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/connection.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/lock.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/lock.cpython-311.pyc new file mode 100644 index 00000000..09b348f2 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/lock.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/retry.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/retry.cpython-311.pyc new file mode 100644 index 00000000..1b391e98 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/retry.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/sentinel.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/sentinel.cpython-311.pyc new file mode 100644 index 00000000..aebcfc5b Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/sentinel.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/redis/asyncio/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..181cab59 Binary files /dev/null and b/.venv/Lib/site-packages/redis/asyncio/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/asyncio/client.py b/.venv/Lib/site-packages/redis/asyncio/client.py new file mode 100644 index 00000000..e4d2e776 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/client.py @@ -0,0 +1,1533 @@ +import asyncio +import copy +import inspect +import re +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.asyncio.lock import Lock +from redis.asyncio.retry import Retry +from redis.client import ( + EMPTY_RESPONSE, + NEVER_DECODE, + AbstractRedis, + CaseInsensitiveDict, +) +from redis.commands import ( + AsyncCoreCommands, + AsyncRedisModuleCommands, + AsyncSentinelCommands, + list_or_args, +) +from redis.compat import Protocol, TypedDict +from redis.credentials import CredentialProvider +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.typing import ChannelT, EncodableT, KeyT +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + deprecated_function, + get_lib_version, + safe_str, + str_if_bytes, +) + +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) +if TYPE_CHECKING: + from redis.commands.core import Script + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +class Redis( + AbstractRedis, AsyncRedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands +): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + + @classmethod + def from_url( + cls, + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: Optional[bool] = None, + **kwargs, + ): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[username@]/path/to/socket.sock?db=0[&password=password] + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + client = cls( + connection_pool=connection_pool, + single_connection_client=single_connection_client, + ) + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.0. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + client.auto_close_connection_pool = auto_close_connection_pool + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + retry_on_error: Optional[list] = None, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = False, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Optional[Retry] = None, + auto_close_connection_pool: Optional[bool] = None, + redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + ): + """ + Initialize a new Redis client. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + """ + kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. It is assumed that if connection_pool is not None, the user + # wants to manage the connection pool themselves. + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.0. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + + if not connection_pool: + # Create internal connection pool, expected to be closed by Redis instance + if not retry_on_error: + retry_on_error = [] + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) + kwargs = { + "db": db, + "username": username, + "password": password, + "credential_provider": credential_provider, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry_on_error": retry_on_error, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, + "redis_connect_func": redis_connect_func, + "protocol": protocol, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_ca_data": ssl_ca_data, + "ssl_check_hostname": ssl_check_hostname, + } + ) + # This arg only used if no pool is passed in + self.auto_close_connection_pool = auto_close_connection_pool + connection_pool = ConnectionPool(**kwargs) + else: + # If a pool is passed in, do not close it + self.auto_close_connection_pool = False + + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection: Optional[Connection] = None + + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) + + # If using a single connection client, we need to lock creation-of and use-of + # the client in order to avoid race conditions such as using asyncio.gather + # on a set of redis commands + self._single_conn_lock = asyncio.Lock() + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self): + return self.initialize().__await__() + + async def initialize(self: _RedisT) -> _RedisT: + if self.single_connection_client: + async with self._single_conn_lock: + if self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def get_encoder(self): + """Get the connection pool's encoder""" + return self.connection_pool.get_encoder() + + def get_connection_kwargs(self): + """Get the connection's key-word arguments""" + return self.connection_pool.connection_kwargs + + def get_retry(self) -> Optional["Retry"]: + return self.get_connection_kwargs().get("retry") + + def set_retry(self, retry: "Retry") -> None: + self.get_connection_kwargs().update({"retry": retry}) + self.connection_pool.set_retry(retry) + + def load_external_module(self, funcname, func): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline( + self, transaction: bool = True, shard_hint: Optional[str] = None + ) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> "PubSub": + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> "Monitor": + return Monitor(self.connection_pool) + + def client(self) -> "Redis": + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.aclose() + + _DEL_MESSAGE = "Unclosed Redis client" + + # passing _warnings and _grl as argument default since they may be gone + # by the time __del__ is called at shutdown + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: + if hasattr(self, "connection") and (self.connection is not None): + _warn(f"Unclosed client session {self!r}", ResourceWarning, source=self) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + _grl().call_exception_handler(context) + except RuntimeError: + pass + + async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. + """ + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Alias for aclose(), for backwards compatibility + """ + await self.aclose(close_connection_pool) + + async def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + + async def _disconnect_raise(self, conn: Connection, error: Exception): + """ + Close the connection and raise an exception + if retry_on_error is not set or the error + is not one of the specified error types + """ + await conn.disconnect() + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + if self.single_connection_client: + await self._single_conn_lock.acquire() + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if self.single_connection_client: + self._single_conn_lock.release() + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = await connection.read_response(disable_decoding=True) + options.pop(NEVER_DECODE) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + + if EMPTY_RESPONSE in options: + options.pop(EMPTY_RESPONSE) + + if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") + command_re = re.compile(r'"(.*?)(? MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: Optional[str] = None, + ignore_subscribe_messages: bool = False, + encoder=None, + push_handler_func: Optional[Callable] = None, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = encoder + self.push_handler_func = push_handler_func + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response = [ + ["pong", self.HEALTH_CHECK_MESSAGE], + self.HEALTH_CHECK_MESSAGE, + ] + else: + self.health_check_response = [ + [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + if self.push_handler_func is None: + _set_info_logger() + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.aclose() + + def __del__(self): + if self.connection: + self.connection._deregister_connect_callback(self.on_connect) + + async def aclose(self): + # In case a connection property does not yet exist + # (due to a crash earlier in the Redis() constructor), return + # immediately as there is nothing to clean-up. + if not hasattr(self, "connection"): + return + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection._deregister_connect_callback(self.on_connect) + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset") + async def reset(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: EncodableT): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + await self.connect() + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def connect(self): + """ + Ensure that the PubSub is connected + """ + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection._register_connect_callback(self.on_connect) + else: + await self.connection.connect() + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) + + async def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + await conn.connect() + + async def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + async def parse_response(self, block: bool = True, timeout: float = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not conn.is_connected: + await conn.connect() + + read_timeout = None if block else timeout + response = await self._execute( + conn, + conn.read_response, + timeout=read_timeout, + disconnect_on_error=False, + push_request=True, + ) + + if conn.health_check_interval and response in self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if ( + conn.health_check_interval + and asyncio.get_running_loop().time() > conn.next_health_check + ): + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: ChannelT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + patterns: Iterable[ChannelT] + if args: + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + else: + parsed_args = [] + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) + else: + parsed_args = [] + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *parsed_args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = await self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number or None to wait indefinitely. + """ + response = await self.parse_response(block=(timeout is None), timeout=timeout) + if response: + return await self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) + + async def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + + await self.connect() + while True: + try: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + except asyncio.CancelledError: + raise + except BaseException as e: + if exception_handler is None: + raise + res = exception_handler(e, self) + if inspect.isawaitable(res): + await res + # Ensure that other tasks on the event loop get a chance to run + # if we didn't have to block for I/O anywhere. + await asyncio.sleep(0) + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): # lgtm [py/init-calls-subclass] + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.is_transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack: CommandStackT = [] + self.scripts: Set["Script"] = set() + self.explicit_transaction = False + + async def __aenter__(self: _RedisT) -> _RedisT: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __await__(self): + return self._async_self().__await__() + + _DEL_MESSAGE = "Unclosed Pipeline client" + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def _async_self(self): + return self + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.connection: + await self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + async def aclose(self) -> None: + """Alias for reset(), a standard method name for cleanup""" + await self.reset() + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self.explicit_transaction = True + + def execute_command( + self, *args, **kwargs + ) -> Union["Pipeline", Awaitable["Pipeline"]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.aclose() + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.aclose() + raise + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( # noqa: C901 + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) + all_cmds = connection.pack_commands( + args for args, options in cmds if EMPTY_RESPONSE not in options + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + if self.connection: + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = await super().parse_response(connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.is_transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + conn = cast(Connection, conn) + + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + await self.reset() + + async def discard(self): + """Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + await self.execute_command("DISCARD") + + async def watch(self, *names: KeyT): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True diff --git a/.venv/Lib/site-packages/redis/asyncio/cluster.py b/.venv/Lib/site-packages/redis/asyncio/cluster.py new file mode 100644 index 00000000..636144a9 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/cluster.py @@ -0,0 +1,1620 @@ +import asyncio +import collections +import random +import socket +import warnings +from typing import ( + Any, + Callable, + Deque, + Dict, + Generator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, +) +from redis.asyncio.client import ResponseCallbackT +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url +from redis.asyncio.lock import Lock +from redis.asyncio.retry import Retry +from redis.backoff import default_backoff +from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, + PRIMARY, + REPLICA, + SLOT_ID, + AbstractRedisCluster, + LoadBalancer, + block_pipeline_command, + get_node_name, + parse_cluster_slots, +) +from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.credentials import CredentialProvider +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + ConnectionError, + DataError, + MasterDownError, + MaxConnectionsError, + MovedError, + RedisClusterException, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.typing import AnyKeyT, EncodableT, KeyT +from redis.utils import ( + deprecated_function, + dict_merge, + get_lib_version, + safe_str, + str_if_bytes, +) + +TargetNodesT = TypeVar( + "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] +) + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + "MOVED": MovedError, + "TRYAGAIN": TryAgainError, + }, + ) + + +class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new RedisCluster client. + + Pass one of parameters: + + - `host` & `port` + - `startup_nodes` + + | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. + | Use ``await`` :meth:`close` to disconnect connections & close client. + + Many commands support the target_nodes kwarg. It can be one of the + :attr:`NODE_FLAGS`: + + - :attr:`PRIMARIES` + - :attr:`REPLICAS` + - :attr:`ALL_NODES` + - :attr:`RANDOM` + - :attr:`DEFAULT_NODE` + + Note: This client is not thread/process/fork safe. + + :param host: + | Can be used to point to a startup node + :param port: + | Port used if **host** is provided + :param startup_nodes: + | :class:`~.ClusterNode` to used as a startup node + :param require_full_coverage: + | When set to ``False``: the client will not require a full coverage of + the slots. However, if not all slots are covered, and at least one node + has ``cluster-require-full-coverage`` set to ``yes``, the server will throw + a :class:`~.ClusterDownError` for some key-based commands. + | When set to ``True``: all slots must be covered to construct the cluster + client. If not all slots are covered, :class:`~.RedisClusterException` will be + thrown. + | See: + https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. + :param reinitialize_steps: + | Specifies the number of MOVED errors that need to occur before reinitializing + the whole cluster topology. If a MOVED error occurs and the cluster does not + need to be reinitialized on this current error handling, only the MOVED slot + will be patched with the redirected node. + To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set reinitialize_steps to + 0. + :param cluster_error_retry_attempts: + | Number of times to retry before raising an error when :class:`~.TimeoutError` + or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered + :param connection_error_retry_attempts: + | Number of times to retry before reinitializing when :class:`~.TimeoutError` + or :class:`~.ConnectionError` are encountered. + The default backoff strategy will be set if Retry object is not passed (see + default_backoff in backoff.py). To change it, pass a custom Retry object + using the "retry" keyword. + :param max_connections: + | Maximum number of connections per node. If there are no free connections & the + maximum number of connections are already created, a + :class:`~.MaxConnectionsError` is raised. This error may be retried as defined + by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. + + | Rest of the arguments will be passed to the + :class:`~redis.asyncio.connection.Connection` instances when created + + :raises RedisClusterException: + if any arguments are invalid or unknown. Eg: + + - `db` != 0 or None + - `path` argument for unix socket connection + - none of the `host`/`port` & `startup_nodes` were provided + + """ + + @classmethod + def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": + """ + Return a Redis client object configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + + The username, password, hostname, path and all querystring values are passed + through ``urllib.parse.unquote`` in order to replace any percent-encoded values + with their corresponding characters. + + All querystring options are cast to their appropriate Python types. Boolean + arguments can be specified with string values "True"/"False" or "Yes"/"No". + Values that cannot be properly cast cause a ``ValueError`` to be raised. Once + parsed, the querystring arguments and keyword arguments are passed to + :class:`~redis.asyncio.connection.Connection` when created. + In the case of conflicting arguments, querystring arguments are used. + """ + kwargs.update(parse_url(url)) + if kwargs.pop("connection_class", None) is SSLConnection: + kwargs["ssl"] = True + return cls(**kwargs) + + __slots__ = ( + "_initialize", + "_lock", + "cluster_error_retry_attempts", + "command_flags", + "commands_parser", + "connection_error_retry_attempts", + "connection_kwargs", + "encoder", + "node_flags", + "nodes_manager", + "read_from_replicas", + "reinitialize_counter", + "reinitialize_steps", + "response_callbacks", + "result_callbacks", + ) + + def __init__( + self, + host: Optional[str] = None, + port: Union[str, int] = 6379, + # Cluster related kwargs + startup_nodes: Optional[List["ClusterNode"]] = None, + require_full_coverage: bool = True, + read_from_replicas: bool = False, + reinitialize_steps: int = 5, + cluster_error_retry_attempts: int = 3, + connection_error_retry_attempts: int = 3, + max_connections: int = 2**31, + # Client related kwargs + db: Union[str, int] = 0, + path: Optional[str] = None, + credential_provider: Optional[CredentialProvider] = None, + username: Optional[str] = None, + password: Optional[str] = None, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + # Encoding related kwargs + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + # Connection related kwargs + health_check_interval: float = 0, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_timeout: Optional[float] = None, + retry: Optional["Retry"] = None, + retry_on_error: Optional[List[Type[Exception]]] = None, + # SSL related kwargs + ssl: bool = False, + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_certfile: Optional[str] = None, + ssl_check_hostname: bool = False, + ssl_keyfile: Optional[str] = None, + protocol: Optional[int] = 2, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + ) -> None: + if db: + raise RedisClusterException( + "Argument 'db' must be 0 or None in cluster mode" + ) + + if path: + raise RedisClusterException( + "Unix domain socket is not supported in cluster mode" + ) + + if (not host or not port) and not startup_nodes: + raise RedisClusterException( + "RedisCluster requires at least one node to discover the cluster.\n" + "Please provide one of the following or use RedisCluster.from_url:\n" + ' - host and port: RedisCluster(host="localhost", port=6379)\n' + " - startup_nodes: RedisCluster(startup_nodes=[" + 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])' + ) + + kwargs: Dict[str, Any] = { + "max_connections": max_connections, + "connection_class": Connection, + "parser_class": ClusterParser, + # Client related kwargs + "credential_provider": credential_provider, + "username": username, + "password": password, + "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, + # Encoding related kwargs + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + # Connection related kwargs + "health_check_interval": health_check_interval, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + "socket_timeout": socket_timeout, + "retry": retry, + "protocol": protocol, + } + + if ssl: + # SSL related kwargs + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_ca_certs": ssl_ca_certs, + "ssl_ca_data": ssl_ca_data, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_certfile": ssl_certfile, + "ssl_check_hostname": ssl_check_hostname, + "ssl_keyfile": ssl_keyfile, + } + ) + + if read_from_replicas: + # Call our on_connect function to configure READONLY mode + kwargs["redis_connect_func"] = self.on_connect + + self.retry = retry + if retry or retry_on_error or connection_error_retry_attempts > 0: + # Set a retry object for all cluster nodes + self.retry = retry or Retry( + default_backoff(), connection_error_retry_attempts + ) + if not retry_on_error: + # Default errors for retrying + retry_on_error = [ConnectionError, TimeoutError] + self.retry.update_supported_errors(retry_on_error) + kwargs.update({"retry": self.retry}) + + kwargs["response_callbacks"] = _RedisCallbacks.copy() + if kwargs.get("protocol") in ["3", 3]: + kwargs["response_callbacks"].update(_RedisCallbacksRESP3) + else: + kwargs["response_callbacks"].update(_RedisCallbacksRESP2) + self.connection_kwargs = kwargs + + if startup_nodes: + passed_nodes = [] + for node in startup_nodes: + passed_nodes.append( + ClusterNode(node.host, node.port, **self.connection_kwargs) + ) + startup_nodes = passed_nodes + else: + startup_nodes = [] + if host and port: + startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.read_from_replicas = read_from_replicas + self.reinitialize_steps = reinitialize_steps + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.connection_error_retry_attempts = connection_error_retry_attempts + self.reinitialize_counter = 0 + self.commands_parser = AsyncCommandsParser() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.response_callbacks = kwargs["response_callbacks"] + self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() + self.result_callbacks[ + "CLUSTER SLOTS" + ] = lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs + ) + + self._initialize = True + self._lock: Optional[asyncio.Lock] = None + + async def initialize(self) -> "RedisCluster": + """Get all nodes from startup nodes & creates connections if not initialized.""" + if self._initialize: + if not self._lock: + self._lock = asyncio.Lock() + async with self._lock: + if self._initialize: + try: + await self.nodes_manager.initialize() + await self.commands_parser.initialize( + self.nodes_manager.default_node + ) + self._initialize = False + except BaseException: + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") + raise + return self + + async def aclose(self) -> None: + """Close all connections & client if initialized.""" + if not self._initialize: + if not self._lock: + self._lock = asyncio.Lock() + async with self._lock: + if not self._initialize: + self._initialize = True + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """alias for aclose() for backwards compatibility""" + await self.aclose() + + async def __aenter__(self) -> "RedisCluster": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + await self.aclose() + + def __await__(self) -> Generator[Any, None, "RedisCluster"]: + return self.initialize().__await__() + + _DEL_MESSAGE = "Unclosed RedisCluster client" + + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: + if hasattr(self, "_initialize") and not self._initialize: + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + _grl().call_exception_handler(context) + except RuntimeError: + pass + + async def on_connect(self, connection: Connection) -> None: + await connection.on_connect() + + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response()) != "OK": + raise ConnectionError("READONLY command failed") + + def get_nodes(self) -> List["ClusterNode"]: + """Get all nodes of the cluster.""" + return list(self.nodes_manager.nodes_cache.values()) + + def get_primaries(self) -> List["ClusterNode"]: + """Get the primary nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self) -> List["ClusterNode"]: + """Get the replica nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self) -> "ClusterNode": + """Get a random node of the cluster.""" + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_default_node(self) -> "ClusterNode": + """Get the default node of the client.""" + return self.nodes_manager.default_node + + def set_default_node(self, node: "ClusterNode") -> None: + """ + Set the default node of the client. + + :raises DataError: if None is passed or node does not exist in cluster. + """ + if not node or not self.get_node(node_name=node.name): + raise DataError("The requested node does not exist in the cluster.") + + self.nodes_manager.default_node = node + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional["ClusterNode"]: + """Get node by (host, port) or node_name.""" + return self.nodes_manager.get_node(host, port, node_name) + + def get_node_from_key( + self, key: str, replica: bool = False + ) -> Optional["ClusterNode"]: + """ + Get the cluster node corresponding to the provided key. + + :param key: + :param replica: + | Indicates if a replica should be returned + | + None will returned if no replica holds this key + + :raises SlotNotCoveredError: if the key is not covered by any slot. + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if not slot_cache: + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') + + if replica: + if len(self.nodes_manager.slots_cache[slot]) < 2: + return None + node_idx = 1 + else: + node_idx = 0 + + return slot_cache[node_idx] + + def keyslot(self, key: EncodableT) -> int: + """ + Find the keyslot for a given key. + + See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding + """ + return key_slot(self.encoder.encode(key)) + + def get_encoder(self) -> Encoder: + """Get the encoder object of the client.""" + return self.encoder + + def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: + """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" + return self.connection_kwargs + + def get_retry(self) -> Optional["Retry"]: + return self.retry + + def set_retry(self, retry: "Retry") -> None: + self.retry = retry + for node in self.get_nodes(): + node.connection_kwargs.update({"retry": retry}) + for conn in node._connections: + conn.retry = retry + + def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: + """Set a custom response callback.""" + self.response_callbacks[command] = callback + + async def _determine_nodes( + self, command: str, *args: Any, node_flag: Optional[str] = None + ) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. + if not node_flag: + # get the nodes group for this command if it was predefined + node_flag = self.command_flags.get(command) + + if node_flag in self.node_flags: + if node_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + if node_flag == self.__class__.PRIMARIES: + # return all primaries + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + if node_flag == self.__class__.REPLICAS: + # return all replicas + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + if node_flag == self.__class__.ALL_NODES: + # return all nodes + return list(self.nodes_manager.nodes_cache.values()) + if node_flag == self.__class__.RANDOM: + # return a random node + return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + + # get the node that holds the key's slot + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(command, *args), + self.read_from_replicas and command in READ_COMMANDS, + ) + ] + + async def _determine_slot(self, command: str, *args: Any) -> int: + if self.command_flags.get(command) == SLOT_ID: + # The command contains the slot ID + return int(args[0]) + + # Get the keys in the command + + # EVAL and EVALSHA are common enough that it's wasteful to go to the + # redis server to parse the keys. Besides, there is a bug in redis<7.0 + # where `self._get_command_keys()` fails anyway. So, we special case + # EVAL/EVALSHA. + # - issue: https://github.com/redis/redis/issues/9493 + # - fix: https://github.com/redis/redis/pull/9733 + if command.upper() in ("EVAL", "EVALSHA"): + # command syntax: EVAL "script body" num_keys ... + if len(args) < 2: + raise RedisClusterException( + f"Invalid args in command: {command, *args}" + ) + keys = args[2 : 2 + int(args[1])] + # if there are 0 keys, that means the script can be run on any node + # so we can just return a random slot + if not keys: + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + else: + keys = await self.commands_parser.get_keys(command, *args) + if not keys: + # FCALL can call a function with 0 keys, that means the function + # can be run on any node so we can just return a random slot + if command.upper() in ("FCALL", "FCALL_RO"): + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + f"target nodes.\nCommand: {args}" + ) + + # single key command + if len(keys) == 1: + return self.keyslot(keys[0]) + + # multi-key command; we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException( + f"{command} - all keys must map to the same key slot" + ) + + return slots.pop() + + def _is_node_flag(self, target_nodes: Any) -> bool: + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = list(target_nodes.values()) + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or dict. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: + """ + Execute a raw command on the appropriate cluster node or target_nodes. + + It will retry the command as specified by :attr:`cluster_error_retry_attempts` & + then raise an exception. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + command = args[0] + target_nodes = [] + target_nodes_specified = False + retry_attempts = self.cluster_error_retry_attempts + + passed_targets = kwargs.pop("target_nodes", None) + if passed_targets and not self._is_node_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + retry_attempts = 0 + + # Add one for the first execution + execute_attempts = 1 + retry_attempts + for _ in range(execute_attempts): + if self._initialize: + await self.initialize() + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + # Replace the default cluster node + self.replace_default_node() + try: + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = await self._determine_nodes( + *args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {args} command on" + ) + + if len(target_nodes) == 1: + # Return the processed result + ret = await self._execute_command(target_nodes[0], *args, **kwargs) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, {target_nodes[0].name: ret}, **kwargs + ) + return ret + else: + keys = [node.name for node in target_nodes] + values = await asyncio.gather( + *( + asyncio.create_task( + self._execute_command(node, *args, **kwargs) + ) + for node in target_nodes + ) + ) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, dict(zip(keys, values)), **kwargs + ) + return dict(zip(keys, values)) + except Exception as e: + if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # The nodes and slots cache were should be reinitialized. + # Try again with the new cluster setup. + retry_attempts -= 1 + continue + else: + # raise the exception + raise e + + async def _execute_command( + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> Any: + asking = moved = False + redirect_addr = None + ttl = self.RedisClusterRequestTTL + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + await target_node.execute_command("ASKING") + asking = False + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = await self._determine_slot(*args) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and args[0] in READ_COMMANDS + ) + moved = False + + return await target_node.execute_command(*args, **kwargs) + except (BusyLoadingError, MaxConnectionsError): + raise + except (ConnectionError, TimeoutError): + # Connection retries are being handled in the node's + # Retry object. + # Remove the failed node from the startup nodes before we try + # to reinitialize the cluster + self.nodes_manager.startup_nodes.pop(target_node.name, None) + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + await self.aclose() + raise + except ClusterDownError: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + await self.aclose() + await asyncio.sleep(0.25) + raise + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when + # the same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + self.reinitialize_counter += 1 + if ( + self.reinitialize_steps + and self.reinitialize_counter % self.reinitialize_steps == 0 + ): + await self.aclose() + # Reset the counter + self.reinitialize_counter = 0 + else: + self.nodes_manager._moved_exception = e + moved = True + except AskError as e: + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + await asyncio.sleep(0.05) + + raise ClusterError("TTL exhausted.") + + def pipeline( + self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None + ) -> "ClusterPipeline": + """ + Create & return a new :class:`~.ClusterPipeline` object. + + Cluster implementation of pipeline does not support transaction or shard_hint. + + :raises RedisClusterException: if transaction or shard_hint are truthy values + """ + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline(self) + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + +class ClusterNode: + """ + Create a new ClusterNode. + + Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` + objects for the (host, port). + """ + + __slots__ = ( + "_connections", + "_free", + "connection_class", + "connection_kwargs", + "host", + "max_connections", + "name", + "port", + "response_callbacks", + "server_type", + ) + + def __init__( + self, + host: str, + port: Union[str, int], + server_type: Optional[str] = None, + *, + max_connections: int = 2**31, + connection_class: Type[Connection] = Connection, + **connection_kwargs: Any, + ) -> None: + if host == "localhost": + host = socket.gethostbyname(host) + + connection_kwargs["host"] = host + connection_kwargs["port"] = port + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + + self.max_connections = max_connections + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.response_callbacks = connection_kwargs.pop("response_callbacks", {}) + + self._connections: List[Connection] = [] + self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + + def __repr__(self) -> str: + return ( + f"[host={self.host}, port={self.port}, " + f"name={self.name}, server_type={self.server_type}]" + ) + + def __eq__(self, obj: Any) -> bool: + return isinstance(obj, ClusterNode) and obj.name == self.name + + _DEL_MESSAGE = "Unclosed ClusterNode object" + + def __del__( + self, + _warn: Any = warnings.warn, + _grl: Any = asyncio.get_running_loop, + ) -> None: + for connection in self._connections: + if connection.is_connected: + _warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) + + try: + context = {"client": self, "message": self._DEL_MESSAGE} + _grl().call_exception_handler(context) + except RuntimeError: + pass + break + + async def disconnect(self) -> None: + ret = await asyncio.gather( + *( + asyncio.create_task(connection.disconnect()) + for connection in self._connections + ), + return_exceptions=True, + ) + exc = next((res for res in ret if isinstance(res, Exception)), None) + if exc: + raise exc + + def acquire_connection(self) -> Connection: + try: + return self._free.popleft() + except IndexError: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + raise MaxConnectionsError() + + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: + try: + if NEVER_DECODE in kwargs: + response = await connection.read_response(disable_decoding=True) + kwargs.pop(NEVER_DECODE) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in kwargs: + return kwargs[EMPTY_RESPONSE] + raise + + if EMPTY_RESPONSE in kwargs: + kwargs.pop(EMPTY_RESPONSE) + + # Return response + if command in self.response_callbacks: + return self.response_callbacks[command](response, **kwargs) + + return response + + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) + + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) + + async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command( + connection.pack_commands(cmd.args for cmd in commands), False + ) + + # Read responses + ret = False + for cmd in commands: + try: + cmd.result = await self.parse_response( + connection, cmd.args[0], **cmd.kwargs + ) + except Exception as e: + cmd.result = e + ret = True + + # Release connection + self._free.append(connection) + + return ret + + +class NodesManager: + __slots__ = ( + "_moved_exception", + "connection_kwargs", + "default_node", + "nodes_cache", + "read_load_balancer", + "require_full_coverage", + "slots_cache", + "startup_nodes", + "address_remap", + ) + + def __init__( + self, + startup_nodes: List["ClusterNode"], + require_full_coverage: bool, + connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + ) -> None: + self.startup_nodes = {node.name: node for node in startup_nodes} + self.require_full_coverage = require_full_coverage + self.connection_kwargs = connection_kwargs + self.address_remap = address_remap + + self.default_node: "ClusterNode" = None + self.nodes_cache: Dict[str, "ClusterNode"] = {} + self.slots_cache: Dict[int, List["ClusterNode"]] = {} + self.read_load_balancer = LoadBalancer() + self._moved_exception: MovedError = None + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional["ClusterNode"]: + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + raise DataError( + "get_node requires one of the following: " + "1. node name " + "2. host and port" + ) + + def set_nodes( + self, + old: Dict[str, "ClusterNode"], + new: Dict[str, "ClusterNode"], + remove_old: bool = False, + ) -> None: + if remove_old: + for name in list(old.keys()): + if name not in new: + task = asyncio.create_task(old.pop(name).disconnect()) # noqa + + for name, node in new.items(): + if name in old: + if old[name] is node: + continue + task = asyncio.create_task(old[name].disconnect()) # noqa + old[name] = node + + def _update_moved_slots(self) -> None: + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node: + # The node already exists + if redirected_node.server_type != PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode( + e.host, e.port, PRIMARY, **self.connection_kwargs + ) + self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot( + self, slot: int, read_from_replicas: bool = False + ) -> "ClusterNode": + if self._moved_exception: + self._update_moved_slots() + + try: + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + return self.slots_cache[slot][node_idx] + return self.slots_cache[slot][0] + except (IndexError, TypeError): + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + f'"require_full_coverage={self.require_full_coverage}"' + ) + + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + async def initialize(self) -> None: + self.read_load_balancer.reset() + tmp_nodes_cache: Dict[str, "ClusterNode"] = {} + tmp_slots: Dict[int, List["ClusterNode"]] = {} + disagreements = [] + startup_nodes_reachable = False + fully_covered = False + exception = None + for startup_node in self.startup_nodes.values(): + try: + # Make sure cluster mode is enabled on this node + if not (await startup_node.execute_command("INFO")).get( + "cluster_enabled" + ): + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") + startup_nodes_reachable = True + except Exception as e: + # Try the next startup node. + # The exception is saved and raised only if we have no more nodes. + exception = e + continue + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and not cluster_slots[0][2][0] + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + for i in range(2, len(slot)): + slot[i] = [str_if_bytes(val) for val in slot[i]] + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if not target_node: + target_node = ClusterNode( + host, port, PRIMARY, **self.connection_kwargs + ) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + host, port = self.remap_host_port(host, port) + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" + ) + + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f'slots cache: {", ".join(disagreements)}' + ) + + # Validate if all slots are covered or if we should try next startup node + fully_covered = True + for i in range(REDIS_CLUSTER_HASH_SLOTS): + if i not in tmp_slots: + fully_covered = False + break + if fully_covered: + break + + if not startup_nodes_reachable: + raise RedisClusterException( + f"Redis Cluster cannot be connected. Please provide at least " + f"one reachable node: {str(exception)}" + ) from exception + + # Check if the slots are not fully covered + if not fully_covered and self.require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) + + # Set the tmp variables to the real variables + self.slots_cache = tmp_slots + self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + # If initialize was called after a MovedError, clear it + self._moved_exception = None + + async def aclose(self, attr: str = "nodes_cache") -> None: + self.default_node = None + await asyncio.gather( + *( + asyncio.create_task(node.disconnect()) + for node in getattr(self, attr).values() + ) + ) + + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + + +class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new ClusterPipeline object. + + Usage:: + + result = await ( + rc.pipeline() + .set("A", 1) + .get("A") + .hset("K", "F", "V") + .hgetall("K") + .mset_nonatomic({"A": 2, "B": 3}) + .get("A") + .get("B") + .delete("A", "B", "K") + .execute() + ) + # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] + + Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which + are split across multiple nodes, you'll get multiple results for them in the array. + + Retryable errors: + - :class:`~.ClusterDownError` + - :class:`~.ConnectionError` + - :class:`~.TimeoutError` + + Redirection errors: + - :class:`~.TryAgainError` + - :class:`~.MovedError` + - :class:`~.AskError` + + :param client: + | Existing :class:`~.RedisCluster` client + """ + + __slots__ = ("_command_stack", "_client") + + def __init__(self, client: RedisCluster) -> None: + self._client = client + + self._command_stack: List["PipelineCommand"] = [] + + async def initialize(self) -> "ClusterPipeline": + if self._client._initialize: + await self._client.initialize() + self._command_stack = [] + return self + + async def __aenter__(self) -> "ClusterPipeline": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_stack = [] + + def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: + return self.initialize().__await__() + + def __enter__(self) -> "ClusterPipeline": + self._command_stack = [] + return self + + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_stack = [] + + def __bool__(self) -> bool: + return bool(self._command_stack) + + def __len__(self) -> int: + return len(self._command_stack) + + def execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> "ClusterPipeline": + """ + Append a raw command to the pipeline. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + """ + self._command_stack.append( + PipelineCommand(len(self._command_stack), *args, **kwargs) + ) + return self + + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + """ + Execute the pipeline. + + It will retry the commands as specified by :attr:`cluster_error_retry_attempts` + & then raise an exception. + + :param raise_on_error: + | Raise the first error if there are any errors + :param allow_redirections: + | Whether to retry each failed command individually in case of redirection + errors + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + if not self._command_stack: + return [] + + try: + for _ in range(self._client.cluster_error_retry_attempts): + if self._client._initialize: + await self._client.initialize() + + try: + return await self._execute( + self._client, + self._command_stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # Try again with the new cluster setup. + exception = e + await self._client.aclose() + await asyncio.sleep(0.25) + else: + # All other errors should be raised. + raise + + # If it fails the configured number of times then raise an exception + raise exception + finally: + self._command_stack = [] + + async def _execute( + self, + client: "RedisCluster", + stack: List["PipelineCommand"], + raise_on_error: bool = True, + allow_redirections: bool = True, + ) -> List[Any]: + todo = [ + cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception) + ] + + nodes = {} + for cmd in todo: + passed_targets = cmd.kwargs.pop("target_nodes", None) + if passed_targets and not client._is_node_flag(passed_targets): + target_nodes = client._parse_target_nodes(passed_targets) + else: + target_nodes = await client._determine_nodes( + *cmd.args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {cmd.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException(f"Too many targets for command {cmd.args}") + node = target_nodes[0] + if node.name not in nodes: + nodes[node.name] = (node, []) + nodes[node.name][1].append(cmd) + + errors = await asyncio.gather( + *( + asyncio.create_task(node[0].execute_pipeline(node[1])) + for node in nodes.values() + ) + ) + + if any(errors): + if allow_redirections: + # send each errored command individually + for cmd in todo: + if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): + try: + cmd.result = await client.execute_command( + *cmd.args, **cmd.kwargs + ) + except Exception as e: + cmd.result = e + + if raise_on_error: + for cmd in todo: + result = cmd.result + if isinstance(result, Exception): + command = " ".join(map(safe_str, cmd.args)) + msg = ( + f"Command # {cmd.position + 1} ({command}) of pipeline " + f"caused error: {result.args}" + ) + result.args = (msg,) + result.args[1:] + raise result + + default_node = nodes.get(client.get_default_node().name) + if default_node is not None: + # This pipeline execution used the default node, check if we need + # to replace it. + # Note: when the error is raised we'll reset the default node in the + # caller function. + for cmd in default_node[1]: + # Check if it has a command that failed with a relevant + # exception + if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + break + + return [cmd.result for cmd in stack] + + def _split_command_across_slots( + self, command: str, *keys: KeyT + ) -> "ClusterPipeline": + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + self.execute_command(command, *slot_keys) + + return self + + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + encoder = self._client.encoder + + slots_pairs = {} + for pair in mapping.items(): + slot = key_slot(encoder.encode(pair[0])) + slots_pairs.setdefault(slot, []).extend(pair) + + for pairs in slots_pairs.values(): + self.execute_command("MSET", *pairs) + + return self + + +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.position = position + self.result: Union[Any, Exception] = None + + def __repr__(self) -> str: + return f"[{self.position}] {self.args} ({self.kwargs})" diff --git a/.venv/Lib/site-packages/redis/asyncio/connection.py b/.venv/Lib/site-packages/redis/asyncio/connection.py new file mode 100644 index 00000000..65fa5864 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/connection.py @@ -0,0 +1,1180 @@ +import asyncio +import copy +import enum +import inspect +import socket +import ssl +import sys +import weakref +from abc import abstractmethod +from itertools import chain +from types import MappingProxyType +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from urllib.parse import ParseResult, parse_qs, unquote, urlparse + +# the functionality is available in 3.11.x but has a major issue before +# 3.11.3. See https://github.com/redis/redis-py/issues/2633 +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.compat import Protocol, TypedDict +from redis.connection import DEFAULT_RESP_VERSION +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + ConnectionError, + DataError, + RedisError, + ResponseError, + TimeoutError, +) +from redis.typing import EncodableT +from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes + +from .._parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) + +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_LF = b"\n" +SYM_EMPTY = b"" + + +class _Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = _Sentinel.sentinel + + +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = _AsyncHiredisParser +else: + DefaultParser = _AsyncRESP2Parser + + +class ConnectCallbackProtocol(Protocol): + def __call__(self, connection: "AbstractConnection"): + ... + + +class AsyncConnectCallbackProtocol(Protocol): + async def __call__(self, connection: "AbstractConnection"): + ... + + +ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] + + +class AbstractConnection: + """Manages communication to and from a Redis server""" + + __slots__ = ( + "db", + "username", + "client_name", + "lib_name", + "lib_version", + "credential_provider", + "password", + "socket_timeout", + "socket_connect_timeout", + "redis_connect_func", + "retry_on_timeout", + "retry_on_error", + "health_check_interval", + "next_health_check", + "last_active_at", + "encoder", + "ssl_context", + "protocol", + "_reader", + "_writer", + "_parser", + "_connect_callbacks", + "_buffer_cutoff", + "_lock", + "_socket_read_size", + "__dict__", + ) + + def __init__( + self, + *, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error: Union[list, _Sentinel] = SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[ConnectCallbackT] = None, + encoder_class: Type[Encoder] = Encoder, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + ): + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) + self.db = db + self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version + self.credential_provider = credential_provider + self.password = password + self.username = username + self.socket_timeout = socket_timeout + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout + self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] + if retry_on_timeout: + retry_on_error.append(TimeoutError) + retry_on_error.append(socket.timeout) + retry_on_error.append(asyncio.TimeoutError) + self.retry_on_error = retry_on_error + if retry or retry_on_error: + if not retry: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(retry_on_error) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check: float = -1 + self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] + self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + self.protocol = protocol + + def __repr__(self): + repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) + return f"{self.__class__.__name__}<{repr_args}>" + + @abstractmethod + def repr_pieces(self): + pass + + @property + def is_connected(self): + return self._reader is not None and self._writer is not None + + def _register_connect_callback(self, callback): + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) + + def _deregister_connect_callback(self, callback): + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass + + def set_parser(self, parser_class: Type[BaseParser]) -> None: + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + async def connect(self): + """Connects to the Redis server if not already connected""" + if self.is_connected: + return + try: + await self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect() + ) + except asyncio.CancelledError: + raise # in 3.7 and earlier, this is an Exception, not BaseException + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + if not self.redis_connect_func: + # Use the default on_connect function + await self.on_connect() + else: + # Use the passed function redis_connect_func + await self.redis_connect_func(self) if asyncio.iscoroutinefunction( + self.redis_connect_func + ) else self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] + for ref in self._connect_callbacks: + callback = ref() + task = callback(self) + if task and inspect.isawaitable(task): + await task + + @abstractmethod + async def _connect(self): + pass + + @abstractmethod + def _host_error(self) -> str: + pass + + @abstractmethod + def _error_message(self, exception: BaseException) -> str: + pass + + async def on_connect(self) -> None: + """Initialize the connection, authenticate and select a database""" + self._parser.on_connect(self) + parser = self._parser + + auth_args = None + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + elif auth_args: + await self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = await self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + await self.send_command("AUTH", auth_args[-1], check_health=False) + auth_response = await self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + + # if a client_name is given, set it + if self.client_name: + await self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + # set the library name and version, pipeline for lower startup latency + if self.lib_name: + await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + if self.lib_version: + await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + # if a database is specified, switch to it. Also pipeline this + if self.db: + await self.send_command("SELECT", self.db) + + # read responses from pipeline + for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): + try: + await self.read_response() + except ResponseError: + pass + + if self.db: + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + async def disconnect(self, nowait: bool = False) -> None: + """Disconnects from the Redis server""" + try: + async with async_timeout(self.socket_connect_timeout): + self._parser.on_disconnect() + if not self.is_connected: + return + try: + self._writer.close() # type: ignore[union-attr] + # wait for close to finish, except when handling errors and + # forcefully disconnecting. + if not nowait: + await self._writer.wait_closed() # type: ignore[union-attr] + except OSError: + pass + finally: + self._reader = None + self._writer = None + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out closing connection after {self.socket_connect_timeout}" + ) from None + + async def _send_ping(self): + """Send PING, expect PONG in return""" + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + async def _ping_failed(self, error): + """Function to call when PING fails""" + await self.disconnect() + + async def check_health(self): + """Check the health of the connection with a PING/PONG""" + if ( + self.health_check_interval + and asyncio.get_running_loop().time() > self.next_health_check + ): + await self.retry.call_with_retry(self._send_ping, self._ping_failed) + + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + self._writer.writelines(command) + await self._writer.drain() + + async def send_packed_command( + self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + ) -> None: + if not self.is_connected: + await self.connect() + elif check_health: + await self.check_health() + + try: + if isinstance(command, str): + command = command.encode() + if isinstance(command, bytes): + command = [command] + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(command), self.socket_timeout + ) + else: + self._writer.writelines(command) + await self._writer.drain() + except asyncio.TimeoutError: + await self.disconnect(nowait=True) + raise TimeoutError("Timeout writing to socket") from None + except OSError as e: + await self.disconnect(nowait=True) + if len(e.args) == 1: + err_no, errmsg = "UNKNOWN", e.args[0] + else: + err_no = e.args[0] + errmsg = e.args[1] + raise ConnectionError( + f"Error {err_no} while writing to socket. {errmsg}." + ) from e + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. + await self.disconnect(nowait=True) + raise + + async def send_command(self, *args: Any, **kwargs: Any) -> None: + """Pack and send a command to the Redis server""" + await self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) + + async def can_read_destructive(self): + """Poll the socket to see if there's data that can be read.""" + try: + return await self._parser.can_read_destructive() + except OSError as e: + await self.disconnect(nowait=True) + host_error = self._host_error() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") + + async def read_response( + self, + disable_decoding: bool = False, + timeout: Optional[float] = None, + *, + disconnect_on_error: bool = True, + push_request: Optional[bool] = False, + ): + """Read the response from a previously sent command""" + read_timeout = timeout if timeout is not None else self.socket_timeout + host_error = self._host_error() + try: + if ( + read_timeout is not None + and self.protocol in ["3", 3] + and not HIREDIS_AVAILABLE + ): + async with async_timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + elif read_timeout is not None: + async with async_timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + if timeout is not None: + # user requested timeout, return None. Operation can be retried + return None + # it was a self.socket_timeout error. + if disconnect_on_error: + await self.disconnect(nowait=True) + raise TimeoutError(f"Timeout reading from {host_error}") + except OSError as e: + if disconnect_on_error: + await self.disconnect(nowait=True) + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) + raise + + if self.health_check_interval: + next_time = asyncio.get_running_loop().time() + self.health_check_interval + self.next_health_check = next_time + + if isinstance(response, ResponseError): + raise response from None + return response + + def pack_command(self, *args: EncodableT) -> List[bytes]: + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + assert not isinstance(args[0], float) + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encoder.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: + """Pack multiple commands into the Redis protocol""" + output: List[bytes] = [] + pieces: List[bytes] = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self.pack_command(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + if pieces: + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connection_arguments(self) -> Mapping: + return {"host": self.host, "port": self.port} + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + **self._connection_arguments() + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _host_error(self) -> str: + return f"{self.host}:{self.port}" + + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return f"Error connecting to {host_error}. Connection reset by peer" + elif len(exception.args) == 1: + return f"Error connecting to {host_error}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[0]}." + ) + + +class SSLConnection(Connection): + """Manages SSL connections to and from the Redis server(s). + This class extends the Connection class, adding SSL functionality, and making + use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) + """ + + def __init__( + self, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_ca_data: Optional[str] = None, + ssl_check_hostname: bool = False, + **kwargs, + ): + self.ssl_context: RedisSSLContext = RedisSSLContext( + keyfile=ssl_keyfile, + certfile=ssl_certfile, + cert_reqs=ssl_cert_reqs, + ca_certs=ssl_ca_certs, + ca_data=ssl_ca_data, + check_hostname=ssl_check_hostname, + ) + super().__init__(**kwargs) + + def _connection_arguments(self) -> Mapping: + kwargs = super()._connection_arguments() + kwargs["ssl"] = self.ssl_context.get() + return kwargs + + @property + def keyfile(self): + return self.ssl_context.keyfile + + @property + def certfile(self): + return self.ssl_context.certfile + + @property + def cert_reqs(self): + return self.ssl_context.cert_reqs + + @property + def ca_certs(self): + return self.ssl_context.ca_certs + + @property + def ca_data(self): + return self.ssl_context.ca_data + + @property + def check_hostname(self): + return self.ssl_context.check_hostname + + +class RedisSSLContext: + __slots__ = ( + "keyfile", + "certfile", + "cert_reqs", + "ca_certs", + "ca_data", + "context", + "check_hostname", + ) + + def __init__( + self, + keyfile: Optional[str] = None, + certfile: Optional[str] = None, + cert_reqs: Optional[str] = None, + ca_certs: Optional[str] = None, + ca_data: Optional[str] = None, + check_hostname: bool = False, + ): + self.keyfile = keyfile + self.certfile = certfile + if cert_reqs is None: + self.cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if cert_reqs not in CERT_REQS: + raise RedisError( + f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" + ) + self.cert_reqs = CERT_REQS[cert_reqs] + self.ca_certs = ca_certs + self.ca_data = ca_data + self.check_hostname = check_hostname + self.context: Optional[ssl.SSLContext] = None + + def get(self) -> ssl.SSLContext: + if not self.context: + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.ca_certs or self.ca_data: + context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) + self.context = context + return self.context + + +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, *, path: str = "", **kwargs): + self.path = path + super().__init__(**kwargs) + + def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: + pieces = [("path", self.path), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + async def _connect(self): + async with async_timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_unix_connection(path=self.path) + self._reader = reader + self._writer = writer + await self.on_connect() + + def _host_error(self) -> str: + return self.path + + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + host_error = self._host_error() + if len(exception.args) == 1: + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{host_error}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value) -> Optional[bool]: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) + + +class ConnectKwargs(TypedDict, total=False): + username: str + password: str + connection_class: Type[AbstractConnection] + host: str + port: int + db: int + path: str + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support redis://, rediss:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("redis", "rediss"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "redis://, rediss://, unix://" + raise ValueError( + f"Redis URL must specify one of the following schemes ({valid_schemes})" + ) + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[username@]/path/to/socket.sock?db=0[&password=password] + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class: Type[AbstractConnection] = Connection, + max_connections: Optional[int] = None, + **connection_kwargs, + ): + max_connections = max_connections or 2**31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + self._available_connections: List[AbstractConnection] = [] + self._in_use_connections: Set[AbstractConnection] = set() + self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" + ) + + def reset(self): + self._available_connections = [] + self._in_use_connections = set() + + def can_get_connection(self) -> bool: + """Return True if a connection can be retrieved from the pool.""" + return ( + self._available_connections + or len(self._in_use_connections) < self.max_connections + ) + + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" + try: + connection = self._available_connections.pop() + except IndexError: + if len(self._in_use_connections) >= self.max_connections: + raise ConnectionError("Too many connections") from None + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + await self.ensure_connection(connection) + except BaseException: + await self.release(connection) + raise + + return connection + + def get_encoder(self): + """Return an encoder based on encoding settings""" + kwargs = self.connection_kwargs + return self.encoder_class( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), + ) + + def make_connection(self): + """Create a new connection. Can be overridden by child classes.""" + return self.connection_class(**self.connection_kwargs) + + async def ensure_connection(self, connection: AbstractConnection): + """Ensure that the connection object is connected and valid""" + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read_destructive(): + raise ConnectionError("Connection has data") from None + except (ConnectionError, OSError): + await connection.disconnect() + await connection.connect() + if await connection.can_read_destructive(): + raise ConnectionError("Connection not ready") from None + + async def release(self, connection: AbstractConnection): + """Releases the connection back to the pool""" + # Connections should always be returned to the correct pool, + # not doing so is an error that will cause an exception here. + self._in_use_connections.remove(connection) + self._available_connections.append(connection) + + async def disconnect(self, inuse_connections: bool = True): + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other tasks. Otherwise only disconnect + connections that are idle in the pool. + """ + if inuse_connections: + connections: Iterable[AbstractConnection] = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + async def aclose(self) -> None: + """Close the pool, disconnecting all connections""" + await self.disconnect() + + def set_retry(self, retry: "Retry") -> None: + for conn in self._available_connections: + conn.retry = retry + for conn in self._in_use_connections: + conn.retry = retry + + +class BlockingConnectionPool(ConnectionPool): + """ + A blocking connection pool:: + + >>> from redis.asyncio import Redis, BlockingConnectionPool + >>> client = Redis.from_pool(BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple async redis clients. + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.asyncio.ConnectionPool` implementation does), it + makes blocks the current `Task` for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections: int = 50, + timeout: Optional[int] = 20, + connection_class: Type[AbstractConnection] = Connection, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated + **connection_kwargs, + ): + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + self._condition = asyncio.Condition() + self.timeout = timeout + + async def get_connection(self, command_name, *keys, **options): + """Gets a connection from the pool, blocking until one is available""" + try: + async with async_timeout(self.timeout): + async with self._condition: + await self._condition.wait_for(self.can_get_connection) + return await super().get_connection(command_name, *keys, **options) + except asyncio.TimeoutError as err: + raise ConnectionError("No connection available.") from err + + async def release(self, connection: AbstractConnection): + """Releases the connection back to the pool.""" + async with self._condition: + await super().release(connection) + self._condition.notify() diff --git a/.venv/Lib/site-packages/redis/asyncio/lock.py b/.venv/Lib/site-packages/redis/asyncio/lock.py new file mode 100644 index 00000000..e1d11a88 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/lock.py @@ -0,0 +1,313 @@ +import asyncio +import threading +import uuid +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, Optional, Union + +from redis.exceptions import LockError, LockNotOwnedError + +if TYPE_CHECKING: + from redis.asyncio import Redis, RedisCluster + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis: Union["Redis", "RedisCluster"], + name: Union[str, bytes, memoryview], + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock in seconds. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep in seconds per loop + iteration when the lock is in blocking mode and another client is + currently holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + async def __aenter__(self): + if await self.acquire(): + return self + raise LockError("Unable to acquire lock within the time specified") + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + + async def acquire( + self, + blocking: Optional[bool] = None, + blocking_timeout: Optional[float] = None, + token: Optional[Union[str, bytes]] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + try: + encoder = self.redis.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = self.redis.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = asyncio.get_running_loop().time() + blocking_timeout + while True: + if await self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = asyncio.get_running_loop().time() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + await asyncio.sleep(sleep) + + async def do_acquire(self, token: Union[str, bytes]) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if await self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + async def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return await self.redis.get(self.name) is not None + + async def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = await self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + try: + encoder = self.redis.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = self.redis.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> Awaitable[None]: + """Releases the already acquired lock""" + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + return self.do_release(expected_token) + + async def do_release(self, expected_token: bytes) -> None: + if not bool( + await self.lua_release( + keys=[self.name], args=[expected_token], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot release a lock that's no longer owned") + + def extend( + self, additional_time: float, replace_ttl: bool = False + ) -> Awaitable[bool]: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + async def do_extend(self, additional_time, replace_ttl) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + await self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's no longer owned") + return True + + def reacquire(self) -> Awaitable[bool]: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + async def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + await self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") + return True diff --git a/.venv/Lib/site-packages/redis/asyncio/retry.py b/.venv/Lib/site-packages/redis/asyncio/retry.py new file mode 100644 index 00000000..7c5e3b0e --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/retry.py @@ -0,0 +1,67 @@ +from asyncio import sleep +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar + +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + + +T = TypeVar("T") + + +class Retry: + """Retry a specific number of times after a failure""" + + __slots__ = "_backoff", "_retries", "_supported_errors" + + def __init__( + self, + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[Type[RedisError], ...] = ( + ConnectionError, + TimeoutError, + ), + ): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + `retries` can be negative to retry forever. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + def update_supported_errors(self, specified_errors: list): + """ + Updates the supported errors with the specified error types + """ + self._supported_errors = tuple( + set(self._supported_errors + tuple(specified_errors)) + ) + + async def call_with_retry( + self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] + ) -> T: + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return await do() + except self._supported_errors as error: + failures += 1 + await fail(error) + if self._retries >= 0 and failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + await sleep(backoff) diff --git a/.venv/Lib/site-packages/redis/asyncio/sentinel.py b/.venv/Lib/site-packages/redis/asyncio/sentinel.py new file mode 100644 index 00000000..6834fb19 --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/sentinel.py @@ -0,0 +1,375 @@ +import asyncio +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Optional, Sequence, Tuple, Type + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + EncodableT, + SSLConnection, +) +from redis.commands import AsyncSentinelCommands +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def _connect_retry(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def connect(self): + return await self.retry.call_with_retry( + self._connect_retry, + lambda error: asyncio.sleep(0), + ) + + async def read_response( + self, + disable_decoding: bool = False, + timeout: Optional[float] = None, + *, + disconnect_on_error: Optional[float] = True, + push_request: Optional[bool] = False, + ): + try: + return await super().read_response( + disable_decoding=disable_decoding, + timeout=timeout, + disconnect_on_error=disconnect_on_error, + push_request=push_request, + ) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: Connection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel(AsyncSentinelCommands): + """ + Redis Sentinel cluster client + + >>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + async def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + else: + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + collected_errors = list() + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + + connection_pool = connection_pool_class(service_name, self, **connection_kwargs) + # The Redis object "owns" the pool + return redis_class.from_pool(connection_pool) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + + connection_pool = connection_pool_class(service_name, self, **connection_kwargs) + # The Redis object "owns" the pool + return redis_class.from_pool(connection_pool) diff --git a/.venv/Lib/site-packages/redis/asyncio/utils.py b/.venv/Lib/site-packages/redis/asyncio/utils.py new file mode 100644 index 00000000..5a55b36a --- /dev/null +++ b/.venv/Lib/site-packages/redis/asyncio/utils.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Pipeline, Redis + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.asyncio.client import Redis + + return Redis.from_url(url, **kwargs) + + +class pipeline: + def __init__(self, redis_obj: "Redis"): + self.p: "Pipeline" = redis_obj.pipeline() + + async def __aenter__(self) -> "Pipeline": + return self.p + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.p.execute() + del self.p diff --git a/.venv/Lib/site-packages/redis/backoff.py b/.venv/Lib/site-packages/redis/backoff.py new file mode 100644 index 00000000..c62e760b --- /dev/null +++ b/.venv/Lib/site-packages/redis/backoff.py @@ -0,0 +1,114 @@ +import random +from abc import ABC, abstractmethod + +# Maximum backoff between each retry in seconds +DEFAULT_CAP = 0.512 +# Minimum backoff between each retry in seconds +DEFAULT_BASE = 0.008 + + +class AbstractBackoff(ABC): + """Backoff interface""" + + def reset(self): + """ + Reset internal state before an operation. + `reset` is called once at the beginning of + every call to `Retry.call_with_retry` + """ + pass + + @abstractmethod + def compute(self, failures): + """Compute backoff in seconds upon failure""" + pass + + +class ConstantBackoff(AbstractBackoff): + """Constant backoff upon failure""" + + def __init__(self, backoff): + """`backoff`: backoff time in seconds""" + self._backoff = backoff + + def compute(self, failures): + return self._backoff + + +class NoBackoff(ConstantBackoff): + """No backoff upon failure""" + + def __init__(self): + super().__init__(0) + + +class ExponentialBackoff(AbstractBackoff): + """Exponential backoff upon failure""" + + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + return min(self._cap, self._base * 2**failures) + + +class FullJitterBackoff(AbstractBackoff): + """Full jitter backoff upon failure""" + + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + return random.uniform(0, min(self._cap, self._base * 2**failures)) + + +class EqualJitterBackoff(AbstractBackoff): + """Equal jitter backoff upon failure""" + + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + temp = min(self._cap, self._base * 2**failures) / 2 + return temp + random.uniform(0, temp) + + +class DecorrelatedJitterBackoff(AbstractBackoff): + """Decorrelated jitter backoff upon failure""" + + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + self._previous_backoff = 0 + + def reset(self): + self._previous_backoff = 0 + + def compute(self, failures): + max_backoff = max(self._base, self._previous_backoff * 3) + temp = random.uniform(self._base, max_backoff) + self._previous_backoff = min(self._cap, temp) + return self._previous_backoff + + +def default_backoff(): + return EqualJitterBackoff() diff --git a/.venv/Lib/site-packages/redis/client.py b/.venv/Lib/site-packages/redis/client.py new file mode 100644 index 00000000..49231435 --- /dev/null +++ b/.venv/Lib/site-packages/redis/client.py @@ -0,0 +1,1500 @@ +import copy +import re +import threading +import time +import warnings +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from redis._parsers.encoders import Encoder +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) +from redis.commands import ( + CoreCommands, + RedisModuleCommands, + SentinelCommands, + list_or_args, +) +from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection +from redis.credentials import CredentialProvider +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + get_lib_version, + safe_str, + str_if_bytes, +) + +SYM_EMPTY = b"" +EMPTY_RESPONSE = "EMPTY_RESPONSE" + +# some responses (ie. dump) are binary, and just meant to never be decoded +NEVER_DECODE = "NEVER_DECODE" + + +class CaseInsensitiveDict(dict): + "Case insensitive dict implementation. Assumes string keys only." + + def __init__(self, data: Dict[str, str]) -> None: + for k, v in data.items(): + self[k.upper()] = v + + def __contains__(self, k): + return super().__contains__(k.upper()) + + def __delitem__(self, k): + super().__delitem__(k.upper()) + + def __getitem__(self, k): + return super().__getitem__(k.upper()) + + def get(self, k, default=None): + return super().get(k.upper(), default) + + def __setitem__(self, k, v): + super().__setitem__(k.upper(), v) + + def update(self, data): + data = CaseInsensitiveDict(data) + super().update(data) + + +class AbstractRedis: + pass + + +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + + It is not safe to pass PubSub or Pipeline objects between threads. + """ + + @classmethod + def from_url(cls, url: str, **kwargs) -> None: + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[username@]/path/to/socket.sock?db=0[&password=password] + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + single_connection_client = kwargs.pop("single_connection_client", False) + connection_pool = ConnectionPool.from_url(url, **kwargs) + client = cls( + connection_pool=connection_pool, + single_connection_client=single_connection_client, + ) + client.auto_close_connection_pool = True + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client + + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=None, + socket_keepalive_options=None, + connection_pool=None, + unix_socket_path=None, + encoding="utf-8", + encoding_errors="strict", + charset=None, + errors=None, + decode_responses=False, + retry_on_timeout=False, + retry_on_error=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_path=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + max_connections=None, + single_connection_client=False, + health_check_interval=0, + client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), + username=None, + retry=None, + redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + ) -> None: + """ + Initialize a new Redis client. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + + Args: + + single_connection_client: + if `True`, connection pool is not used. In that case `Redis` + instance use is not thread safe. + """ + if not connection_pool: + if charset is not None: + warnings.warn( + DeprecationWarning( + '"charset" is deprecated. Use "encoding" instead' + ) + ) + encoding = charset + if errors is not None: + warnings.warn( + DeprecationWarning( + '"errors" is deprecated. Use "encoding_errors" instead' + ) + ) + encoding_errors = errors + if not retry_on_error: + retry_on_error = [] + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_error": retry_on_error, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, + "redis_connect_func": redis_connect_func, + "credential_provider": credential_provider, + "protocol": protocol, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_ca_data": ssl_ca_data, + "ssl_check_hostname": ssl_check_hostname, + "ssl_password": ssl_password, + "ssl_ca_path": ssl_ca_path, + "ssl_validate_ocsp_stapled": ssl_validate_ocsp_stapled, + "ssl_validate_ocsp": ssl_validate_ocsp, + "ssl_ocsp_context": ssl_ocsp_context, + "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.auto_close_connection_pool = True + else: + self.auto_close_connection_pool = False + + self.connection_pool = connection_pool + self.connection = None + if single_connection_client: + self.connection = self.connection_pool.get_connection("_") + + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) + + def __repr__(self) -> str: + return f"{type(self).__name__}<{repr(self.connection_pool)}>" + + def get_encoder(self) -> "Encoder": + """Get the connection pool's encoder""" + return self.connection_pool.get_encoder() + + def get_connection_kwargs(self) -> Dict: + """Get the connection's key-word arguments""" + return self.connection_pool.connection_kwargs + + def get_retry(self) -> Optional["Retry"]: + return self.get_connection_kwargs().get("retry") + + def set_retry(self, retry: "Retry") -> None: + self.get_connection_kwargs().update({"retry": retry}) + self.connection_pool.set_retry(retry) + + def set_response_callback(self, command: str, callback: Callable) -> None: + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def load_external_module(self, funcname, func) -> None: + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + def transaction( + self, func: Callable[["Pipeline"], None], *watches, **kwargs + ) -> None: + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + shard_hint = kwargs.pop("shard_hint", None) + value_from_callable = kwargs.pop("value_from_callable", False) + watch_delay = kwargs.pop("watch_delay", None) + with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + pipe.watch(*watches) + func_value = func(pipe) + exec_value = pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + time.sleep(watch_delay) + continue + + def lock( + self, + name: str, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + lock_class: Union[None, Any] = None, + thread_local: bool = True, + ): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self): + return Monitor(self.connection_pool) + + def client(self): + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def close(self): + # In case a connection property does not yet exist + # (due to a crash earlier in the Redis() constructor), return + # immediately as there is nothing to clean-up. + if not hasattr(self, "connection"): + return + + conn = self.connection + if conn: + self.connection = None + self.connection_pool.release(conn) + + if self.auto_close_connection_pool: + self.connection_pool.disconnect() + + def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + conn.send_command(*args) + return self.parse_response(conn, command_name, **options) + + def _disconnect_raise(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_error is not set or the error + is not one of the specified error types + """ + conn.disconnect() + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + pool = self.connection_pool + command_name = args[0] + conn = self.connection or pool.get_connection(command_name, **options) + + try: + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if not self.connection: + pool.release(conn) + + def parse_response(self, connection, command_name, **options): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = connection.read_response(disable_decoding=True) + options.pop(NEVER_DECODE) + else: + response = connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + + if EMPTY_RESPONSE in options: + options.pop(EMPTY_RESPONSE) + + if command_name in self.response_callbacks: + return self.response_callbacks[command_name](response, **options) + return response + + +StrictRedis = Redis + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*?)\] (.*)") + command_re = re.compile(r'"(.*?)(? "PubSub": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.reset() + + def __del__(self) -> None: + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + self.reset() + except Exception: + pass + + def reset(self) -> None: + if self.connection: + self.connection.disconnect() + self.connection._deregister_connect_callback(self.on_connect) + self.connection_pool.release(self.connection) + self.connection = None + self.health_check_response_counter = 0 + self.channels = {} + self.pending_unsubscribe_channels = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self.subscribed_event.clear() + + def close(self) -> None: + self.reset() + + def on_connect(self, connection) -> None: + "Re-subscribe to any channels and patterns previously subscribed to" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() + if self.channels: + channels = { + self.encoder.decode(k, force=True): v for k, v in self.channels.items() + } + self.subscribe(**channels) + if self.patterns: + patterns = { + self.encoder.decode(k, force=True): v for k, v in self.patterns.items() + } + self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } + self.ssubscribe(**shard_channels) + + @property + def subscribed(self) -> bool: + """Indicates if there are subscriptions to any channels or patterns""" + return self.subscribed_event.is_set() + + def execute_command(self, *args): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection._register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + if not self.subscribed: + self.clean_health_check_responses() + self._execute(connection, connection.send_command, *args, **kwargs) + + def clean_health_check_responses(self) -> None: + """ + If any health check responses are present, clean them + """ + ttl = 10 + conn = self.connection + while self.health_check_response_counter > 0 and ttl > 0: + if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): + response = self._execute(conn, conn.read_response) + if self.is_health_check_response(response): + self.health_check_response_counter -= 1 + else: + raise PubSubError( + "A non health check response was cleaned by " + "execute_command: {0}".format(response) + ) + ttl -= 1 + + def _disconnect_raise_connect(self, conn, error) -> None: + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + conn.connect() + + def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + def parse_response(self, block=True, timeout=0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + self.check_health() + + def try_read(): + if not block: + if not conn.can_read(timeout=timeout): + return None + else: + conn.connect() + return conn.read_response(disconnect_on_error=False, push_request=True) + + response = self._execute(conn, try_read) + + if self.is_health_check_response(response): + # ignore the health check message as user might not expect it + self.health_check_response_counter -= 1 + return None + return response + + def is_health_check_response(self, response) -> bool: + """ + Check if the response is a health check response. + If there are no subscriptions redis responds to PING command with a + bulk response, instead of a multi-bulk with "pong" and the response. + """ + return response in [ + self.health_check_response, # If there was a subscription + self.health_check_response_b, # If there wasn't + ] + + def check_health(self) -> None: + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if conn.health_check_interval and time.time() > conn.next_health_check: + conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) + self.health_check_response_counter += 1 + + def _normalize_keys(self, data) -> Dict: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} + + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_patterns = dict.fromkeys(args) + new_patterns.update(kwargs) + ret_val = self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + if args: + args = list_or_args(args[0], args[1:]) + patterns = self._normalize_keys(dict.fromkeys(args)) + else: + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *args) + + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) + ret_val = self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(args)) + else: + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *args) + + def ssubscribe(self, *args, target_node=None, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + + def listen(self): + "Listen for messages on channels this client has been subscribed to" + while self.subscribed: + response = self.handle_message(self.parse_response(block=True)) + if response is not None: + yield response + + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + if not self.subscribed: + # Wait for subscription + start_time = time.time() + if self.subscribed_event.wait(timeout) is True: + # The connection was subscribed during the timeout time frame. + # The timeout should be adjusted based on the time spent + # waiting for the subscription + time_spent = time.time() - start_time + timeout = max(0.0, timeout - time_spent) + else: + # The connection isn't subscribed to any channels or patterns, + # so no messages are available + return None + + response = self.parse_response(block=(timeout is None), timeout=timeout) + if response: + return self.handle_message(response, ignore_subscribe_messages) + return None + + get_sharded_message = get_message + + def ping(self, message: Union[str, None] = None) -> bool: + """ + Ping the Redis server + """ + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) + + def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + def run_in_thread( + self, + sleep_time: int = 0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + for s_channel, handler in self.shard_channels.items(): + if handler is None: + raise PubSubError( + f"Shard Channel: '{s_channel}' has no handler registered" + ) + + thread = PubSubWorkerThread( + self, sleep_time, daemon=daemon, exception_handler=exception_handler + ) + thread.start() + return thread + + +class PubSubWorkerThread(threading.Thread): + def __init__( + self, + pubsub, + sleep_time: float, + daemon: bool = False, + exception_handler: Union[ + Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None + ] = None, + ): + super().__init__() + self.daemon = daemon + self.pubsub = pubsub + self.sleep_time = sleep_time + self.exception_handler = exception_handler + self._running = threading.Event() + + def run(self) -> None: + if self._running.is_set(): + return + self._running.set() + pubsub = self.pubsub + sleep_time = self.sleep_time + while self._running.is_set(): + try: + pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + except BaseException as e: + if self.exception_handler is None: + raise + self.exception_handler(e, pubsub, self) + pubsub.close() + + def stop(self) -> None: + # trip the flag so the run loop exits. the run loop will + # close the pubsub connection, which disconnects the socket + # and returns the connection to the pool. + self._running.clear() + + +class Pipeline(Redis): + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.transaction = transaction + self.shard_hint = shard_hint + + self.watching = False + self.reset() + + def __enter__(self) -> "Pipeline": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self) -> int: + return len(self.command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + def reset(self) -> None: + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + self.connection.send_command("UNWATCH") + self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + self.connection_pool.release(self.connection) + self.connection = None + + def close(self) -> None: + """Close the pipeline""" + self.reset() + + def multi(self) -> None: + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already been issued" + ) + self.explicit_transaction = True + + def execute_command(self, *args, **kwargs): + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + def _disconnect_reset_raise(self, conn, error) -> None: + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + self.reset() + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + self.reset() + raise + + def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = self.connection_pool.get_connection(command_name, self.shard_hint) + self.connection = conn + + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + def _execute_transaction(self, connection, commands, raise_on_error) -> List: + cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + all_cmds = connection.pack_commands( + [args for args, options in cmds if EMPTY_RESPONSE not in options] + ) + connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + self.parse_response(connection, "_") + except ResponseError as e: + errors.append((0, e)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + self.parse_response(connection, "_") + except ResponseError as e: + self.annotate_exception(e, i + 1, command[0]) + errors.append((i, e)) + + # parse the EXEC. + try: + response = self.parse_response(connection, "_") + except ExecAbortError: + if errors: + raise errors[0][1] + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + data.append(r) + return data + + def _execute_pipeline(self, connection, commands, raise_on_error): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append(self.parse_response(connection, args[0], **options)) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands, response): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception(self, exception, number, command): + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] + + def parse_response(self, connection, command_name, **options): + result = Redis.parse_response(self, connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = immediate("SCRIPT LOAD", s.script) + + def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: + """ + Close the connection, raise an exception if we were watching, + and raise an exception if TimeoutError is not part of retry_on_error, + or the error is not a TimeoutError + """ + conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) + # if TimeoutError is not part of retry_on_error, or the error + # is not a TimeoutError, raise it + if not ( + TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError) + ): + self.reset() + raise error + + def execute(self, raise_on_error=True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + self.load_scripts() + if self.transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + + try: + return conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + self.reset() + + def discard(self): + """ + Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + self.execute_command("DISCARD") + + def watch(self, *names): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return self.execute_command("WATCH", *names) + + def unwatch(self) -> bool: + """Unwatches all previously specified keys""" + return self.watching and self.execute_command("UNWATCH") or True diff --git a/.venv/Lib/site-packages/redis/cluster.py b/.venv/Lib/site-packages/redis/cluster.py new file mode 100644 index 00000000..873d586c --- /dev/null +++ b/.venv/Lib/site-packages/redis/cluster.py @@ -0,0 +1,2486 @@ +import random +import socket +import sys +import threading +import time +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from redis._parsers import CommandsParser, Encoder +from redis._parsers.helpers import parse_scan +from redis.backoff import default_backoff +from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.commands.helpers import list_or_args +from redis.connection import ConnectionPool, DefaultParser, parse_url +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.exceptions import ( + AskError, + AuthenticationError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + ConnectionError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + RedisError, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.utils import ( + HIREDIS_AVAILABLE, + dict_merge, + list_keys_to_dict, + merge_result, + safe_str, + str_if_bytes, +) + + +def get_node_name(host: str, port: Union[str, int]) -> str: + return f"{host}:{port}" + + +def get_connection(redis_node, *args, **options): + return redis_node.connection or redis_node.connection_pool.get_connection( + args[0], **options + ) + + +def parse_scan_result(command, res, **options): + cursors = {} + ret = [] + for node_name, response in res.items(): + cursor, r = parse_scan(response, **options) + cursors[node_name] = cursor + ret += r + + return cursors, ret + + +def parse_pubsub_numsub(command, res, **options): + numsub_d = OrderedDict() + for numsub_tups in res.values(): + for channel, numsubbed in numsub_tups: + try: + numsub_d[channel] += numsubbed + except KeyError: + numsub_d[channel] = numsubbed + + ret_numsub = [(channel, numsub) for channel, numsub in numsub_d.items()] + return ret_numsub + + +def parse_cluster_slots( + resp: Any, **options: Any +) -> Dict[Tuple[int, int], Dict[str, Any]]: + current_host = options.get("current_host", "") + + def fix_server(*args: Any) -> Tuple[str, Any]: + return str_if_bytes(args[0]) or current_host, args[1] + + slots = {} + for slot in resp: + start, end, primary = slot[:3] + replicas = slot[3:] + slots[start, end] = { + "primary": fix_server(*primary), + "replicas": [fix_server(*replica) for replica in replicas], + } + + return slots + + +def parse_cluster_shards(resp, **options): + """ + Parse CLUSTER SHARDS response. + """ + if isinstance(resp[0], dict): + return resp + shards = [] + for x in resp: + shard = {"slots": [], "nodes": []} + for i in range(0, len(x[1]), 2): + shard["slots"].append((x[1][i], (x[1][i + 1]))) + nodes = x[3] + for node in nodes: + dict_node = {} + for i in range(0, len(node), 2): + dict_node[node[i]] = node[i + 1] + shard["nodes"].append(dict_node) + shards.append(shard) + + return shards + + +def parse_cluster_myshardid(resp, **options): + """ + Parse CLUSTER MYSHARDID response. + """ + return resp.decode("utf-8") + + +PRIMARY = "primary" +REPLICA = "replica" +SLOT_ID = "slot-id" + +REDIS_ALLOWED_KEYS = ( + "charset", + "connection_class", + "connection_pool", + "connection_pool_class", + "client_name", + "credential_provider", + "db", + "decode_responses", + "encoding", + "encoding_errors", + "errors", + "host", + "lib_name", + "lib_version", + "max_connections", + "nodes_flag", + "redis_connect_func", + "password", + "port", + "queue_class", + "retry", + "retry_on_timeout", + "protocol", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_timeout", + "ssl", + "ssl_ca_certs", + "ssl_ca_data", + "ssl_certfile", + "ssl_cert_reqs", + "ssl_keyfile", + "ssl_password", + "unix_socket_path", + "username", +) +KWARGS_DISABLED_KEYS = ("host", "port") + + +def cleanup_kwargs(**kwargs): + """ + Remove unsupported or disabled keys from kwargs + """ + connection_kwargs = { + k: v + for k, v in kwargs.items() + if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS + } + + return connection_kwargs + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) + + +class AbstractRedisCluster: + RedisClusterRequestTTL = 16 + + PRIMARIES = "primaries" + REPLICAS = "replicas" + ALL_NODES = "all" + RANDOM = "random" + DEFAULT_NODE = "default-node" + + NODE_FLAGS = {PRIMARIES, REPLICAS, ALL_NODES, RANDOM, DEFAULT_NODE} + + COMMAND_FLAGS = dict_merge( + list_keys_to_dict( + [ + "ACL CAT", + "ACL DELUSER", + "ACL DRYRUN", + "ACL GENPASS", + "ACL GETUSER", + "ACL HELP", + "ACL LIST", + "ACL LOG", + "ACL LOAD", + "ACL SAVE", + "ACL SETUSER", + "ACL USERS", + "ACL WHOAMI", + "AUTH", + "CLIENT LIST", + "CLIENT SETINFO", + "CLIENT SETNAME", + "CLIENT GETNAME", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "TIME", + "PUBSUB CHANNELS", + "PUBSUB NUMPAT", + "PUBSUB NUMSUB", + "PUBSUB SHARDCHANNELS", + "PUBSUB SHARDNUMSUB", + "PING", + "INFO", + "SHUTDOWN", + "KEYS", + "DBSIZE", + "BGSAVE", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "WAIT", + "WAITAOF", + "SAVE", + "MEMORY PURGE", + "MEMORY MALLOC-STATS", + "MEMORY STATS", + "LASTSAVE", + "CLIENT TRACKINGINFO", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + "CLIENT UNBLOCK", + "CLIENT ID", + "CLIENT REPLY", + "CLIENT GETREDIR", + "CLIENT INFO", + "CLIENT KILL", + "READONLY", + "CLUSTER INFO", + "CLUSTER MEET", + "CLUSTER MYSHARDID", + "CLUSTER NODES", + "CLUSTER REPLICAS", + "CLUSTER RESET", + "CLUSTER SET-CONFIG-EPOCH", + "CLUSTER SLOTS", + "CLUSTER SHARDS", + "CLUSTER COUNT-FAILURE-REPORTS", + "CLUSTER KEYSLOT", + "COMMAND", + "COMMAND COUNT", + "COMMAND LIST", + "COMMAND GETKEYS", + "CONFIG GET", + "DEBUG", + "RANDOMKEY", + "READONLY", + "READWRITE", + "TIME", + "TFUNCTION LOAD", + "TFUNCTION DELETE", + "TFUNCTION LIST", + "TFCALL", + "TFCALLASYNC", + "GRAPH.CONFIG", + "LATENCY HISTORY", + "LATENCY LATEST", + "LATENCY RESET", + "MODULE LIST", + "MODULE LOAD", + "MODULE UNLOAD", + "MODULE LOADEX", + ], + DEFAULT_NODE, + ), + list_keys_to_dict( + [ + "FLUSHALL", + "FLUSHDB", + "FUNCTION DELETE", + "FUNCTION FLUSH", + "FUNCTION LIST", + "FUNCTION LOAD", + "FUNCTION RESTORE", + "REDISGEARS_2.REFRESHCLUSTER", + "SCAN", + "SCRIPT EXISTS", + "SCRIPT FLUSH", + "SCRIPT LOAD", + ], + PRIMARIES, + ), + list_keys_to_dict(["FUNCTION DUMP"], RANDOM), + list_keys_to_dict( + [ + "CLUSTER COUNTKEYSINSLOT", + "CLUSTER DELSLOTS", + "CLUSTER DELSLOTSRANGE", + "CLUSTER GETKEYSINSLOT", + "CLUSTER SETSLOT", + ], + SLOT_ID, + ), + ) + + SEARCH_COMMANDS = ( + [ + "FT.CREATE", + "FT.SEARCH", + "FT.AGGREGATE", + "FT.EXPLAIN", + "FT.EXPLAINCLI", + "FT,PROFILE", + "FT.ALTER", + "FT.DROPINDEX", + "FT.ALIASADD", + "FT.ALIASUPDATE", + "FT.ALIASDEL", + "FT.TAGVALS", + "FT.SUGADD", + "FT.SUGGET", + "FT.SUGDEL", + "FT.SUGLEN", + "FT.SYNUPDATE", + "FT.SYNDUMP", + "FT.SPELLCHECK", + "FT.DICTADD", + "FT.DICTDEL", + "FT.DICTDUMP", + "FT.INFO", + "FT._LIST", + "FT.CONFIG", + "FT.ADD", + "FT.DEL", + "FT.DROP", + "FT.GET", + "FT.MGET", + "FT.SYNADD", + ], + ) + + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { + "CLUSTER SLOTS": parse_cluster_slots, + "CLUSTER SHARDS": parse_cluster_shards, + "CLUSTER MYSHARDID": parse_cluster_myshardid, + } + + RESULT_CALLBACKS = dict_merge( + list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub), + list_keys_to_dict( + ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) + ), + list_keys_to_dict( + ["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result + ), + list_keys_to_dict( + [ + "PING", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "CLIENT SETNAME", + "BGSAVE", + "SLOWLOG RESET", + "SAVE", + "MEMORY PURGE", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + ], + lambda command, res: all(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + ["DBSIZE", "WAIT"], + lambda command, res: sum(res.values()) if isinstance(res, dict) else res, + ), + list_keys_to_dict( + ["CLIENT UNBLOCK"], lambda command, res: 1 if sum(res.values()) > 0 else 0 + ), + list_keys_to_dict(["SCAN"], parse_scan_result), + list_keys_to_dict( + ["SCRIPT LOAD"], lambda command, res: list(res.values()).pop() + ), + list_keys_to_dict( + ["SCRIPT EXISTS"], lambda command, res: [all(k) for k in zip(*res.values())] + ), + list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), + ) + + ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + + def replace_default_node(self, target_node: "ClusterNode" = None) -> None: + """Replace the default cluster node. + A random cluster node will be chosen if target_node isn't passed, and primaries + will be prioritized. The default node will not be changed if there are no other + nodes in the cluster. + + Args: + target_node (ClusterNode, optional): Target node to replace the default + node. Defaults to None. + """ + if target_node: + self.nodes_manager.default_node = target_node + else: + curr_node = self.get_default_node() + primaries = [node for node in self.get_primaries() if node != curr_node] + if primaries: + # Choose a primary if the cluster contains different primaries + self.nodes_manager.default_node = random.choice(primaries) + else: + # Otherwise, hoose a primary if the cluster contains different primaries + replicas = [node for node in self.get_replicas() if node != curr_node] + if replicas: + self.nodes_manager.default_node = random.choice(replicas) + + +class RedisCluster(AbstractRedisCluster, RedisClusterCommands): + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[username@]/path/to/socket.sock?db=0[&password=password] + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + def __init__( + self, + host: Optional[str] = None, + port: int = 6379, + startup_nodes: Optional[List["ClusterNode"]] = None, + cluster_error_retry_attempts: int = 3, + retry: Optional["Retry"] = None, + require_full_coverage: bool = False, + reinitialize_steps: int = 5, + read_from_replicas: bool = False, + dynamic_startup_nodes: bool = True, + url: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + **kwargs, + ): + """ + Initialize a new RedisCluster client. + + :param startup_nodes: + List of nodes from which initial bootstrapping can be done + :param host: + Can be used to point to a startup node + :param port: + Can be used to point to a startup node + :param require_full_coverage: + When set to False (default value): the client will not require a + full coverage of the slots. However, if not all slots are covered, + and at least one node has 'cluster-require-full-coverage' set to + 'yes,' the server will throw a ClusterDownError for some key-based + commands. See - + https://redis.io/topics/cluster-tutorial#redis-cluster-configuration-parameters + When set to True: all slots must be covered to construct the + cluster client. If not all slots are covered, RedisClusterException + will be thrown. + :param read_from_replicas: + Enable read from replicas in READONLY mode. You can read possibly + stale data. + When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. + :param dynamic_startup_nodes: + Set the RedisCluster's startup nodes to all of the discovered nodes. + If true (default value), the cluster's discovered nodes will be used to + determine the cluster nodes-slots mapping in the next topology refresh. + It will remove the initial passed startup nodes if their endpoints aren't + listed in the CLUSTER SLOTS output. + If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists + specific IP addresses, it is best to set it to false. + :param cluster_error_retry_attempts: + Number of times to retry before raising an error when + :class:`~.TimeoutError` or :class:`~.ConnectionError` or + :class:`~.ClusterDownError` are encountered + :param reinitialize_steps: + Specifies the number of MOVED errors that need to occur before + reinitializing the whole cluster topology. If a MOVED error occurs + and the cluster does not need to be reinitialized on this current + error handling, only the MOVED slot will be patched with the + redirected node. + To reinitialize the cluster on every MOVED error, set + reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set + reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. + + :**kwargs: + Extra arguments that will be sent into Redis instance when created + (See Official redis-py doc for supported kwargs + [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) + Some kwargs are not supported and will raise a + RedisClusterException: + - db (Redis do not support database SELECT in cluster mode) + """ + if startup_nodes is None: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + from_url = False + if url is not None: + from_url = True + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections" + ) + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + host = kwargs.get("host") + port = kwargs.get("port", port) + startup_nodes.append(ClusterNode(host, port)) + elif host is not None and port is not None: + startup_nodes.append(ClusterNode(host, port)) + elif len(startup_nodes) == 0: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])" + ) + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + # If the user passed on_connect function we'll save it and run it + # inside the RedisCluster.on_connect() function + self.user_on_connect_func = kwargs.pop("redis_connect_func", None) + kwargs.update({"redis_connect_func": self.on_connect}) + kwargs = cleanup_kwargs(**kwargs) + if retry: + self.retry = retry + kwargs.update({"retry": self.retry}) + else: + kwargs.update({"retry": Retry(default_backoff(), 0)}) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + from_url=from_url, + require_full_coverage=require_full_coverage, + dynamic_startup_nodes=dynamic_startup_nodes, + address_remap=address_remap, + **kwargs, + ) + + self.cluster_response_callbacks = CaseInsensitiveDict( + self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS + ) + self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + self.commands_parser = CommandsParser(self) + self._lock = threading.Lock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def disconnect_connection_pools(self): + for node in self.get_nodes(): + if node.redis_connection: + try: + node.redis_connection.connection_pool.disconnect() + except OSError: + # Client was already disconnected. do nothing + pass + + def on_connect(self, connection): + """ + Initialize the connection, authenticate and select a database and send + READONLY if it is set during object initialization. + """ + connection.set_parser(ClusterParser) + connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + connection.send_command("READONLY") + if str_if_bytes(connection.read_response()) != "OK": + raise ConnectionError("READONLY command failed") + + if self.user_on_connect_func is not None: + self.user_on_connect_func(connection) + + def get_redis_connection(self, node): + if not node.redis_connection: + with self._lock: + if not node.redis_connection: + self.nodes_manager.create_redis_connections([node]) + return node.redis_connection + + def get_node(self, host=None, port=None, node_name=None): + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self): + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self): + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self): + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self): + return list(self.nodes_manager.nodes_cache.values()) + + def get_node_from_key(self, key, replica=False): + """ + Get the node that holds the key's slot. + If replica set to True but the slot doesn't have any replicas, None is + returned. + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if slot_cache is None or len(slot_cache) == 0: + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') + if replica and len(self.nodes_manager.slots_cache[slot]) < 2: + return None + elif replica: + node_idx = 1 + else: + # primary + node_idx = 0 + + return slot_cache[node_idx] + + def get_default_node(self): + """ + Get the cluster's default node + """ + return self.nodes_manager.default_node + + def set_default_node(self, node): + """ + Set the default node of the cluster. + :param node: 'ClusterNode' + :return True if the default node was set, else False + """ + if node is None or self.get_node(node_name=node.name) is None: + return False + self.nodes_manager.default_node = node + return True + + def get_retry(self) -> Optional["Retry"]: + return self.retry + + def set_retry(self, retry: "Retry") -> None: + self.retry = retry + for node in self.get_nodes(): + node.redis_connection.set_retry(retry) + + def monitor(self, target_node=None): + """ + Returns a Monitor object for the specified target node. + The default cluster node will be selected if no target node was + specified. + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + if target_node is None: + target_node = self.get_default_node() + if target_node.redis_connection is None: + raise RedisClusterException( + f"Cluster Node {target_node.name} has no redis_connection" + ) + return target_node.redis_connection.monitor() + + def pubsub(self, node=None, host=None, port=None, **kwargs): + """ + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + + def pipeline(self, transaction=None, shard_hint=None): + """ + Cluster impl: + Pipelines do not work in cluster mode the same way they + do in normal mode. Create a clone of this object so + that simulating pipelines will work correctly. Each + command will be called directly when used and + when calling execute() will only return the result stack. + """ + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline( + nodes_manager=self.nodes_manager, + commands_parser=self.commands_parser, + startup_nodes=self.nodes_manager.startup_nodes, + result_callbacks=self.result_callbacks, + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + read_from_replicas=self.read_from_replicas, + reinitialize_steps=self.reinitialize_steps, + lock=self._lock, + ) + + def lock( + self, + name, + timeout=None, + sleep=0.1, + blocking=True, + blocking_timeout=None, + lock_class=None, + thread_local=True, + ): + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. Note that as + of redis-py 3.0, the only lock class we implement is ``Lock`` (which is + a Lua-based lock). So, it's unlikely you'll need this parameter, unless + you have created your own custom lock class. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def set_response_callback(self, command, callback): + """Set a custom Response Callback""" + self.cluster_response_callbacks[command] = callback + + def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. + command = args[0].upper() + if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: + command = f"{args[0]} {args[1]}".upper() + + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the nodes group for this command if it was predefined + command_flag = self.command_flags.get(command) + if command_flag == self.__class__.RANDOM: + # return a random node + return [self.get_random_node()] + elif command_flag == self.__class__.PRIMARIES: + # return all primaries + return self.get_primaries() + elif command_flag == self.__class__.REPLICAS: + # return all replicas + return self.get_replicas() + elif command_flag == self.__class__.ALL_NODES: + # return all nodes + return self.get_nodes() + elif command_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + elif command in self.__class__.SEARCH_COMMANDS[0]: + return [self.nodes_manager.default_node] + else: + # get the node that holds the key's slot + slot = self.determine_slot(*args) + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) + return [node] + + def _should_reinitialized(self): + # To reinitialize the cluster on every MOVED error, + # set reinitialize_steps to 1. + # To avoid reinitializing the cluster on moved errors, set + # reinitialize_steps to 0. + if self.reinitialize_steps == 0: + return False + else: + return self.reinitialize_counter % self.reinitialize_steps == 0 + + def keyslot(self, key): + """ + Calculate keyslot for a given key. + See Keys distribution model in https://redis.io/topics/cluster-spec + """ + k = self.encoder.encode(key) + return key_slot(k) + + def _get_command_keys(self, *args): + """ + Get the keys in the command. If the command has no keys in in, None is + returned. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. + """ + redis_conn = self.get_default_node().redis_connection + return self.commands_parser.get_keys(redis_conn, *args) + + def determine_slot(self, *args): + """ + Figure out what slot to use based on args. + + Raises a RedisClusterException if there's a missing key and we can't + determine what slots to map the command to; or, if the keys don't + all map to the same key slot. + """ + command = args[0] + if self.command_flags.get(command) == SLOT_ID: + # The command contains the slot ID + return args[1] + + # Get the keys in the command + + # EVAL and EVALSHA are common enough that it's wasteful to go to the + # redis server to parse the keys. Besides, there is a bug in redis<7.0 + # where `self._get_command_keys()` fails anyway. So, we special case + # EVAL/EVALSHA. + if command.upper() in ("EVAL", "EVALSHA"): + # command syntax: EVAL "script body" num_keys ... + if len(args) <= 2: + raise RedisClusterException(f"Invalid args in command: {args}") + num_actual_keys = int(args[2]) + eval_keys = args[3 : 3 + num_actual_keys] + # if there are 0 keys, that means the script can be run on any node + # so we can just return a random slot + if len(eval_keys) == 0: + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + keys = eval_keys + else: + keys = self._get_command_keys(*args) + if keys is None or len(keys) == 0: + # FCALL can call a function with 0 keys, that means the function + # can be run on any node so we can just return a random slot + if command.upper() in ("FCALL", "FCALL_RO"): + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + f"target nodes.\nCommand: {args}" + ) + + # single key command + if len(keys) == 1: + return self.keyslot(keys[0]) + + # multi-key command; we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException( + f"{command} - all keys must map to the same key slot" + ) + + return slots.pop() + + def get_encoder(self): + """ + Get the connections' encoder + """ + return self.encoder + + def get_connection_kwargs(self): + """ + Get the connections' key-word arguments + """ + return self.nodes_manager.connection_kwargs + + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or dict. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + def execute_command(self, *args, **kwargs): + """ + Wrapper for ERRORS_ALLOW_RETRY error handling. + + It will try the number of times specified by the config option + "self.cluster_error_retry_attempts" which defaults to 3 unless manually + configured. + + If it reaches the number of times, the command will raise the exception + + Key argument :target_nodes: can be passed with the following types: + nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM + ClusterNode + list + dict + """ + target_nodes_specified = False + is_default_node = False + target_nodes = None + passed_targets = kwargs.pop("target_nodes", None) + if passed_targets is not None and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + # If an error that allows retrying was thrown, the nodes and slots + # cache were reinitialized. We will retry executing the command with + # the updated cluster setup only when the target nodes can be + # determined again with the new cache tables. Therefore, when target + # nodes were passed to this function, we cannot retry the command + # execution since the nodes may not be valid anymore after the tables + # were reinitialized. So in case of passed target nodes, + # retry_attempts will be set to 0. + retry_attempts = ( + 0 if target_nodes_specified else self.cluster_error_retry_attempts + ) + # Add one for the first execution + execute_attempts = 1 + retry_attempts + for _ in range(execute_attempts): + try: + res = {} + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = self._determine_nodes( + *args, **kwargs, nodes_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {args} command on" + ) + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + is_default_node = True + for node in target_nodes: + res[node.name] = self._execute_command(node, *args, **kwargs) + # Return the processed result + return self._process_result(args[0], res, **kwargs) + except Exception as e: + if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: + if is_default_node: + # Replace the default cluster node + self.replace_default_node() + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. + retry_attempts -= 1 + continue + else: + # raise the exception + raise e + + def _execute_command(self, target_node, *args, **kwargs): + """ + Send a command to a node in the cluster + """ + command = args[0] + redis_node = None + connection = None + redirect_addr = None + asking = False + moved = False + ttl = int(self.RedisClusterRequestTTL) + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = self.determine_slot(*args) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS + ) + moved = False + + redis_node = self.get_redis_connection(target_node) + connection = get_connection(redis_node, *args, **kwargs) + if asking: + connection.send_command("ASKING") + redis_node.parse_response(connection, "ASKING", **kwargs) + asking = False + + connection.send_command(*args) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + return response + except AuthenticationError: + raise + except (ConnectionError, TimeoutError) as e: + # Connection retries are being handled in the node's + # Retry object. + # ConnectionError can also be raised if we couldn't get a + # connection from the pool before timing out, so check that + # this is an actual connection before attempting to disconnect. + if connection is not None: + connection.disconnect() + + # Remove the failed node from the startup nodes before we try + # to reinitialize the cluster + self.nodes_manager.startup_nodes.pop(target_node.name, None) + # Reset the cluster node's connection + target_node.redis_connection = None + self.nodes_manager.initialize() + raise e + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when + # the same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + # Reset the counter + self.reinitialize_counter = 0 + else: + self.nodes_manager.update_moved_exception(e) + moved = True + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except AskError as e: + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError as e: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + time.sleep(0.25) + self.nodes_manager.initialize() + raise e + except ResponseError: + raise + except Exception as e: + if connection: + connection.disconnect() + raise e + finally: + if connection is not None: + redis_node.connection_pool.release(connection) + + raise ClusterError("TTL exhausted.") + + def close(self): + try: + with self._lock: + if self.nodes_manager: + self.nodes_manager.close() + except AttributeError: + # RedisCluster's __init__ can fail before nodes_manager is set + pass + + def _process_result(self, command, res, **kwargs): + """ + Process the result of the executed command. + The function would return a dict or a single value. + + :type command: str + :type res: dict + + `res` should be in the following format: + Dict + """ + if command in self.result_callbacks: + return self.result_callbacks[command](command, res, **kwargs) + elif len(res) == 1: + # When we execute the command on a single node, we can + # remove the dictionary and return a single response + return list(res.values())[0] + else: + return res + + def load_external_module(self, funcname, func): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + ``funcname`` - A string containing the name of the function to create + ``func`` - The function, being added to this class. + """ + setattr(self, funcname, func) + + +class ClusterNode: + def __init__(self, host, port, server_type=None, redis_connection=None): + if host == "localhost": + host = socket.gethostbyname(host) + + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + self.redis_connection = redis_connection + + def __repr__(self): + return ( + f"[host={self.host}," + f"port={self.port}," + f"name={self.name}," + f"server_type={self.server_type}," + f"redis_connection={self.redis_connection}]" + ) + + def __eq__(self, obj): + return isinstance(obj, ClusterNode) and obj.name == self.name + + def __del__(self): + if self.redis_connection is not None: + self.redis_connection.close() + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self, start_index: int = 0) -> None: + self.primary_to_idx = {} + self.start_index = start_index + + def get_server_index(self, primary: str, list_size: int) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + # Update the index + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + + def reset(self) -> None: + self.primary_to_idx.clear() + + +class NodesManager: + def __init__( + self, + startup_nodes, + from_url=False, + require_full_coverage=False, + lock=None, + dynamic_startup_nodes=True, + connection_pool_class=ConnectionPool, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, + **kwargs, + ): + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {} + self.default_node = None + self.populate_startup_nodes(startup_nodes) + self.from_url = from_url + self._require_full_coverage = require_full_coverage + self._dynamic_startup_nodes = dynamic_startup_nodes + self.connection_pool_class = connection_pool_class + self.address_remap = address_remap + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + if lock is None: + lock = threading.Lock() + self._lock = lock + self.initialize() + + def get_node(self, host=None, port=None, node_name=None): + """ + Get the requested node from the cluster's nodes. + nodes. + :return: ClusterNode if the node exists, else None + """ + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + return None + + def update_moved_exception(self, exception): + self._moved_exception = exception + + def _update_moved_slots(self): + """ + Update the slot's node with the redirected one + """ + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node is not None: + # The node already exists + if redirected_node.server_type is not PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode(e.host, e.port, PRIMARY) + self.nodes_cache[redirected_node.name] = redirected_node + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): + """ + Gets a node that servers this hash slot + """ + if self._moved_exception: + with self._lock: + if self._moved_exception: + self._update_moved_slots() + + if self.slots_cache.get(slot) is None or len(self.slots_cache[slot]) == 0: + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + f'"require_full_coverage={self._require_full_coverage}"' + ) + + if read_from_replicas is True: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + elif ( + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 + ): + # return a primary + node_idx = 0 + else: + # return a replica + # randomly choose one of the replicas + node_idx = random.randint(1, len(self.slots_cache[slot]) - 1) + + return self.slots_cache[slot][node_idx] + + def get_nodes_by_server_type(self, server_type): + """ + Get all nodes with the specified server type + :param server_type: 'primary' or 'replica' + :return: list of ClusterNode + """ + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + def populate_startup_nodes(self, nodes): + """ + Populate all startup nodes and filters out any duplicates + """ + for n in nodes: + self.startup_nodes[n.name] = n + + def check_slots_coverage(self, slots_cache): + # Validate if all slots are covered or if we should try next + # startup node + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in slots_cache: + return False + return True + + def create_redis_connections(self, nodes): + """ + This function will create a redis connection to all nodes in :nodes: + """ + for node in nodes: + if node.redis_connection is None: + node.redis_connection = self.create_redis_node( + host=node.host, port=node.port, **self.connection_kwargs + ) + + def create_redis_node(self, host, port, **kwargs): + if self.from_url: + # Create a redis node with a costumed connection pool + kwargs.update({"host": host}) + kwargs.update({"port": port}) + r = Redis(connection_pool=self.connection_pool_class(**kwargs)) + else: + r = Redis(host=host, port=port, **kwargs) + return r + + def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): + node_name = get_node_name(host, port) + # check if we already have this node in the tmp_nodes_cache + target_node = tmp_nodes_cache.get(node_name) + if target_node is None: + # before creating a new cluster node, check if the cluster node already + # exists in the current nodes cache and has a valid connection so we can + # reuse it + target_node = self.nodes_cache.get(node_name) + if target_node is None or target_node.redis_connection is None: + # create new cluster node for this cluster + target_node = ClusterNode(host, port, role) + if target_node.server_type != role: + target_node.server_type = role + + return target_node + + def initialize(self): + """ + Initializes the nodes cache, slots cache and redis connections. + :startup_nodes: + Responsible for discovering other nodes in the cluster + """ + self.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + fully_covered = False + kwargs = self.connection_kwargs + exception = None + for startup_node in self.startup_nodes.values(): + try: + if startup_node.redis_connection: + r = startup_node.redis_connection + else: + # Create a new Redis connection + r = self.create_redis_node( + startup_node.host, startup_node.port, **kwargs + ) + self.startup_nodes[startup_node.name].redis_connection = r + # Make sure cluster mode is enabled on this node + if bool(r.info().get("cluster_enabled")) is False: + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + startup_nodes_reachable = True + except Exception as e: + # Try the next startup node. + # The exception is saved and raised only if we have no more nodes. + exception = e + continue + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + primary_node = slot[2] + host = str_if_bytes(primary_node[0]) + if host == "": + host = startup_node.host + port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) + + target_node = self._get_or_create_cluster_node( + host, port, PRIMARY, tmp_nodes_cache + ) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = str_if_bytes(replica_node[0]) + port = replica_node[1] + host, port = self.remap_host_port(host, port) + + target_replica_node = self._get_or_create_cluster_node( + host, port, REPLICA, tmp_nodes_cache + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" + ) + + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f'slots cache: {", ".join(disagreements)}' + ) + + fully_covered = self.check_slots_coverage(tmp_slots) + if fully_covered: + # Don't need to continue to the next startup node if all + # slots are covered + break + + if not startup_nodes_reachable: + raise RedisClusterException( + f"Redis Cluster cannot be connected. Please provide at least " + f"one reachable node: {str(exception)}" + ) from exception + + # Create Redis connections to all nodes + self.create_redis_connections(list(tmp_nodes_cache.values())) + + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) + + # Set the tmp variables to the real variables + self.nodes_cache = tmp_nodes_cache + self.slots_cache = tmp_slots + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + if self._dynamic_startup_nodes: + # Populate the startup nodes with all discovered nodes + self.startup_nodes = tmp_nodes_cache + # If initialize was called after a MovedError, clear it + self._moved_exception = None + + def close(self): + self.default_node = None + for node in self.nodes_cache.values(): + if node.redis_connection: + node.redis_connection.close() + + def reset(self): + try: + self.read_load_balancer.reset() + except TypeError: + # The read_load_balancer is None, do nothing + pass + + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + + +class ClusterPubSub(PubSub): + """ + Wrapper for PubSub class. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + """ + + def __init__( + self, + redis_cluster, + node=None, + host=None, + port=None, + push_handler_func=None, + **kwargs, + ): + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + + :type redis_cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int + """ + self.node = None + self.set_pubsub_node(redis_cluster, node, host, port) + connection_pool = ( + None + if self.node is None + else redis_cluster.get_redis_connection(self.node).connection_pool + ) + self.cluster = redis_cluster + self.node_pubsub_mapping = {} + self._pubsubs_generator = self._pubsubs_generator() + super().__init__( + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + **kwargs, + ) + + def set_pubsub_node(self, cluster, node=None, host=None, port=None): + """ + The pubsub node will be set according to the passed node, host and port + When none of the node, host, or port are specified - the node is set + to None and will be determined by the keyslot of the channel in the + first command to be executed. + RedisClusterException will be thrown if the passed node does not exist + in the cluster. + If host is passed without port, or vice versa, a DataError will be + thrown. + :type cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int + """ + if node is not None: + # node is passed by the user + self._raise_on_invalid_node(cluster, node, node.host, node.port) + pubsub_node = node + elif host is not None and port is not None: + # host and port passed by the user + node = cluster.get_node(host=host, port=port) + self._raise_on_invalid_node(cluster, node, host, port) + pubsub_node = node + elif any([host, port]) is True: + # only 'host' or 'port' passed + raise DataError("Passing a host requires passing a port, and vice versa") + else: + # nothing passed by the user. set node to None + pubsub_node = None + + self.node = pubsub_node + + def get_pubsub_node(self): + """ + Get the node that is being used as the pubsub connection + """ + return self.node + + def _raise_on_invalid_node(self, redis_cluster, node, host, port): + """ + Raise a RedisClusterException if the node is None or doesn't exist in + the cluster. + """ + if node is None or redis_cluster.get_node(node_name=node.name) is None: + raise RedisClusterException( + f"Node {host}:{port} doesn't exist in the cluster" + ) + + def execute_command(self, *args): + """ + Execute a subscribe/unsubscribe command. + + Taken code from redis-py and tweak to make it work within a cluster. + """ + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + if self.connection_pool is None: + if len(args) > 1: + # Hash the first channel and get one of the nodes holding + # this slot + channel = args[1] + slot = self.cluster.keyslot(channel) + node = self.cluster.nodes_manager.get_node_from_slot( + slot, self.cluster.read_from_replicas + ) + else: + # Get a random node + node = self.cluster.get_random_node() + self.node = node + redis_connection = self.cluster.get_redis_connection(node) + self.connection_pool = redis_connection.connection_pool + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection._register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) + connection = self.connection + self._execute(connection, connection.send_command, *args) + + def _get_node_pubsub(self, node): + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = node.redis_connection.pubsub( + push_handler_func=self.push_handler_func + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self): + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + message = pubsub.get_message() + if message is not None: + return message + return None + + def _pubsubs_generator(self): + while True: + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub + + def get_sharded_message( + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): + if target_node: + message = self.node_pubsub_mapping[target_node.name].get_message( + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + else: + message = self._sharded_message_generator() + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if self.node_pubsub_mapping[node.name].subscribed is False: + self.node_pubsub_mapping.pop(node.name) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + def ssubscribe(self, *args, **kwargs): + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + pubsub = self._get_node_pubsub(node) + if handler: + pubsub.ssubscribe(**{s_channel: handler}) + else: + pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + if pubsub.subscribed and not self.subscribed: + self.subscribed_event.set() + self.health_check_response_counter = 0 + + def sunsubscribe(self, *args): + if args: + args = list_or_args(args[0], args[1:]) + else: + args = self.shard_channels + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + p = self._get_node_pubsub(node) + p.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + p.pending_unsubscribe_shard_channels + ) + + def get_redis_connection(self): + """ + Get the Redis connection of the pubsub connected node. + """ + if self.node is not None: + return self.node.redis_connection + + def disconnect(self): + """ + Disconnect the pubsub connection. + """ + if self.connection: + self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + pubsub.connection.disconnect() + + +class ClusterPipeline(RedisCluster): + """ + Support for Redis pipeline + in cluster mode + """ + + ERRORS_ALLOW_RETRY = ( + ConnectionError, + TimeoutError, + MovedError, + AskError, + TryAgainError, + ) + + def __init__( + self, + nodes_manager: "NodesManager", + commands_parser: "CommandsParser", + result_callbacks: Optional[Dict[str, Callable]] = None, + cluster_response_callbacks: Optional[Dict[str, Callable]] = None, + startup_nodes: Optional[List["ClusterNode"]] = None, + read_from_replicas: bool = False, + cluster_error_retry_attempts: int = 3, + reinitialize_steps: int = 5, + lock=None, + **kwargs, + ): + """ """ + self.command_stack = [] + self.nodes_manager = nodes_manager + self.commands_parser = commands_parser + self.refresh_table_asap = False + self.result_callbacks = ( + result_callbacks or self.__class__.RESULT_CALLBACKS.copy() + ) + self.startup_nodes = startup_nodes if startup_nodes else [] + self.read_from_replicas = read_from_replicas + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.cluster_response_callbacks = cluster_response_callbacks + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + if lock is None: + lock = threading.Lock() + self._lock = lock + + def __repr__(self): + """ """ + return f"{type(self).__name__}" + + def __enter__(self): + """ """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ """ + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self): + """ """ + return len(self.command_stack) + + def __bool__(self): + "Pipeline instances should always evaluate to True on Python 3+" + return True + + def execute_command(self, *args, **kwargs): + """ + Wrapper function for pipeline_execute_command + """ + return self.pipeline_execute_command(*args, **kwargs) + + def pipeline_execute_command(self, *args, **options): + """ + Appends the executed command to the pipeline's command stack + """ + self.command_stack.append( + PipelineCommand(args, options, len(self.command_stack)) + ) + return self + + def raise_first_error(self, stack): + """ + Raise the first exception on the stack + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def annotate_exception(self, exception, number, command): + """ + Provides extra context to the exception prior to it being handled + """ + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] + + def execute(self, raise_on_error=True): + """ + Execute all the commands in the current pipeline + """ + stack = self.command_stack + try: + return self.send_cluster_commands(stack, raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self.command_stack = [] + + self.scripts = set() + + # TODO: Implement + # make sure to reset the connection state in the event that we were + # watching something + # if self.watching and self.connection: + # try: + # # call this manually since our unwatch or + # # immediate_execute_command methods can call reset() + # self.connection.send_command('UNWATCH') + # self.connection.read_response() + # except ConnectionError: + # # disconnect will also remove any previous WATCHes + # self.connection.disconnect() + + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + + # TODO: Implement + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + # if self.connection: + # self.connection_pool.release(self.connection) + # self.connection = None + + def send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Wrapper for CLUSTERDOWN error handling. + + If the cluster reports it is down it is assumed that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the config option "self.cluster_error_retry_attempts" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + if not stack: + return [] + retry_attempts = self.cluster_error_retry_attempts + while True: + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except (ClusterDownError, ConnectionError) as e: + if retry_attempts > 0: + # Try again with the new cluster setup. All other errors + # should be raised. + retry_attempts -= 1 + pass + else: + raise e + + def _send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + is_default_node = False + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + while True: + # refer to our internal node -> slot table that + # tells us where a given command should route to. + # (it might be possible we have a cached node that no longer + # exists in the cluster, which is why we do this in a loop) + passed_targets = c.options.pop("target_nodes", None) + if passed_targets and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + else: + target_nodes = self._determine_nodes( + *c.args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {c.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException( + f"Too many targets for command {c.args}" + ) + + node = target_nodes[0] + if node == self.get_default_node(): + is_default_node = True + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node.name + if node_name not in nodes: + redis_node = self.get_redis_connection(node) + try: + connection = get_connection(redis_node, c.args) + except ConnectionError: + # Connection retries are being handled in the node's + # Retry object. Reinitialize the node -> slot table. + self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() + raise + nodes[node_name] = NodeCommands( + redis_node.parse_response, + redis_node.connection_pool, + connection, + ) + nodes[node_name].append(c) + break + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network essentially in parallel + # so that we can read them all in parallel as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted( + ( + c + for c in attempt + if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) + ), + key=lambda x: x.position, + ) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaing commands one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = super().execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [] + for c in sorted(stack, key=lambda x: x.position): + if c.args[0] in self.cluster_response_callbacks: + c.result = self.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) + response.append(c.result) + + if raise_on_error: + self.raise_first_error(stack) + + return response + + def _fail_on_redirect(self, allow_redirections): + """ """ + if not allow_redirections: + raise RedisClusterException( + "ASK & MOVED redirection not allowed in this pipeline" + ) + + def exists(self, *keys): + return self.execute_command("EXISTS", *keys) + + def eval(self): + """ """ + raise RedisClusterException("method eval() is not implemented") + + def multi(self): + """ """ + raise RedisClusterException("method multi() is not implemented") + + def immediate_execute_command(self, *args, **options): + """ """ + raise RedisClusterException( + "method immediate_execute_command() is not implemented" + ) + + def _execute_transaction(self, *args, **kwargs): + """ """ + raise RedisClusterException("method _execute_transaction() is not implemented") + + def load_scripts(self): + """ """ + raise RedisClusterException("method load_scripts() is not implemented") + + def watch(self, *names): + """ """ + raise RedisClusterException("method watch() is not implemented") + + def unwatch(self): + """ """ + raise RedisClusterException("method unwatch() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented" + ) + + def delete(self, *names): + """ + "Delete a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("DEL", names[0]) + + def unlink(self, *names): + """ + "Unlink a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) + + +def block_pipeline_command(name: str) -> Callable[..., Any]: + """ + Prints error because some pipelined commands should + be blocked when running in cluster-mode + """ + + def inner(*args, **kwargs): + raise RedisClusterException( + f"ERROR: Calling pipelined function {name} is blocked " + f"when running redis in cluster mode..." + ) + + return inner + + +# Blocked pipeline commands +PIPELINE_BLOCKED_COMMANDS = ( + "BGREWRITEAOF", + "BGSAVE", + "BITOP", + "BRPOPLPUSH", + "CLIENT GETNAME", + "CLIENT KILL", + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT", + "CONFIG GET", + "CONFIG RESETSTAT", + "CONFIG REWRITE", + "CONFIG SET", + "CONFIG", + "DBSIZE", + "ECHO", + "EVALSHA", + "FLUSHALL", + "FLUSHDB", + "INFO", + "KEYS", + "LASTSAVE", + "MGET", + "MGET NONATOMIC", + "MOVE", + "MSET", + "MSET NONATOMIC", + "MSETNX", + "PFCOUNT", + "PFMERGE", + "PING", + "PUBLISH", + "RANDOMKEY", + "READONLY", + "READWRITE", + "RENAME", + "RENAMENX", + "RPOPLPUSH", + "SAVE", + "SCAN", + "SCRIPT EXISTS", + "SCRIPT FLUSH", + "SCRIPT KILL", + "SCRIPT LOAD", + "SCRIPT", + "SDIFF", + "SDIFFSTORE", + "SENTINEL GET MASTER ADDR BY NAME", + "SENTINEL MASTER", + "SENTINEL MASTERS", + "SENTINEL MONITOR", + "SENTINEL REMOVE", + "SENTINEL SENTINELS", + "SENTINEL SET", + "SENTINEL SLAVES", + "SENTINEL", + "SHUTDOWN", + "SINTER", + "SINTERSTORE", + "SLAVEOF", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "SLOWLOG", + "SMOVE", + "SORT", + "SUNION", + "SUNIONSTORE", + "TIME", +) +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + """ """ + + def __init__(self, args, options=None, position=None): + self.args = args + if options is None: + options = {} + self.options = options + self.position = position + self.result = None + self.node = None + self.asking = False + + +class NodeCommands: + """ """ + + def __init__(self, parse_response, connection_pool, connection): + """ """ + self.parse_response = parse_response + self.connection_pool = connection_pool + self.connection = connection + self.commands = [] + + def append(self, c): + """ """ + self.commands.append(c) + + def write(self): + """ + Code borrowed from Redis so it can be fixed + """ + connection = self.connection + commands = self.commands + + # We are going to clobber the commands with the write, so go ahead + # and ensure that nothing is sitting there from a previous run. + for c in commands: + c.result = None + + # build up all commands into a single request to increase network perf + # send all the commands and catch connection and timeout errors. + try: + connection.send_packed_command( + connection.pack_commands([c.args for c in commands]) + ) + except (ConnectionError, TimeoutError) as e: + for c in commands: + c.result = e + + def read(self): + """ """ + connection = self.connection + for c in self.commands: + # if there is a result on this command, + # it means we ran into an exception + # like a connection error. Trying to parse + # a response on a connection that + # is no longer open will result in a + # connection error raised by redis-py. + # but redis-py doesn't check in parse_response + # that the sock object is + # still set and if you try to + # read from a closed connection, it will + # result in an AttributeError because + # it will do a readline() call on None. + # This can have all kinds of nasty side-effects. + # Treating this case as a connection error + # is fine because it will dump + # the connection object back into the + # pool and on the next write, it will + # explicitly open the connection and all will be well. + if c.result is None: + try: + c.result = self.parse_response(connection, c.args[0], **c.options) + except (ConnectionError, TimeoutError) as e: + for c in self.commands: + c.result = e + return + except RedisError: + c.result = sys.exc_info()[1] diff --git a/.venv/Lib/site-packages/redis/commands/__init__.py b/.venv/Lib/site-packages/redis/commands/__init__.py new file mode 100644 index 00000000..a94d9764 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/__init__.py @@ -0,0 +1,18 @@ +from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands +from .core import AsyncCoreCommands, CoreCommands +from .helpers import list_or_args +from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands +from .sentinel import AsyncSentinelCommands, SentinelCommands + +__all__ = [ + "AsyncCoreCommands", + "AsyncRedisClusterCommands", + "AsyncRedisModuleCommands", + "AsyncSentinelCommands", + "CoreCommands", + "READ_COMMANDS", + "RedisClusterCommands", + "RedisModuleCommands", + "SentinelCommands", + "list_or_args", +] diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..7e1bfc79 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/cluster.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/cluster.cpython-311.pyc new file mode 100644 index 00000000..c1638fe3 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/cluster.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/core.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/core.cpython-311.pyc new file mode 100644 index 00000000..2bcbe180 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/helpers.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 00000000..bb62bef2 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/helpers.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/redismodules.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/redismodules.cpython-311.pyc new file mode 100644 index 00000000..2d8ab6a7 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/redismodules.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/__pycache__/sentinel.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/__pycache__/sentinel.cpython-311.pyc new file mode 100644 index 00000000..b875aa16 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/__pycache__/sentinel.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/bf/__init__.py b/.venv/Lib/site-packages/redis/commands/bf/__init__.py new file mode 100644 index 00000000..959358f8 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/bf/__init__.py @@ -0,0 +1,253 @@ +from redis._parsers.helpers import bool_ok + +from ..helpers import get_protocol_version, parse_to_list +from .commands import * # noqa +from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo + + +class AbstractBloom(object): + """ + The client allows to interact with RedisBloom and use all of + it's functionality. + + - BF for Bloom Filter + - CF for Cuckoo Filter + - CMS for Count-Min Sketch + - TOPK for TopK Data Structure + - TDIGEST for estimate rank statistics + """ + + @staticmethod + def append_items(params, items): + """Append ITEMS to params.""" + params.extend(["ITEMS"]) + params += items + + @staticmethod + def append_error(params, error): + """Append ERROR to params.""" + if error is not None: + params.extend(["ERROR", error]) + + @staticmethod + def append_capacity(params, capacity): + """Append CAPACITY to params.""" + if capacity is not None: + params.extend(["CAPACITY", capacity]) + + @staticmethod + def append_expansion(params, expansion): + """Append EXPANSION to params.""" + if expansion is not None: + params.extend(["EXPANSION", expansion]) + + @staticmethod + def append_no_scale(params, noScale): + """Append NONSCALING tag to params.""" + if noScale is not None: + params.extend(["NONSCALING"]) + + @staticmethod + def append_weights(params, weights): + """Append WEIGHTS to params.""" + if len(weights) > 0: + params.append("WEIGHTS") + params += weights + + @staticmethod + def append_no_create(params, noCreate): + """Append NOCREATE tag to params.""" + if noCreate is not None: + params.extend(["NOCREATE"]) + + @staticmethod + def append_items_and_increments(params, items, increments): + """Append pairs of items and increments to params.""" + for i in range(len(items)): + params.append(items[i]) + params.append(increments[i]) + + @staticmethod + def append_values_and_weights(params, items, weights): + """Append pairs of items and weights to params.""" + for i in range(len(items)): + params.append(items[i]) + params.append(weights[i]) + + @staticmethod + def append_max_iterations(params, max_iterations): + """Append MAXITERATIONS to params.""" + if max_iterations is not None: + params.extend(["MAXITERATIONS", max_iterations]) + + @staticmethod + def append_bucket_size(params, bucket_size): + """Append BUCKETSIZE to params.""" + if bucket_size is not None: + params.extend(["BUCKETSIZE", bucket_size]) + + +class CMSBloom(CMSCommands, AbstractBloom): + def __init__(self, client, **kwargs): + """Create a new RedisBloom client.""" + # Set the module commands' callbacks + _MODULE_CALLBACKS = { + CMS_INITBYDIM: bool_ok, + CMS_INITBYPROB: bool_ok, + # CMS_INCRBY: spaceHolder, + # CMS_QUERY: spaceHolder, + CMS_MERGE: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { + CMS_INFO: CMSInfo, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.commandmixin = CMSCommands + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) + + +class TOPKBloom(TOPKCommands, AbstractBloom): + def __init__(self, client, **kwargs): + """Create a new RedisBloom client.""" + # Set the module commands' callbacks + _MODULE_CALLBACKS = { + TOPK_RESERVE: bool_ok, + # TOPK_QUERY: spaceHolder, + # TOPK_COUNT: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + TOPK_ADD: parse_to_list, + TOPK_INCRBY: parse_to_list, + TOPK_INFO: TopKInfo, + TOPK_LIST: parse_to_list, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.commandmixin = TOPKCommands + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) + + +class CFBloom(CFCommands, AbstractBloom): + def __init__(self, client, **kwargs): + """Create a new RedisBloom client.""" + # Set the module commands' callbacks + _MODULE_CALLBACKS = { + CF_RESERVE: bool_ok, + # CF_ADD: spaceHolder, + # CF_ADDNX: spaceHolder, + # CF_INSERT: spaceHolder, + # CF_INSERTNX: spaceHolder, + # CF_EXISTS: spaceHolder, + # CF_DEL: spaceHolder, + # CF_COUNT: spaceHolder, + # CF_SCANDUMP: spaceHolder, + # CF_LOADCHUNK: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + CF_INFO: CFInfo, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.commandmixin = CFCommands + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) + + +class TDigestBloom(TDigestCommands, AbstractBloom): + def __init__(self, client, **kwargs): + """Create a new RedisBloom client.""" + # Set the module commands' callbacks + _MODULE_CALLBACKS = { + TDIGEST_CREATE: bool_ok, + # TDIGEST_RESET: bool_ok, + # TDIGEST_ADD: spaceHolder, + # TDIGEST_MERGE: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + TDIGEST_BYRANK: parse_to_list, + TDIGEST_BYREVRANK: parse_to_list, + TDIGEST_CDF: parse_to_list, + TDIGEST_INFO: TDigestInfo, + TDIGEST_MIN: float, + TDIGEST_MAX: float, + TDIGEST_TRIMMED_MEAN: float, + TDIGEST_QUANTILE: parse_to_list, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.commandmixin = TDigestCommands + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) + + +class BFBloom(BFCommands, AbstractBloom): + def __init__(self, client, **kwargs): + """Create a new RedisBloom client.""" + # Set the module commands' callbacks + _MODULE_CALLBACKS = { + BF_RESERVE: bool_ok, + # BF_ADD: spaceHolder, + # BF_MADD: spaceHolder, + # BF_INSERT: spaceHolder, + # BF_EXISTS: spaceHolder, + # BF_MEXISTS: spaceHolder, + # BF_SCANDUMP: spaceHolder, + # BF_LOADCHUNK: spaceHolder, + # BF_CARD: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + BF_INFO: BFInfo, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.commandmixin = BFCommands + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) diff --git a/.venv/Lib/site-packages/redis/commands/bf/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..352d33c5 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/bf/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..ee4fa6bc Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/bf/__pycache__/info.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/info.cpython-311.pyc new file mode 100644 index 00000000..a90459ea Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/bf/__pycache__/info.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/bf/commands.py b/.venv/Lib/site-packages/redis/commands/bf/commands.py new file mode 100644 index 00000000..447f8445 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/bf/commands.py @@ -0,0 +1,542 @@ +from redis.client import NEVER_DECODE +from redis.exceptions import ModuleError +from redis.utils import HIREDIS_AVAILABLE, deprecated_function + +BF_RESERVE = "BF.RESERVE" +BF_ADD = "BF.ADD" +BF_MADD = "BF.MADD" +BF_INSERT = "BF.INSERT" +BF_EXISTS = "BF.EXISTS" +BF_MEXISTS = "BF.MEXISTS" +BF_SCANDUMP = "BF.SCANDUMP" +BF_LOADCHUNK = "BF.LOADCHUNK" +BF_INFO = "BF.INFO" +BF_CARD = "BF.CARD" + +CF_RESERVE = "CF.RESERVE" +CF_ADD = "CF.ADD" +CF_ADDNX = "CF.ADDNX" +CF_INSERT = "CF.INSERT" +CF_INSERTNX = "CF.INSERTNX" +CF_EXISTS = "CF.EXISTS" +CF_MEXISTS = "CF.MEXISTS" +CF_DEL = "CF.DEL" +CF_COUNT = "CF.COUNT" +CF_SCANDUMP = "CF.SCANDUMP" +CF_LOADCHUNK = "CF.LOADCHUNK" +CF_INFO = "CF.INFO" + +CMS_INITBYDIM = "CMS.INITBYDIM" +CMS_INITBYPROB = "CMS.INITBYPROB" +CMS_INCRBY = "CMS.INCRBY" +CMS_QUERY = "CMS.QUERY" +CMS_MERGE = "CMS.MERGE" +CMS_INFO = "CMS.INFO" + +TOPK_RESERVE = "TOPK.RESERVE" +TOPK_ADD = "TOPK.ADD" +TOPK_INCRBY = "TOPK.INCRBY" +TOPK_QUERY = "TOPK.QUERY" +TOPK_COUNT = "TOPK.COUNT" +TOPK_LIST = "TOPK.LIST" +TOPK_INFO = "TOPK.INFO" + +TDIGEST_CREATE = "TDIGEST.CREATE" +TDIGEST_RESET = "TDIGEST.RESET" +TDIGEST_ADD = "TDIGEST.ADD" +TDIGEST_MERGE = "TDIGEST.MERGE" +TDIGEST_CDF = "TDIGEST.CDF" +TDIGEST_QUANTILE = "TDIGEST.QUANTILE" +TDIGEST_MIN = "TDIGEST.MIN" +TDIGEST_MAX = "TDIGEST.MAX" +TDIGEST_INFO = "TDIGEST.INFO" +TDIGEST_TRIMMED_MEAN = "TDIGEST.TRIMMED_MEAN" +TDIGEST_RANK = "TDIGEST.RANK" +TDIGEST_REVRANK = "TDIGEST.REVRANK" +TDIGEST_BYRANK = "TDIGEST.BYRANK" +TDIGEST_BYREVRANK = "TDIGEST.BYREVRANK" + + +class BFCommands: + """Bloom Filter commands.""" + + def create(self, key, errorRate, capacity, expansion=None, noScale=None): + """ + Create a new Bloom Filter `key` with desired probability of false positives + `errorRate` expected entries to be inserted as `capacity`. + Default expansion value is 2. By default, filter is auto-scaling. + For more information see `BF.RESERVE `_. + """ # noqa + params = [key, errorRate, capacity] + self.append_expansion(params, expansion) + self.append_no_scale(params, noScale) + return self.execute_command(BF_RESERVE, *params) + + reserve = create + + def add(self, key, item): + """ + Add to a Bloom Filter `key` an `item`. + For more information see `BF.ADD `_. + """ # noqa + return self.execute_command(BF_ADD, key, item) + + def madd(self, key, *items): + """ + Add to a Bloom Filter `key` multiple `items`. + For more information see `BF.MADD `_. + """ # noqa + return self.execute_command(BF_MADD, key, *items) + + def insert( + self, + key, + items, + capacity=None, + error=None, + noCreate=None, + expansion=None, + noScale=None, + ): + """ + Add to a Bloom Filter `key` multiple `items`. + + If `nocreate` remain `None` and `key` does not exist, a new Bloom Filter + `key` will be created with desired probability of false positives `errorRate` + and expected entries to be inserted as `size`. + For more information see `BF.INSERT `_. + """ # noqa + params = [key] + self.append_capacity(params, capacity) + self.append_error(params, error) + self.append_expansion(params, expansion) + self.append_no_create(params, noCreate) + self.append_no_scale(params, noScale) + self.append_items(params, items) + + return self.execute_command(BF_INSERT, *params) + + def exists(self, key, item): + """ + Check whether an `item` exists in Bloom Filter `key`. + For more information see `BF.EXISTS `_. + """ # noqa + return self.execute_command(BF_EXISTS, key, item) + + def mexists(self, key, *items): + """ + Check whether `items` exist in Bloom Filter `key`. + For more information see `BF.MEXISTS `_. + """ # noqa + return self.execute_command(BF_MEXISTS, key, *items) + + def scandump(self, key, iter): + """ + Begin an incremental save of the bloom filter `key`. + + This is useful for large bloom filters which cannot fit into the normal SAVE and RESTORE model. + The first time this command is called, the value of `iter` should be 0. + This command will return successive (iter, data) pairs until (0, NULL) to indicate completion. + For more information see `BF.SCANDUMP `_. + """ # noqa + if HIREDIS_AVAILABLE: + raise ModuleError("This command cannot be used when hiredis is available.") + + params = [key, iter] + options = {} + options[NEVER_DECODE] = [] + return self.execute_command(BF_SCANDUMP, *params, **options) + + def loadchunk(self, key, iter, data): + """ + Restore a filter previously saved using SCANDUMP. + + See the SCANDUMP command for example usage. + This command will overwrite any bloom filter stored under key. + Ensure that the bloom filter will not be modified between invocations. + For more information see `BF.LOADCHUNK `_. + """ # noqa + return self.execute_command(BF_LOADCHUNK, key, iter, data) + + def info(self, key): + """ + Return capacity, size, number of filters, number of items inserted, and expansion rate. + For more information see `BF.INFO `_. + """ # noqa + return self.execute_command(BF_INFO, key) + + def card(self, key): + """ + Returns the cardinality of a Bloom filter - number of items that were added to a Bloom filter and detected as unique + (items that caused at least one bit to be set in at least one sub-filter). + For more information see `BF.CARD `_. + """ # noqa + return self.execute_command(BF_CARD, key) + + +class CFCommands: + """Cuckoo Filter commands.""" + + def create( + self, key, capacity, expansion=None, bucket_size=None, max_iterations=None + ): + """ + Create a new Cuckoo Filter `key` an initial `capacity` items. + For more information see `CF.RESERVE `_. + """ # noqa + params = [key, capacity] + self.append_expansion(params, expansion) + self.append_bucket_size(params, bucket_size) + self.append_max_iterations(params, max_iterations) + return self.execute_command(CF_RESERVE, *params) + + reserve = create + + def add(self, key, item): + """ + Add an `item` to a Cuckoo Filter `key`. + For more information see `CF.ADD `_. + """ # noqa + return self.execute_command(CF_ADD, key, item) + + def addnx(self, key, item): + """ + Add an `item` to a Cuckoo Filter `key` only if item does not yet exist. + Command might be slower that `add`. + For more information see `CF.ADDNX `_. + """ # noqa + return self.execute_command(CF_ADDNX, key, item) + + def insert(self, key, items, capacity=None, nocreate=None): + """ + Add multiple `items` to a Cuckoo Filter `key`, allowing the filter + to be created with a custom `capacity` if it does not yet exist. + `items` must be provided as a list. + For more information see `CF.INSERT `_. + """ # noqa + params = [key] + self.append_capacity(params, capacity) + self.append_no_create(params, nocreate) + self.append_items(params, items) + return self.execute_command(CF_INSERT, *params) + + def insertnx(self, key, items, capacity=None, nocreate=None): + """ + Add multiple `items` to a Cuckoo Filter `key` only if they do not exist yet, + allowing the filter to be created with a custom `capacity` if it does not yet exist. + `items` must be provided as a list. + For more information see `CF.INSERTNX `_. + """ # noqa + params = [key] + self.append_capacity(params, capacity) + self.append_no_create(params, nocreate) + self.append_items(params, items) + return self.execute_command(CF_INSERTNX, *params) + + def exists(self, key, item): + """ + Check whether an `item` exists in Cuckoo Filter `key`. + For more information see `CF.EXISTS `_. + """ # noqa + return self.execute_command(CF_EXISTS, key, item) + + def mexists(self, key, *items): + """ + Check whether an `items` exist in Cuckoo Filter `key`. + For more information see `CF.MEXISTS `_. + """ # noqa + return self.execute_command(CF_MEXISTS, key, *items) + + def delete(self, key, item): + """ + Delete `item` from `key`. + For more information see `CF.DEL `_. + """ # noqa + return self.execute_command(CF_DEL, key, item) + + def count(self, key, item): + """ + Return the number of times an `item` may be in the `key`. + For more information see `CF.COUNT `_. + """ # noqa + return self.execute_command(CF_COUNT, key, item) + + def scandump(self, key, iter): + """ + Begin an incremental save of the Cuckoo filter `key`. + + This is useful for large Cuckoo filters which cannot fit into the normal + SAVE and RESTORE model. + The first time this command is called, the value of `iter` should be 0. + This command will return successive (iter, data) pairs until + (0, NULL) to indicate completion. + For more information see `CF.SCANDUMP `_. + """ # noqa + return self.execute_command(CF_SCANDUMP, key, iter) + + def loadchunk(self, key, iter, data): + """ + Restore a filter previously saved using SCANDUMP. See the SCANDUMP command for example usage. + + This command will overwrite any Cuckoo filter stored under key. + Ensure that the Cuckoo filter will not be modified between invocations. + For more information see `CF.LOADCHUNK `_. + """ # noqa + return self.execute_command(CF_LOADCHUNK, key, iter, data) + + def info(self, key): + """ + Return size, number of buckets, number of filter, number of items inserted, + number of items deleted, bucket size, expansion rate, and max iteration. + For more information see `CF.INFO `_. + """ # noqa + return self.execute_command(CF_INFO, key) + + +class TOPKCommands: + """TOP-k Filter commands.""" + + def reserve(self, key, k, width, depth, decay): + """ + Create a new Top-K Filter `key` with desired probability of false + positives `errorRate` expected entries to be inserted as `size`. + For more information see `TOPK.RESERVE `_. + """ # noqa + return self.execute_command(TOPK_RESERVE, key, k, width, depth, decay) + + def add(self, key, *items): + """ + Add one `item` or more to a Top-K Filter `key`. + For more information see `TOPK.ADD `_. + """ # noqa + return self.execute_command(TOPK_ADD, key, *items) + + def incrby(self, key, items, increments): + """ + Add/increase `items` to a Top-K Sketch `key` by ''increments''. + Both `items` and `increments` are lists. + For more information see `TOPK.INCRBY `_. + + Example: + + >>> topkincrby('A', ['foo'], [1]) + """ # noqa + params = [key] + self.append_items_and_increments(params, items, increments) + return self.execute_command(TOPK_INCRBY, *params) + + def query(self, key, *items): + """ + Check whether one `item` or more is a Top-K item at `key`. + For more information see `TOPK.QUERY `_. + """ # noqa + return self.execute_command(TOPK_QUERY, key, *items) + + @deprecated_function(version="4.4.0", reason="deprecated since redisbloom 2.4.0") + def count(self, key, *items): + """ + Return count for one `item` or more from `key`. + For more information see `TOPK.COUNT `_. + """ # noqa + return self.execute_command(TOPK_COUNT, key, *items) + + def list(self, key, withcount=False): + """ + Return full list of items in Top-K list of `key`. + If `withcount` set to True, return full list of items + with probabilistic count in Top-K list of `key`. + For more information see `TOPK.LIST `_. + """ # noqa + params = [key] + if withcount: + params.append("WITHCOUNT") + return self.execute_command(TOPK_LIST, *params) + + def info(self, key): + """ + Return k, width, depth and decay values of `key`. + For more information see `TOPK.INFO `_. + """ # noqa + return self.execute_command(TOPK_INFO, key) + + +class TDigestCommands: + def create(self, key, compression=100): + """ + Allocate the memory and initialize the t-digest. + For more information see `TDIGEST.CREATE `_. + """ # noqa + return self.execute_command(TDIGEST_CREATE, key, "COMPRESSION", compression) + + def reset(self, key): + """ + Reset the sketch `key` to zero - empty out the sketch and re-initialize it. + For more information see `TDIGEST.RESET `_. + """ # noqa + return self.execute_command(TDIGEST_RESET, key) + + def add(self, key, values): + """ + Adds one or more observations to a t-digest sketch `key`. + + For more information see `TDIGEST.ADD `_. + """ # noqa + return self.execute_command(TDIGEST_ADD, key, *values) + + def merge(self, destination_key, num_keys, *keys, compression=None, override=False): + """ + Merges all of the values from `keys` to 'destination-key' sketch. + It is mandatory to provide the `num_keys` before passing the input keys and + the other (optional) arguments. + If `destination_key` already exists its values are merged with the input keys. + If you wish to override the destination key contents use the `OVERRIDE` parameter. + + For more information see `TDIGEST.MERGE `_. + """ # noqa + params = [destination_key, num_keys, *keys] + if compression is not None: + params.extend(["COMPRESSION", compression]) + if override: + params.append("OVERRIDE") + return self.execute_command(TDIGEST_MERGE, *params) + + def min(self, key): + """ + Return minimum value from the sketch `key`. Will return DBL_MAX if the sketch is empty. + For more information see `TDIGEST.MIN `_. + """ # noqa + return self.execute_command(TDIGEST_MIN, key) + + def max(self, key): + """ + Return maximum value from the sketch `key`. Will return DBL_MIN if the sketch is empty. + For more information see `TDIGEST.MAX `_. + """ # noqa + return self.execute_command(TDIGEST_MAX, key) + + def quantile(self, key, quantile, *quantiles): + """ + Returns estimates of one or more cutoffs such that a specified fraction of the + observations added to this t-digest would be less than or equal to each of the + specified cutoffs. (Multiple quantiles can be returned with one call) + For more information see `TDIGEST.QUANTILE `_. + """ # noqa + return self.execute_command(TDIGEST_QUANTILE, key, quantile, *quantiles) + + def cdf(self, key, value, *values): + """ + Return double fraction of all points added which are <= value. + For more information see `TDIGEST.CDF `_. + """ # noqa + return self.execute_command(TDIGEST_CDF, key, value, *values) + + def info(self, key): + """ + Return Compression, Capacity, Merged Nodes, Unmerged Nodes, Merged Weight, Unmerged Weight + and Total Compressions. + For more information see `TDIGEST.INFO `_. + """ # noqa + return self.execute_command(TDIGEST_INFO, key) + + def trimmed_mean(self, key, low_cut_quantile, high_cut_quantile): + """ + Return mean value from the sketch, excluding observation values outside + the low and high cutoff quantiles. + For more information see `TDIGEST.TRIMMED_MEAN `_. + """ # noqa + return self.execute_command( + TDIGEST_TRIMMED_MEAN, key, low_cut_quantile, high_cut_quantile + ) + + def rank(self, key, value, *values): + """ + Retrieve the estimated rank of value (the number of observations in the sketch + that are smaller than value + half the number of observations that are equal to value). + + For more information see `TDIGEST.RANK `_. + """ # noqa + return self.execute_command(TDIGEST_RANK, key, value, *values) + + def revrank(self, key, value, *values): + """ + Retrieve the estimated rank of value (the number of observations in the sketch + that are larger than value + half the number of observations that are equal to value). + + For more information see `TDIGEST.REVRANK `_. + """ # noqa + return self.execute_command(TDIGEST_REVRANK, key, value, *values) + + def byrank(self, key, rank, *ranks): + """ + Retrieve an estimation of the value with the given rank. + + For more information see `TDIGEST.BY_RANK `_. + """ # noqa + return self.execute_command(TDIGEST_BYRANK, key, rank, *ranks) + + def byrevrank(self, key, rank, *ranks): + """ + Retrieve an estimation of the value with the given reverse rank. + + For more information see `TDIGEST.BY_REVRANK `_. + """ # noqa + return self.execute_command(TDIGEST_BYREVRANK, key, rank, *ranks) + + +class CMSCommands: + """Count-Min Sketch Commands""" + + def initbydim(self, key, width, depth): + """ + Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user. + For more information see `CMS.INITBYDIM `_. + """ # noqa + return self.execute_command(CMS_INITBYDIM, key, width, depth) + + def initbyprob(self, key, error, probability): + """ + Initialize a Count-Min Sketch `key` to characteristics (`error`, `probability`) specified by user. + For more information see `CMS.INITBYPROB `_. + """ # noqa + return self.execute_command(CMS_INITBYPROB, key, error, probability) + + def incrby(self, key, items, increments): + """ + Add/increase `items` to a Count-Min Sketch `key` by ''increments''. + Both `items` and `increments` are lists. + For more information see `CMS.INCRBY `_. + + Example: + + >>> cmsincrby('A', ['foo'], [1]) + """ # noqa + params = [key] + self.append_items_and_increments(params, items, increments) + return self.execute_command(CMS_INCRBY, *params) + + def query(self, key, *items): + """ + Return count for an `item` from `key`. Multiple items can be queried with one call. + For more information see `CMS.QUERY `_. + """ # noqa + return self.execute_command(CMS_QUERY, key, *items) + + def merge(self, destKey, numKeys, srcKeys, weights=[]): + """ + Merge `numKeys` of sketches into `destKey`. Sketches specified in `srcKeys`. + All sketches must have identical width and depth. + `Weights` can be used to multiply certain sketches. Default weight is 1. + Both `srcKeys` and `weights` are lists. + For more information see `CMS.MERGE `_. + """ # noqa + params = [destKey, numKeys] + params += srcKeys + self.append_weights(params, weights) + return self.execute_command(CMS_MERGE, *params) + + def info(self, key): + """ + Return width, depth and total count of the sketch. + For more information see `CMS.INFO `_. + """ # noqa + return self.execute_command(CMS_INFO, key) diff --git a/.venv/Lib/site-packages/redis/commands/bf/info.py b/.venv/Lib/site-packages/redis/commands/bf/info.py new file mode 100644 index 00000000..e1f02086 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/bf/info.py @@ -0,0 +1,120 @@ +from ..helpers import nativestr + + +class BFInfo(object): + capacity = None + size = None + filterNum = None + insertedNum = None + expansionRate = None + + def __init__(self, args): + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.capacity = response["Capacity"] + self.size = response["Size"] + self.filterNum = response["Number of filters"] + self.insertedNum = response["Number of items inserted"] + self.expansionRate = response["Expansion rate"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + + +class CFInfo(object): + size = None + bucketNum = None + filterNum = None + insertedNum = None + deletedNum = None + bucketSize = None + expansionRate = None + maxIteration = None + + def __init__(self, args): + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.size = response["Size"] + self.bucketNum = response["Number of buckets"] + self.filterNum = response["Number of filters"] + self.insertedNum = response["Number of items inserted"] + self.deletedNum = response["Number of items deleted"] + self.bucketSize = response["Bucket size"] + self.expansionRate = response["Expansion rate"] + self.maxIteration = response["Max iterations"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + + +class CMSInfo(object): + width = None + depth = None + count = None + + def __init__(self, args): + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.width = response["width"] + self.depth = response["depth"] + self.count = response["count"] + + def __getitem__(self, item): + return getattr(self, item) + + +class TopKInfo(object): + k = None + width = None + depth = None + decay = None + + def __init__(self, args): + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.k = response["k"] + self.width = response["width"] + self.depth = response["depth"] + self.decay = response["decay"] + + def __getitem__(self, item): + return getattr(self, item) + + +class TDigestInfo(object): + compression = None + capacity = None + merged_nodes = None + unmerged_nodes = None + merged_weight = None + unmerged_weight = None + total_compressions = None + memory_usage = None + + def __init__(self, args): + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.compression = response["Compression"] + self.capacity = response["Capacity"] + self.merged_nodes = response["Merged nodes"] + self.unmerged_nodes = response["Unmerged nodes"] + self.merged_weight = response["Merged weight"] + self.unmerged_weight = response["Unmerged weight"] + self.total_compressions = response["Total compressions"] + self.memory_usage = response["Memory usage"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/.venv/Lib/site-packages/redis/commands/cluster.py b/.venv/Lib/site-packages/redis/commands/cluster.py new file mode 100644 index 00000000..14b87414 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/cluster.py @@ -0,0 +1,928 @@ +import asyncio +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterable, + Iterator, + List, + Mapping, + NoReturn, + Optional, + Union, +) + +from redis.compat import Literal +from redis.crc import key_slot +from redis.exceptions import RedisClusterException, RedisError +from redis.typing import ( + AnyKeyT, + ClusterCommandsProtocol, + EncodableT, + KeysT, + KeyT, + PatternT, +) + +from .core import ( + ACLCommands, + AsyncACLCommands, + AsyncDataAccessCommands, + AsyncFunctionCommands, + AsyncGearsCommands, + AsyncManagementCommands, + AsyncModuleCommands, + AsyncScriptCommands, + DataAccessCommands, + FunctionCommands, + GearsCommands, + ManagementCommands, + ModuleCommands, + PubSubCommands, + ResponseT, + ScriptCommands, +) +from .helpers import list_or_args +from .redismodules import RedisModuleCommands + +if TYPE_CHECKING: + from redis.asyncio.cluster import TargetNodesT + +# Not complete, but covers the major ones +# https://redis.io/commands +READ_COMMANDS = frozenset( + [ + "BITCOUNT", + "BITPOS", + "EVAL_RO", + "EVALSHA_RO", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", + ] +) + + +class ClusterMultiKeyCommands(ClusterCommandsProtocol): + """ + A class containing commands that handle more than one key + """ + + def _partition_keys_by_slot(self, keys: Iterable[KeyT]) -> Dict[int, List[KeyT]]: + """Split keys into a dictionary that maps a slot to a list of keys.""" + + slots_to_keys = {} + for key in keys: + slot = key_slot(self.encoder.encode(key)) + slots_to_keys.setdefault(slot, []).append(key) + + return slots_to_keys + + def _partition_pairs_by_slot( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> Dict[int, List[EncodableT]]: + """Split pairs into a dictionary that maps a slot to a list of pairs.""" + + slots_to_pairs = {} + for pair in mapping.items(): + slot = key_slot(self.encoder.encode(pair[0])) + slots_to_pairs.setdefault(slot, []).extend(pair) + + return slots_to_pairs + + def _execute_pipeline_by_slot( + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + ) -> List[Any]: + read_from_replicas = self.read_from_replicas and command in READ_COMMANDS + pipe = self.pipeline() + [ + pipe.execute_command( + command, + *slot_args, + target_nodes=[ + self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + ], + ) + for slot, slot_args in slots_to_args.items() + ] + return pipe.execute() + + def _reorder_keys_by_command( + self, + keys: Iterable[KeyT], + slots_to_args: Mapping[int, Iterable[EncodableT]], + responses: Iterable[Any], + ) -> List[Any]: + results = { + k: v + for slot_values, response in zip(slots_to_args.values(), responses) + for k, v in zip(slot_values, response) + } + return [results[key] for key in keys] + + def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + + For more information see https://redis.io/commands/mget + """ + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Execute commands using a pipeline + res = self._execute_pipeline_by_slot("MGET", slots_to_keys) + + # Reorder keys in the order the user provided & return + return self._reorder_keys_by_command(keys, slots_to_keys, res) + + def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + For more information see https://redis.io/commands/mset + """ + + # Partition the keys by slot + slots_to_pairs = self._partition_pairs_by_slot(mapping) + + # Execute commands using a pipeline & return list of replies + return self._execute_pipeline_by_slot("MSET", slots_to_pairs) + + def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + return sum(self._execute_pipeline_by_slot(command, slots_to_keys)) + + def exists(self, *keys: KeyT) -> ResponseT: + """ + Returns the number of ``names`` that exist in the + whole cluster. The keys are first split up into slots + and then an EXISTS command is sent for every slot + + For more information see https://redis.io/commands/exists + """ + return self._split_command_across_slots("EXISTS", *keys) + + def delete(self, *keys: KeyT) -> ResponseT: + """ + Deletes the given keys in the cluster. + The keys are first split up into slots + and then an DEL command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were deleted. + + For more information see https://redis.io/commands/del + """ + return self._split_command_across_slots("DEL", *keys) + + def touch(self, *keys: KeyT) -> ResponseT: + """ + Updates the last access time of given keys across the + cluster. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were touched. + + For more information see https://redis.io/commands/touch + """ + return self._split_command_across_slots("TOUCH", *keys) + + def unlink(self, *keys: KeyT) -> ResponseT: + """ + Remove the specified keys in a different thread. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were unlinked. + + For more information see https://redis.io/commands/unlink + """ + return self._split_command_across_slots("UNLINK", *keys) + + +class AsyncClusterMultiKeyCommands(ClusterMultiKeyCommands): + """ + A class containing commands that handle more than one key + """ + + async def mget_nonatomic(self, keys: KeysT, *args: KeyT) -> List[Optional[Any]]: + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + + For more information see https://redis.io/commands/mget + """ + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Execute commands using a pipeline + res = await self._execute_pipeline_by_slot("MGET", slots_to_keys) + + # Reorder keys in the order the user provided & return + return self._reorder_keys_by_command(keys, slots_to_keys, res) + + async def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> List[bool]: + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + For more information see https://redis.io/commands/mset + """ + + # Partition the keys by slot + slots_to_pairs = self._partition_pairs_by_slot(mapping) + + # Execute commands using a pipeline & return list of replies + return await self._execute_pipeline_by_slot("MSET", slots_to_pairs) + + async def _split_command_across_slots(self, command: str, *keys: KeyT) -> int: + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + return sum(await self._execute_pipeline_by_slot(command, slots_to_keys)) + + async def _execute_pipeline_by_slot( + self, command: str, slots_to_args: Mapping[int, Iterable[EncodableT]] + ) -> List[Any]: + if self._initialize: + await self.initialize() + read_from_replicas = self.read_from_replicas and command in READ_COMMANDS + pipe = self.pipeline() + [ + pipe.execute_command( + command, + *slot_args, + target_nodes=[ + self.nodes_manager.get_node_from_slot(slot, read_from_replicas) + ], + ) + for slot, slot_args in slots_to_args.items() + ] + return await pipe.execute() + + +class ClusterManagementCommands(ManagementCommands): + """ + A class for Redis Cluster management commands + + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ + + def slaveof(self, *args, **kwargs) -> NoReturn: + """ + Make the server a replica of another instance, or promote it as master. + + For more information see https://redis.io/commands/slaveof + """ + raise RedisClusterException("SLAVEOF is not supported in cluster mode") + + def replicaof(self, *args, **kwargs) -> NoReturn: + """ + Make the server a replica of another instance, or promote it as master. + + For more information see https://redis.io/commands/replicaof + """ + raise RedisClusterException("REPLICAOF is not supported in cluster mode") + + def swapdb(self, *args, **kwargs) -> NoReturn: + """ + Swaps two Redis databases. + + For more information see https://redis.io/commands/swapdb + """ + raise RedisClusterException("SWAPDB is not supported in cluster mode") + + def cluster_myid(self, target_node: "TargetNodesT") -> ResponseT: + """ + Returns the node's id. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information check https://redis.io/commands/cluster-myid/ + """ + return self.execute_command("CLUSTER MYID", target_nodes=target_node) + + def cluster_addslots( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: + """ + Assign new hash slots to receiving node. Sends to specified node. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslots + """ + return self.execute_command( + "CLUSTER ADDSLOTS", *slots, target_nodes=target_node + ) + + def cluster_addslotsrange( + self, target_node: "TargetNodesT", *slots: EncodableT + ) -> ResponseT: + """ + Similar to the CLUSTER ADDSLOTS command. + The difference between the two commands is that ADDSLOTS takes a list of slots + to assign to the node, while ADDSLOTSRANGE takes a list of slot ranges + (specified by start and end slots) to assign to the node. + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-addslotsrange + """ + return self.execute_command( + "CLUSTER ADDSLOTSRANGE", *slots, target_nodes=target_node + ) + + def cluster_countkeysinslot(self, slot_id: int) -> ResponseT: + """ + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + + For more information see https://redis.io/commands/cluster-countkeysinslot + """ + return self.execute_command("CLUSTER COUNTKEYSINSLOT", slot_id) + + def cluster_count_failure_report(self, node_id: str) -> ResponseT: + """ + Return the number of failure reports active for a given node + Sends to a random node + + For more information see https://redis.io/commands/cluster-count-failure-reports + """ + return self.execute_command("CLUSTER COUNT-FAILURE-REPORTS", node_id) + + def cluster_delslots(self, *slots: EncodableT) -> List[bool]: + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + + For more information see https://redis.io/commands/cluster-delslots + """ + return [self.execute_command("CLUSTER DELSLOTS", slot) for slot in slots] + + def cluster_delslotsrange(self, *slots: EncodableT) -> ResponseT: + """ + Similar to the CLUSTER DELSLOTS command. + The difference is that CLUSTER DELSLOTS takes a list of hash slots to remove + from the node, while CLUSTER DELSLOTSRANGE takes a list of slot ranges to remove + from the node. + + For more information see https://redis.io/commands/cluster-delslotsrange + """ + return self.execute_command("CLUSTER DELSLOTSRANGE", *slots) + + def cluster_failover( + self, target_node: "TargetNodesT", option: Optional[str] = None + ) -> ResponseT: + """ + Forces a slave to perform a manual failover of its master + Sends to specified node + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-failover + """ + if option: + if option.upper() not in ["FORCE", "TAKEOVER"]: + raise RedisError( + f"Invalid option for CLUSTER FAILOVER command: {option}" + ) + else: + return self.execute_command( + "CLUSTER FAILOVER", option, target_nodes=target_node + ) + else: + return self.execute_command("CLUSTER FAILOVER", target_nodes=target_node) + + def cluster_info(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. + + For more information see https://redis.io/commands/cluster-info + """ + return self.execute_command("CLUSTER INFO", target_nodes=target_nodes) + + def cluster_keyslot(self, key: str) -> ResponseT: + """ + Returns the hash slot of the specified key + Sends to random node in the cluster + + For more information see https://redis.io/commands/cluster-keyslot + """ + return self.execute_command("CLUSTER KEYSLOT", key) + + def cluster_meet( + self, host: str, port: int, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: + """ + Force a node cluster to handshake with another node. + Sends to specified node. + + For more information see https://redis.io/commands/cluster-meet + """ + return self.execute_command( + "CLUSTER MEET", host, port, target_nodes=target_nodes + ) + + def cluster_nodes(self) -> ResponseT: + """ + Get Cluster config for the node. + Sends to random node in the cluster + + For more information see https://redis.io/commands/cluster-nodes + """ + return self.execute_command("CLUSTER NODES") + + def cluster_replicate( + self, target_nodes: "TargetNodesT", node_id: str + ) -> ResponseT: + """ + Reconfigure a node as a slave of the specified master node + + For more information see https://redis.io/commands/cluster-replicate + """ + return self.execute_command( + "CLUSTER REPLICATE", node_id, target_nodes=target_nodes + ) + + def cluster_reset( + self, soft: bool = True, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: + """ + Reset a Redis Cluster node + + If 'soft' is True then it will send 'SOFT' argument + If 'soft' is False then it will send 'HARD' argument + + For more information see https://redis.io/commands/cluster-reset + """ + return self.execute_command( + "CLUSTER RESET", b"SOFT" if soft else b"HARD", target_nodes=target_nodes + ) + + def cluster_save_config( + self, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: + """ + Forces the node to save cluster state on disk + + For more information see https://redis.io/commands/cluster-saveconfig + """ + return self.execute_command("CLUSTER SAVECONFIG", target_nodes=target_nodes) + + def cluster_get_keys_in_slot(self, slot: int, num_keys: int) -> ResponseT: + """ + Returns the number of keys in the specified cluster slot + + For more information see https://redis.io/commands/cluster-getkeysinslot + """ + return self.execute_command("CLUSTER GETKEYSINSLOT", slot, num_keys) + + def cluster_set_config_epoch( + self, epoch: int, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: + """ + Set the configuration epoch in a new node + + For more information see https://redis.io/commands/cluster-set-config-epoch + """ + return self.execute_command( + "CLUSTER SET-CONFIG-EPOCH", epoch, target_nodes=target_nodes + ) + + def cluster_setslot( + self, target_node: "TargetNodesT", node_id: str, slot_id: int, state: str + ) -> ResponseT: + """ + Bind an hash slot to a specific node + + :target_node: 'ClusterNode' + The node to execute the command on + + For more information see https://redis.io/commands/cluster-setslot + """ + if state.upper() in ("IMPORTING", "NODE", "MIGRATING"): + return self.execute_command( + "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node + ) + elif state.upper() == "STABLE": + raise RedisError('For "stable" state please use ' "cluster_setslot_stable") + else: + raise RedisError(f"Invalid slot state: {state}") + + def cluster_setslot_stable(self, slot_id: int) -> ResponseT: + """ + Clears migrating / importing state from the slot. + It determines by it self what node the slot is in and sends it there. + + For more information see https://redis.io/commands/cluster-setslot + """ + return self.execute_command("CLUSTER SETSLOT", slot_id, "STABLE") + + def cluster_replicas( + self, node_id: str, target_nodes: Optional["TargetNodesT"] = None + ) -> ResponseT: + """ + Provides a list of replica nodes replicating from the specified primary + target node. + + For more information see https://redis.io/commands/cluster-replicas + """ + return self.execute_command( + "CLUSTER REPLICAS", node_id, target_nodes=target_nodes + ) + + def cluster_slots(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: + """ + Get array of Cluster slot to node mappings + + For more information see https://redis.io/commands/cluster-slots + """ + return self.execute_command("CLUSTER SLOTS", target_nodes=target_nodes) + + def cluster_shards(self, target_nodes=None): + """ + Returns details about the shards of the cluster. + + For more information see https://redis.io/commands/cluster-shards + """ + return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes) + + def cluster_myshardid(self, target_nodes=None): + """ + Returns the shard ID of the node. + + For more information see https://redis.io/commands/cluster-myshardid/ + """ + return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes) + + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: + """ + Each node in a Redis Cluster maintains a pair of long-lived TCP link with each + peer in the cluster: One for sending outbound messages towards the peer and one + for receiving inbound messages from the peer. + + This command outputs information of all such peer links as an array. + + For more information see https://redis.io/commands/cluster-links + """ + return self.execute_command("CLUSTER LINKS", target_nodes=target_node) + + def cluster_flushslots(self, target_nodes: Optional["TargetNodesT"] = None) -> None: + raise NotImplementedError( + "CLUSTER FLUSHSLOTS is intentionally not implemented in the client." + ) + + def cluster_bumpepoch(self, target_nodes: Optional["TargetNodesT"] = None) -> None: + raise NotImplementedError( + "CLUSTER BUMPEPOCH is intentionally not implemented in the client." + ) + + def readonly(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: + """ + Enables read queries. + The command will be sent to the default cluster node if target_nodes is + not specified. + + For more information see https://redis.io/commands/readonly + """ + if target_nodes == "replicas" or target_nodes == "all": + # read_from_replicas will only be enabled if the READONLY command + # is sent to all replicas + self.read_from_replicas = True + return self.execute_command("READONLY", target_nodes=target_nodes) + + def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: + """ + Disables read queries. + The command will be sent to the default cluster node if target_nodes is + not specified. + + For more information see https://redis.io/commands/readwrite + """ + # Reset read from replicas flag + self.read_from_replicas = False + return self.execute_command("READWRITE", target_nodes=target_nodes) + + def gears_refresh_cluster(self, **kwargs) -> ResponseT: + """ + On an OSS cluster, before executing any gears function, you must call this command. # noqa + """ + return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs) + + +class AsyncClusterManagementCommands( + ClusterManagementCommands, AsyncManagementCommands +): + """ + A class for Redis Cluster management commands + + The class inherits from Redis's core ManagementCommands class and do the + required adjustments to work with cluster mode + """ + + async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + + For more information see https://redis.io/commands/cluster-delslots + """ + return await asyncio.gather( + *( + asyncio.create_task(self.execute_command("CLUSTER DELSLOTS", slot)) + for slot in slots + ) + ) + + +class ClusterDataAccessCommands(DataAccessCommands): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + def stralgo( + self, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Optional[int] = None, + withmatchlen: bool = False, + **kwargs, + ) -> ResponseT: + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + + For more information see https://redis.io/commands/stralgo + """ + target_nodes = kwargs.pop("target_nodes", None) + if specific_argument == "strings" and target_nodes is None: + target_nodes = "default-node" + kwargs.update({"target_nodes": target_nodes}) + return super().stralgo( + algo, + value1, + value2, + specific_argument, + len, + idx, + minmatchlen, + withmatchlen, + **kwargs, + ) + + def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> Iterator: + # Do the first query with cursor=0 for all nodes + cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs) + yield from data + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + yield from data + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class AsyncClusterDataAccessCommands( + ClusterDataAccessCommands, AsyncDataAccessCommands +): + """ + A class for Redis Cluster Data Access Commands + + The class inherits from Redis's core DataAccessCommand class and do the + required adjustments to work with cluster mode + """ + + async def scan_iter( + self, + match: Optional[PatternT] = None, + count: Optional[int] = None, + _type: Optional[str] = None, + **kwargs, + ) -> AsyncIterator: + # Do the first query with cursor=0 for all nodes + cursors, data = await self.scan(match=match, count=count, _type=_type, **kwargs) + for value in data: + yield value + + cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0} + if cursors: + # Get nodes by name + nodes = {name: self.get_node(node_name=name) for name in cursors.keys()} + + # Iterate over each node till its cursor is 0 + kwargs.pop("target_nodes", None) + while cursors: + for name, cursor in cursors.items(): + cur, data = await self.scan( + cursor=cursor, + match=match, + count=count, + _type=_type, + target_nodes=nodes[name], + **kwargs, + ) + for value in data: + yield value + cursors[name] = cur[name] + + cursors = { + name: cursor for name, cursor in cursors.items() if cursor != 0 + } + + +class RedisClusterCommands( + ClusterMultiKeyCommands, + ClusterManagementCommands, + ACLCommands, + PubSubCommands, + ClusterDataAccessCommands, + ScriptCommands, + FunctionCommands, + GearsCommands, + ModuleCommands, + RedisModuleCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ + + +class AsyncRedisClusterCommands( + AsyncClusterMultiKeyCommands, + AsyncClusterManagementCommands, + AsyncACLCommands, + AsyncClusterDataAccessCommands, + AsyncScriptCommands, + AsyncFunctionCommands, + AsyncGearsCommands, + AsyncModuleCommands, +): + """ + A class for all Redis Cluster commands + + For key-based commands, the target node(s) will be internally determined + by the keys' hash slot. + Non-key-based commands can be executed with the 'target_nodes' argument to + target specific nodes. By default, if target_nodes is not specified, the + command will be executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: ALL_NODES, PRIMARIES, REPLICAS, RANDOM + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes=RedisCluster.ALL_NODES) + """ diff --git a/.venv/Lib/site-packages/redis/commands/core.py b/.venv/Lib/site-packages/redis/commands/core.py new file mode 100644 index 00000000..e73553e4 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/core.py @@ -0,0 +1,6305 @@ +# from __future__ import annotations + +import datetime +import hashlib +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +from redis.compat import Literal +from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError +from redis.typing import ( + AbsExpiryT, + AnyKeyT, + BitfieldOffsetT, + ChannelT, + CommandsProtocol, + ConsumerT, + EncodableT, + ExpiryT, + FieldT, + GroupT, + KeysT, + KeyT, + PatternT, + ScriptTextT, + StreamIdT, + TimeoutSecT, + ZScoreBoundT, +) + +from .helpers import list_or_args + +if TYPE_CHECKING: + from redis.asyncio.client import Redis as AsyncRedis + from redis.client import Redis + +ResponseT = Union[Awaitable, Any] + + +class ACLCommands(CommandsProtocol): + """ + Redis Access Control List (ACL) commands. + see: https://redis.io/topics/acl + """ + + def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: + """ + Returns a list of categories or commands within a category. + + If ``category`` is not supplied, returns a list of all categories. + If ``category`` is supplied, returns a list of all commands within + that category. + + For more information see https://redis.io/commands/acl-cat + """ + pieces: list[EncodableT] = [category] if category else [] + return self.execute_command("ACL CAT", *pieces, **kwargs) + + def acl_dryrun(self, username, *args, **kwargs): + """ + Simulate the execution of a given command by a given ``username``. + + For more information see https://redis.io/commands/acl-dryrun + """ + return self.execute_command("ACL DRYRUN", username, *args, **kwargs) + + def acl_deluser(self, *username: str, **kwargs) -> ResponseT: + """ + Delete the ACL for the specified ``username``s + + For more information see https://redis.io/commands/acl-deluser + """ + return self.execute_command("ACL DELUSER", *username, **kwargs) + + def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: + """Generate a random password value. + If ``bits`` is supplied then use this number of bits, rounded to + the next multiple of 4. + See: https://redis.io/commands/acl-genpass + """ + pieces = [] + if bits is not None: + try: + b = int(bits) + if b < 0 or b > 4096: + raise ValueError + except ValueError: + raise DataError( + "genpass optionally accepts a bits argument, between 0 and 4096." + ) + return self.execute_command("ACL GENPASS", *pieces, **kwargs) + + def acl_getuser(self, username: str, **kwargs) -> ResponseT: + """ + Get the ACL details for the specified ``username``. + + If ``username`` does not exist, return None + + For more information see https://redis.io/commands/acl-getuser + """ + return self.execute_command("ACL GETUSER", username, **kwargs) + + def acl_help(self, **kwargs) -> ResponseT: + """The ACL HELP command returns helpful text describing + the different subcommands. + + For more information see https://redis.io/commands/acl-help + """ + return self.execute_command("ACL HELP", **kwargs) + + def acl_list(self, **kwargs) -> ResponseT: + """ + Return a list of all ACLs on the server + + For more information see https://redis.io/commands/acl-list + """ + return self.execute_command("ACL LIST", **kwargs) + + def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: + """ + Get ACL logs as a list. + :param int count: Get logs[0:count]. + :rtype: List. + + For more information see https://redis.io/commands/acl-log + """ + args = [] + if count is not None: + if not isinstance(count, int): + raise DataError("ACL LOG count must be an integer") + args.append(count) + + return self.execute_command("ACL LOG", *args, **kwargs) + + def acl_log_reset(self, **kwargs) -> ResponseT: + """ + Reset ACL logs. + :rtype: Boolean. + + For more information see https://redis.io/commands/acl-log + """ + args = [b"RESET"] + return self.execute_command("ACL LOG", *args, **kwargs) + + def acl_load(self, **kwargs) -> ResponseT: + """ + Load ACL rules from the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to load ACL rules from an aclfile. + + For more information see https://redis.io/commands/acl-load + """ + return self.execute_command("ACL LOAD", **kwargs) + + def acl_save(self, **kwargs) -> ResponseT: + """ + Save ACL rules to the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to save ACL rules to an aclfile. + + For more information see https://redis.io/commands/acl-save + """ + return self.execute_command("ACL SAVE", **kwargs) + + def acl_setuser( + self, + username: str, + enabled: bool = False, + nopass: bool = False, + passwords: Union[str, Iterable[str], None] = None, + hashed_passwords: Union[str, Iterable[str], None] = None, + categories: Optional[Iterable[str]] = None, + commands: Optional[Iterable[str]] = None, + keys: Optional[Iterable[KeyT]] = None, + channels: Optional[Iterable[ChannelT]] = None, + selectors: Optional[Iterable[Tuple[str, KeyT]]] = None, + reset: bool = False, + reset_keys: bool = False, + reset_channels: bool = False, + reset_passwords: bool = False, + **kwargs, + ) -> ResponseT: + """ + Create or update an ACL user. + + Create or update the ACL for ``username``. If the user already exists, + the existing ACL is completely overwritten and replaced with the + specified values. + + ``enabled`` is a boolean indicating whether the user should be allowed + to authenticate or not. Defaults to ``False``. + + ``nopass`` is a boolean indicating whether the can authenticate without + a password. This cannot be True if ``passwords`` are also specified. + + ``passwords`` if specified is a list of plain text passwords + to add to or remove from the user. Each password must be prefixed with + a '+' to add or a '-' to remove. For convenience, the value of + ``passwords`` can be a simple prefixed string when adding or + removing a single password. + + ``hashed_passwords`` if specified is a list of SHA-256 hashed passwords + to add to or remove from the user. Each hashed password must be + prefixed with a '+' to add or a '-' to remove. For convenience, + the value of ``hashed_passwords`` can be a simple prefixed string when + adding or removing a single password. + + ``categories`` if specified is a list of strings representing category + permissions. Each string must be prefixed with either a '+' to add the + category permission or a '-' to remove the category permission. + + ``commands`` if specified is a list of strings representing command + permissions. Each string must be prefixed with either a '+' to add the + command permission or a '-' to remove the command permission. + + ``keys`` if specified is a list of key patterns to grant the user + access to. Keys patterns allow '*' to support wildcard matching. For + example, '*' grants access to all keys while 'cache:*' grants access + to all keys that are prefixed with 'cache:'. ``keys`` should not be + prefixed with a '~'. + + ``reset`` is a boolean indicating whether the user should be fully + reset prior to applying the new ACL. Setting this to True will + remove all existing passwords, flags and privileges from the user and + then apply the specified rules. If this is False, the user's existing + passwords, flags and privileges will be kept and any new specified + rules will be applied on top. + + ``reset_keys`` is a boolean indicating whether the user's key + permissions should be reset prior to applying any new key permissions + specified in ``keys``. If this is False, the user's existing + key permissions will be kept and any new specified key permissions + will be applied on top. + + ``reset_channels`` is a boolean indicating whether the user's channel + permissions should be reset prior to applying any new channel permissions + specified in ``channels``.If this is False, the user's existing + channel permissions will be kept and any new specified channel permissions + will be applied on top. + + ``reset_passwords`` is a boolean indicating whether to remove all + existing passwords and the 'nopass' flag from the user prior to + applying any new passwords specified in 'passwords' or + 'hashed_passwords'. If this is False, the user's existing passwords + and 'nopass' status will be kept and any new specified passwords + or hashed_passwords will be applied on top. + + For more information see https://redis.io/commands/acl-setuser + """ + encoder = self.get_encoder() + pieces: List[EncodableT] = [username] + + if reset: + pieces.append(b"reset") + + if reset_keys: + pieces.append(b"resetkeys") + + if reset_channels: + pieces.append(b"resetchannels") + + if reset_passwords: + pieces.append(b"resetpass") + + if enabled: + pieces.append(b"on") + else: + pieces.append(b"off") + + if (passwords or hashed_passwords) and nopass: + raise DataError( + "Cannot set 'nopass' and supply 'passwords' or 'hashed_passwords'" + ) + + if passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + passwords = list_or_args(passwords, []) + for i, password in enumerate(passwords): + password = encoder.encode(password) + if password.startswith(b"+"): + pieces.append(b">%s" % password[1:]) + elif password.startswith(b"-"): + pieces.append(b"<%s" % password[1:]) + else: + raise DataError( + f"Password {i} must be prefixed with a " + f'"+" to add or a "-" to remove' + ) + + if hashed_passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + hashed_passwords = list_or_args(hashed_passwords, []) + for i, hashed_password in enumerate(hashed_passwords): + hashed_password = encoder.encode(hashed_password) + if hashed_password.startswith(b"+"): + pieces.append(b"#%s" % hashed_password[1:]) + elif hashed_password.startswith(b"-"): + pieces.append(b"!%s" % hashed_password[1:]) + else: + raise DataError( + f"Hashed password {i} must be prefixed with a " + f'"+" to add or a "-" to remove' + ) + + if nopass: + pieces.append(b"nopass") + + if categories: + for category in categories: + category = encoder.encode(category) + # categories can be prefixed with one of (+@, +, -@, -) + if category.startswith(b"+@"): + pieces.append(category) + elif category.startswith(b"+"): + pieces.append(b"+@%s" % category[1:]) + elif category.startswith(b"-@"): + pieces.append(category) + elif category.startswith(b"-"): + pieces.append(b"-@%s" % category[1:]) + else: + raise DataError( + f'Category "{encoder.decode(category, force=True)}" ' + 'must be prefixed with "+" or "-"' + ) + if commands: + for cmd in commands: + cmd = encoder.encode(cmd) + if not cmd.startswith(b"+") and not cmd.startswith(b"-"): + raise DataError( + f'Command "{encoder.decode(cmd, force=True)}" ' + 'must be prefixed with "+" or "-"' + ) + pieces.append(cmd) + + if keys: + for key in keys: + key = encoder.encode(key) + if not key.startswith(b"%") and not key.startswith(b"~"): + key = b"~%s" % key + pieces.append(key) + + if channels: + for channel in channels: + channel = encoder.encode(channel) + pieces.append(b"&%s" % channel) + + if selectors: + for cmd, key in selectors: + cmd = encoder.encode(cmd) + if not cmd.startswith(b"+") and not cmd.startswith(b"-"): + raise DataError( + f'Command "{encoder.decode(cmd, force=True)}" ' + 'must be prefixed with "+" or "-"' + ) + + key = encoder.encode(key) + if not key.startswith(b"%") and not key.startswith(b"~"): + key = b"~%s" % key + + pieces.append(b"(%s %s)" % (cmd, key)) + + return self.execute_command("ACL SETUSER", *pieces, **kwargs) + + def acl_users(self, **kwargs) -> ResponseT: + """Returns a list of all registered users on the server. + + For more information see https://redis.io/commands/acl-users + """ + return self.execute_command("ACL USERS", **kwargs) + + def acl_whoami(self, **kwargs) -> ResponseT: + """Get the username for the current connection + + For more information see https://redis.io/commands/acl-whoami + """ + return self.execute_command("ACL WHOAMI", **kwargs) + + +AsyncACLCommands = ACLCommands + + +class ManagementCommands(CommandsProtocol): + """ + Redis management commands + """ + + def auth(self, password: str, username: Optional[str] = None, **kwargs): + """ + Authenticates the user. If you do not pass username, Redis will try to + authenticate for the "default" user. If you do pass username, it will + authenticate for the given user. + For more information see https://redis.io/commands/auth + """ + pieces = [] + if username is not None: + pieces.append(username) + pieces.append(password) + return self.execute_command("AUTH", *pieces, **kwargs) + + def bgrewriteaof(self, **kwargs): + """Tell the Redis server to rewrite the AOF file from data in memory. + + For more information see https://redis.io/commands/bgrewriteaof + """ + return self.execute_command("BGREWRITEAOF", **kwargs) + + def bgsave(self, schedule: bool = True, **kwargs) -> ResponseT: + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + + For more information see https://redis.io/commands/bgsave + """ + pieces = [] + if schedule: + pieces.append("SCHEDULE") + return self.execute_command("BGSAVE", *pieces, **kwargs) + + def role(self) -> ResponseT: + """ + Provide information on the role of a Redis instance in + the context of replication, by returning if the instance + is currently a master, slave, or sentinel. + + For more information see https://redis.io/commands/role + """ + return self.execute_command("ROLE") + + def client_kill(self, address: str, **kwargs) -> ResponseT: + """Disconnects the client at ``address`` (ip:port) + + For more information see https://redis.io/commands/client-kill + """ + return self.execute_command("CLIENT KILL", address, **kwargs) + + def client_kill_filter( + self, + _id: Union[str, None] = None, + _type: Union[str, None] = None, + addr: Union[str, None] = None, + skipme: Union[bool, None] = None, + laddr: Union[bool, None] = None, + user: str = None, + **kwargs, + ) -> ResponseT: + """ + Disconnects client(s) using a variety of filter options + :param _id: Kills a client by its unique ID field + :param _type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + :param laddr: Kills a client by its 'local (bind) address:port' + :param user: Kills a client for a specific user name + """ + args = [] + if _type is not None: + client_types = ("normal", "master", "slave", "pubsub") + if str(_type).lower() not in client_types: + raise DataError(f"CLIENT KILL type must be one of {client_types!r}") + args.extend((b"TYPE", _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b"SKIPME", b"YES")) + else: + args.extend((b"SKIPME", b"NO")) + if _id is not None: + args.extend((b"ID", _id)) + if addr is not None: + args.extend((b"ADDR", addr)) + if laddr is not None: + args.extend((b"LADDR", laddr)) + if user is not None: + args.extend((b"USER", user)) + if not args: + raise DataError( + "CLIENT KILL ... ... " + " must specify at least one filter" + ) + return self.execute_command("CLIENT KILL", *args, **kwargs) + + def client_info(self, **kwargs) -> ResponseT: + """ + Returns information and statistics about the current + client connection. + + For more information see https://redis.io/commands/client-info + """ + return self.execute_command("CLIENT INFO", **kwargs) + + def client_list( + self, _type: Union[str, None] = None, client_id: List[EncodableT] = [], **kwargs + ) -> ResponseT: + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + :param client_id: optional. a list of client ids + + For more information see https://redis.io/commands/client-list + """ + args = [] + if _type is not None: + client_types = ("normal", "master", "replica", "pubsub") + if str(_type).lower() not in client_types: + raise DataError(f"CLIENT LIST _type must be one of {client_types!r}") + args.append(b"TYPE") + args.append(_type) + if not isinstance(client_id, list): + raise DataError("client_id must be a list") + if client_id: + args.append(b"ID") + args.append(" ".join(client_id)) + return self.execute_command("CLIENT LIST", *args, **kwargs) + + def client_getname(self, **kwargs) -> ResponseT: + """ + Returns the current connection name + + For more information see https://redis.io/commands/client-getname + """ + return self.execute_command("CLIENT GETNAME", **kwargs) + + def client_getredir(self, **kwargs) -> ResponseT: + """ + Returns the ID (an integer) of the client to whom we are + redirecting tracking notifications. + + see: https://redis.io/commands/client-getredir + """ + return self.execute_command("CLIENT GETREDIR", **kwargs) + + def client_reply( + self, reply: Union[Literal["ON"], Literal["OFF"], Literal["SKIP"]], **kwargs + ) -> ResponseT: + """ + Enable and disable redis server replies. + + ``reply`` Must be ON OFF or SKIP, + ON - The default most with server replies to commands + OFF - Disable server responses to commands + SKIP - Skip the response of the immediately following command. + + Note: When setting OFF or SKIP replies, you will need a client object + with a timeout specified in seconds, and will need to catch the + TimeoutError. + The test_client_reply unit test illustrates this, and + conftest.py has a client with a timeout. + + See https://redis.io/commands/client-reply + """ + replies = ["ON", "OFF", "SKIP"] + if reply not in replies: + raise DataError(f"CLIENT REPLY must be one of {replies!r}") + return self.execute_command("CLIENT REPLY", reply, **kwargs) + + def client_id(self, **kwargs) -> ResponseT: + """ + Returns the current connection id + + For more information see https://redis.io/commands/client-id + """ + return self.execute_command("CLIENT ID", **kwargs) + + def client_tracking_on( + self, + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: + """ + Turn on the tracking mode. + For more information about the options look at client_tracking func. + + See https://redis.io/commands/client-tracking + """ + return self.client_tracking( + True, clientid, prefix, bcast, optin, optout, noloop + ) + + def client_tracking_off( + self, + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + ) -> ResponseT: + """ + Turn off the tracking mode. + For more information about the options look at client_tracking func. + + See https://redis.io/commands/client-tracking + """ + return self.client_tracking( + False, clientid, prefix, bcast, optin, optout, noloop + ) + + def client_tracking( + self, + on: bool = True, + clientid: Union[int, None] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + **kwargs, + ) -> ResponseT: + """ + Enables the tracking feature of the Redis server, that is used + for server assisted client side caching. + + ``on`` indicate for tracking on or tracking off. The dafualt is on. + + ``clientid`` send invalidation messages to the connection with + the specified ID. + + ``bcast`` enable tracking in broadcasting mode. In this mode + invalidation messages are reported for all the prefixes + specified, regardless of the keys requested by the connection. + + ``optin`` when broadcasting is NOT active, normally don't track + keys in read only commands, unless they are called immediately + after a CLIENT CACHING yes command. + + ``optout`` when broadcasting is NOT active, normally track keys in + read only commands, unless they are called immediately after a + CLIENT CACHING no command. + + ``noloop`` don't send notifications about keys modified by this + connection itself. + + ``prefix`` for broadcasting, register a given key prefix, so that + notifications will be provided only for keys starting with this string. + + See https://redis.io/commands/client-tracking + """ + + if len(prefix) != 0 and bcast is False: + raise DataError("Prefix can only be used with bcast") + + pieces = ["ON"] if on else ["OFF"] + if clientid is not None: + pieces.extend(["REDIRECT", clientid]) + for p in prefix: + pieces.extend(["PREFIX", p]) + if bcast: + pieces.append("BCAST") + if optin: + pieces.append("OPTIN") + if optout: + pieces.append("OPTOUT") + if noloop: + pieces.append("NOLOOP") + + return self.execute_command("CLIENT TRACKING", *pieces) + + def client_trackinginfo(self, **kwargs) -> ResponseT: + """ + Returns the information about the current client connection's + use of the server assisted client side cache. + + See https://redis.io/commands/client-trackinginfo + """ + return self.execute_command("CLIENT TRACKINGINFO", **kwargs) + + def client_setname(self, name: str, **kwargs) -> ResponseT: + """ + Sets the current connection name + + For more information see https://redis.io/commands/client-setname + + .. note:: + This method sets client name only for **current** connection. + + If you want to set a common name for all connections managed + by this client, use ``client_name`` constructor argument. + """ + return self.execute_command("CLIENT SETNAME", name, **kwargs) + + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + """ + Sets the current connection library name or version + For mor information see https://redis.io/commands/client-setinfo + """ + return self.execute_command("CLIENT SETINFO", attr, value, **kwargs) + + def client_unblock( + self, client_id: int, error: bool = False, **kwargs + ) -> ResponseT: + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + + For more information see https://redis.io/commands/client-unblock + """ + args = ["CLIENT UNBLOCK", int(client_id)] + if error: + args.append(b"ERROR") + return self.execute_command(*args, **kwargs) + + def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: + """ + Suspend all the Redis clients for the specified amount of time. + + + For more information see https://redis.io/commands/client-pause + + :param timeout: milliseconds to pause clients + :param all: If true (default) all client commands are blocked. + otherwise, clients are only blocked if they attempt to execute + a write command. + For the WRITE mode, some commands have special behavior: + EVAL/EVALSHA: Will block client for all scripts. + PUBLISH: Will block client. + PFCOUNT: Will block client. + WAIT: Acknowledgments will be delayed, so this command will + appear blocked. + """ + args = ["CLIENT PAUSE", str(timeout)] + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + if not all: + args.append("WRITE") + return self.execute_command(*args, **kwargs) + + def client_unpause(self, **kwargs) -> ResponseT: + """ + Unpause all redis clients + + For more information see https://redis.io/commands/client-unpause + """ + return self.execute_command("CLIENT UNPAUSE", **kwargs) + + def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: + """ + Sets the client eviction mode for the current connection. + + For more information see https://redis.io/commands/client-no-evict + """ + return self.execute_command("CLIENT NO-EVICT", mode) + + def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: + """ + # The command controls whether commands sent by the client will alter + # the LRU/LFU of the keys they access. + # When turned on, the current client will not change LFU/LRU stats, + # unless it sends the TOUCH command. + + For more information see https://redis.io/commands/client-no-touch + """ + return self.execute_command("CLIENT NO-TOUCH", mode) + + def command(self, **kwargs): + """ + Returns dict reply of details about all Redis commands. + + For more information see https://redis.io/commands/command + """ + return self.execute_command("COMMAND", **kwargs) + + def command_info(self, **kwargs) -> None: + raise NotImplementedError( + "COMMAND INFO is intentionally not implemented in the client." + ) + + def command_count(self, **kwargs) -> ResponseT: + return self.execute_command("COMMAND COUNT", **kwargs) + + def command_list( + self, + module: Optional[str] = None, + category: Optional[str] = None, + pattern: Optional[str] = None, + ) -> ResponseT: + """ + Return an array of the server's command names. + You can use one of the following filters: + ``module``: get the commands that belong to the module + ``category``: get the commands in the ACL category + ``pattern``: get the commands that match the given pattern + + For more information see https://redis.io/commands/command-list/ + """ + pieces = [] + if module is not None: + pieces.extend(["MODULE", module]) + if category is not None: + pieces.extend(["ACLCAT", category]) + if pattern is not None: + pieces.extend(["PATTERN", pattern]) + + if pieces: + pieces.insert(0, "FILTERBY") + + return self.execute_command("COMMAND LIST", *pieces) + + def command_getkeysandflags(self, *args: List[str]) -> List[Union[str, List[str]]]: + """ + Returns array of keys from a full Redis command and their usage flags. + + For more information see https://redis.io/commands/command-getkeysandflags + """ + return self.execute_command("COMMAND GETKEYSANDFLAGS", *args) + + def command_docs(self, *args): + """ + This function throws a NotImplementedError since it is intentionally + not supported. + """ + raise NotImplementedError( + "COMMAND DOCS is intentionally not implemented in the client." + ) + + def config_get( + self, pattern: PatternT = "*", *args: List[PatternT], **kwargs + ) -> ResponseT: + """ + Return a dictionary of configuration based on the ``pattern`` + + For more information see https://redis.io/commands/config-get + """ + return self.execute_command("CONFIG GET", pattern, *args, **kwargs) + + def config_set( + self, + name: KeyT, + value: EncodableT, + *args: List[Union[KeyT, EncodableT]], + **kwargs, + ) -> ResponseT: + """Set config item ``name`` with ``value`` + + For more information see https://redis.io/commands/config-set + """ + return self.execute_command("CONFIG SET", name, value, *args, **kwargs) + + def config_resetstat(self, **kwargs) -> ResponseT: + """ + Reset runtime statistics + + For more information see https://redis.io/commands/config-resetstat + """ + return self.execute_command("CONFIG RESETSTAT", **kwargs) + + def config_rewrite(self, **kwargs) -> ResponseT: + """ + Rewrite config file with the minimal change to reflect running config. + + For more information see https://redis.io/commands/config-rewrite + """ + return self.execute_command("CONFIG REWRITE", **kwargs) + + def dbsize(self, **kwargs) -> ResponseT: + """ + Returns the number of keys in the current database + + For more information see https://redis.io/commands/dbsize + """ + return self.execute_command("DBSIZE", **kwargs) + + def debug_object(self, key: KeyT, **kwargs) -> ResponseT: + """ + Returns version specific meta information about a given key + + For more information see https://redis.io/commands/debug-object + """ + return self.execute_command("DEBUG OBJECT", key, **kwargs) + + def debug_segfault(self, **kwargs) -> None: + raise NotImplementedError( + """ + DEBUG SEGFAULT is intentionally not implemented in the client. + + For more information see https://redis.io/commands/debug-segfault + """ + ) + + def echo(self, value: EncodableT, **kwargs) -> ResponseT: + """ + Echo the string back from the server + + For more information see https://redis.io/commands/echo + """ + return self.execute_command("ECHO", value, **kwargs) + + def flushall(self, asynchronous: bool = False, **kwargs) -> ResponseT: + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + + For more information see https://redis.io/commands/flushall + """ + args = [] + if asynchronous: + args.append(b"ASYNC") + return self.execute_command("FLUSHALL", *args, **kwargs) + + def flushdb(self, asynchronous: bool = False, **kwargs) -> ResponseT: + """ + Delete all keys in the current database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + + For more information see https://redis.io/commands/flushdb + """ + args = [] + if asynchronous: + args.append(b"ASYNC") + return self.execute_command("FLUSHDB", *args, **kwargs) + + def sync(self) -> ResponseT: + """ + Initiates a replication stream from the master. + + For more information see https://redis.io/commands/sync + """ + from redis.client import NEVER_DECODE + + options = {} + options[NEVER_DECODE] = [] + return self.execute_command("SYNC", **options) + + def psync(self, replicationid: str, offset: int): + """ + Initiates a replication stream from the master. + Newer version for `sync`. + + For more information see https://redis.io/commands/sync + """ + from redis.client import NEVER_DECODE + + options = {} + options[NEVER_DECODE] = [] + return self.execute_command("PSYNC", replicationid, offset, **options) + + def swapdb(self, first: int, second: int, **kwargs) -> ResponseT: + """ + Swap two databases + + For more information see https://redis.io/commands/swapdb + """ + return self.execute_command("SWAPDB", first, second, **kwargs) + + def select(self, index: int, **kwargs) -> ResponseT: + """Select the Redis logical database at index. + + See: https://redis.io/commands/select + """ + return self.execute_command("SELECT", index, **kwargs) + + def info( + self, section: Union[str, None] = None, *args: List[str], **kwargs + ) -> ResponseT: + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + + For more information see https://redis.io/commands/info + """ + if section is None: + return self.execute_command("INFO", **kwargs) + else: + return self.execute_command("INFO", section, *args, **kwargs) + + def lastsave(self, **kwargs) -> ResponseT: + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + + For more information see https://redis.io/commands/lastsave + """ + return self.execute_command("LASTSAVE", **kwargs) + + def latency_doctor(self): + """Raise a NotImplementedError, as the client will not support LATENCY DOCTOR. + This funcion is best used within the redis-cli. + + For more information see https://redis.io/commands/latency-doctor + """ + raise NotImplementedError( + """ + LATENCY DOCTOR is intentionally not implemented in the client. + + For more information see https://redis.io/commands/latency-doctor + """ + ) + + def latency_graph(self): + """Raise a NotImplementedError, as the client will not support LATENCY GRAPH. + This funcion is best used within the redis-cli. + + For more information see https://redis.io/commands/latency-graph. + """ + raise NotImplementedError( + """ + LATENCY GRAPH is intentionally not implemented in the client. + + For more information see https://redis.io/commands/latency-graph + """ + ) + + def lolwut(self, *version_numbers: Union[str, float], **kwargs) -> ResponseT: + """ + Get the Redis version and a piece of generative computer art + + See: https://redis.io/commands/lolwut + """ + if version_numbers: + return self.execute_command("LOLWUT VERSION", *version_numbers, **kwargs) + else: + return self.execute_command("LOLWUT", **kwargs) + + def reset(self) -> ResponseT: + """Perform a full reset on the connection's server side contenxt. + + See: https://redis.io/commands/reset + """ + return self.execute_command("RESET") + + def migrate( + self, + host: str, + port: int, + keys: KeysT, + destination_db: int, + timeout: int, + copy: bool = False, + replace: bool = False, + auth: Union[str, None] = None, + **kwargs, + ) -> ResponseT: + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + + For more information see https://redis.io/commands/migrate + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError("MIGRATE requires at least one key") + pieces = [] + if copy: + pieces.append(b"COPY") + if replace: + pieces.append(b"REPLACE") + if auth: + pieces.append(b"AUTH") + pieces.append(auth) + pieces.append(b"KEYS") + pieces.extend(keys) + return self.execute_command( + "MIGRATE", host, port, "", destination_db, timeout, *pieces, **kwargs + ) + + def object(self, infotype: str, key: KeyT, **kwargs) -> ResponseT: + """ + Return the encoding, idletime, or refcount about the key + """ + return self.execute_command( + "OBJECT", infotype, key, infotype=infotype, **kwargs + ) + + def memory_doctor(self, **kwargs) -> None: + raise NotImplementedError( + """ + MEMORY DOCTOR is intentionally not implemented in the client. + + For more information see https://redis.io/commands/memory-doctor + """ + ) + + def memory_help(self, **kwargs) -> None: + raise NotImplementedError( + """ + MEMORY HELP is intentionally not implemented in the client. + + For more information see https://redis.io/commands/memory-help + """ + ) + + def memory_stats(self, **kwargs) -> ResponseT: + """ + Return a dictionary of memory stats + + For more information see https://redis.io/commands/memory-stats + """ + return self.execute_command("MEMORY STATS", **kwargs) + + def memory_malloc_stats(self, **kwargs) -> ResponseT: + """ + Return an internal statistics report from the memory allocator. + + See: https://redis.io/commands/memory-malloc-stats + """ + return self.execute_command("MEMORY MALLOC-STATS", **kwargs) + + def memory_usage( + self, key: KeyT, samples: Union[int, None] = None, **kwargs + ) -> ResponseT: + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + + For more information see https://redis.io/commands/memory-usage + """ + args = [] + if isinstance(samples, int): + args.extend([b"SAMPLES", samples]) + return self.execute_command("MEMORY USAGE", key, *args, **kwargs) + + def memory_purge(self, **kwargs) -> ResponseT: + """ + Attempts to purge dirty pages for reclamation by allocator + + For more information see https://redis.io/commands/memory-purge + """ + return self.execute_command("MEMORY PURGE", **kwargs) + + def latency_histogram(self, *args): + """ + This function throws a NotImplementedError since it is intentionally + not supported. + """ + raise NotImplementedError( + "LATENCY HISTOGRAM is intentionally not implemented in the client." + ) + + def latency_history(self, event: str) -> ResponseT: + """ + Returns the raw data of the ``event``'s latency spikes time series. + + For more information see https://redis.io/commands/latency-history + """ + return self.execute_command("LATENCY HISTORY", event) + + def latency_latest(self) -> ResponseT: + """ + Reports the latest latency events logged. + + For more information see https://redis.io/commands/latency-latest + """ + return self.execute_command("LATENCY LATEST") + + def latency_reset(self, *events: str) -> ResponseT: + """ + Resets the latency spikes time series of all, or only some, events. + + For more information see https://redis.io/commands/latency-reset + """ + return self.execute_command("LATENCY RESET", *events) + + def ping(self, **kwargs) -> ResponseT: + """ + Ping the Redis server + + For more information see https://redis.io/commands/ping + """ + return self.execute_command("PING", **kwargs) + + def quit(self, **kwargs) -> ResponseT: + """ + Ask the server to close the connection. + + For more information see https://redis.io/commands/quit + """ + return self.execute_command("QUIT", **kwargs) + + def replicaof(self, *args, **kwargs) -> ResponseT: + """ + Update the replication settings of a redis replica, on the fly. + + Examples of valid arguments include: + + NO ONE (set no replication) + host port (set to the host and port of a redis server) + + For more information see https://redis.io/commands/replicaof + """ + return self.execute_command("REPLICAOF", *args, **kwargs) + + def save(self, **kwargs) -> ResponseT: + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + + For more information see https://redis.io/commands/save + """ + return self.execute_command("SAVE", **kwargs) + + def shutdown( + self, + save: bool = False, + nosave: bool = False, + now: bool = False, + force: bool = False, + abort: bool = False, + **kwargs, + ) -> None: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. + It is possible to specify modifiers to alter the behavior of the command: + ``save`` will force a DB saving operation even if no save points are configured. + ``nosave`` will prevent a DB saving operation even if one or more save points + are configured. + ``now`` skips waiting for lagging replicas, i.e. it bypasses the first step in + the shutdown sequence. + ``force`` ignores any errors that would normally prevent the server from exiting + ``abort`` cancels an ongoing shutdown and cannot be combined with other flags. + + For more information see https://redis.io/commands/shutdown + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + if now: + args.append("NOW") + if force: + args.append("FORCE") + if abort: + args.append("ABORT") + try: + self.execute_command(*args, **kwargs) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slaveof( + self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs + ) -> ResponseT: + """ + Set the server to be a replicated slave of the instance identified + by the ``host`` and ``port``. If called without arguments, the + instance is promoted to a master instead. + + For more information see https://redis.io/commands/slaveof + """ + if host is None and port is None: + return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) + return self.execute_command("SLAVEOF", host, port, **kwargs) + + def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + + For more information see https://redis.io/commands/slowlog-get + """ + from redis.client import NEVER_DECODE + + args = ["SLOWLOG GET"] + if num is not None: + args.append(num) + decode_responses = self.get_connection_kwargs().get("decode_responses", False) + if decode_responses is True: + kwargs[NEVER_DECODE] = [] + return self.execute_command(*args, **kwargs) + + def slowlog_len(self, **kwargs) -> ResponseT: + """ + Get the number of items in the slowlog + + For more information see https://redis.io/commands/slowlog-len + """ + return self.execute_command("SLOWLOG LEN", **kwargs) + + def slowlog_reset(self, **kwargs) -> ResponseT: + """ + Remove all items in the slowlog + + For more information see https://redis.io/commands/slowlog-reset + """ + return self.execute_command("SLOWLOG RESET", **kwargs) + + def time(self, **kwargs) -> ResponseT: + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + + For more information see https://redis.io/commands/time + """ + return self.execute_command("TIME", **kwargs) + + def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + + For more information see https://redis.io/commands/wait + """ + return self.execute_command("WAIT", num_replicas, timeout, **kwargs) + + def waitaof( + self, num_local: int, num_replicas: int, timeout: int, **kwargs + ) -> ResponseT: + """ + This command blocks the current client until all previous write + commands by that client are acknowledged as having been fsynced + to the AOF of the local Redis and/or at least the specified number + of replicas. + + For more information see https://redis.io/commands/waitaof + """ + return self.execute_command( + "WAITAOF", num_local, num_replicas, timeout, **kwargs + ) + + def hello(self): + """ + This function throws a NotImplementedError since it is intentionally + not supported. + """ + raise NotImplementedError( + "HELLO is intentionally not implemented in the client." + ) + + def failover(self): + """ + This function throws a NotImplementedError since it is intentionally + not supported. + """ + raise NotImplementedError( + "FAILOVER is intentionally not implemented in the client." + ) + + +AsyncManagementCommands = ManagementCommands + + +class AsyncManagementCommands(ManagementCommands): + async def command_info(self, **kwargs) -> None: + return super().command_info(**kwargs) + + async def debug_segfault(self, **kwargs) -> None: + return super().debug_segfault(**kwargs) + + async def memory_doctor(self, **kwargs) -> None: + return super().memory_doctor(**kwargs) + + async def memory_help(self, **kwargs) -> None: + return super().memory_help(**kwargs) + + async def shutdown( + self, + save: bool = False, + nosave: bool = False, + now: bool = False, + force: bool = False, + abort: bool = False, + **kwargs, + ) -> None: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + + For more information see https://redis.io/commands/shutdown + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + if now: + args.append("NOW") + if force: + args.append("FORCE") + if abort: + args.append("ABORT") + try: + await self.execute_command(*args, **kwargs) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__( + self, + client: Union["Redis", "AsyncRedis"], + key: str, + default_overflow: Union[str, None] = None, + ): + self.client = client + self.key = key + self._default_overflow = default_overflow + # for typing purposes, run the following in constructor and in reset() + self.operations: list[tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = "WRAP" + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow: str): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(("OVERFLOW", overflow)) + return self + + def incrby( + self, + fmt: str, + offset: BitfieldOffsetT, + increment: int, + overflow: Union[str, None] = None, + ): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(("INCRBY", fmt, offset, increment)) + return self + + def get(self, fmt: str, offset: BitfieldOffsetT): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("GET", fmt, offset)) + return self + + def set(self, fmt: str, offset: BitfieldOffsetT, value: int): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("SET", fmt, offset, value)) + return self + + @property + def command(self): + cmd = ["BITFIELD", self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self) -> ResponseT: + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) + + +class BasicKeyCommands(CommandsProtocol): + """ + Redis basic key-based commands + """ + + def append(self, key: KeyT, value: EncodableT) -> ResponseT: + """ + Appends the string ``value`` to the value at ``key``. If ``key`` + doesn't already exist, create it with a value of ``value``. + Returns the new length of the value at ``key``. + + For more information see https://redis.io/commands/append + """ + return self.execute_command("APPEND", key, value) + + def bitcount( + self, + key: KeyT, + start: Union[int, None] = None, + end: Union[int, None] = None, + mode: Optional[str] = None, + ) -> ResponseT: + """ + Returns the count of set bits in the value of ``key``. Optional + ``start`` and ``end`` parameters indicate which bytes to consider + + For more information see https://redis.io/commands/bitcount + """ + params = [key] + if start is not None and end is not None: + params.append(start) + params.append(end) + elif (start is not None and end is None) or (end is not None and start is None): + raise DataError("Both start and end must be specified") + if mode is not None: + params.append(mode) + return self.execute_command("BITCOUNT", *params) + + def bitfield( + self: Union["Redis", "AsyncRedis"], + key: KeyT, + default_overflow: Union[str, None] = None, + ) -> BitFieldOperation: + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + + For more information see https://redis.io/commands/bitfield + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + + def bitfield_ro( + self: Union["Redis", "AsyncRedis"], + key: KeyT, + encoding: str, + offset: BitfieldOffsetT, + items: Optional[list] = None, + ) -> ResponseT: + """ + Return an array of the specified bitfield values + where the first value is found using ``encoding`` and ``offset`` + parameters and remaining values are result of corresponding + encoding/offset pairs in optional list ``items`` + Read-only variant of the BITFIELD command. + + For more information see https://redis.io/commands/bitfield_ro + """ + params = [key, "GET", encoding, offset] + + items = items or [] + for encoding, offset in items: + params.extend(["GET", encoding, offset]) + return self.execute_command("BITFIELD_RO", *params) + + def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: + """ + Perform a bitwise operation using ``operation`` between ``keys`` and + store the result in ``dest``. + + For more information see https://redis.io/commands/bitop + """ + return self.execute_command("BITOP", operation, dest, *keys) + + def bitpos( + self, + key: KeyT, + bit: int, + start: Union[int, None] = None, + end: Union[int, None] = None, + mode: Optional[str] = None, + ) -> ResponseT: + """ + Return the position of the first bit set to 1 or 0 in a string. + ``start`` and ``end`` defines search range. The range is interpreted + as a range of bytes and not a range of bits, so start=0 and end=2 + means to look at the first three bytes. + + For more information see https://redis.io/commands/bitpos + """ + if bit not in (0, 1): + raise DataError("bit must be 0 or 1") + params = [key, bit] + + start is not None and params.append(start) + + if start is not None and end is not None: + params.append(end) + elif start is None and end is not None: + raise DataError("start argument is not set, when end is specified") + + if mode is not None: + params.append(mode) + return self.execute_command("BITPOS", *params) + + def copy( + self, + source: str, + destination: str, + destination_db: Union[str, None] = None, + replace: bool = False, + ) -> ResponseT: + """ + Copy the value stored in the ``source`` key to the ``destination`` key. + + ``destination_db`` an alternative destination database. By default, + the ``destination`` key is created in the source Redis database. + + ``replace`` whether the ``destination`` key should be removed before + copying the value to it. By default, the value is not copied if + the ``destination`` key already exists. + + For more information see https://redis.io/commands/copy + """ + params = [source, destination] + if destination_db is not None: + params.extend(["DB", destination_db]) + if replace: + params.append("REPLACE") + return self.execute_command("COPY", *params) + + def decrby(self, name: KeyT, amount: int = 1) -> ResponseT: + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + + For more information see https://redis.io/commands/decrby + """ + return self.execute_command("DECRBY", name, amount) + + decr = decrby + + def delete(self, *names: KeyT) -> ResponseT: + """ + Delete one or more keys specified by ``names`` + """ + return self.execute_command("DEL", *names) + + def __delitem__(self, name: KeyT): + self.delete(name) + + def dump(self, name: KeyT) -> ResponseT: + """ + Return a serialized version of the value stored at the specified key. + If key does not exist a nil bulk reply is returned. + + For more information see https://redis.io/commands/dump + """ + from redis.client import NEVER_DECODE + + options = {} + options[NEVER_DECODE] = [] + return self.execute_command("DUMP", name, **options) + + def exists(self, *names: KeyT) -> ResponseT: + """ + Returns the number of ``names`` that exist + + For more information see https://redis.io/commands/exists + """ + return self.execute_command("EXISTS", *names) + + __contains__ = exists + + def expire( + self, + name: KeyT, + time: ExpiryT, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: + """ + Set an expire flag on key ``name`` for ``time`` seconds with given + ``option``. ``time`` can be represented by an integer or a Python timedelta + object. + + Valid options are: + NX -> Set expiry only when the key has no expiry + XX -> Set expiry only when the key has an existing expiry + GT -> Set expiry only when the new expiry is greater than current one + LT -> Set expiry only when the new expiry is less than current one + + For more information see https://redis.io/commands/expire + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + + exp_option = list() + if nx: + exp_option.append("NX") + if xx: + exp_option.append("XX") + if gt: + exp_option.append("GT") + if lt: + exp_option.append("LT") + + return self.execute_command("EXPIRE", name, time, *exp_option) + + def expireat( + self, + name: KeyT, + when: AbsExpiryT, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: + """ + Set an expire flag on key ``name`` with given ``option``. ``when`` + can be represented as an integer indicating unix time or a Python + datetime object. + + Valid options are: + -> NX -- Set expiry only when the key has no expiry + -> XX -- Set expiry only when the key has an existing expiry + -> GT -- Set expiry only when the new expiry is greater than current one + -> LT -- Set expiry only when the new expiry is less than current one + + For more information see https://redis.io/commands/expireat + """ + if isinstance(when, datetime.datetime): + when = int(when.timestamp()) + + exp_option = list() + if nx: + exp_option.append("NX") + if xx: + exp_option.append("XX") + if gt: + exp_option.append("GT") + if lt: + exp_option.append("LT") + + return self.execute_command("EXPIREAT", name, when, *exp_option) + + def expiretime(self, key: str) -> int: + """ + Returns the absolute Unix timestamp (since January 1, 1970) in seconds + at which the given key will expire. + + For more information see https://redis.io/commands/expiretime + """ + return self.execute_command("EXPIRETIME", key) + + def get(self, name: KeyT) -> ResponseT: + """ + Return the value at key ``name``, or None if the key doesn't exist + + For more information see https://redis.io/commands/get + """ + return self.execute_command("GET", name) + + def getdel(self, name: KeyT) -> ResponseT: + """ + Get the value at key ``name`` and delete the key. This command + is similar to GET, except for the fact that it also deletes + the key on success (if and only if the key's value type + is a string). + + For more information see https://redis.io/commands/getdel + """ + return self.execute_command("GETDEL", name) + + def getex( + self, + name: KeyT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + persist: bool = False, + ) -> ResponseT: + """ + Get the value of key and optionally set its expiration. + GETEX is similar to GET, but is a write command with + additional options. All time parameters can be given as + datetime.timedelta or integers. + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on key ``name`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on key ``name`` for ``ex`` milliseconds, + specified in unix time. + + ``persist`` remove the time to live associated with ``name``. + + For more information see https://redis.io/commands/getex + """ + + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and persist: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``persist`` are mutually exclusive." + ) + + pieces: list[EncodableT] = [] + # similar to set command + if ex is not None: + pieces.append("EX") + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append("PX") + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + # similar to pexpireat command + if exat is not None: + pieces.append("EXAT") + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + pieces.append(exat) + if pxat is not None: + pieces.append("PXAT") + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + pieces.append(pxat) + if persist: + pieces.append("PERSIST") + + return self.execute_command("GETEX", name, *pieces) + + def __getitem__(self, name: KeyT): + """ + Return the value at key ``name``, raises a KeyError if the key + doesn't exist. + """ + value = self.get(name) + if value is not None: + return value + raise KeyError(name) + + def getbit(self, name: KeyT, offset: int) -> ResponseT: + """ + Returns an integer indicating the value of ``offset`` in ``name`` + + For more information see https://redis.io/commands/getbit + """ + return self.execute_command("GETBIT", name, offset) + + def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: + """ + Returns the substring of the string value stored at ``key``, + determined by the offsets ``start`` and ``end`` (both are inclusive) + + For more information see https://redis.io/commands/getrange + """ + return self.execute_command("GETRANGE", key, start, end) + + def getset(self, name: KeyT, value: EncodableT) -> ResponseT: + """ + Sets the value at key ``name`` to ``value`` + and returns the old value at key ``name`` atomically. + + As per Redis 6.2, GETSET is considered deprecated. + Please use SET with GET parameter in new code. + + For more information see https://redis.io/commands/getset + """ + return self.execute_command("GETSET", name, value) + + def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + + For more information see https://redis.io/commands/incrby + """ + return self.execute_command("INCRBY", name, amount) + + incr = incrby + + def incrbyfloat(self, name: KeyT, amount: float = 1.0) -> ResponseT: + """ + Increments the value at key ``name`` by floating ``amount``. + If no key exists, the value will be initialized as ``amount`` + + For more information see https://redis.io/commands/incrbyfloat + """ + return self.execute_command("INCRBYFLOAT", name, amount) + + def keys(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Returns a list of keys matching ``pattern`` + + For more information see https://redis.io/commands/keys + """ + return self.execute_command("KEYS", pattern, **kwargs) + + def lmove( + self, first_list: str, second_list: str, src: str = "LEFT", dest: str = "RIGHT" + ) -> ResponseT: + """ + Atomically returns and removes the first/last element of a list, + pushing it as the first/last element on the destination list. + Returns the element being popped and pushed. + + For more information see https://redis.io/commands/lmove + """ + params = [first_list, second_list, src, dest] + return self.execute_command("LMOVE", *params) + + def blmove( + self, + first_list: str, + second_list: str, + timeout: int, + src: str = "LEFT", + dest: str = "RIGHT", + ) -> ResponseT: + """ + Blocking version of lmove. + + For more information see https://redis.io/commands/blmove + """ + params = [first_list, second_list, src, dest, timeout] + return self.execute_command("BLMOVE", *params) + + def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT: + """ + Returns a list of values ordered identically to ``keys`` + + For more information see https://redis.io/commands/mget + """ + from redis.client import EMPTY_RESPONSE + + args = list_or_args(keys, args) + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command("MGET", *args, **options) + + def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + For more information see https://redis.io/commands/mset + """ + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("MSET", *items) + + def msetnx(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: + """ + Sets key/values based on a mapping if none of the keys are already set. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). + Returns a boolean indicating if the operation was successful. + + For more information see https://redis.io/commands/msetnx + """ + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("MSETNX", *items) + + def move(self, name: KeyT, db: int) -> ResponseT: + """ + Moves the key ``name`` to a different Redis database ``db`` + + For more information see https://redis.io/commands/move + """ + return self.execute_command("MOVE", name, db) + + def persist(self, name: KeyT) -> ResponseT: + """ + Removes an expiration on ``name`` + + For more information see https://redis.io/commands/persist + """ + return self.execute_command("PERSIST", name) + + def pexpire( + self, + name: KeyT, + time: ExpiryT, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: + """ + Set an expire flag on key ``name`` for ``time`` milliseconds + with given ``option``. ``time`` can be represented by an + integer or a Python timedelta object. + + Valid options are: + NX -> Set expiry only when the key has no expiry + XX -> Set expiry only when the key has an existing expiry + GT -> Set expiry only when the new expiry is greater than current one + LT -> Set expiry only when the new expiry is less than current one + + For more information see https://redis.io/commands/pexpire + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds() * 1000) + + exp_option = list() + if nx: + exp_option.append("NX") + if xx: + exp_option.append("XX") + if gt: + exp_option.append("GT") + if lt: + exp_option.append("LT") + return self.execute_command("PEXPIRE", name, time, *exp_option) + + def pexpireat( + self, + name: KeyT, + when: AbsExpiryT, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: + """ + Set an expire flag on key ``name`` with given ``option``. ``when`` + can be represented as an integer representing unix time in + milliseconds (unix time * 1000) or a Python datetime object. + + Valid options are: + NX -> Set expiry only when the key has no expiry + XX -> Set expiry only when the key has an existing expiry + GT -> Set expiry only when the new expiry is greater than current one + LT -> Set expiry only when the new expiry is less than current one + + For more information see https://redis.io/commands/pexpireat + """ + if isinstance(when, datetime.datetime): + when = int(when.timestamp() * 1000) + exp_option = list() + if nx: + exp_option.append("NX") + if xx: + exp_option.append("XX") + if gt: + exp_option.append("GT") + if lt: + exp_option.append("LT") + return self.execute_command("PEXPIREAT", name, when, *exp_option) + + def pexpiretime(self, key: str) -> int: + """ + Returns the absolute Unix timestamp (since January 1, 1970) in milliseconds + at which the given key will expire. + + For more information see https://redis.io/commands/pexpiretime + """ + return self.execute_command("PEXPIRETIME", key) + + def psetex(self, name: KeyT, time_ms: ExpiryT, value: EncodableT): + """ + Set the value of key ``name`` to ``value`` that expires in ``time_ms`` + milliseconds. ``time_ms`` can be represented by an integer or a Python + timedelta object + + For more information see https://redis.io/commands/psetex + """ + if isinstance(time_ms, datetime.timedelta): + time_ms = int(time_ms.total_seconds() * 1000) + return self.execute_command("PSETEX", name, time_ms, value) + + def pttl(self, name: KeyT) -> ResponseT: + """ + Returns the number of milliseconds until the key ``name`` will expire + + For more information see https://redis.io/commands/pttl + """ + return self.execute_command("PTTL", name) + + def hrandfield( + self, key: str, count: int = None, withvalues: bool = False + ) -> ResponseT: + """ + Return a random field from the hash value stored at key. + + count: if the argument is positive, return an array of distinct fields. + If called with a negative count, the behavior changes and the command + is allowed to return the same field multiple times. In this case, + the number of returned fields is the absolute value of the + specified count. + withvalues: The optional WITHVALUES modifier changes the reply so it + includes the respective values of the randomly selected hash fields. + + For more information see https://redis.io/commands/hrandfield + """ + params = [] + if count is not None: + params.append(count) + if withvalues: + params.append("WITHVALUES") + + return self.execute_command("HRANDFIELD", key, *params) + + def randomkey(self, **kwargs) -> ResponseT: + """ + Returns the name of a random key + + For more information see https://redis.io/commands/randomkey + """ + return self.execute_command("RANDOMKEY", **kwargs) + + def rename(self, src: KeyT, dst: KeyT) -> ResponseT: + """ + Rename key ``src`` to ``dst`` + + For more information see https://redis.io/commands/rename + """ + return self.execute_command("RENAME", src, dst) + + def renamenx(self, src: KeyT, dst: KeyT): + """ + Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist + + For more information see https://redis.io/commands/renamenx + """ + return self.execute_command("RENAMENX", src, dst) + + def restore( + self, + name: KeyT, + ttl: float, + value: EncodableT, + replace: bool = False, + absttl: bool = False, + idletime: Union[int, None] = None, + frequency: Union[int, None] = None, + ) -> ResponseT: + """ + Create a key using the provided serialized value, previously obtained + using DUMP. + + ``replace`` allows an existing key on ``name`` to be overridden. If + it's not specified an error is raised on collision. + + ``absttl`` if True, specified ``ttl`` should represent an absolute Unix + timestamp in milliseconds in which the key will expire. (Redis 5.0 or + greater). + + ``idletime`` Used for eviction, this is the number of seconds the + key must be idle, prior to execution. + + ``frequency`` Used for eviction, this is the frequency counter of + the object stored at the key, prior to execution. + + For more information see https://redis.io/commands/restore + """ + params = [name, ttl, value] + if replace: + params.append("REPLACE") + if absttl: + params.append("ABSTTL") + if idletime is not None: + params.append("IDLETIME") + try: + params.append(int(idletime)) + except ValueError: + raise DataError("idletimemust be an integer") + + if frequency is not None: + params.append("FREQ") + try: + params.append(int(frequency)) + except ValueError: + raise DataError("frequency must be an integer") + + return self.execute_command("RESTORE", *params) + + def set( + self, + name: KeyT, + value: EncodableT, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, + ) -> ResponseT: + """ + Set the value at key ``name`` to ``value`` + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. + + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. + + ``keepttl`` if True, retain the time to live associated with the key. + (Available since Redis 6.0) + + ``get`` if True, set the value at key ``name`` to ``value`` and return + the old value stored at key, or None if the key did not exist. + (Available since Redis 6.2) + + ``exat`` sets an expire flag on key ``name`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on key ``name`` for ``ex`` milliseconds, + specified in unix time. + + For more information see https://redis.io/commands/set + """ + pieces: list[EncodableT] = [name, value] + options = {} + if ex is not None: + pieces.append("EX") + if isinstance(ex, datetime.timedelta): + pieces.append(int(ex.total_seconds())) + elif isinstance(ex, int): + pieces.append(ex) + elif isinstance(ex, str) and ex.isdigit(): + pieces.append(int(ex)) + else: + raise DataError("ex must be datetime.timedelta or int") + if px is not None: + pieces.append("PX") + if isinstance(px, datetime.timedelta): + pieces.append(int(px.total_seconds() * 1000)) + elif isinstance(px, int): + pieces.append(px) + else: + raise DataError("px must be datetime.timedelta or int") + if exat is not None: + pieces.append("EXAT") + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + pieces.append(exat) + if pxat is not None: + pieces.append("PXAT") + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + pieces.append(pxat) + if keepttl: + pieces.append("KEEPTTL") + + if nx: + pieces.append("NX") + if xx: + pieces.append("XX") + + if get: + pieces.append("GET") + options["get"] = True + + return self.execute_command("SET", *pieces, **options) + + def __setitem__(self, name: KeyT, value: EncodableT): + self.set(name, value) + + def setbit(self, name: KeyT, offset: int, value: int) -> ResponseT: + """ + Flag the ``offset`` in ``name`` as ``value``. Returns an integer + indicating the previous value of ``offset``. + + For more information see https://redis.io/commands/setbit + """ + value = value and 1 or 0 + return self.execute_command("SETBIT", name, offset, value) + + def setex(self, name: KeyT, time: ExpiryT, value: EncodableT) -> ResponseT: + """ + Set the value of key ``name`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + + For more information see https://redis.io/commands/setex + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + return self.execute_command("SETEX", name, time, value) + + def setnx(self, name: KeyT, value: EncodableT) -> ResponseT: + """ + Set the value of key ``name`` to ``value`` if key doesn't exist + + For more information see https://redis.io/commands/setnx + """ + return self.execute_command("SETNX", name, value) + + def setrange(self, name: KeyT, offset: int, value: EncodableT) -> ResponseT: + """ + Overwrite bytes in the value of ``name`` starting at ``offset`` with + ``value``. If ``offset`` plus the length of ``value`` exceeds the + length of the original value, the new value will be larger than before. + If ``offset`` exceeds the length of the original value, null bytes + will be used to pad between the end of the previous value and the start + of what's being injected. + + Returns the length of the new string. + + For more information see https://redis.io/commands/setrange + """ + return self.execute_command("SETRANGE", name, offset, value) + + def stralgo( + self, + algo: Literal["LCS"], + value1: KeyT, + value2: KeyT, + specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", + len: bool = False, + idx: bool = False, + minmatchlen: Union[int, None] = None, + withmatchlen: bool = False, + **kwargs, + ) -> ResponseT: + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + + For more information see https://redis.io/commands/stralgo + """ + # check validity + supported_algo = ["LCS"] + if algo not in supported_algo: + supported_algos_str = ", ".join(supported_algo) + raise DataError(f"The supported algorithms are: {supported_algos_str}") + if specific_argument not in ["keys", "strings"]: + raise DataError("specific_argument can be only keys or strings") + if len and idx: + raise DataError("len and idx cannot be provided together.") + + pieces: list[EncodableT] = [algo, specific_argument.upper(), value1, value2] + if len: + pieces.append(b"LEN") + if idx: + pieces.append(b"IDX") + try: + int(minmatchlen) + pieces.extend([b"MINMATCHLEN", minmatchlen]) + except TypeError: + pass + if withmatchlen: + pieces.append(b"WITHMATCHLEN") + + return self.execute_command( + "STRALGO", + *pieces, + len=len, + idx=idx, + minmatchlen=minmatchlen, + withmatchlen=withmatchlen, + **kwargs, + ) + + def strlen(self, name: KeyT) -> ResponseT: + """ + Return the number of bytes stored in the value of ``name`` + + For more information see https://redis.io/commands/strlen + """ + return self.execute_command("STRLEN", name) + + def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: + """ + Return a substring of the string at key ``name``. ``start`` and ``end`` + are 0-based integers specifying the portion of the string to return. + """ + return self.execute_command("SUBSTR", name, start, end) + + def touch(self, *args: KeyT) -> ResponseT: + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + + For more information see https://redis.io/commands/touch + """ + return self.execute_command("TOUCH", *args) + + def ttl(self, name: KeyT) -> ResponseT: + """ + Returns the number of seconds until the key ``name`` will expire + + For more information see https://redis.io/commands/ttl + """ + return self.execute_command("TTL", name) + + def type(self, name: KeyT) -> ResponseT: + """ + Returns the type of key ``name`` + + For more information see https://redis.io/commands/type + """ + return self.execute_command("TYPE", name) + + def watch(self, *names: KeyT) -> None: + """ + Watches the values at keys ``names``, or None if the key doesn't exist + + For more information see https://redis.io/commands/watch + """ + warnings.warn(DeprecationWarning("Call WATCH from a Pipeline object")) + + def unwatch(self) -> None: + """ + Unwatches the value at key ``name``, or None of the key doesn't exist + + For more information see https://redis.io/commands/unwatch + """ + warnings.warn(DeprecationWarning("Call UNWATCH from a Pipeline object")) + + def unlink(self, *names: KeyT) -> ResponseT: + """ + Unlink one or more keys specified by ``names`` + + For more information see https://redis.io/commands/unlink + """ + return self.execute_command("UNLINK", *names) + + def lcs( + self, + key1: str, + key2: str, + len: Optional[bool] = False, + idx: Optional[bool] = False, + minmatchlen: Optional[int] = 0, + withmatchlen: Optional[bool] = False, + ) -> Union[str, int, list]: + """ + Find the longest common subsequence between ``key1`` and ``key2``. + If ``len`` is true the length of the match will will be returned. + If ``idx`` is true the match position in each strings will be returned. + ``minmatchlen`` restrict the list of matches to the ones of + the given ``minmatchlen``. + If ``withmatchlen`` the length of the match also will be returned. + For more information see https://redis.io/commands/lcs + """ + pieces = [key1, key2] + if len: + pieces.append("LEN") + if idx: + pieces.append("IDX") + if minmatchlen != 0: + pieces.extend(["MINMATCHLEN", minmatchlen]) + if withmatchlen: + pieces.append("WITHMATCHLEN") + return self.execute_command("LCS", *pieces) + + +class AsyncBasicKeyCommands(BasicKeyCommands): + def __delitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class deletion") + + def __contains__(self, name: KeyT): + raise TypeError("Async Redis client does not support class inclusion") + + def __getitem__(self, name: KeyT): + raise TypeError("Async Redis client does not support class retrieval") + + def __setitem__(self, name: KeyT, value: EncodableT): + raise TypeError("Async Redis client does not support class assignment") + + async def watch(self, *names: KeyT) -> None: + return super().watch(*names) + + async def unwatch(self) -> None: + return super().unwatch() + + +class ListCommands(CommandsProtocol): + """ + Redis commands for List data type. + see: https://redis.io/topics/data-types#lists + """ + + def blpop( + self, keys: List, timeout: Optional[int] = 0 + ) -> Union[Awaitable[list], list]: + """ + LPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to LPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + + For more information see https://redis.io/commands/blpop + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BLPOP", *keys) + + def brpop( + self, keys: List, timeout: Optional[int] = 0 + ) -> Union[Awaitable[list], list]: + """ + RPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to RPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + + For more information see https://redis.io/commands/brpop + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BRPOP", *keys) + + def brpoplpush( + self, src: str, dst: str, timeout: Optional[int] = 0 + ) -> Union[Awaitable[Optional[str]], Optional[str]]: + """ + Pop a value off the tail of ``src``, push it on the head of ``dst`` + and then return it. + + This command blocks until a value is in ``src`` or until ``timeout`` + seconds elapse, whichever is first. A ``timeout`` value of 0 blocks + forever. + + For more information see https://redis.io/commands/brpoplpush + """ + if timeout is None: + timeout = 0 + return self.execute_command("BRPOPLPUSH", src, dst, timeout) + + def blmpop( + self, + timeout: float, + numkeys: int, + *args: List[str], + direction: str, + count: Optional[int] = 1, + ) -> Optional[list]: + """ + Pop ``count`` values (default 1) from first non-empty in the list + of provided key names. + + When all lists are empty this command blocks the connection until another + client pushes to it or until the timeout, timeout of 0 blocks indefinitely + + For more information see https://redis.io/commands/blmpop + """ + args = [timeout, numkeys, *args, direction, "COUNT", count] + + return self.execute_command("BLMPOP", *args) + + def lmpop( + self, + num_keys: int, + *args: List[str], + direction: str, + count: Optional[int] = 1, + ) -> Union[Awaitable[list], list]: + """ + Pop ``count`` values (default 1) first non-empty list key from the list + of args provided key names. + + For more information see https://redis.io/commands/lmpop + """ + args = [num_keys] + list(args) + [direction] + if count != 1: + args.extend(["COUNT", count]) + + return self.execute_command("LMPOP", *args) + + def lindex( + self, name: str, index: int + ) -> Union[Awaitable[Optional[str]], Optional[str]]: + """ + Return the item from list ``name`` at position ``index`` + + Negative indexes are supported and will return an item at the + end of the list + + For more information see https://redis.io/commands/lindex + """ + return self.execute_command("LINDEX", name, index) + + def linsert( + self, name: str, where: str, refvalue: str, value: str + ) -> Union[Awaitable[int], int]: + """ + Insert ``value`` in list ``name`` either immediately before or after + [``where``] ``refvalue`` + + Returns the new length of the list on success or -1 if ``refvalue`` + is not in the list. + + For more information see https://redis.io/commands/linsert + """ + return self.execute_command("LINSERT", name, where, refvalue, value) + + def llen(self, name: str) -> Union[Awaitable[int], int]: + """ + Return the length of the list ``name`` + + For more information see https://redis.io/commands/llen + """ + return self.execute_command("LLEN", name) + + def lpop( + self, + name: str, + count: Optional[int] = None, + ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + """ + Removes and returns the first elements of the list ``name``. + + By default, the command pops a single element from the beginning of + the list. When provided with the optional ``count`` argument, the reply + will consist of up to count elements, depending on the list's length. + + For more information see https://redis.io/commands/lpop + """ + if count is not None: + return self.execute_command("LPOP", name, count) + else: + return self.execute_command("LPOP", name) + + def lpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + """ + Push ``values`` onto the head of the list ``name`` + + For more information see https://redis.io/commands/lpush + """ + return self.execute_command("LPUSH", name, *values) + + def lpushx(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + """ + Push ``value`` onto the head of the list ``name`` if ``name`` exists + + For more information see https://redis.io/commands/lpushx + """ + return self.execute_command("LPUSHX", name, *values) + + def lrange(self, name: str, start: int, end: int) -> Union[Awaitable[list], list]: + """ + Return a slice of the list ``name`` between + position ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + + For more information see https://redis.io/commands/lrange + """ + return self.execute_command("LRANGE", name, start, end) + + def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: + """ + Remove the first ``count`` occurrences of elements equal to ``value`` + from the list stored at ``name``. + + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to value moving from head to tail. + count < 0: Remove elements equal to value moving from tail to head. + count = 0: Remove all elements equal to value. + + For more information see https://redis.io/commands/lrem + """ + return self.execute_command("LREM", name, count, value) + + def lset(self, name: str, index: int, value: str) -> Union[Awaitable[str], str]: + """ + Set element at ``index`` of list ``name`` to ``value`` + + For more information see https://redis.io/commands/lset + """ + return self.execute_command("LSET", name, index, value) + + def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: + """ + Trim the list ``name``, removing all values not within the slice + between ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + + For more information see https://redis.io/commands/ltrim + """ + return self.execute_command("LTRIM", name, start, end) + + def rpop( + self, + name: str, + count: Optional[int] = None, + ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: + """ + Removes and returns the last elements of the list ``name``. + + By default, the command pops a single element from the end of the list. + When provided with the optional ``count`` argument, the reply will + consist of up to count elements, depending on the list's length. + + For more information see https://redis.io/commands/rpop + """ + if count is not None: + return self.execute_command("RPOP", name, count) + else: + return self.execute_command("RPOP", name) + + def rpoplpush(self, src: str, dst: str) -> Union[Awaitable[str], str]: + """ + RPOP a value off of the ``src`` list and atomically LPUSH it + on to the ``dst`` list. Returns the value. + + For more information see https://redis.io/commands/rpoplpush + """ + return self.execute_command("RPOPLPUSH", src, dst) + + def rpush(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + """ + Push ``values`` onto the tail of the list ``name`` + + For more information see https://redis.io/commands/rpush + """ + return self.execute_command("RPUSH", name, *values) + + def rpushx(self, name: str, *values: str) -> Union[Awaitable[int], int]: + """ + Push ``value`` onto the tail of the list ``name`` if ``name`` exists + + For more information see https://redis.io/commands/rpushx + """ + return self.execute_command("RPUSHX", name, *values) + + def lpos( + self, + name: str, + value: str, + rank: Optional[int] = None, + count: Optional[int] = None, + maxlen: Optional[int] = None, + ) -> Union[str, List, None]: + """ + Get position of ``value`` within the list ``name`` + + If specified, ``rank`` indicates the "rank" of the first element to + return in case there are multiple copies of ``value`` in the list. + By default, LPOS returns the position of the first occurrence of + ``value`` in the list. When ``rank`` 2, LPOS returns the position of + the second ``value`` in the list. If ``rank`` is negative, LPOS + searches the list in reverse. For example, -1 would return the + position of the last occurrence of ``value`` and -2 would return the + position of the next to last occurrence of ``value``. + + If specified, ``count`` indicates that LPOS should return a list of + up to ``count`` positions. A ``count`` of 2 would return a list of + up to 2 positions. A ``count`` of 0 returns a list of all positions + matching ``value``. When ``count`` is specified and but ``value`` + does not exist in the list, an empty list is returned. + + If specified, ``maxlen`` indicates the maximum number of list + elements to scan. A ``maxlen`` of 1000 will only return the + position(s) of items within the first 1000 entries in the list. + A ``maxlen`` of 0 (the default) will scan the entire list. + + For more information see https://redis.io/commands/lpos + """ + pieces: list[EncodableT] = [name, value] + if rank is not None: + pieces.extend(["RANK", rank]) + + if count is not None: + pieces.extend(["COUNT", count]) + + if maxlen is not None: + pieces.extend(["MAXLEN", maxlen]) + + return self.execute_command("LPOS", *pieces) + + def sort( + self, + name: str, + start: Optional[int] = None, + num: Optional[int] = None, + by: Optional[str] = None, + get: Optional[List[str]] = None, + desc: bool = False, + alpha: bool = False, + store: Optional[str] = None, + groups: Optional[bool] = False, + ) -> Union[List, int]: + """ + Sort and return the list, set or sorted set at ``name``. + + ``start`` and ``num`` allow for paging through the sorted data + + ``by`` allows using an external key to weight and sort the items. + Use an "*" to indicate where in the key the item value is located + + ``get`` allows for returning items from external keys rather than the + sorted data itself. Use an "*" to indicate where in the key + the item value is located + + ``desc`` allows for reversing the sort + + ``alpha`` allows for sorting lexicographically rather than numerically + + ``store`` allows for storing the result of the sort into + the key ``store`` + + ``groups`` if set to True and if ``get`` contains at least two + elements, sort will return a list of tuples, each containing the + values fetched from the arguments to ``get``. + + For more information see https://redis.io/commands/sort + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + + pieces: list[EncodableT] = [name] + if by is not None: + pieces.extend([b"BY", by]) + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + if get is not None: + # If get is a string assume we want to get a single value. + # Otherwise assume it's an interable and we want to get multiple + # values. We can't just iterate blindly because strings are + # iterable. + if isinstance(get, (bytes, str)): + pieces.extend([b"GET", get]) + else: + for g in get: + pieces.extend([b"GET", g]) + if desc: + pieces.append(b"DESC") + if alpha: + pieces.append(b"ALPHA") + if store is not None: + pieces.extend([b"STORE", store]) + if groups: + if not get or isinstance(get, (bytes, str)) or len(get) < 2: + raise DataError( + 'when using "groups" the "get" argument ' + "must be specified and contain at least " + "two keys" + ) + + options = {"groups": len(get) if groups else None} + return self.execute_command("SORT", *pieces, **options) + + def sort_ro( + self, + key: str, + start: Optional[int] = None, + num: Optional[int] = None, + by: Optional[str] = None, + get: Optional[List[str]] = None, + desc: bool = False, + alpha: bool = False, + ) -> list: + """ + Returns the elements contained in the list, set or sorted set at key. + (read-only variant of the SORT command) + + ``start`` and ``num`` allow for paging through the sorted data + + ``by`` allows using an external key to weight and sort the items. + Use an "*" to indicate where in the key the item value is located + + ``get`` allows for returning items from external keys rather than the + sorted data itself. Use an "*" to indicate where in the key + the item value is located + + ``desc`` allows for reversing the sort + + ``alpha`` allows for sorting lexicographically rather than numerically + + For more information see https://redis.io/commands/sort_ro + """ + return self.sort( + key, start=start, num=num, by=by, get=get, desc=desc, alpha=alpha + ) + + +AsyncListCommands = ListCommands + + +class ScanCommands(CommandsProtocol): + """ + Redis SCAN commands. + see: https://redis.io/commands/scan + """ + + def scan( + self, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> ResponseT: + """ + Incrementally return lists of key names. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + + For more information see https://redis.io/commands/scan + """ + pieces: list[EncodableT] = [cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + if _type is not None: + pieces.extend([b"TYPE", _type]) + return self.execute_command("SCAN", *pieces, **kwargs) + + def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> Iterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = self.scan( + cursor=cursor, match=match, count=count, _type=_type, **kwargs + ) + yield from data + + def sscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: + """ + Incrementally return lists of elements in a set. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + For more information see https://redis.io/commands/sscan + """ + pieces: list[EncodableT] = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + return self.execute_command("SSCAN", *pieces) + + def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = self.sscan(name, cursor=cursor, match=match, count=count) + yield from data + + def hscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> ResponseT: + """ + Incrementally return key/value slices in a hash. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + For more information see https://redis.io/commands/hscan + """ + pieces: list[EncodableT] = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + return self.execute_command("HSCAN", *pieces) + + def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> Iterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) + yield from data.items() + + def zscan( + self, + name: KeyT, + cursor: int = 0, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: + """ + Incrementally return lists of elements in a sorted set. Also return a + cursor indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + + For more information see https://redis.io/commands/zscan + """ + pieces = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + options = {"score_cast_func": score_cast_func} + return self.execute_command("ZSCAN", *pieces, **options) + + def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> Iterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + yield from data + + +class AsyncScanCommands(ScanCommands): + async def scan_iter( + self, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, + **kwargs, + ) -> AsyncIterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.scan( + cursor=cursor, match=match, count=count, _type=_type, **kwargs + ) + for d in data: + yield d + + async def sscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.sscan( + name, cursor=cursor, match=match, count=count + ) + for d in data: + yield d + + async def hscan_iter( + self, + name: str, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + ) -> AsyncIterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.hscan( + name, cursor=cursor, match=match, count=count + ) + for it in data.items(): + yield it + + async def zscan_iter( + self, + name: KeyT, + match: Union[PatternT, None] = None, + count: Union[int, None] = None, + score_cast_func: Union[type, Callable] = float, + ) -> AsyncIterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + for d in data: + yield d + + +class SetCommands(CommandsProtocol): + """ + Redis commands for Set data type. + see: https://redis.io/topics/data-types#sets + """ + + def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + """ + Add ``value(s)`` to set ``name`` + + For more information see https://redis.io/commands/sadd + """ + return self.execute_command("SADD", name, *values) + + def scard(self, name: str) -> Union[Awaitable[int], int]: + """ + Return the number of elements in set ``name`` + + For more information see https://redis.io/commands/scard + """ + return self.execute_command("SCARD", name) + + def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + """ + Return the difference of sets specified by ``keys`` + + For more information see https://redis.io/commands/sdiff + """ + args = list_or_args(keys, args) + return self.execute_command("SDIFF", *args) + + def sdiffstore( + self, dest: str, keys: List, *args: List + ) -> Union[Awaitable[int], int]: + """ + Store the difference of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + + For more information see https://redis.io/commands/sdiffstore + """ + args = list_or_args(keys, args) + return self.execute_command("SDIFFSTORE", dest, *args) + + def sinter(self, keys: List, *args: List) -> Union[Awaitable[list], list]: + """ + Return the intersection of sets specified by ``keys`` + + For more information see https://redis.io/commands/sinter + """ + args = list_or_args(keys, args) + return self.execute_command("SINTER", *args) + + def sintercard( + self, numkeys: int, keys: List[str], limit: int = 0 + ) -> Union[Awaitable[int], int]: + """ + Return the cardinality of the intersect of multiple sets specified by ``keys`. + + When LIMIT provided (defaults to 0 and means unlimited), if the intersection + cardinality reaches limit partway through the computation, the algorithm will + exit and yield limit as the cardinality + + For more information see https://redis.io/commands/sintercard + """ + args = [numkeys, *keys, "LIMIT", limit] + return self.execute_command("SINTERCARD", *args) + + def sinterstore( + self, dest: str, keys: List, *args: List + ) -> Union[Awaitable[int], int]: + """ + Store the intersection of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + + For more information see https://redis.io/commands/sinterstore + """ + args = list_or_args(keys, args) + return self.execute_command("SINTERSTORE", dest, *args) + + def sismember( + self, name: str, value: str + ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: + """ + Return whether ``value`` is a member of set ``name``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. + + For more information see https://redis.io/commands/sismember + """ + return self.execute_command("SISMEMBER", name, value) + + def smembers(self, name: str) -> Union[Awaitable[Set], Set]: + """ + Return all members of the set ``name`` + + For more information see https://redis.io/commands/smembers + """ + return self.execute_command("SMEMBERS", name) + + def smismember( + self, name: str, values: List, *args: List + ) -> Union[ + Awaitable[List[Union[Literal[0], Literal[1]]]], + List[Union[Literal[0], Literal[1]]], + ]: + """ + Return whether each value in ``values`` is a member of the set ``name`` + as a list of ``int`` in the order of ``values``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. + + For more information see https://redis.io/commands/smismember + """ + args = list_or_args(values, args) + return self.execute_command("SMISMEMBER", name, *args) + + def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: + """ + Move ``value`` from set ``src`` to set ``dst`` atomically + + For more information see https://redis.io/commands/smove + """ + return self.execute_command("SMOVE", src, dst, value) + + def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + """ + Remove and return a random member of set ``name`` + + For more information see https://redis.io/commands/spop + """ + args = (count is not None) and [count] or [] + return self.execute_command("SPOP", name, *args) + + def srandmember( + self, name: str, number: Optional[int] = None + ) -> Union[str, List, None]: + """ + If ``number`` is None, returns a random member of set ``name``. + + If ``number`` is supplied, returns a list of ``number`` random + members of set ``name``. Note this is only available when running + Redis 2.6+. + + For more information see https://redis.io/commands/srandmember + """ + args = (number is not None) and [number] or [] + return self.execute_command("SRANDMEMBER", name, *args) + + def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: + """ + Remove ``values`` from set ``name`` + + For more information see https://redis.io/commands/srem + """ + return self.execute_command("SREM", name, *values) + + def sunion(self, keys: List, *args: List) -> Union[Awaitable[List], List]: + """ + Return the union of sets specified by ``keys`` + + For more information see https://redis.io/commands/sunion + """ + args = list_or_args(keys, args) + return self.execute_command("SUNION", *args) + + def sunionstore( + self, dest: str, keys: List, *args: List + ) -> Union[Awaitable[int], int]: + """ + Store the union of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + + For more information see https://redis.io/commands/sunionstore + """ + args = list_or_args(keys, args) + return self.execute_command("SUNIONSTORE", dest, *args) + + +AsyncSetCommands = SetCommands + + +class StreamCommands(CommandsProtocol): + """ + Redis commands for Stream data type. + see: https://redis.io/topics/streams-intro + """ + + def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowledge. + + For more information see https://redis.io/commands/xack + """ + return self.execute_command("XACK", name, groupname, *ids) + + def xadd( + self, + name: KeyT, + fields: Dict[FieldT, EncodableT], + id: StreamIdT = "*", + maxlen: Union[int, None] = None, + approximate: bool = True, + nomkstream: bool = False, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size. + Can't be specified with minid. + approximate: actual stream length may be slightly more than maxlen + nomkstream: When set to true, do not make a stream + minid: the minimum id in the stream to query. + Can't be specified with maxlen. + limit: specifies the maximum number of entries to retrieve + + For more information see https://redis.io/commands/xadd + """ + pieces: list[EncodableT] = [] + if maxlen is not None and minid is not None: + raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") + + if maxlen is not None: + if not isinstance(maxlen, int) or maxlen < 0: + raise DataError("XADD maxlen must be non-negative integer") + pieces.append(b"MAXLEN") + if approximate: + pieces.append(b"~") + pieces.append(str(maxlen)) + if minid is not None: + pieces.append(b"MINID") + if approximate: + pieces.append(b"~") + pieces.append(minid) + if limit is not None: + pieces.extend([b"LIMIT", limit]) + if nomkstream: + pieces.append(b"NOMKSTREAM") + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError("XADD fields must be a non-empty dict") + for pair in fields.items(): + pieces.extend(pair) + return self.execute_command("XADD", name, *pieces) + + def xautoclaim( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + start_id: StreamIdT = "0-0", + count: Union[int, None] = None, + justid: bool = False, + ) -> ResponseT: + """ + Transfers ownership of pending stream entries that match the specified + criteria. Conceptually, equivalent to calling XPENDING and then XCLAIM, + but provides a more straightforward way to deal with message delivery + failures via SCAN-like semantics. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds. + start_id: filter messages with equal or greater ID. + count: optional integer, upper limit of the number of entries that the + command attempts to claim. Set to 100 by default. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + + For more information see https://redis.io/commands/xautoclaim + """ + try: + if int(min_idle_time) < 0: + raise DataError( + "XAUTOCLAIM min_idle_time must be a nonnegative integer" + ) + except TypeError: + pass + + kwargs = {} + pieces = [name, groupname, consumername, min_idle_time, start_id] + + try: + if int(count) < 0: + raise DataError("XPENDING count must be a integer >= 0") + pieces.extend([b"COUNT", count]) + except TypeError: + pass + if justid: + pieces.append(b"JUSTID") + kwargs["parse_justid"] = True + + return self.execute_command("XAUTOCLAIM", *pieces, **kwargs) + + def xclaim( + self, + name: KeyT, + groupname: GroupT, + consumername: ConsumerT, + min_idle_time: int, + message_ids: Union[List[StreamIdT], Tuple[StreamIdT]], + idle: Union[int, None] = None, + time: Union[int, None] = None, + retrycount: Union[int, None] = None, + force: bool = False, + justid: bool = False, + ) -> ResponseT: + """ + Changes the ownership of a pending message. + + name: name of the stream. + + groupname: name of the consumer group. + + consumername: name of a consumer that claims the message. + + min_idle_time: filter messages that were idle less than this amount of + milliseconds + + message_ids: non-empty list or tuple of message IDs to claim + + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + + For more information see https://redis.io/commands/xclaim + """ + if not isinstance(min_idle_time, int) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError( + "XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim" + ) + + kwargs = {} + pieces: list[EncodableT] = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, int): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((b"IDLE", str(idle))) + if time is not None: + if not isinstance(time, int): + raise DataError("XCLAIM time must be an integer") + pieces.extend((b"TIME", str(time))) + if retrycount is not None: + if not isinstance(retrycount, int): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((b"RETRYCOUNT", str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(b"FORCE") + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(b"JUSTID") + kwargs["parse_justid"] = True + return self.execute_command("XCLAIM", *pieces, **kwargs) + + def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + + For more information see https://redis.io/commands/xdel + """ + return self.execute_command("XDEL", name, *ids) + + def xgroup_create( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT = "$", + mkstream: bool = False, + entries_read: Optional[int] = None, + ) -> ResponseT: + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + + For more information see https://redis.io/commands/xgroup-create + """ + pieces: list[EncodableT] = ["XGROUP CREATE", name, groupname, id] + if mkstream: + pieces.append(b"MKSTREAM") + if entries_read is not None: + pieces.extend(["ENTRIESREAD", entries_read]) + + return self.execute_command(*pieces) + + def xgroup_delconsumer( + self, name: KeyT, groupname: GroupT, consumername: ConsumerT + ) -> ResponseT: + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + + For more information see https://redis.io/commands/xgroup-delconsumer + """ + return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) + + def xgroup_destroy(self, name: KeyT, groupname: GroupT) -> ResponseT: + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + + For more information see https://redis.io/commands/xgroup-destroy + """ + return self.execute_command("XGROUP DESTROY", name, groupname) + + def xgroup_createconsumer( + self, name: KeyT, groupname: GroupT, consumername: ConsumerT + ) -> ResponseT: + """ + Consumers in a consumer group are auto-created every time a new + consumer name is mentioned by some command. + They can be explicitly created by using this command. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to create. + + See: https://redis.io/commands/xgroup-createconsumer + """ + return self.execute_command( + "XGROUP CREATECONSUMER", name, groupname, consumername + ) + + def xgroup_setid( + self, + name: KeyT, + groupname: GroupT, + id: StreamIdT, + entries_read: Optional[int] = None, + ) -> ResponseT: + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + + For more information see https://redis.io/commands/xgroup-setid + """ + pieces = [name, groupname, id] + if entries_read is not None: + pieces.extend(["ENTRIESREAD", entries_read]) + return self.execute_command("XGROUP SETID", *pieces) + + def xinfo_consumers(self, name: KeyT, groupname: GroupT) -> ResponseT: + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + + For more information see https://redis.io/commands/xinfo-consumers + """ + return self.execute_command("XINFO CONSUMERS", name, groupname) + + def xinfo_groups(self, name: KeyT) -> ResponseT: + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + + For more information see https://redis.io/commands/xinfo-groups + """ + return self.execute_command("XINFO GROUPS", name) + + def xinfo_stream(self, name: KeyT, full: bool = False) -> ResponseT: + """ + Returns general information about the stream. + name: name of the stream. + full: optional boolean, false by default. Return full summary + + For more information see https://redis.io/commands/xinfo-stream + """ + pieces = [name] + options = {} + if full: + pieces.append(b"FULL") + options = {"full": full} + return self.execute_command("XINFO STREAM", *pieces, **options) + + def xlen(self, name: KeyT) -> ResponseT: + """ + Returns the number of elements in a given stream. + + For more information see https://redis.io/commands/xlen + """ + return self.execute_command("XLEN", name) + + def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + + For more information see https://redis.io/commands/xpending + """ + return self.execute_command("XPENDING", name, groupname) + + def xpending_range( + self, + name: KeyT, + groupname: GroupT, + min: StreamIdT, + max: StreamIdT, + count: int, + consumername: Union[ConsumerT, None] = None, + idle: Union[int, None] = None, + ) -> ResponseT: + """ + Returns information about pending messages, in a range. + + name: name of the stream. + groupname: name of the consumer group. + idle: available from version 6.2. filter entries by their + idle-time, given in milliseconds (optional). + min: minimum stream ID. + max: maximum stream ID. + count: number of messages to return + consumername: name of a consumer to filter by (optional). + """ + if {min, max, count} == {None}: + if idle is not None or consumername is not None: + raise DataError( + "if XPENDING is provided with idle time" + " or consumername, it must be provided" + " with min, max and count parameters" + ) + return self.xpending(name, groupname) + + pieces = [name, groupname] + if min is None or max is None or count is None: + raise DataError( + "XPENDING must be provided with min, max " + "and count parameters, or none of them." + ) + # idle + try: + if int(idle) < 0: + raise DataError("XPENDING idle must be a integer >= 0") + pieces.extend(["IDLE", idle]) + except TypeError: + pass + # count + try: + if int(count) < 0: + raise DataError("XPENDING count must be a integer >= 0") + pieces.extend([min, max, count]) + except TypeError: + pass + # consumername + if consumername: + pieces.append(consumername) + + return self.execute_command("XPENDING", *pieces, parse_detail=True) + + def xrange( + self, + name: KeyT, + min: StreamIdT = "-", + max: StreamIdT = "+", + count: Union[int, None] = None, + ) -> ResponseT: + """ + Read stream values within an interval. + + name: name of the stream. + + start: first stream ID. defaults to '-', + meaning the earliest available. + + finish: last stream ID. defaults to '+', + meaning the latest available. + + count: if set, only return this many items, beginning with the + earliest available. + + For more information see https://redis.io/commands/xrange + """ + pieces = [min, max] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XRANGE count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + + return self.execute_command("XRANGE", name, *pieces) + + def xread( + self, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + ) -> ResponseT: + """ + Block and monitor multiple streams for new data. + + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + + count: if set, only return this many items, beginning with the + earliest available. + + block: number of milliseconds to wait, if nothing already present. + + For more information see https://redis.io/commands/xread + """ + pieces = [] + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError("XREAD block must be a non-negative integer") + pieces.append(b"BLOCK") + pieces.append(str(block)) + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREAD count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError("XREAD streams must be a non empty dict") + pieces.append(b"STREAMS") + keys, values = zip(*streams.items()) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command("XREAD", *pieces) + + def xreadgroup( + self, + groupname: str, + consumername: str, + streams: Dict[KeyT, StreamIdT], + count: Union[int, None] = None, + block: Union[int, None] = None, + noack: bool = False, + ) -> ResponseT: + """ + Read from a stream via a consumer group. + + groupname: name of the consumer group. + + consumername: name of the requesting consumer. + + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + + count: if set, only return this many items, beginning with the + earliest available. + + block: number of milliseconds to wait, if nothing already present. + noack: do not add messages to the PEL + + For more information see https://redis.io/commands/xreadgroup + """ + pieces: list[EncodableT] = [b"GROUP", groupname, consumername] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError("XREADGROUP block must be a non-negative integer") + pieces.append(b"BLOCK") + pieces.append(str(block)) + if noack: + pieces.append(b"NOACK") + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError("XREADGROUP streams must be a non empty dict") + pieces.append(b"STREAMS") + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command("XREADGROUP", *pieces) + + def xrevrange( + self, + name: KeyT, + max: StreamIdT = "+", + min: StreamIdT = "-", + count: Union[int, None] = None, + ) -> ResponseT: + """ + Read stream values within an interval, in reverse order. + + name: name of the stream + + start: first stream ID. defaults to '+', + meaning the latest available. + + finish: last stream ID. defaults to '-', + meaning the earliest available. + + count: if set, only return this many items, beginning with the + latest available. + + For more information see https://redis.io/commands/xrevrange + """ + pieces: list[EncodableT] = [max, min] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREVRANGE count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + + return self.execute_command("XREVRANGE", name, *pieces) + + def xtrim( + self, + name: KeyT, + maxlen: Union[int, None] = None, + approximate: bool = True, + minid: Union[StreamIdT, None] = None, + limit: Union[int, None] = None, + ) -> ResponseT: + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + Can't be specified with minid. + approximate: actual stream length may be slightly more than maxlen + minid: the minimum id in the stream to query + Can't be specified with maxlen. + limit: specifies the maximum number of entries to retrieve + + For more information see https://redis.io/commands/xtrim + """ + pieces: list[EncodableT] = [] + if maxlen is not None and minid is not None: + raise DataError("Only one of ``maxlen`` or ``minid`` may be specified") + + if maxlen is None and minid is None: + raise DataError("One of ``maxlen`` or ``minid`` must be specified") + + if maxlen is not None: + pieces.append(b"MAXLEN") + if minid is not None: + pieces.append(b"MINID") + if approximate: + pieces.append(b"~") + if maxlen is not None: + pieces.append(maxlen) + if minid is not None: + pieces.append(minid) + if limit is not None: + pieces.append(b"LIMIT") + pieces.append(limit) + + return self.execute_command("XTRIM", name, *pieces) + + +AsyncStreamCommands = StreamCommands + + +class SortedSetCommands(CommandsProtocol): + """ + Redis commands for Sorted Sets data type. + see: https://redis.io/topics/data-types-intro#redis-sorted-sets + """ + + def zadd( + self, + name: KeyT, + mapping: Mapping[AnyKeyT, EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> ResponseT: + """ + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. + + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. + + ``LT`` Only update existing elements if the new score is less than + the current score. This flag doesn't prevent adding new elements. + + ``GT`` Only update existing elements if the new score is greater than + the current score. This flag doesn't prevent adding new elements. + + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. + + ``NX``, ``LT``, and ``GT`` are mutually exclusive options. + + See: https://redis.io/commands/ZADD + """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if gt and lt: + raise DataError("ZADD allows either 'gt' or 'lt', not both") + if incr and len(mapping) != 1: + raise DataError( + "ZADD option 'incr' only works when passing a " + "single element/score pair" + ) + if nx and (gt or lt): + raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") + + pieces: list[EncodableT] = [] + options = {} + if nx: + pieces.append(b"NX") + if xx: + pieces.append(b"XX") + if ch: + pieces.append(b"CH") + if incr: + pieces.append(b"INCR") + options["as_score"] = True + if gt: + pieces.append(b"GT") + if lt: + pieces.append(b"LT") + for pair in mapping.items(): + pieces.append(pair[1]) + pieces.append(pair[0]) + return self.execute_command("ZADD", name, *pieces, **options) + + def zcard(self, name: KeyT) -> ResponseT: + """ + Return the number of elements in the sorted set ``name`` + + For more information see https://redis.io/commands/zcard + """ + return self.execute_command("ZCARD", name) + + def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: + """ + Returns the number of elements in the sorted set at key ``name`` with + a score between ``min`` and ``max``. + + For more information see https://redis.io/commands/zcount + """ + return self.execute_command("ZCOUNT", name, min, max) + + def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: + """ + Returns the difference between the first and all successive input + sorted sets provided in ``keys``. + + For more information see https://redis.io/commands/zdiff + """ + pieces = [len(keys), *keys] + if withscores: + pieces.append("WITHSCORES") + return self.execute_command("ZDIFF", *pieces) + + def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: + """ + Computes the difference between the first and all successive input + sorted sets provided in ``keys`` and stores the result in ``dest``. + + For more information see https://redis.io/commands/zdiffstore + """ + pieces = [len(keys), *keys] + return self.execute_command("ZDIFFSTORE", dest, *pieces) + + def zincrby(self, name: KeyT, amount: float, value: EncodableT) -> ResponseT: + """ + Increment the score of ``value`` in sorted set ``name`` by ``amount`` + + For more information see https://redis.io/commands/zincrby + """ + return self.execute_command("ZINCRBY", name, amount, value) + + def zinter( + self, keys: KeysT, aggregate: Union[str, None] = None, withscores: bool = False + ) -> ResponseT: + """ + Return the intersect of multiple sorted sets specified by ``keys``. + With the ``aggregate`` option, it is possible to specify how the + results of the union are aggregated. This option defaults to SUM, + where the score of an element is summed across the inputs where it + exists. When this option is set to either MIN or MAX, the resulting + set will contain the minimum or maximum score of an element across + the inputs where it exists. + + For more information see https://redis.io/commands/zinter + """ + return self._zaggregate("ZINTER", None, keys, aggregate, withscores=withscores) + + def zinterstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: + """ + Intersect multiple sorted sets specified by ``keys`` into a new + sorted set, ``dest``. Scores in the destination will be aggregated + based on the ``aggregate``. This option defaults to SUM, where the + score of an element is summed across the inputs where it exists. + When this option is set to either MIN or MAX, the resulting set will + contain the minimum or maximum score of an element across the inputs + where it exists. + + For more information see https://redis.io/commands/zinterstore + """ + return self._zaggregate("ZINTERSTORE", dest, keys, aggregate) + + def zintercard( + self, numkeys: int, keys: List[str], limit: int = 0 + ) -> Union[Awaitable[int], int]: + """ + Return the cardinality of the intersect of multiple sorted sets + specified by ``keys`. + When LIMIT provided (defaults to 0 and means unlimited), if the intersection + cardinality reaches limit partway through the computation, the algorithm will + exit and yield limit as the cardinality + + For more information see https://redis.io/commands/zintercard + """ + args = [numkeys, *keys, "LIMIT", limit] + return self.execute_command("ZINTERCARD", *args) + + def zlexcount(self, name, min, max): + """ + Return the number of items in the sorted set ``name`` between the + lexicographical range ``min`` and ``max``. + + For more information see https://redis.io/commands/zlexcount + """ + return self.execute_command("ZLEXCOUNT", name, min, max) + + def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + + For more information see https://redis.io/commands/zpopmax + """ + args = (count is not None) and [count] or [] + options = {"withscores": True} + return self.execute_command("ZPOPMAX", name, *args, **options) + + def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + + For more information see https://redis.io/commands/zpopmin + """ + args = (count is not None) and [count] or [] + options = {"withscores": True} + return self.execute_command("ZPOPMIN", name, *args, **options) + + def zrandmember( + self, key: KeyT, count: int = None, withscores: bool = False + ) -> ResponseT: + """ + Return a random element from the sorted set value stored at key. + + ``count`` if the argument is positive, return an array of distinct + fields. If called with a negative count, the behavior changes and + the command is allowed to return the same field multiple times. + In this case, the number of returned fields is the absolute value + of the specified count. + + ``withscores`` The optional WITHSCORES modifier changes the reply so it + includes the respective scores of the randomly selected elements from + the sorted set. + + For more information see https://redis.io/commands/zrandmember + """ + params = [] + if count is not None: + params.append(count) + if withscores: + params.append("WITHSCORES") + + return self.execute_command("ZRANDMEMBER", key, *params) + + def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + + For more information see https://redis.io/commands/bzpopmax + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BZPOPMAX", *keys) + + def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> ResponseT: + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + + For more information see https://redis.io/commands/bzpopmin + """ + if timeout is None: + timeout = 0 + keys: list[EncodableT] = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BZPOPMIN", *keys) + + def zmpop( + self, + num_keys: int, + keys: List[str], + min: Optional[bool] = False, + max: Optional[bool] = False, + count: Optional[int] = 1, + ) -> Union[Awaitable[list], list]: + """ + Pop ``count`` values (default 1) off of the first non-empty sorted set + named in the ``keys`` list. + For more information see https://redis.io/commands/zmpop + """ + args = [num_keys] + keys + if (min and max) or (not min and not max): + raise DataError + elif min: + args.append("MIN") + else: + args.append("MAX") + if count != 1: + args.extend(["COUNT", count]) + + return self.execute_command("ZMPOP", *args) + + def bzmpop( + self, + timeout: float, + numkeys: int, + keys: List[str], + min: Optional[bool] = False, + max: Optional[bool] = False, + count: Optional[int] = 1, + ) -> Optional[list]: + """ + Pop ``count`` values (default 1) off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to pop, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + + For more information see https://redis.io/commands/bzmpop + """ + args = [timeout, numkeys, *keys] + if (min and max) or (not min and not max): + raise DataError("Either min or max, but not both must be set") + elif min: + args.append("MIN") + else: + args.append("MAX") + args.extend(["COUNT", count]) + + return self.execute_command("BZMPOP", *args) + + def _zrange( + self, + command, + dest: Union[KeyT, None], + name: KeyT, + start: int, + end: int, + desc: bool = False, + byscore: bool = False, + bylex: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable, None] = float, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: + if byscore and bylex: + raise DataError("``byscore`` and ``bylex`` can not be specified together.") + if (offset is not None and num is None) or (num is not None and offset is None): + raise DataError("``offset`` and ``num`` must both be specified.") + if bylex and withscores: + raise DataError( + "``withscores`` not supported in combination with ``bylex``." + ) + pieces = [command] + if dest: + pieces.append(dest) + pieces.extend([name, start, end]) + if byscore: + pieces.append("BYSCORE") + if bylex: + pieces.append("BYLEX") + if desc: + pieces.append("REV") + if offset is not None and num is not None: + pieces.extend(["LIMIT", offset, num]) + if withscores: + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrange( + self, + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ) -> ResponseT: + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``desc`` a boolean indicating whether to sort the results in reversed + order. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs. + + ``score_cast_func`` a callable used to cast the score return value. + + ``byscore`` when set to True, returns the range of elements from the + sorted set having scores equal or between ``start`` and ``end``. + + ``bylex`` when set to True, returns the range of elements from the + sorted set between the ``start`` and ``end`` lexicographical closed + range intervals. + Valid ``start`` and ``end`` must start with ( or [, in order to specify + whether the range interval is exclusive or inclusive, respectively. + + ``offset`` and ``num`` are specified, then return a slice of the range. + Can't be provided when using ``bylex``. + + For more information see https://redis.io/commands/zrange + """ + # Need to support ``desc`` also when using old redis version + # because it was supported in 3.5.3 (of redis-py) + if not byscore and not bylex and (offset is None and num is None) and desc: + return self.zrevrange(name, start, end, withscores, score_cast_func) + + return self._zrange( + "ZRANGE", + None, + name, + start, + end, + desc, + byscore, + bylex, + withscores, + score_cast_func, + offset, + num, + ) + + def zrevrange( + self, + name: KeyT, + start: int, + end: int, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in descending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``withscores`` indicates to return the scores along with the values + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + + For more information see https://redis.io/commands/zrevrange + """ + pieces = ["ZREVRANGE", name, start, end] + if withscores: + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrangestore( + self, + dest: KeyT, + name: KeyT, + start: int, + end: int, + byscore: bool = False, + bylex: bool = False, + desc: bool = False, + offset: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: + """ + Stores in ``dest`` the result of a range of values from sorted set + ``name`` between ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``byscore`` when set to True, returns the range of elements from the + sorted set having scores equal or between ``start`` and ``end``. + + ``bylex`` when set to True, returns the range of elements from the + sorted set between the ``start`` and ``end`` lexicographical closed + range intervals. + Valid ``start`` and ``end`` must start with ( or [, in order to specify + whether the range interval is exclusive or inclusive, respectively. + + ``desc`` a boolean indicating whether to sort the results in reversed + order. + + ``offset`` and ``num`` are specified, then return a slice of the range. + Can't be provided when using ``bylex``. + + For more information see https://redis.io/commands/zrangestore + """ + return self._zrange( + "ZRANGESTORE", + dest, + name, + start, + end, + desc, + byscore, + bylex, + False, + None, + offset, + num, + ) + + def zrangebylex( + self, + name: KeyT, + min: EncodableT, + max: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: + """ + Return the lexicographical range of values from sorted set ``name`` + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + + For more information see https://redis.io/commands/zrangebylex + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ["ZRANGEBYLEX", name, min, max] + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + return self.execute_command(*pieces) + + def zrevrangebylex( + self, + name: KeyT, + max: EncodableT, + min: EncodableT, + start: Union[int, None] = None, + num: Union[int, None] = None, + ) -> ResponseT: + """ + Return the reversed lexicographical range of values from sorted set + ``name`` between ``max`` and ``min``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + + For more information see https://redis.io/commands/zrevrangebylex + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ["ZREVRANGEBYLEX", name, max, min] + if start is not None and num is not None: + pieces.extend(["LIMIT", start, num]) + return self.execute_command(*pieces) + + def zrangebyscore( + self, + name: KeyT, + min: ZScoreBoundT, + max: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ) -> ResponseT: + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + `score_cast_func`` a callable used to cast the score return value + + For more information see https://redis.io/commands/zrangebyscore + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ["ZRANGEBYSCORE", name, min, max] + if start is not None and num is not None: + pieces.extend(["LIMIT", start, num]) + if withscores: + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrevrangebyscore( + self, + name: KeyT, + max: ZScoreBoundT, + min: ZScoreBoundT, + start: Union[int, None] = None, + num: Union[int, None] = None, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + ): + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max`` in descending order. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + + For more information see https://redis.io/commands/zrevrangebyscore + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces = ["ZREVRANGEBYSCORE", name, max, min] + if start is not None and num is not None: + pieces.extend(["LIMIT", start, num]) + if withscores: + pieces.append("WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: + """ + Returns a 0-based value indicating the rank of ``value`` in sorted set + ``name``. + The optional WITHSCORE argument supplements the command's + reply with the score of the element returned. + + For more information see https://redis.io/commands/zrank + """ + if withscore: + return self.execute_command("ZRANK", name, value, "WITHSCORE") + return self.execute_command("ZRANK", name, value) + + def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: + """ + Remove member ``values`` from sorted set ``name`` + + For more information see https://redis.io/commands/zrem + """ + return self.execute_command("ZREM", name, *values) + + def zremrangebylex(self, name: KeyT, min: EncodableT, max: EncodableT) -> ResponseT: + """ + Remove all elements in the sorted set ``name`` between the + lexicographical range specified by ``min`` and ``max``. + + Returns the number of elements removed. + + For more information see https://redis.io/commands/zremrangebylex + """ + return self.execute_command("ZREMRANGEBYLEX", name, min, max) + + def zremrangebyrank(self, name: KeyT, min: int, max: int) -> ResponseT: + """ + Remove all elements in the sorted set ``name`` with ranks between + ``min`` and ``max``. Values are 0-based, ordered from smallest score + to largest. Values can be negative indicating the highest scores. + Returns the number of elements removed + + For more information see https://redis.io/commands/zremrangebyrank + """ + return self.execute_command("ZREMRANGEBYRANK", name, min, max) + + def zremrangebyscore( + self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT + ) -> ResponseT: + """ + Remove all elements in the sorted set ``name`` with scores + between ``min`` and ``max``. Returns the number of elements removed. + + For more information see https://redis.io/commands/zremrangebyscore + """ + return self.execute_command("ZREMRANGEBYSCORE", name, min, max) + + def zrevrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: + """ + Returns a 0-based value indicating the descending rank of + ``value`` in sorted set ``name``. + The optional ``withscore`` argument supplements the command's + reply with the score of the element returned. + + For more information see https://redis.io/commands/zrevrank + """ + if withscore: + return self.execute_command("ZREVRANK", name, value, "WITHSCORE") + return self.execute_command("ZREVRANK", name, value) + + def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: + """ + Return the score of element ``value`` in sorted set ``name`` + + For more information see https://redis.io/commands/zscore + """ + return self.execute_command("ZSCORE", name, value) + + def zunion( + self, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + withscores: bool = False, + ) -> ResponseT: + """ + Return the union of multiple sorted sets specified by ``keys``. + ``keys`` can be provided as dictionary of keys and their weights. + Scores will be aggregated based on the ``aggregate``, or SUM if + none is provided. + + For more information see https://redis.io/commands/zunion + """ + return self._zaggregate("ZUNION", None, keys, aggregate, withscores=withscores) + + def zunionstore( + self, + dest: KeyT, + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + ) -> ResponseT: + """ + Union multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + + For more information see https://redis.io/commands/zunionstore + """ + return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) + + def zmscore(self, key: KeyT, members: List[str]) -> ResponseT: + """ + Returns the scores associated with the specified members + in the sorted set stored at key. + ``members`` should be a list of the member name. + Return type is a list of score. + If the member does not exist, a None will be returned + in corresponding position. + + For more information see https://redis.io/commands/zmscore + """ + if not members: + raise DataError("ZMSCORE members must be a non-empty list") + pieces = [key] + members + return self.execute_command("ZMSCORE", *pieces) + + def _zaggregate( + self, + command: str, + dest: Union[KeyT, None], + keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], + aggregate: Union[str, None] = None, + **options, + ) -> ResponseT: + pieces: list[EncodableT] = [command] + if dest is not None: + pieces.append(dest) + pieces.append(len(keys)) + if isinstance(keys, dict): + keys, weights = keys.keys(), keys.values() + else: + weights = None + pieces.extend(keys) + if weights: + pieces.append(b"WEIGHTS") + pieces.extend(weights) + if aggregate: + if aggregate.upper() in ["SUM", "MIN", "MAX"]: + pieces.append(b"AGGREGATE") + pieces.append(aggregate) + else: + raise DataError("aggregate can be sum, min or max.") + if options.get("withscores", False): + pieces.append(b"WITHSCORES") + return self.execute_command(*pieces, **options) + + +AsyncSortedSetCommands = SortedSetCommands + + +class HyperlogCommands(CommandsProtocol): + """ + Redis commands of HyperLogLogs data type. + see: https://redis.io/topics/data-types-intro#hyperloglogs + """ + + def pfadd(self, name: KeyT, *values: FieldT) -> ResponseT: + """ + Adds the specified elements to the specified HyperLogLog. + + For more information see https://redis.io/commands/pfadd + """ + return self.execute_command("PFADD", name, *values) + + def pfcount(self, *sources: KeyT) -> ResponseT: + """ + Return the approximated cardinality of + the set observed by the HyperLogLog at key(s). + + For more information see https://redis.io/commands/pfcount + """ + return self.execute_command("PFCOUNT", *sources) + + def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: + """ + Merge N different HyperLogLogs into a single one. + + For more information see https://redis.io/commands/pfmerge + """ + return self.execute_command("PFMERGE", dest, *sources) + + +AsyncHyperlogCommands = HyperlogCommands + + +class HashCommands(CommandsProtocol): + """ + Redis commands for Hash data type. + see: https://redis.io/topics/data-types-intro#redis-hashes + """ + + def hdel(self, name: str, *keys: List) -> Union[Awaitable[int], int]: + """ + Delete ``keys`` from hash ``name`` + + For more information see https://redis.io/commands/hdel + """ + return self.execute_command("HDEL", name, *keys) + + def hexists(self, name: str, key: str) -> Union[Awaitable[bool], bool]: + """ + Returns a boolean indicating if ``key`` exists within hash ``name`` + + For more information see https://redis.io/commands/hexists + """ + return self.execute_command("HEXISTS", name, key) + + def hget( + self, name: str, key: str + ) -> Union[Awaitable[Optional[str]], Optional[str]]: + """ + Return the value of ``key`` within the hash ``name`` + + For more information see https://redis.io/commands/hget + """ + return self.execute_command("HGET", name, key) + + def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: + """ + Return a Python dict of the hash's name/value pairs + + For more information see https://redis.io/commands/hgetall + """ + return self.execute_command("HGETALL", name) + + def hincrby( + self, name: str, key: str, amount: int = 1 + ) -> Union[Awaitable[int], int]: + """ + Increment the value of ``key`` in hash ``name`` by ``amount`` + + For more information see https://redis.io/commands/hincrby + """ + return self.execute_command("HINCRBY", name, key, amount) + + def hincrbyfloat( + self, name: str, key: str, amount: float = 1.0 + ) -> Union[Awaitable[float], float]: + """ + Increment the value of ``key`` in hash ``name`` by floating ``amount`` + + For more information see https://redis.io/commands/hincrbyfloat + """ + return self.execute_command("HINCRBYFLOAT", name, key, amount) + + def hkeys(self, name: str) -> Union[Awaitable[List], List]: + """ + Return the list of keys within hash ``name`` + + For more information see https://redis.io/commands/hkeys + """ + return self.execute_command("HKEYS", name) + + def hlen(self, name: str) -> Union[Awaitable[int], int]: + """ + Return the number of elements in hash ``name`` + + For more information see https://redis.io/commands/hlen + """ + return self.execute_command("HLEN", name) + + def hset( + self, + name: str, + key: Optional[str] = None, + value: Optional[str] = None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ) -> Union[Awaitable[int], int]: + """ + Set ``key`` to ``value`` within hash ``name``, + ``mapping`` accepts a dict of key/value pairs that will be + added to hash ``name``. + ``items`` accepts a list of key/value pairs that will be + added to hash ``name``. + Returns the number of fields that were added. + + For more information see https://redis.io/commands/hset + """ + if key is None and not mapping and not items: + raise DataError("'hset' with no key value pairs") + items = items or [] + if key is not None: + items.extend((key, value)) + if mapping: + for pair in mapping.items(): + items.extend(pair) + + return self.execute_command("HSET", name, *items) + + def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: + """ + Set ``key`` to ``value`` within hash ``name`` if ``key`` does not + exist. Returns 1 if HSETNX created a field, otherwise 0. + + For more information see https://redis.io/commands/hsetnx + """ + return self.execute_command("HSETNX", name, key, value) + + def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: + """ + Set key to value within hash ``name`` for each corresponding + key and value from the ``mapping`` dict. + + For more information see https://redis.io/commands/hmset + """ + warnings.warn( + f"{self.__class__.__name__}.hmset() is deprecated. " + f"Use {self.__class__.__name__}.hset() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not mapping: + raise DataError("'hmset' with 'mapping' of length 0") + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("HMSET", name, *items) + + def hmget(self, name: str, keys: List, *args: List) -> Union[Awaitable[List], List]: + """ + Returns a list of values ordered identically to ``keys`` + + For more information see https://redis.io/commands/hmget + """ + args = list_or_args(keys, args) + return self.execute_command("HMGET", name, *args) + + def hvals(self, name: str) -> Union[Awaitable[List], List]: + """ + Return the list of values within hash ``name`` + + For more information see https://redis.io/commands/hvals + """ + return self.execute_command("HVALS", name) + + def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + + For more information see https://redis.io/commands/hstrlen + """ + return self.execute_command("HSTRLEN", name, key) + + +AsyncHashCommands = HashCommands + + +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "Redis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + try: + encoder = registered_client.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = registered_client.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["Redis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + +class AsyncScript: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + try: + encoder = registered_client.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = registered_client.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Union[Sequence[KeyT], None] = None, + args: Union[Iterable[EncodableT], None] = None, + client: Union["AsyncRedis", None] = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.asyncio.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + +class PubSubCommands(CommandsProtocol): + """ + Redis PubSub commands. + see https://redis.io/topics/pubsub + """ + + def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT: + """ + Publish ``message`` on ``channel``. + Returns the number of subscribers the message was delivered to. + + For more information see https://redis.io/commands/publish + """ + return self.execute_command("PUBLISH", channel, message, **kwargs) + + def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + """ + Posts a message to the given shard channel. + Returns the number of clients that received the message + + For more information see https://redis.io/commands/spublish + """ + return self.execute_command("SPUBLISH", shard_channel, message) + + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-channels + """ + return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) + + def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of shard_channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-shardchannels + """ + return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) + + def pubsub_numpat(self, **kwargs) -> ResponseT: + """ + Returns the number of subscriptions to patterns + + For more information see https://redis.io/commands/pubsub-numpat + """ + return self.execute_command("PUBSUB NUMPAT", **kwargs) + + def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-numsub + """ + return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) + + def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (shard_channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-shardnumsub + """ + return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs) + + +AsyncPubSubCommands = PubSubCommands + + +class ScriptCommands(CommandsProtocol): + """ + Redis Lua script commands. see: + https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ + """ + + def _eval( + self, command: str, script: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + return self.execute_command(command, script, numkeys, *keys_and_args) + + def eval( + self, script: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + """ + Execute the Lua ``script``, specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + + For more information see https://redis.io/commands/eval + """ + return self._eval("EVAL", script, numkeys, *keys_and_args) + + def eval_ro( + self, script: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + """ + The read-only variant of the EVAL command + + Execute the read-only Lua ``script`` specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + For more information see https://redis.io/commands/eval_ro + """ + return self._eval("EVAL_RO", script, numkeys, *keys_and_args) + + def _evalsha( + self, command: str, sha: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + return self.execute_command(command, sha, numkeys, *keys_and_args) + + def evalsha( + self, sha: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + """ + Use the ``sha`` to execute a Lua script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + + For more information see https://redis.io/commands/evalsha + """ + return self._evalsha("EVALSHA", sha, numkeys, *keys_and_args) + + def evalsha_ro( + self, sha: str, numkeys: int, *keys_and_args: list + ) -> Union[Awaitable[str], str]: + """ + The read-only variant of the EVALSHA command + + Use the ``sha`` to execute a read-only Lua script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + For more information see https://redis.io/commands/evalsha_ro + """ + return self._evalsha("EVALSHA_RO", sha, numkeys, *keys_and_args) + + def script_exists(self, *args: str) -> ResponseT: + """ + Check if a script exists in the script cache by specifying the SHAs of + each script as ``args``. Returns a list of boolean values indicating if + if each already script exists in the cache. + + For more information see https://redis.io/commands/script-exists + """ + return self.execute_command("SCRIPT EXISTS", *args) + + def script_debug(self, *args) -> None: + raise NotImplementedError( + "SCRIPT DEBUG is intentionally not implemented in the client." + ) + + def script_flush( + self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None + ) -> ResponseT: + """Flush all scripts from the script cache. + + ``sync_type`` is by default SYNC (synchronous) but it can also be + ASYNC. + + For more information see https://redis.io/commands/script-flush + """ + + # Redis pre 6 had no sync_type. + if sync_type not in ["SYNC", "ASYNC", None]: + raise DataError( + "SCRIPT FLUSH defaults to SYNC in redis > 6.2, or " + "accepts SYNC/ASYNC. For older versions, " + "of redis leave as None." + ) + if sync_type is None: + pieces = [] + else: + pieces = [sync_type] + return self.execute_command("SCRIPT FLUSH", *pieces) + + def script_kill(self) -> ResponseT: + """ + Kill the currently executing Lua script + + For more information see https://redis.io/commands/script-kill + """ + return self.execute_command("SCRIPT KILL") + + def script_load(self, script: ScriptTextT) -> ResponseT: + """ + Load a Lua ``script`` into the script cache. Returns the SHA. + + For more information see https://redis.io/commands/script-load + """ + return self.execute_command("SCRIPT LOAD", script) + + def register_script(self: "Redis", script: ScriptTextT) -> Script: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return Script(self, script) + + +class AsyncScriptCommands(ScriptCommands): + async def script_debug(self, *args) -> None: + return super().script_debug() + + def register_script(self: "AsyncRedis", script: ScriptTextT) -> AsyncScript: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return AsyncScript(self, script) + + +class GeoCommands(CommandsProtocol): + """ + Redis Geospatial commands. + see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ + """ + + def geoadd( + self, + name: KeyT, + values: Sequence[EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + ) -> ResponseT: + """ + Add the specified geospatial items to the specified key identified + by the ``name`` argument. The Geospatial items are given as ordered + members of the ``values`` argument, each item or place is formed by + the triad longitude, latitude and name. + + Note: You can use ZREM to remove elements. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. + + For more information see https://redis.io/commands/geoadd + """ + if nx and xx: + raise DataError("GEOADD allows either 'nx' or 'xx', not both") + if len(values) % 3 != 0: + raise DataError("GEOADD requires places with lon, lat and name values") + pieces = [name] + if nx: + pieces.append("NX") + if xx: + pieces.append("XX") + if ch: + pieces.append("CH") + pieces.extend(values) + return self.execute_command("GEOADD", *pieces) + + def geodist( + self, name: KeyT, place1: FieldT, place2: FieldT, unit: Union[str, None] = None + ) -> ResponseT: + """ + Return the distance between ``place1`` and ``place2`` members of the + ``name`` key. + The units must be one of the following : m, km mi, ft. By default + meters are used. + + For more information see https://redis.io/commands/geodist + """ + pieces: list[EncodableT] = [name, place1, place2] + if unit and unit not in ("m", "km", "mi", "ft"): + raise DataError("GEODIST invalid unit") + elif unit: + pieces.append(unit) + return self.execute_command("GEODIST", *pieces) + + def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: + """ + Return the geo hash string for each item of ``values`` members of + the specified key identified by the ``name`` argument. + + For more information see https://redis.io/commands/geohash + """ + return self.execute_command("GEOHASH", name, *values) + + def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: + """ + Return the positions of each item of ``values`` as members of + the specified key identified by the ``name`` argument. Each position + is represented by the pairs lon and lat. + + For more information see https://redis.io/commands/geopos + """ + return self.execute_command("GEOPOS", name, *values) + + def georadius( + self, + name: KeyT, + longitude: float, + latitude: float, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: + """ + Return the members of the specified key identified by the + ``name`` argument which are within the borders of the area specified + with the ``latitude`` and ``longitude`` location and the maximum + distance from the center specified by the ``radius`` value. + + The units must be one of the following : m, km mi, ft. By default + + ``withdist`` indicates to return the distances of each place. + + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + ``count`` indicates to return the number of elements up to N. + + ``sort`` indicates to return the places in a sorted way, ASC for + nearest to fairest and DESC for fairest to nearest. + + ``store`` indicates to save the places names in a sorted set named + with a specific key, each element of the destination sorted set is + populated with the score got from the original geo sorted set. + + ``store_dist`` indicates to save the places names in a sorted set + named with a specific key, instead of ``store`` the sorted set + destination score is set with the distance. + + For more information see https://redis.io/commands/georadius + """ + return self._georadiusgeneric( + "GEORADIUS", + name, + longitude, + latitude, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + any=any, + ) + + def georadiusbymember( + self, + name: KeyT, + member: FieldT, + radius: float, + unit: Union[str, None] = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, + any: bool = False, + ) -> ResponseT: + """ + This command is exactly like ``georadius`` with the sole difference + that instead of taking, as the center of the area to query, a longitude + and latitude value, it takes the name of a member already existing + inside the geospatial index represented by the sorted set. + + For more information see https://redis.io/commands/georadiusbymember + """ + return self._georadiusgeneric( + "GEORADIUSBYMEMBER", + name, + member, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + any=any, + ) + + def _georadiusgeneric( + self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] + ) -> ResponseT: + pieces = list(args) + if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): + raise DataError("GEORADIUS invalid unit") + elif kwargs["unit"]: + pieces.append(kwargs["unit"]) + else: + pieces.append("m") + + if kwargs["any"] and kwargs["count"] is None: + raise DataError("``any`` can't be provided without ``count``") + + for arg_name, byte_repr in ( + ("withdist", "WITHDIST"), + ("withcoord", "WITHCOORD"), + ("withhash", "WITHHASH"), + ): + if kwargs[arg_name]: + pieces.append(byte_repr) + + if kwargs["count"] is not None: + pieces.extend(["COUNT", kwargs["count"]]) + if kwargs["any"]: + pieces.append("ANY") + + if kwargs["sort"]: + if kwargs["sort"] == "ASC": + pieces.append("ASC") + elif kwargs["sort"] == "DESC": + pieces.append("DESC") + else: + raise DataError("GEORADIUS invalid sort") + + if kwargs["store"] and kwargs["store_dist"]: + raise DataError("GEORADIUS store and store_dist cant be set together") + + if kwargs["store"]: + pieces.extend([b"STORE", kwargs["store"]]) + + if kwargs["store_dist"]: + pieces.extend([b"STOREDIST", kwargs["store_dist"]]) + + return self.execute_command(command, *pieces, **kwargs) + + def geosearch( + self, + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + withcoord: bool = False, + withdist: bool = False, + withhash: bool = False, + ) -> ResponseT: + """ + Return the members of specified key identified by the + ``name`` argument, which are within the borders of the + area specified by a given shape. This command extends the + GEORADIUS command, so in addition to searching within circular + areas, it supports searching within rectangular areas. + + This command should be used in place of the deprecated + GEORADIUS and GEORADIUSBYMEMBER commands. + + ``member`` Use the position of the given existing + member in the sorted set. Can't be given with ``longitude`` + and ``latitude``. + + ``longitude`` and ``latitude`` Use the position given by + this coordinates. Can't be given with ``member`` + ``radius`` Similar to GEORADIUS, search inside circular + area according the given radius. Can't be given with + ``height`` and ``width``. + ``height`` and ``width`` Search inside an axis-aligned + rectangle, determined by the given height and width. + Can't be given with ``radius`` + + ``unit`` must be one of the following : m, km, mi, ft. + `m` for meters (the default value), `km` for kilometers, + `mi` for miles and `ft` for feet. + + ``sort`` indicates to return the places in a sorted way, + ASC for nearest to furthest and DESC for furthest to nearest. + + ``count`` limit the results to the first count matching items. + + ``any`` is set to True, the command will return as soon as + enough matches are found. Can't be provided without ``count`` + + ``withdist`` indicates to return the distances of each place. + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + For more information see https://redis.io/commands/geosearch + """ + + return self._geosearchgeneric( + "GEOSEARCH", + name, + member=member, + longitude=longitude, + latitude=latitude, + unit=unit, + radius=radius, + width=width, + height=height, + sort=sort, + count=count, + any=any, + withcoord=withcoord, + withdist=withdist, + withhash=withhash, + store=None, + store_dist=None, + ) + + def geosearchstore( + self, + dest: KeyT, + name: KeyT, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, + unit: str = "m", + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, + any: bool = False, + storedist: bool = False, + ) -> ResponseT: + """ + This command is like GEOSEARCH, but stores the result in + ``dest``. By default, it stores the results in the destination + sorted set with their geospatial information. + if ``store_dist`` set to True, the command will stores the + items in a sorted set populated with their distance from the + center of the circle or box, as a floating-point number. + + For more information see https://redis.io/commands/geosearchstore + """ + return self._geosearchgeneric( + "GEOSEARCHSTORE", + dest, + name, + member=member, + longitude=longitude, + latitude=latitude, + unit=unit, + radius=radius, + width=width, + height=height, + sort=sort, + count=count, + any=any, + withcoord=None, + withdist=None, + withhash=None, + store=None, + store_dist=storedist, + ) + + def _geosearchgeneric( + self, command: str, *args: EncodableT, **kwargs: Union[EncodableT, None] + ) -> ResponseT: + pieces = list(args) + + # FROMMEMBER or FROMLONLAT + if kwargs["member"] is None: + if kwargs["longitude"] is None or kwargs["latitude"] is None: + raise DataError("GEOSEARCH must have member or longitude and latitude") + if kwargs["member"]: + if kwargs["longitude"] or kwargs["latitude"]: + raise DataError( + "GEOSEARCH member and longitude or latitude cant be set together" + ) + pieces.extend([b"FROMMEMBER", kwargs["member"]]) + if kwargs["longitude"] is not None and kwargs["latitude"] is not None: + pieces.extend([b"FROMLONLAT", kwargs["longitude"], kwargs["latitude"]]) + + # BYRADIUS or BYBOX + if kwargs["radius"] is None: + if kwargs["width"] is None or kwargs["height"] is None: + raise DataError("GEOSEARCH must have radius or width and height") + if kwargs["unit"] is None: + raise DataError("GEOSEARCH must have unit") + if kwargs["unit"].lower() not in ("m", "km", "mi", "ft"): + raise DataError("GEOSEARCH invalid unit") + if kwargs["radius"]: + if kwargs["width"] or kwargs["height"]: + raise DataError( + "GEOSEARCH radius and width or height cant be set together" + ) + pieces.extend([b"BYRADIUS", kwargs["radius"], kwargs["unit"]]) + if kwargs["width"] and kwargs["height"]: + pieces.extend([b"BYBOX", kwargs["width"], kwargs["height"], kwargs["unit"]]) + + # sort + if kwargs["sort"]: + if kwargs["sort"].upper() == "ASC": + pieces.append(b"ASC") + elif kwargs["sort"].upper() == "DESC": + pieces.append(b"DESC") + else: + raise DataError("GEOSEARCH invalid sort") + + # count any + if kwargs["count"]: + pieces.extend([b"COUNT", kwargs["count"]]) + if kwargs["any"]: + pieces.append(b"ANY") + elif kwargs["any"]: + raise DataError("GEOSEARCH ``any`` can't be provided without count") + + # other properties + for arg_name, byte_repr in ( + ("withdist", b"WITHDIST"), + ("withcoord", b"WITHCOORD"), + ("withhash", b"WITHHASH"), + ("store_dist", b"STOREDIST"), + ): + if kwargs[arg_name]: + pieces.append(byte_repr) + + return self.execute_command(command, *pieces, **kwargs) + + +AsyncGeoCommands = GeoCommands + + +class ModuleCommands(CommandsProtocol): + """ + Redis Module commands. + see: https://redis.io/topics/modules-intro + """ + + def module_load(self, path, *args) -> ResponseT: + """ + Loads the module from ``path``. + Passes all ``*args`` to the module, during loading. + Raises ``ModuleError`` if a module is not found at ``path``. + + For more information see https://redis.io/commands/module-load + """ + return self.execute_command("MODULE LOAD", path, *args) + + def module_loadex( + self, + path: str, + options: Optional[List[str]] = None, + args: Optional[List[str]] = None, + ) -> ResponseT: + """ + Loads a module from a dynamic library at runtime with configuration directives. + + For more information see https://redis.io/commands/module-loadex + """ + pieces = [] + if options is not None: + pieces.append("CONFIG") + pieces.extend(options) + if args is not None: + pieces.append("ARGS") + pieces.extend(args) + + return self.execute_command("MODULE LOADEX", path, *pieces) + + def module_unload(self, name) -> ResponseT: + """ + Unloads the module ``name``. + Raises ``ModuleError`` if ``name`` is not in loaded modules. + + For more information see https://redis.io/commands/module-unload + """ + return self.execute_command("MODULE UNLOAD", name) + + def module_list(self) -> ResponseT: + """ + Returns a list of dictionaries containing the name and version of + all loaded modules. + + For more information see https://redis.io/commands/module-list + """ + return self.execute_command("MODULE LIST") + + def command_info(self) -> None: + raise NotImplementedError( + "COMMAND INFO is intentionally not implemented in the client." + ) + + def command_count(self) -> ResponseT: + return self.execute_command("COMMAND COUNT") + + def command_getkeys(self, *args) -> ResponseT: + return self.execute_command("COMMAND GETKEYS", *args) + + def command(self) -> ResponseT: + return self.execute_command("COMMAND") + + +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client, script): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = self.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__(self, keys=[], args=[], client=None): + "Execute the script, passing any required ``args``" + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + def get_encoder(self): + """Get the encoder to encode string scripts into bytes.""" + try: + return self.registered_client.get_encoder() + except AttributeError: + # DEPRECATED + # In version <=4.1.2, this was the code we used to get the encoder. + # However, after 4.1.2 we added support for scripting in clustered + # redis. ClusteredRedis doesn't have a `.connection_pool` attribute + # so we changed the Script class to use + # `self.registered_client.get_encoder` (see above). + # However, that is technically a breaking change, as consumers who + # use Scripts directly might inject a `registered_client` that + # doesn't have a `.get_encoder` field. This try/except prevents us + # from breaking backward-compatibility. Ideally, it would be + # removed in the next major release. + return self.registered_client.connection_pool.get_encoder() + + +class AsyncModuleCommands(ModuleCommands): + async def command_info(self) -> None: + return super().command_info() + + +class ClusterCommands(CommandsProtocol): + """ + Class for Redis Cluster commands + """ + + def cluster(self, cluster_arg, *args, **kwargs) -> ResponseT: + return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args, **kwargs) + + def readwrite(self, **kwargs) -> ResponseT: + """ + Disables read queries for a connection to a Redis Cluster slave node. + + For more information see https://redis.io/commands/readwrite + """ + return self.execute_command("READWRITE", **kwargs) + + def readonly(self, **kwargs) -> ResponseT: + """ + Enables read queries for a connection to a Redis Cluster replica node. + + For more information see https://redis.io/commands/readonly + """ + return self.execute_command("READONLY", **kwargs) + + +AsyncClusterCommands = ClusterCommands + + +class FunctionCommands: + """ + Redis Function commands + """ + + def function_load( + self, code: str, replace: Optional[bool] = False + ) -> Union[Awaitable[str], str]: + """ + Load a library to Redis. + :param code: the source code (must start with + Shebang statement that provides a metadata about the library) + :param replace: changes the behavior to overwrite the existing library + with the new contents. + Return the library name that was loaded. + + For more information see https://redis.io/commands/function-load + """ + pieces = ["REPLACE"] if replace else [] + pieces.append(code) + return self.execute_command("FUNCTION LOAD", *pieces) + + def function_delete(self, library: str) -> Union[Awaitable[str], str]: + """ + Delete the library called ``library`` and all its functions. + + For more information see https://redis.io/commands/function-delete + """ + return self.execute_command("FUNCTION DELETE", library) + + def function_flush(self, mode: str = "SYNC") -> Union[Awaitable[str], str]: + """ + Deletes all the libraries. + + For more information see https://redis.io/commands/function-flush + """ + return self.execute_command("FUNCTION FLUSH", mode) + + def function_list( + self, library: Optional[str] = "*", withcode: Optional[bool] = False + ) -> Union[Awaitable[List], List]: + """ + Return information about the functions and libraries. + :param library: pecify a pattern for matching library names + :param withcode: cause the server to include the libraries source + implementation in the reply + """ + args = ["LIBRARYNAME", library] + if withcode: + args.append("WITHCODE") + return self.execute_command("FUNCTION LIST", *args) + + def _fcall( + self, command: str, function, numkeys: int, *keys_and_args: Optional[List] + ) -> Union[Awaitable[str], str]: + return self.execute_command(command, function, numkeys, *keys_and_args) + + def fcall( + self, function, numkeys: int, *keys_and_args: Optional[List] + ) -> Union[Awaitable[str], str]: + """ + Invoke a function. + + For more information see https://redis.io/commands/fcall + """ + return self._fcall("FCALL", function, numkeys, *keys_and_args) + + def fcall_ro( + self, function, numkeys: int, *keys_and_args: Optional[List] + ) -> Union[Awaitable[str], str]: + """ + This is a read-only variant of the FCALL command that cannot + execute commands that modify data. + + For more information see https://redis.io/commands/fcal_ro + """ + return self._fcall("FCALL_RO", function, numkeys, *keys_and_args) + + def function_dump(self) -> Union[Awaitable[str], str]: + """ + Return the serialized payload of loaded libraries. + + For more information see https://redis.io/commands/function-dump + """ + from redis.client import NEVER_DECODE + + options = {} + options[NEVER_DECODE] = [] + + return self.execute_command("FUNCTION DUMP", **options) + + def function_restore( + self, payload: str, policy: Optional[str] = "APPEND" + ) -> Union[Awaitable[str], str]: + """ + Restore libraries from the serialized ``payload``. + You can use the optional policy argument to provide a policy + for handling existing libraries. + + For more information see https://redis.io/commands/function-restore + """ + return self.execute_command("FUNCTION RESTORE", payload, policy) + + def function_kill(self) -> Union[Awaitable[str], str]: + """ + Kill a function that is currently executing. + + For more information see https://redis.io/commands/function-kill + """ + return self.execute_command("FUNCTION KILL") + + def function_stats(self) -> Union[Awaitable[List], List]: + """ + Return information about the function that's currently running + and information about the available execution engines. + + For more information see https://redis.io/commands/function-stats + """ + return self.execute_command("FUNCTION STATS") + + +AsyncFunctionCommands = FunctionCommands + + +class GearsCommands: + def tfunction_load( + self, lib_code: str, replace: bool = False, config: Union[str, None] = None + ) -> ResponseT: + """ + Load a new library to RedisGears. + + ``lib_code`` - the library code. + ``config`` - a string representation of a JSON object + that will be provided to the library on load time, + for more information refer to + https://github.com/RedisGears/RedisGears/blob/master/docs/function_advance_topics.md#library-configuration + ``replace`` - an optional argument, instructs RedisGears to replace the + function if its already exists + + For more information see https://redis.io/commands/tfunction-load/ + """ + pieces = [] + if replace: + pieces.append("REPLACE") + if config is not None: + pieces.extend(["CONFIG", config]) + pieces.append(lib_code) + return self.execute_command("TFUNCTION LOAD", *pieces) + + def tfunction_delete(self, lib_name: str) -> ResponseT: + """ + Delete a library from RedisGears. + + ``lib_name`` the library name to delete. + + For more information see https://redis.io/commands/tfunction-delete/ + """ + return self.execute_command("TFUNCTION DELETE", lib_name) + + def tfunction_list( + self, + with_code: bool = False, + verbose: int = 0, + lib_name: Union[str, None] = None, + ) -> ResponseT: + """ + List the functions with additional information about each function. + + ``with_code`` Show libraries code. + ``verbose`` output verbosity level, higher number will increase verbosity level + ``lib_name`` specifying a library name (can be used multiple times to show multiple libraries in a single command) # noqa + + For more information see https://redis.io/commands/tfunction-list/ + """ + pieces = [] + if with_code: + pieces.append("WITHCODE") + if verbose >= 1 and verbose <= 3: + pieces.append("v" * verbose) + else: + raise DataError("verbose can be 1, 2 or 3") + if lib_name is not None: + pieces.append("LIBRARY") + pieces.append(lib_name) + + return self.execute_command("TFUNCTION LIST", *pieces) + + def _tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + _async: bool = False, + *args: List, + ) -> ResponseT: + pieces = [f"{lib_name}.{func_name}"] + if keys is not None: + pieces.append(len(keys)) + pieces.extend(keys) + else: + pieces.append(0) + if args is not None: + pieces.extend(args) + if _async: + return self.execute_command("TFCALLASYNC", *pieces) + return self.execute_command("TFCALL", *pieces) + + def tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke a function. + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, False, *args) + + def tfcall_async( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke an async function (coroutine). + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, True, *args) + + +AsyncGearsCommands = GearsCommands + + +class DataAccessCommands( + BasicKeyCommands, + HyperlogCommands, + HashCommands, + GeoCommands, + ListCommands, + ScanCommands, + SetCommands, + StreamCommands, + SortedSetCommands, +): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncDataAccessCommands( + AsyncBasicKeyCommands, + AsyncHyperlogCommands, + AsyncHashCommands, + AsyncGeoCommands, + AsyncListCommands, + AsyncScanCommands, + AsyncSetCommands, + AsyncStreamCommands, + AsyncSortedSetCommands, +): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin for asynchronous Redis clients. + """ + + +class CoreCommands( + ACLCommands, + ClusterCommands, + DataAccessCommands, + ManagementCommands, + ModuleCommands, + PubSubCommands, + ScriptCommands, + FunctionCommands, + GearsCommands, +): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin for synchronous Redis clients. + """ + + +class AsyncCoreCommands( + AsyncACLCommands, + AsyncClusterCommands, + AsyncDataAccessCommands, + AsyncManagementCommands, + AsyncModuleCommands, + AsyncPubSubCommands, + AsyncScriptCommands, + AsyncFunctionCommands, + AsyncGearsCommands, +): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin for asynchronous Redis clients. + """ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__init__.py b/.venv/Lib/site-packages/redis/commands/graph/__init__.py new file mode 100644 index 00000000..ffaf1fb4 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/__init__.py @@ -0,0 +1,263 @@ +import warnings + +from ..helpers import quote_string, random_string, stringify_param_value +from .commands import AsyncGraphCommands, GraphCommands +from .edge import Edge # noqa +from .node import Node # noqa +from .path import Path # noqa + +DB_LABELS = "DB.LABELS" +DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" +DB_PROPERTYKEYS = "DB.PROPERTYKEYS" + + +class Graph(GraphCommands): + """ + Graph, collection of nodes and edges. + """ + + def __init__(self, client, name=random_string()): + """ + Create a new graph. + """ + warnings.warn( + DeprecationWarning( + "RedisGraph support is deprecated as of Redis Stack 7.2 \ + (https://redis.com/blog/redisgraph-eol/)" + ) + ) + self.NAME = name # Graph key + self.client = client + self.execute_command = client.execute_command + + self.nodes = {} + self.edges = [] + self._labels = [] # List of node labels. + self._properties = [] # List of properties. + self._relationship_types = [] # List of relation types. + self.version = 0 # Graph version + + @property + def name(self): + return self.NAME + + def _clear_schema(self): + self._labels = [] + self._properties = [] + self._relationship_types = [] + + def _refresh_schema(self): + self._clear_schema() + self._refresh_labels() + self._refresh_relations() + self._refresh_attributes() + + def _refresh_labels(self): + lbls = self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + def _refresh_relations(self): + rels = self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + def _refresh_attributes(self): + props = self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + self._refresh_labels() + label = self._labels[idx] + return label + + def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + self._refresh_attributes() + p = self._properties[idx] + return p + + def add_node(self, node): + """ + Adds a node to the graph. + """ + if node.alias is None: + node.alias = random_string() + self.nodes[node.alias] = node + + def add_edge(self, edge): + """ + Adds an edge to the graph. + """ + if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]): + raise AssertionError("Both edge's end must be in the graph") + + self.edges.append(edge) + + def _build_params_header(self, params): + if params is None: + return "" + if not isinstance(params, dict): + raise TypeError("'params' must be a dict") + # Header starts with "CYPHER" + params_header = "CYPHER " + for key, value in params.items(): + params_header += str(key) + "=" + stringify_param_value(value) + " " + return params_header + + # Procedures. + def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + q += f"YIELD {','.join(y)}" + + return self.query(q, read_only=read_only) + + def labels(self): + return self.call_procedure(DB_LABELS, read_only=True).result_set + + def relationship_types(self): + return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set + + def property_keys(self): + return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set + + +class AsyncGraph(Graph, AsyncGraphCommands): + """Async version for Graph""" + + async def _refresh_labels(self): + lbls = await self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + async def _refresh_attributes(self): + props = await self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + async def _refresh_relations(self): + rels = await self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + async def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + await self._refresh_labels() + label = self._labels[idx] + return label + + async def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + await self._refresh_attributes() + p = self._properties[idx] + return p + + async def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + await self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + f"YIELD {','.join(y)}" + return await self.query(q, read_only=read_only) + + async def labels(self): + return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set + + async def property_keys(self): + return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set + + async def relationship_types(self): + return ( + await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) + ).result_set diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..d83a08f0 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..adee1ba2 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/edge.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/edge.cpython-311.pyc new file mode 100644 index 00000000..c5685cb2 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/edge.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/exceptions.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 00000000..a4068807 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/execution_plan.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/execution_plan.cpython-311.pyc new file mode 100644 index 00000000..3e3b4fff Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/execution_plan.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/node.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/node.cpython-311.pyc new file mode 100644 index 00000000..71463bf4 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/node.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/path.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/path.cpython-311.pyc new file mode 100644 index 00000000..63e27175 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/path.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/__pycache__/query_result.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/query_result.cpython-311.pyc new file mode 100644 index 00000000..ba5a93a5 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/graph/__pycache__/query_result.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/graph/commands.py b/.venv/Lib/site-packages/redis/commands/graph/commands.py new file mode 100644 index 00000000..762ab42e --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/commands.py @@ -0,0 +1,313 @@ +from redis import DataError +from redis.exceptions import ResponseError + +from .exceptions import VersionMismatchException +from .execution_plan import ExecutionPlan +from .query_result import AsyncQueryResult, QueryResult + +PROFILE_CMD = "GRAPH.PROFILE" +RO_QUERY_CMD = "GRAPH.RO_QUERY" +QUERY_CMD = "GRAPH.QUERY" +DELETE_CMD = "GRAPH.DELETE" +SLOWLOG_CMD = "GRAPH.SLOWLOG" +CONFIG_CMD = "GRAPH.CONFIG" +LIST_CMD = "GRAPH.LIST" +EXPLAIN_CMD = "GRAPH.EXPLAIN" + + +class GraphCommands: + """RedisGraph Commands""" + + def commit(self): + """ + Create entire graph. + """ + if len(self.nodes) == 0 and len(self.edges) == 0: + return None + + query = "CREATE " + for _, node in self.nodes.items(): + query += str(node) + "," + + query += ",".join([str(edge) for edge in self.edges]) + + # Discard leading comma. + if query[-1] == ",": + query = query[:-1] + + return self.query(query) + + def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = self.execute_command(*command) + return QueryResult(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return self.query(q, params, timeout, read_only) + + def merge(self, pattern): + """ + Merge pattern. + """ + query = "MERGE " + query += str(pattern) + + return self.query(query) + + def delete(self): + """ + Deletes graph. + For more information see `DELETE `_. # noqa + """ + self._clear_schema() + return self.execute_command(DELETE_CMD, self.name) + + # declared here, to override the built in redis.db.flush() + def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + self.commit() + self.nodes = {} + self.edges = [] + + def bulk(self, **kwargs): + """Internal only. Not supported.""" + raise NotImplementedError( + "GRAPH.BULK is internal only. " + "Use https://github.com/redisgraph/redisgraph-bulk-loader." + ) + + def profile(self, query): + """ + Execute a query and produce an execution plan augmented with metrics + for each operation's execution. Return a string representation of a + query execution plan, with details on results produced by and time + spent in each operation. + For more information see `GRAPH.PROFILE `_. # noqa + """ + return self.query(query, profile=True) + + def slowlog(self): + """ + Get a list containing up to 10 of the slowest queries issued + against the given graph ID. + For more information see `GRAPH.SLOWLOG `_. # noqa + + Each item in the list has the following structure: + 1. A unix timestamp at which the log entry was processed. + 2. The issued command. + 3. The issued query. + 4. The amount of time needed for its execution, in milliseconds. + """ + return self.execute_command(SLOWLOG_CMD, self.name) + + def config(self, name, value=None, set=False): + """ + Retrieve or update a RedisGraph configuration. + For more information see `https://redis.io/commands/graph.config-get/>`_. # noqa + + Args: + + name : str + The name of the configuration + value : + The value we want to set (can be used only when `set` is on) + set : bool + Turn on to set a configuration. Default behavior is get. + """ + params = ["SET" if set else "GET", name] + if value is not None: + if set: + params.append(value) + else: + raise DataError( + "``value`` can be provided only when ``set`` is True" + ) # noqa + return self.execute_command(CONFIG_CMD, *params) + + def list_keys(self): + """ + Lists all graph keys in the keyspace. + For more information see `GRAPH.LIST `_. # noqa + """ + return self.execute_command(LIST_CMD) + + def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + For more information see `GRAPH.EXPLAIN `_. # noqa + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + +class AsyncGraphCommands(GraphCommands): + async def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = await self.execute_command(*command) + return await AsyncQueryResult().initialize(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return await self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return await self.query(q, params, timeout, read_only) + + async def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + async def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + async def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + await self.commit() + self.nodes = {} + self.edges = [] diff --git a/.venv/Lib/site-packages/redis/commands/graph/edge.py b/.venv/Lib/site-packages/redis/commands/graph/edge.py new file mode 100644 index 00000000..6ee195f1 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/edge.py @@ -0,0 +1,91 @@ +from ..helpers import quote_string +from .node import Node + + +class Edge: + """ + An edge connecting two nodes. + """ + + def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None): + """ + Create a new edge. + """ + if src_node is None or dest_node is None: + # NOTE(bors-42): It makes sense to change AssertionError to + # ValueError here + raise AssertionError("Both src_node & dest_node must be provided") + + self.id = edge_id + self.relation = relation or "" + self.properties = properties or {} + self.src_node = src_node + self.dest_node = dest_node + + def to_string(self): + res = "" + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + + return res + + def __str__(self): + # Source node. + if isinstance(self.src_node, Node): + res = str(self.src_node) + else: + res = "()" + + # Edge + res += "-[" + if self.relation: + res += ":" + self.relation + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + res += "]->" + + # Dest node. + if isinstance(self.dest_node, Node): + res += str(self.dest_node) + else: + res += "()" + + return res + + def __eq__(self, rhs): + # Type checking + if not isinstance(rhs, Edge): + return False + + # Quick positive check, if both IDs are set. + if self.id is not None and rhs.id is not None and self.id == rhs.id: + return True + + # Source and destination nodes should match. + if self.src_node != rhs.src_node: + return False + + if self.dest_node != rhs.dest_node: + return False + + # Relation should match. + if self.relation != rhs.relation: + return False + + # Quick check for number of properties. + if len(self.properties) != len(rhs.properties): + return False + + # Compare properties. + if self.properties != rhs.properties: + return False + + return True diff --git a/.venv/Lib/site-packages/redis/commands/graph/exceptions.py b/.venv/Lib/site-packages/redis/commands/graph/exceptions.py new file mode 100644 index 00000000..4bbac100 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/exceptions.py @@ -0,0 +1,3 @@ +class VersionMismatchException(Exception): + def __init__(self, version): + self.version = version diff --git a/.venv/Lib/site-packages/redis/commands/graph/execution_plan.py b/.venv/Lib/site-packages/redis/commands/graph/execution_plan.py new file mode 100644 index 00000000..179a80cc --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/execution_plan.py @@ -0,0 +1,211 @@ +import re + + +class ProfileStats: + """ + ProfileStats, runtime execution statistics of operation. + """ + + def __init__(self, records_produced, execution_time): + self.records_produced = records_produced + self.execution_time = execution_time + + +class Operation: + """ + Operation, single operation within execution plan. + """ + + def __init__(self, name, args=None, profile_stats=None): + """ + Create a new operation. + + Args: + name: string that represents the name of the operation + args: operation arguments + profile_stats: profile statistics + """ + self.name = name + self.args = args + self.profile_stats = profile_stats + self.children = [] + + def append_child(self, child): + if not isinstance(child, Operation) or self is child: + raise Exception("child must be Operation") + + self.children.append(child) + return self + + def child_count(self): + return len(self.children) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, Operation): + return False + + return self.name == o.name and self.args == o.args + + def __str__(self) -> str: + args_str = "" if self.args is None else " | " + self.args + return f"{self.name}{args_str}" + + +class ExecutionPlan: + """ + ExecutionPlan, collection of operations. + """ + + def __init__(self, plan): + """ + Create a new execution plan. + + Args: + plan: array of strings that represents the collection operations + the output from GRAPH.EXPLAIN + """ + if not isinstance(plan, list): + raise Exception("plan must be an array") + + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + + self.plan = plan + self.structured_plan = self._operation_tree() + + def _compare_operations(self, root_a, root_b): + """ + Compare execution plan operation tree + + Return: True if operation trees are equal, False otherwise + """ + + # compare current root + if root_a != root_b: + return False + + # make sure root have the same number of children + if root_a.child_count() != root_b.child_count(): + return False + + # recursively compare children + for i in range(root_a.child_count()): + if not self._compare_operations(root_a.children[i], root_b.children[i]): + return False + + return True + + def __str__(self) -> str: + def aggraget_str(str_children): + return "\n".join( + [ + " " + line + for str_child in str_children + for line in str_child.splitlines() + ] + ) + + def combine_str(x, y): + return f"{x}\n{y}" + + return self._operation_traverse( + self.structured_plan, str, aggraget_str, combine_str + ) + + def __eq__(self, o: object) -> bool: + """Compares two execution plans + + Return: True if the two plans are equal False otherwise + """ + # make sure 'o' is an execution-plan + if not isinstance(o, ExecutionPlan): + return False + + # get root for both plans + root_a = self.structured_plan + root_b = o.structured_plan + + # compare execution trees + return self._compare_operations(root_a, root_b) + + def _operation_traverse(self, op, op_f, aggregate_f, combine_f): + """ + Traverse operation tree recursively applying functions + + Args: + op: operation to traverse + op_f: function applied for each operation + aggregate_f: aggregation function applied for all children of a single operation + combine_f: combine function applied for the operation result and the children result + """ # noqa + # apply op_f for each operation + op_res = op_f(op) + if len(op.children) == 0: + return op_res # no children return + else: + # apply _operation_traverse recursively + children = [ + self._operation_traverse(child, op_f, aggregate_f, combine_f) + for child in op.children + ] + # combine the operation result with the children aggregated result + return combine_f(op_res, aggregate_f(children)) + + def _operation_tree(self): + """Build the operation tree from the string representation""" + + # initial state + i = 0 + level = 0 + stack = [] + current = None + + def _create_operation(args): + profile_stats = None + name = args[0].strip() + args.pop(0) + if len(args) > 0 and "Records produced" in args[-1]: + records_produced = int( + re.search("Records produced: (\\d+)", args[-1]).group(1) + ) + execution_time = float( + re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1) + ) + profile_stats = ProfileStats(records_produced, execution_time) + args.pop(-1) + return Operation( + name, None if len(args) == 0 else args[0].strip(), profile_stats + ) + + # iterate plan operations + while i < len(self.plan): + current_op = self.plan[i] + op_level = current_op.count(" ") + if op_level == level: + # if the operation level equal to the current level + # set the current operation and move next + child = _create_operation(current_op.split("|")) + if current: + current = stack.pop() + current.append_child(child) + current = child + i += 1 + elif op_level == level + 1: + # if the operation is child of the current operation + # add it as child and set as current operation + child = _create_operation(current_op.split("|")) + current.append_child(child) + stack.append(current) + current = child + level += 1 + i += 1 + elif op_level < level: + # if the operation is not child of current operation + # go back to it's parent operation + levels_back = level - op_level + 1 + for _ in range(levels_back): + current = stack.pop() + level -= levels_back + else: + raise Exception("corrupted plan") + return stack[0] diff --git a/.venv/Lib/site-packages/redis/commands/graph/node.py b/.venv/Lib/site-packages/redis/commands/graph/node.py new file mode 100644 index 00000000..4546a393 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/node.py @@ -0,0 +1,88 @@ +from ..helpers import quote_string + + +class Node: + """ + A node within the graph. + """ + + def __init__(self, node_id=None, alias=None, label=None, properties=None): + """ + Create a new node. + """ + self.id = node_id + self.alias = alias + if isinstance(label, list): + label = [inner_label for inner_label in label if inner_label != ""] + + if ( + label is None + or label == "" + or (isinstance(label, list) and len(label) == 0) + ): + self.label = None + self.labels = None + elif isinstance(label, str): + self.label = label + self.labels = [label] + elif isinstance(label, list) and all( + [isinstance(inner_label, str) for inner_label in label] + ): + self.label = label[0] + self.labels = label + else: + raise AssertionError( + "label should be either None, string or a list of strings" + ) + + self.properties = properties or {} + + def to_string(self): + res = "" + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + + return res + + def __str__(self): + res = "(" + if self.alias: + res += self.alias + if self.labels: + res += ":" + ":".join(self.labels) + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + res += ")" + + return res + + def __eq__(self, rhs): + # Type checking + if not isinstance(rhs, Node): + return False + + # Quick positive check, if both IDs are set. + if self.id is not None and rhs.id is not None and self.id != rhs.id: + return False + + # Label should match. + if self.label != rhs.label: + return False + + # Quick check for number of properties. + if len(self.properties) != len(rhs.properties): + return False + + # Compare properties. + if self.properties != rhs.properties: + return False + + return True diff --git a/.venv/Lib/site-packages/redis/commands/graph/path.py b/.venv/Lib/site-packages/redis/commands/graph/path.py new file mode 100644 index 00000000..ee22dc8c --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/path.py @@ -0,0 +1,78 @@ +from .edge import Edge +from .node import Node + + +class Path: + def __init__(self, nodes, edges): + if not (isinstance(nodes, list) and isinstance(edges, list)): + raise TypeError("nodes and edges must be list") + + self._nodes = nodes + self._edges = edges + self.append_type = Node + + @classmethod + def new_empty_path(cls): + return cls([], []) + + def nodes(self): + return self._nodes + + def edges(self): + return self._edges + + def get_node(self, index): + return self._nodes[index] + + def get_relationship(self, index): + return self._edges[index] + + def first_node(self): + return self._nodes[0] + + def last_node(self): + return self._nodes[-1] + + def edge_count(self): + return len(self._edges) + + def nodes_count(self): + return len(self._nodes) + + def add_node(self, node): + if not isinstance(node, self.append_type): + raise AssertionError("Add Edge before adding Node") + self._nodes.append(node) + self.append_type = Edge + return self + + def add_edge(self, edge): + if not isinstance(edge, self.append_type): + raise AssertionError("Add Node before adding Edge") + self._edges.append(edge) + self.append_type = Node + return self + + def __eq__(self, other): + # Type checking + if not isinstance(other, Path): + return False + + return self.nodes() == other.nodes() and self.edges() == other.edges() + + def __str__(self): + res = "<" + edge_count = self.edge_count() + for i in range(0, edge_count): + node_id = self.get_node(i).id + res += "(" + str(node_id) + ")" + edge = self.get_relationship(i) + res += ( + "-[" + str(int(edge.id)) + "]->" + if edge.src_node == node_id + else "<-[" + str(int(edge.id)) + "]-" + ) + node_id = self.get_node(edge_count).id + res += "(" + str(node_id) + ")" + res += ">" + return res diff --git a/.venv/Lib/site-packages/redis/commands/graph/query_result.py b/.venv/Lib/site-packages/redis/commands/graph/query_result.py new file mode 100644 index 00000000..7c7f58b9 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/graph/query_result.py @@ -0,0 +1,573 @@ +import sys +from collections import OrderedDict +from distutils.util import strtobool + +# from prettytable import PrettyTable +from redis import ResponseError + +from .edge import Edge +from .exceptions import VersionMismatchException +from .node import Node +from .path import Path + +LABELS_ADDED = "Labels added" +LABELS_REMOVED = "Labels removed" +NODES_CREATED = "Nodes created" +NODES_DELETED = "Nodes deleted" +RELATIONSHIPS_DELETED = "Relationships deleted" +PROPERTIES_SET = "Properties set" +PROPERTIES_REMOVED = "Properties removed" +RELATIONSHIPS_CREATED = "Relationships created" +INDICES_CREATED = "Indices created" +INDICES_DELETED = "Indices deleted" +CACHED_EXECUTION = "Cached execution" +INTERNAL_EXECUTION_TIME = "internal execution time" + +STATS = [ + LABELS_ADDED, + LABELS_REMOVED, + NODES_CREATED, + PROPERTIES_SET, + PROPERTIES_REMOVED, + RELATIONSHIPS_CREATED, + NODES_DELETED, + RELATIONSHIPS_DELETED, + INDICES_CREATED, + INDICES_DELETED, + CACHED_EXECUTION, + INTERNAL_EXECUTION_TIME, +] + + +class ResultSetColumnTypes: + COLUMN_UNKNOWN = 0 + COLUMN_SCALAR = 1 + COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa + COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa + + +class ResultSetScalarTypes: + VALUE_UNKNOWN = 0 + VALUE_NULL = 1 + VALUE_STRING = 2 + VALUE_INTEGER = 3 + VALUE_BOOLEAN = 4 + VALUE_DOUBLE = 5 + VALUE_ARRAY = 6 + VALUE_EDGE = 7 + VALUE_NODE = 8 + VALUE_PATH = 9 + VALUE_MAP = 10 + VALUE_POINT = 11 + + +class QueryResult: + def __init__(self, graph, response, profile=False): + """ + A class that represents a result of the query operation. + + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + self.parse_results(response) + + def _check_for_errors(self, response): + """ + Check if the response contains an error. + """ + if isinstance(response[0], ResponseError): + error = response[0] + if str(error) == "version mismatch": + version = response[1] + error = VersionMismatchException(version) + raise error + + # If we encountered a run-time error, the last response + # element will be an exception + if isinstance(response[-1], ResponseError): + raise response[-1] + + def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = self.parse_records(raw_result_set) + + def parse_statistics(self, raw_statistics): + """ + Parse the statistics returned in the response. + """ + self.statistics = {} + + # decode statistics + for idx, stat in enumerate(raw_statistics): + if isinstance(stat, bytes): + raw_statistics[idx] = stat.decode() + + for s in STATS: + v = self._get_value(s, raw_statistics) + if v is not None: + self.statistics[s] = v + + def parse_header(self, raw_result_set): + """ + Parse the header of the result. + """ + # An array of column name/column type pairs. + header = raw_result_set[0] + return header + + def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [ + [ + self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + for row in raw_result_set[1] + ] + + return records + + def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = self.graph.get_property(prop[0]) + prop_value = self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + def parse_string(self, cell): + """ + Parse the cell as a string. + """ + if isinstance(cell, bytes): + return cell.decode() + elif not isinstance(cell, str): + return str(cell) + else: + return cell + + def parse_node(self, cell): + """ + Parse the cell to a node. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + node_id = int(cell[0]) + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(self.graph.get_label(inner_label)) + properties = self.parse_entity_properties(cell[2]) + return Node(node_id=node_id, label=labels, properties=properties) + + def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = self.parse_scalar(cell[0]) + edges = self.parse_scalar(cell[1]) + return Path(nodes, edges) + + def parse_map(self, cell): + """ + Parse the cell as a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = self.parse_scalar(cell[i + 1]) + + return m + + def parse_point(self, cell): + """ + Parse the cell to point. + """ + p = {} + # A point is received an array of the form: [latitude, longitude] + # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa + p["latitude"] = float(cell[0]) + p["longitude"] = float(cell[1]) + return p + + def parse_null(self, cell): + """ + Parse a null value. + """ + return None + + def parse_integer(self, cell): + """ + Parse the integer value from the cell. + """ + return int(cell) + + def parse_boolean(self, value): + """ + Parse the cell value as a boolean. + """ + value = value.decode() if isinstance(value, bytes) else value + try: + scalar = True if strtobool(value) else False + except ValueError: + sys.stderr.write("unknown boolean type\n") + scalar = None + return scalar + + def parse_double(self, cell): + """ + Parse the cell as a double. + """ + return float(cell) + + def parse_array(self, value): + """ + Parse an array of values. + """ + scalar = [self.parse_scalar(value[i]) for i in range(len(value))] + return scalar + + def parse_unknown(self, cell): + """ + Parse a cell of unknown type. + """ + sys.stderr.write("Unknown type\n") + return None + + def parse_scalar(self, cell): + """ + Parse a scalar value from a cell in the result set. + """ + scalar_type = int(cell[0]) + value = cell[1] + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + def parse_profile(self, response): + self.result_set = [x[0 : x.index(",")].strip() for x in response] + + def is_empty(self): + return len(self.result_set) == 0 + + @staticmethod + def _get_value(prop, statistics): + for stat in statistics: + if prop in stat: + return float(stat.split(": ")[1].split(" ")[0]) + + return None + + def _get_stat(self, stat): + return self.statistics[stat] if stat in self.statistics else 0 + + @property + def labels_added(self): + """Returns the number of labels added in the query""" + return self._get_stat(LABELS_ADDED) + + @property + def labels_removed(self): + """Returns the number of labels removed in the query""" + return self._get_stat(LABELS_REMOVED) + + @property + def nodes_created(self): + """Returns the number of nodes created in the query""" + return self._get_stat(NODES_CREATED) + + @property + def nodes_deleted(self): + """Returns the number of nodes deleted in the query""" + return self._get_stat(NODES_DELETED) + + @property + def properties_set(self): + """Returns the number of properties set in the query""" + return self._get_stat(PROPERTIES_SET) + + @property + def properties_removed(self): + """Returns the number of properties removed in the query""" + return self._get_stat(PROPERTIES_REMOVED) + + @property + def relationships_created(self): + """Returns the number of relationships created in the query""" + return self._get_stat(RELATIONSHIPS_CREATED) + + @property + def relationships_deleted(self): + """Returns the number of relationships deleted in the query""" + return self._get_stat(RELATIONSHIPS_DELETED) + + @property + def indices_created(self): + """Returns the number of indices created in the query""" + return self._get_stat(INDICES_CREATED) + + @property + def indices_deleted(self): + """Returns the number of indices deleted in the query""" + return self._get_stat(INDICES_DELETED) + + @property + def cached_execution(self): + """Returns whether or not the query execution plan was cached""" + return self._get_stat(CACHED_EXECUTION) == 1 + + @property + def run_time_ms(self): + """Returns the server execution time of the query""" + return self._get_stat(INTERNAL_EXECUTION_TIME) + + @property + def parse_scalar_types(self): + return { + ResultSetScalarTypes.VALUE_NULL: self.parse_null, + ResultSetScalarTypes.VALUE_STRING: self.parse_string, + ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, + ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, + ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, + ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, + ResultSetScalarTypes.VALUE_NODE: self.parse_node, + ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, + ResultSetScalarTypes.VALUE_PATH: self.parse_path, + ResultSetScalarTypes.VALUE_MAP: self.parse_map, + ResultSetScalarTypes.VALUE_POINT: self.parse_point, + ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, + } + + @property + def parse_record_types(self): + return { + ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, + ResultSetColumnTypes.COLUMN_NODE: self.parse_node, + ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, + ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, + } + + +class AsyncQueryResult(QueryResult): + """ + Async version for the QueryResult class - a class that + represents a result of the query operation. + """ + + def __init__(self): + """ + To init the class you must call self.initialize() + """ + pass + + async def initialize(self, graph, response, profile=False): + """ + Initializes the class. + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + await self.parse_results(response) + + return self + + async def parse_node(self, cell): + """ + Parses a node from the cell. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(await self.graph.get_label(inner_label)) + properties = await self.parse_entity_properties(cell[2]) + node_id = int(cell[0]) + return Node(node_id=node_id, label=labels, properties=properties) + + async def parse_scalar(self, cell): + """ + Parses a scalar value from the server response. + """ + scalar_type = int(cell[0]) + value = cell[1] + try: + scalar = await self.parse_scalar_types[scalar_type](value) + except TypeError: + # Not all of the functions are async + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + async def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [] + for row in raw_result_set[1]: + record = [ + await self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + records.append(record) + + return records + + async def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = await self.parse_records(raw_result_set) + + async def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = await self.graph.get_property(prop[0]) + prop_value = await self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + async def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = await self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = await self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + async def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = await self.parse_scalar(cell[0]) + edges = await self.parse_scalar(cell[1]) + return Path(nodes, edges) + + async def parse_map(self, cell): + """ + Parse the cell to a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = await self.parse_scalar(cell[i + 1]) + + return m + + async def parse_array(self, value): + """ + Parse array value. + """ + scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] + return scalar diff --git a/.venv/Lib/site-packages/redis/commands/helpers.py b/.venv/Lib/site-packages/redis/commands/helpers.py new file mode 100644 index 00000000..324d981d --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/helpers.py @@ -0,0 +1,166 @@ +import copy +import random +import string +from typing import List, Tuple + +import redis +from redis.typing import KeysT, KeyT + + +def list_or_args(keys: KeysT, args: Tuple[KeyT, ...]) -> List[KeyT]: + # returns a single new list combining keys and args + try: + iter(keys) + # a string or bytes instance can be iterated, but indicates + # keys wasn't passed as a list + if isinstance(keys, (bytes, str)): + keys = [keys] + else: + keys = list(keys) + except TypeError: + keys = [keys] + if args: + keys.extend(args) + return keys + + +def nativestr(x): + """Return the decoded binary string, or a string, depending on type.""" + r = x.decode("utf-8", "replace") if isinstance(x, bytes) else x + if r == "null": + return + return r + + +def delist(x): + """Given a list of binaries, return the stringified version.""" + if x is None: + return x + return [nativestr(obj) for obj in x] + + +def parse_to_list(response): + """Optimistically parse the response to a list.""" + res = [] + + if response is None: + return res + + for item in response: + try: + res.append(int(item)) + except ValueError: + try: + res.append(float(item)) + except ValueError: + res.append(nativestr(item)) + except TypeError: + res.append(None) + return res + + +def parse_list_to_dict(response): + res = {} + for i in range(0, len(response), 2): + if isinstance(response[i], list): + res["Child iterators"].append(parse_list_to_dict(response[i])) + elif isinstance(response[i + 1], list): + res["Child iterators"] = [parse_list_to_dict(response[i + 1])] + else: + try: + res[response[i]] = float(response[i + 1]) + except (TypeError, ValueError): + res[response[i]] = response[i + 1] + return res + + +def parse_to_dict(response): + if response is None: + return {} + + res = {} + for det in response: + if isinstance(det[1], list): + res[det[0]] = parse_list_to_dict(det[1]) + else: + try: # try to set the attribute. may be provided without value + try: # try to convert the value to float + res[det[0]] = float(det[1]) + except (TypeError, ValueError): + res[det[0]] = det[1] + except IndexError: + pass + return res + + +def random_string(length=10): + """ + Returns a random N character long string. + """ + return "".join( # nosec + random.choice(string.ascii_lowercase) for x in range(length) + ) + + +def quote_string(v): + """ + RedisGraph strings must be quoted, + quote_string wraps given v with quotes incase + v is a string. + """ + + if isinstance(v, bytes): + v = v.decode() + elif not isinstance(v, str): + return v + if len(v) == 0: + return '""' + + v = v.replace("\\", "\\\\") + v = v.replace('"', '\\"') + + return f'"{v}"' + + +def decode_dict_keys(obj): + """Decode the keys of the given dictionary with utf-8.""" + newobj = copy.copy(obj) + for k in obj.keys(): + if isinstance(k, bytes): + newobj[k.decode("utf-8")] = newobj[k] + newobj.pop(k) + return newobj + + +def stringify_param_value(value): + """ + Turn a parameter value into a string suitable for the params header of + a Cypher command. + You may pass any value that would be accepted by `json.dumps()`. + + Ways in which output differs from that of `str()`: + * Strings are quoted. + * None --> "null". + * In dictionaries, keys are _not_ quoted. + + :param value: The parameter value to be turned into a string. + :return: string + """ + + if isinstance(value, str): + return quote_string(value) + elif value is None: + return "null" + elif isinstance(value, (list, tuple)): + return f'[{",".join(map(stringify_param_value, value))}]' + elif isinstance(value, dict): + return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa + else: + return str(value) + + +def get_protocol_version(client): + if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis): + return client.connection_pool.connection_kwargs.get("protocol") + elif isinstance(client, redis.cluster.AbstractRedisCluster): + return client.nodes_manager.connection_kwargs.get("protocol") diff --git a/.venv/Lib/site-packages/redis/commands/json/__init__.py b/.venv/Lib/site-packages/redis/commands/json/__init__.py new file mode 100644 index 00000000..01077e6b --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/json/__init__.py @@ -0,0 +1,147 @@ +from json import JSONDecodeError, JSONDecoder, JSONEncoder + +import redis + +from ..helpers import get_protocol_version, nativestr +from .commands import JSONCommands +from .decoders import bulk_of_jsons, decode_list + + +class JSON(JSONCommands): + """ + Create a client for talking to json. + + :param decoder: + :type json.JSONDecoder: An instance of json.JSONDecoder + + :param encoder: + :type json.JSONEncoder: An instance of json.JSONEncoder + """ + + def __init__( + self, client, version=None, decoder=JSONDecoder(), encoder=JSONEncoder() + ): + """ + Create a client for talking to json. + + :param decoder: + :type json.JSONDecoder: An instance of json.JSONDecoder + + :param encoder: + :type json.JSONEncoder: An instance of json.JSONEncoder + """ + # Set the module commands' callbacks + self._MODULE_CALLBACKS = { + "JSON.ARRPOP": self._decode, + "JSON.DEBUG": self._decode, + "JSON.GET": self._decode, + "JSON.MERGE": lambda r: r and nativestr(r) == "OK", + "JSON.MGET": bulk_of_jsons(self._decode), + "JSON.MSET": lambda r: r and nativestr(r) == "OK", + "JSON.RESP": self._decode, + "JSON.SET": lambda r: r and nativestr(r) == "OK", + "JSON.TOGGLE": self._decode, + } + + _RESP2_MODULE_CALLBACKS = { + "JSON.ARRAPPEND": self._decode, + "JSON.ARRINDEX": self._decode, + "JSON.ARRINSERT": self._decode, + "JSON.ARRLEN": self._decode, + "JSON.ARRTRIM": self._decode, + "JSON.CLEAR": int, + "JSON.DEL": int, + "JSON.FORGET": int, + "JSON.GET": self._decode, + "JSON.NUMINCRBY": self._decode, + "JSON.NUMMULTBY": self._decode, + "JSON.OBJKEYS": self._decode, + "JSON.STRAPPEND": self._decode, + "JSON.OBJLEN": self._decode, + "JSON.STRLEN": self._decode, + "JSON.TOGGLE": self._decode, + } + + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.execute_command = client.execute_command + self.MODULE_VERSION = version + + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for key, value in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(key, value) + + self.__encoder__ = encoder + self.__decoder__ = decoder + + def _decode(self, obj): + """Get the decoder.""" + if obj is None: + return obj + + try: + x = self.__decoder__.decode(obj) + if x is None: + raise TypeError + return x + except TypeError: + try: + return self.__decoder__.decode(obj.decode()) + except AttributeError: + return decode_list(obj) + except (AttributeError, JSONDecodeError): + return decode_list(obj) + + def _encode(self, obj): + """Get the encoder.""" + return self.__encoder__.encode(obj) + + def pipeline(self, transaction=True, shard_hint=None): + """Creates a pipeline for the JSON module, that can be used for executing + JSON commands, as well as classic core commands. + + Usage example: + + r = redis.Redis() + pipe = r.json().pipeline() + pipe.jsonset('foo', '.', {'hello!': 'world'}) + pipe.jsonget('foo') + pipe.jsonget('notakey') + """ + if isinstance(self.client, redis.RedisCluster): + p = ClusterPipeline( + nodes_manager=self.client.nodes_manager, + commands_parser=self.client.commands_parser, + startup_nodes=self.client.nodes_manager.startup_nodes, + result_callbacks=self.client.result_callbacks, + cluster_response_callbacks=self.client.cluster_response_callbacks, + cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, + read_from_replicas=self.client.read_from_replicas, + reinitialize_steps=self.client.reinitialize_steps, + lock=self.client._lock, + ) + + else: + p = Pipeline( + connection_pool=self.client.connection_pool, + response_callbacks=self._MODULE_CALLBACKS, + transaction=transaction, + shard_hint=shard_hint, + ) + + p._encode = self._encode + p._decode = self._decode + return p + + +class ClusterPipeline(JSONCommands, redis.cluster.ClusterPipeline): + """Cluster pipeline for the module.""" + + +class Pipeline(JSONCommands, redis.client.Pipeline): + """Pipeline for the module.""" diff --git a/.venv/Lib/site-packages/redis/commands/json/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/json/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..1d16612c Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/json/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/json/__pycache__/_util.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/json/__pycache__/_util.cpython-311.pyc new file mode 100644 index 00000000..b14c369e Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/json/__pycache__/_util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/json/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/json/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..870eee34 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/json/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/json/__pycache__/decoders.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/json/__pycache__/decoders.cpython-311.pyc new file mode 100644 index 00000000..a4dabe0e Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/json/__pycache__/decoders.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/json/__pycache__/path.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/json/__pycache__/path.cpython-311.pyc new file mode 100644 index 00000000..ef363fdd Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/json/__pycache__/path.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/json/_util.py b/.venv/Lib/site-packages/redis/commands/json/_util.py new file mode 100644 index 00000000..3400bcd6 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/json/_util.py @@ -0,0 +1,3 @@ +from typing import Any, Dict, List, Union + +JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] diff --git a/.venv/Lib/site-packages/redis/commands/json/commands.py b/.venv/Lib/site-packages/redis/commands/json/commands.py new file mode 100644 index 00000000..0f92e0d6 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/json/commands.py @@ -0,0 +1,429 @@ +import os +from json import JSONDecodeError, loads +from typing import Dict, List, Optional, Tuple, Union + +from redis.exceptions import DataError +from redis.utils import deprecated_function + +from ._util import JsonType +from .decoders import decode_dict_keys +from .path import Path + + +class JSONCommands: + """json commands.""" + + def arrappend( + self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType] + ) -> List[Union[int, None]]: + """Append the objects ``args`` to the array under the + ``path` in key ``name``. + + For more information see `JSON.ARRAPPEND `_.. + """ # noqa + pieces = [name, str(path)] + for o in args: + pieces.append(self._encode(o)) + return self.execute_command("JSON.ARRAPPEND", *pieces) + + def arrindex( + self, + name: str, + path: str, + scalar: int, + start: Optional[int] = None, + stop: Optional[int] = None, + ) -> List[Union[int, None]]: + """ + Return the index of ``scalar`` in the JSON array under ``path`` at key + ``name``. + + The search can be limited using the optional inclusive ``start`` + and exclusive ``stop`` indices. + + For more information see `JSON.ARRINDEX `_. + """ # noqa + pieces = [name, str(path), self._encode(scalar)] + if start is not None: + pieces.append(start) + if stop is not None: + pieces.append(stop) + + return self.execute_command("JSON.ARRINDEX", *pieces) + + def arrinsert( + self, name: str, path: str, index: int, *args: List[JsonType] + ) -> List[Union[int, None]]: + """Insert the objects ``args`` to the array at index ``index`` + under the ``path` in key ``name``. + + For more information see `JSON.ARRINSERT `_. + """ # noqa + pieces = [name, str(path), index] + for o in args: + pieces.append(self._encode(o)) + return self.execute_command("JSON.ARRINSERT", *pieces) + + def arrlen( + self, name: str, path: Optional[str] = Path.root_path() + ) -> List[Union[int, None]]: + """Return the length of the array JSON value under ``path`` + at key``name``. + + For more information see `JSON.ARRLEN `_. + """ # noqa + return self.execute_command("JSON.ARRLEN", name, str(path)) + + def arrpop( + self, + name: str, + path: Optional[str] = Path.root_path(), + index: Optional[int] = -1, + ) -> List[Union[str, None]]: + """Pop the element at ``index`` in the array JSON value under + ``path`` at key ``name``. + + For more information see `JSON.ARRPOP `_. + """ # noqa + return self.execute_command("JSON.ARRPOP", name, str(path), index) + + def arrtrim( + self, name: str, path: str, start: int, stop: int + ) -> List[Union[int, None]]: + """Trim the array JSON value under ``path`` at key ``name`` to the + inclusive range given by ``start`` and ``stop``. + + For more information see `JSON.ARRTRIM `_. + """ # noqa + return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop) + + def type(self, name: str, path: Optional[str] = Path.root_path()) -> List[str]: + """Get the type of the JSON value under ``path`` from key ``name``. + + For more information see `JSON.TYPE `_. + """ # noqa + return self.execute_command("JSON.TYPE", name, str(path)) + + def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: + """Return the JSON value under ``path`` at key ``name``. + + For more information see `JSON.RESP `_. + """ # noqa + return self.execute_command("JSON.RESP", name, str(path)) + + def objkeys( + self, name: str, path: Optional[str] = Path.root_path() + ) -> List[Union[List[str], None]]: + """Return the key names in the dictionary JSON value under ``path`` at + key ``name``. + + For more information see `JSON.OBJKEYS `_. + """ # noqa + return self.execute_command("JSON.OBJKEYS", name, str(path)) + + def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: + """Return the length of the dictionary JSON value under ``path`` at key + ``name``. + + For more information see `JSON.OBJLEN `_. + """ # noqa + return self.execute_command("JSON.OBJLEN", name, str(path)) + + def numincrby(self, name: str, path: str, number: int) -> str: + """Increment the numeric (integer or floating point) JSON value under + ``path`` at key ``name`` by the provided ``number``. + + For more information see `JSON.NUMINCRBY `_. + """ # noqa + return self.execute_command( + "JSON.NUMINCRBY", name, str(path), self._encode(number) + ) + + @deprecated_function(version="4.0.0", reason="deprecated since redisjson 1.0.0") + def nummultby(self, name: str, path: str, number: int) -> str: + """Multiply the numeric (integer or floating point) JSON value under + ``path`` at key ``name`` with the provided ``number``. + + For more information see `JSON.NUMMULTBY `_. + """ # noqa + return self.execute_command( + "JSON.NUMMULTBY", name, str(path), self._encode(number) + ) + + def clear(self, name: str, path: Optional[str] = Path.root_path()) -> int: + """Empty arrays and objects (to have zero slots/keys without deleting the + array/object). + + Return the count of cleared paths (ignoring non-array and non-objects + paths). + + For more information see `JSON.CLEAR `_. + """ # noqa + return self.execute_command("JSON.CLEAR", name, str(path)) + + def delete(self, key: str, path: Optional[str] = Path.root_path()) -> int: + """Delete the JSON value stored at key ``key`` under ``path``. + + For more information see `JSON.DEL `_. + """ + return self.execute_command("JSON.DEL", key, str(path)) + + # forget is an alias for delete + forget = delete + + def get( + self, name: str, *args, no_escape: Optional[bool] = False + ) -> List[JsonType]: + """ + Get the object stored as a JSON value at key ``name``. + + ``args`` is zero or more paths, and defaults to root path + ```no_escape`` is a boolean flag to add no_escape option to get + non-ascii characters + + For more information see `JSON.GET `_. + """ # noqa + pieces = [name] + if no_escape: + pieces.append("noescape") + + if len(args) == 0: + pieces.append(Path.root_path()) + + else: + for p in args: + pieces.append(str(p)) + + # Handle case where key doesn't exist. The JSONDecoder would raise a + # TypeError exception since it can't decode None + try: + return self.execute_command("JSON.GET", *pieces) + except TypeError: + return None + + def mget(self, keys: List[str], path: str) -> List[JsonType]: + """ + Get the objects stored as a JSON values under ``path``. ``keys`` + is a list of one or more keys. + + For more information see `JSON.MGET `_. + """ # noqa + pieces = [] + pieces += keys + pieces.append(str(path)) + return self.execute_command("JSON.MGET", *pieces) + + def set( + self, + name: str, + path: str, + obj: JsonType, + nx: Optional[bool] = False, + xx: Optional[bool] = False, + decode_keys: Optional[bool] = False, + ) -> Optional[str]: + """ + Set the JSON value at key ``name`` under the ``path`` to ``obj``. + + ``nx`` if set to True, set ``value`` only if it does not exist. + ``xx`` if set to True, set ``value`` only if it exists. + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + For the purpose of using this within a pipeline, this command is also + aliased to JSON.SET. + + For more information see `JSON.SET `_. + """ + if decode_keys: + obj = decode_dict_keys(obj) + + pieces = [name, str(path), self._encode(obj)] + + # Handle existential modifiers + if nx and xx: + raise Exception( + "nx and xx are mutually exclusive: use one, the " + "other or neither - but not both" + ) + elif nx: + pieces.append("NX") + elif xx: + pieces.append("XX") + return self.execute_command("JSON.SET", *pieces) + + def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]: + """ + Set the JSON value at key ``name`` under the ``path`` to ``obj`` + for one or more keys. + + ``triplets`` is a list of one or more triplets of key, path, value. + + For the purpose of using this within a pipeline, this command is also + aliased to JSON.MSET. + + For more information see `JSON.MSET `_. + """ + pieces = [] + for triplet in triplets: + pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])]) + return self.execute_command("JSON.MSET", *pieces) + + def merge( + self, + name: str, + path: str, + obj: JsonType, + decode_keys: Optional[bool] = False, + ) -> Optional[str]: + """ + Merges a given JSON value into matching paths. Consequently, JSON values + at matching paths are updated, deleted, or expanded with new children + + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + For more information see `JSON.MERGE `_. + """ + if decode_keys: + obj = decode_dict_keys(obj) + + pieces = [name, str(path), self._encode(obj)] + + return self.execute_command("JSON.MERGE", *pieces) + + def set_file( + self, + name: str, + path: str, + file_name: str, + nx: Optional[bool] = False, + xx: Optional[bool] = False, + decode_keys: Optional[bool] = False, + ) -> Optional[str]: + """ + Set the JSON value at key ``name`` under the ``path`` to the content + of the json file ``file_name``. + + ``nx`` if set to True, set ``value`` only if it does not exist. + ``xx`` if set to True, set ``value`` only if it exists. + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + """ + + with open(file_name, "r") as fp: + file_content = loads(fp.read()) + + return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys) + + def set_path( + self, + json_path: str, + root_folder: str, + nx: Optional[bool] = False, + xx: Optional[bool] = False, + decode_keys: Optional[bool] = False, + ) -> List[Dict[str, bool]]: + """ + Iterate over ``root_folder`` and set each JSON file to a value + under ``json_path`` with the file name as the key. + + ``nx`` if set to True, set ``value`` only if it does not exist. + ``xx`` if set to True, set ``value`` only if it exists. + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + """ + set_files_result = {} + for root, dirs, files in os.walk(root_folder): + for file in files: + file_path = os.path.join(root, file) + try: + file_name = file_path.rsplit(".")[0] + self.set_file( + file_name, + json_path, + file_path, + nx=nx, + xx=xx, + decode_keys=decode_keys, + ) + set_files_result[file_path] = True + except JSONDecodeError: + set_files_result[file_path] = False + + return set_files_result + + def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]: + """Return the length of the string JSON value under ``path`` at key + ``name``. + + For more information see `JSON.STRLEN `_. + """ # noqa + pieces = [name] + if path is not None: + pieces.append(str(path)) + return self.execute_command("JSON.STRLEN", *pieces) + + def toggle( + self, name: str, path: Optional[str] = Path.root_path() + ) -> Union[bool, List[Optional[int]]]: + """Toggle boolean value under ``path`` at key ``name``. + returning the new value. + + For more information see `JSON.TOGGLE `_. + """ # noqa + return self.execute_command("JSON.TOGGLE", name, str(path)) + + def strappend( + self, name: str, value: str, path: Optional[int] = Path.root_path() + ) -> Union[int, List[Optional[int]]]: + """Append to the string JSON value. If two options are specified after + the key name, the path is determined to be the first. If a single + option is passed, then the root_path (i.e Path.root_path()) is used. + + For more information see `JSON.STRAPPEND `_. + """ # noqa + pieces = [name, str(path), self._encode(value)] + return self.execute_command("JSON.STRAPPEND", *pieces) + + def debug( + self, + subcommand: str, + key: Optional[str] = None, + path: Optional[str] = Path.root_path(), + ) -> Union[int, List[str]]: + """Return the memory usage in bytes of a value under ``path`` from + key ``name``. + + For more information see `JSON.DEBUG `_. + """ # noqa + valid_subcommands = ["MEMORY", "HELP"] + if subcommand not in valid_subcommands: + raise DataError("The only valid subcommands are ", str(valid_subcommands)) + pieces = [subcommand] + if subcommand == "MEMORY": + if key is None: + raise DataError("No key specified") + pieces.append(key) + pieces.append(str(path)) + return self.execute_command("JSON.DEBUG", *pieces) + + @deprecated_function( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) + def jsonget(self, *args, **kwargs): + return self.get(*args, **kwargs) + + @deprecated_function( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) + def jsonmget(self, *args, **kwargs): + return self.mget(*args, **kwargs) + + @deprecated_function( + version="4.0.0", reason="redisjson-py supported this, call get directly." + ) + def jsonset(self, *args, **kwargs): + return self.set(*args, **kwargs) diff --git a/.venv/Lib/site-packages/redis/commands/json/decoders.py b/.venv/Lib/site-packages/redis/commands/json/decoders.py new file mode 100644 index 00000000..b9384711 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/json/decoders.py @@ -0,0 +1,60 @@ +import copy +import re + +from ..helpers import nativestr + + +def bulk_of_jsons(d): + """Replace serialized JSON values with objects in a + bulk array response (list). + """ + + def _f(b): + for index, item in enumerate(b): + if item is not None: + b[index] = d(item) + return b + + return _f + + +def decode_dict_keys(obj): + """Decode the keys of the given dictionary with utf-8.""" + newobj = copy.copy(obj) + for k in obj.keys(): + if isinstance(k, bytes): + newobj[k.decode("utf-8")] = newobj[k] + newobj.pop(k) + return newobj + + +def unstring(obj): + """ + Attempt to parse string to native integer formats. + One can't simply call int/float in a try/catch because there is a + semantic difference between (for example) 15.0 and 15. + """ + floatreg = "^\\d+.\\d+$" + match = re.findall(floatreg, obj) + if match != []: + return float(match[0]) + + intreg = "^\\d+$" + match = re.findall(intreg, obj) + if match != []: + return int(match[0]) + return obj + + +def decode_list(b): + """ + Given a non-deserializable object, make a best effort to + return a useful set of results. + """ + if isinstance(b, list): + return [nativestr(obj) for obj in b] + elif isinstance(b, bytes): + return unstring(nativestr(b)) + elif isinstance(b, str): + return unstring(b) + return b diff --git a/.venv/Lib/site-packages/redis/commands/json/path.py b/.venv/Lib/site-packages/redis/commands/json/path.py new file mode 100644 index 00000000..bfb0ab2d --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/json/path.py @@ -0,0 +1,16 @@ +class Path: + """This class represents a path in a JSON value.""" + + strPath = "" + + @staticmethod + def root_path(): + """Return the root path's string representation.""" + return "." + + def __init__(self, path): + """Make a new path based on the string representation in `path`.""" + self.strPath = path + + def __repr__(self): + return self.strPath diff --git a/.venv/Lib/site-packages/redis/commands/redismodules.py b/.venv/Lib/site-packages/redis/commands/redismodules.py new file mode 100644 index 00000000..7e2045a7 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/redismodules.py @@ -0,0 +1,103 @@ +from json import JSONDecoder, JSONEncoder + + +class RedisModuleCommands: + """This class contains the wrapper functions to bring supported redis + modules into the command namespace. + """ + + def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): + """Access the json namespace, providing support for redis json.""" + + from .json import JSON + + jj = JSON(client=self, encoder=encoder, decoder=decoder) + return jj + + def ft(self, index_name="idx"): + """Access the search namespace, providing support for redis search.""" + + from .search import Search + + s = Search(client=self, index_name=index_name) + return s + + def ts(self): + """Access the timeseries namespace, providing support for + redis timeseries data. + """ + + from .timeseries import TimeSeries + + s = TimeSeries(client=self) + return s + + def bf(self): + """Access the bloom namespace.""" + + from .bf import BFBloom + + bf = BFBloom(client=self) + return bf + + def cf(self): + """Access the bloom namespace.""" + + from .bf import CFBloom + + cf = CFBloom(client=self) + return cf + + def cms(self): + """Access the bloom namespace.""" + + from .bf import CMSBloom + + cms = CMSBloom(client=self) + return cms + + def topk(self): + """Access the bloom namespace.""" + + from .bf import TOPKBloom + + topk = TOPKBloom(client=self) + return topk + + def tdigest(self): + """Access the bloom namespace.""" + + from .bf import TDigestBloom + + tdigest = TDigestBloom(client=self) + return tdigest + + def graph(self, index_name="idx"): + """Access the graph namespace, providing support for + redis graph data. + """ + + from .graph import Graph + + g = Graph(client=self, name=index_name) + return g + + +class AsyncRedisModuleCommands(RedisModuleCommands): + def ft(self, index_name="idx"): + """Access the search namespace, providing support for redis search.""" + + from .search import AsyncSearch + + s = AsyncSearch(client=self, index_name=index_name) + return s + + def graph(self, index_name="idx"): + """Access the graph namespace, providing support for + redis graph data. + """ + + from .graph import AsyncGraph + + g = AsyncGraph(client=self, name=index_name) + return g diff --git a/.venv/Lib/site-packages/redis/commands/search/__init__.py b/.venv/Lib/site-packages/redis/commands/search/__init__.py new file mode 100644 index 00000000..a2bb23b7 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/__init__.py @@ -0,0 +1,189 @@ +import redis + +from ...asyncio.client import Pipeline as AsyncioPipeline +from .commands import ( + AGGREGATE_CMD, + CONFIG_CMD, + INFO_CMD, + PROFILE_CMD, + SEARCH_CMD, + SPELLCHECK_CMD, + SYNDUMP_CMD, + AsyncSearchCommands, + SearchCommands, +) + + +class Search(SearchCommands): + """ + Create a client for talking to search. + It abstracts the API of the module and lets you just use the engine. + """ + + class BatchIndexer: + """ + A batch indexer allows you to automatically batch + document indexing in pipelines, flushing it every N documents. + """ + + def __init__(self, client, chunk_size=1000): + self.client = client + self.execute_command = client.execute_command + self._pipeline = client.pipeline(transaction=False, shard_hint=None) + self.total = 0 + self.chunk_size = chunk_size + self.current_chunk = 0 + + def __del__(self): + if self.current_chunk: + self.commit() + + def add_document( + self, + doc_id, + nosave=False, + score=1.0, + payload=None, + replace=False, + partial=False, + no_create=False, + **fields, + ): + """ + Add a document to the batch query + """ + self.client._add_document( + doc_id, + conn=self._pipeline, + nosave=nosave, + score=score, + payload=payload, + replace=replace, + partial=partial, + no_create=no_create, + **fields, + ) + self.current_chunk += 1 + self.total += 1 + if self.current_chunk >= self.chunk_size: + self.commit() + + def add_document_hash(self, doc_id, score=1.0, replace=False): + """ + Add a hash to the batch query + """ + self.client._add_document_hash( + doc_id, conn=self._pipeline, score=score, replace=replace + ) + self.current_chunk += 1 + self.total += 1 + if self.current_chunk >= self.chunk_size: + self.commit() + + def commit(self): + """ + Manually commit and flush the batch indexing query + """ + self._pipeline.execute() + self.current_chunk = 0 + + def __init__(self, client, index_name="idx"): + """ + Create a new Client for the given index_name. + The default name is `idx` + + If conn is not None, we employ an already existing redis connection + """ + self._MODULE_CALLBACKS = {} + self.client = client + self.index_name = index_name + self.execute_command = client.execute_command + self._pipeline = client.pipeline + self._RESP2_MODULE_CALLBACKS = { + INFO_CMD: self._parse_info, + SEARCH_CMD: self._parse_search, + AGGREGATE_CMD: self._parse_aggregate, + PROFILE_CMD: self._parse_profile, + SPELLCHECK_CMD: self._parse_spellcheck, + CONFIG_CMD: self._parse_config_get, + SYNDUMP_CMD: self._parse_syndump, + } + + def pipeline(self, transaction=True, shard_hint=None): + """Creates a pipeline for the SEARCH module, that can be used for executing + SEARCH commands, as well as classic core commands. + """ + p = Pipeline( + connection_pool=self.client.connection_pool, + response_callbacks=self._MODULE_CALLBACKS, + transaction=transaction, + shard_hint=shard_hint, + ) + p.index_name = self.index_name + return p + + +class AsyncSearch(Search, AsyncSearchCommands): + class BatchIndexer(Search.BatchIndexer): + """ + A batch indexer allows you to automatically batch + document indexing in pipelines, flushing it every N documents. + """ + + async def add_document( + self, + doc_id, + nosave=False, + score=1.0, + payload=None, + replace=False, + partial=False, + no_create=False, + **fields, + ): + """ + Add a document to the batch query + """ + self.client._add_document( + doc_id, + conn=self._pipeline, + nosave=nosave, + score=score, + payload=payload, + replace=replace, + partial=partial, + no_create=no_create, + **fields, + ) + self.current_chunk += 1 + self.total += 1 + if self.current_chunk >= self.chunk_size: + await self.commit() + + async def commit(self): + """ + Manually commit and flush the batch indexing query + """ + await self._pipeline.execute() + self.current_chunk = 0 + + def pipeline(self, transaction=True, shard_hint=None): + """Creates a pipeline for the SEARCH module, that can be used for executing + SEARCH commands, as well as classic core commands. + """ + p = AsyncPipeline( + connection_pool=self.client.connection_pool, + response_callbacks=self._MODULE_CALLBACKS, + transaction=transaction, + shard_hint=shard_hint, + ) + p.index_name = self.index_name + return p + + +class Pipeline(SearchCommands, redis.client.Pipeline): + """Pipeline for the module.""" + + +class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline): + """AsyncPipeline for the module.""" diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..bbc0d60d Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/_util.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/_util.cpython-311.pyc new file mode 100644 index 00000000..2c9fd91b Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/_util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/aggregation.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/aggregation.cpython-311.pyc new file mode 100644 index 00000000..8e9bca57 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/aggregation.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..24ab72f2 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/document.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/document.cpython-311.pyc new file mode 100644 index 00000000..dc5ba948 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/document.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/field.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/field.cpython-311.pyc new file mode 100644 index 00000000..23766cea Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/field.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/indexDefinition.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/indexDefinition.cpython-311.pyc new file mode 100644 index 00000000..91c91d07 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/indexDefinition.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/query.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/query.cpython-311.pyc new file mode 100644 index 00000000..4e1117d7 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/query.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/querystring.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/querystring.cpython-311.pyc new file mode 100644 index 00000000..f7b03453 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/querystring.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/reducers.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/reducers.cpython-311.pyc new file mode 100644 index 00000000..9689a621 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/reducers.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/result.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/result.cpython-311.pyc new file mode 100644 index 00000000..14116acf Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/result.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/__pycache__/suggestion.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/search/__pycache__/suggestion.cpython-311.pyc new file mode 100644 index 00000000..ce0deff3 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/search/__pycache__/suggestion.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/search/_util.py b/.venv/Lib/site-packages/redis/commands/search/_util.py new file mode 100644 index 00000000..dd1dff33 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/_util.py @@ -0,0 +1,7 @@ +def to_string(s): + if isinstance(s, str): + return s + elif isinstance(s, bytes): + return s.decode("utf-8", "ignore") + else: + return s # Not a string we care about diff --git a/.venv/Lib/site-packages/redis/commands/search/aggregation.py b/.venv/Lib/site-packages/redis/commands/search/aggregation.py new file mode 100644 index 00000000..50d18f47 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/aggregation.py @@ -0,0 +1,372 @@ +from typing import List, Union + +FIELDNAME = object() + + +class Limit: + def __init__(self, offset: int = 0, count: int = 0) -> None: + self.offset = offset + self.count = count + + def build_args(self): + if self.count: + return ["LIMIT", str(self.offset), str(self.count)] + else: + return [] + + +class Reducer: + """ + Base reducer object for all reducers. + + See the `redisearch.reducers` module for the actual reducers. + """ + + NAME = None + + def __init__(self, *args: List[str]) -> None: + self._args = args + self._field = None + self._alias = None + + def alias(self, alias: str) -> "Reducer": + """ + Set the alias for this reducer. + + ### Parameters + + - **alias**: The value of the alias for this reducer. If this is the + special value `aggregation.FIELDNAME` then this reducer will be + aliased using the same name as the field upon which it operates. + Note that using `FIELDNAME` is only possible on reducers which + operate on a single field value. + + This method returns the `Reducer` object making it suitable for + chaining. + """ + if alias is FIELDNAME: + if not self._field: + raise ValueError("Cannot use FIELDNAME alias with no field") + # Chop off initial '@' + alias = self._field[1:] + self._alias = alias + return self + + @property + def args(self) -> List[str]: + return self._args + + +class SortDirection: + """ + This special class is used to indicate sort direction. + """ + + DIRSTRING = None + + def __init__(self, field: str) -> None: + self.field = field + + +class Asc(SortDirection): + """ + Indicate that the given field should be sorted in ascending order + """ + + DIRSTRING = "ASC" + + +class Desc(SortDirection): + """ + Indicate that the given field should be sorted in descending order + """ + + DIRSTRING = "DESC" + + +class AggregateRequest: + """ + Aggregation request which can be passed to `Client.aggregate`. + """ + + def __init__(self, query: str = "*") -> None: + """ + Create an aggregation request. This request may then be passed to + `client.aggregate()`. + + In order for the request to be usable, it must contain at least one + group. + + - **query** Query string for filtering records. + + All member methods (except `build_args()`) + return the object itself, making them useful for chaining. + """ + self._query = query + self._aggregateplan = [] + self._loadfields = [] + self._loadall = False + self._max = 0 + self._with_schema = False + self._verbatim = False + self._cursor = [] + self._dialect = None + + def load(self, *fields: List[str]) -> "AggregateRequest": + """ + Indicate the fields to be returned in the response. These fields are + returned in addition to any others implicitly specified. + + ### Parameters + + - **fields**: If fields not specified, all the fields will be loaded. + Otherwise, fields should be given in the format of `@field`. + """ + if fields: + self._loadfields.extend(fields) + else: + self._loadall = True + return self + + def group_by( + self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + ) -> "AggregateRequest": + """ + Specify by which fields to group the aggregation. + + ### Parameters + + - **fields**: Fields to group by. This can either be a single string, + or a list of strings. both cases, the field should be specified as + `@field`. + - **reducers**: One or more reducers. Reducers may be found in the + `aggregation` module. + """ + fields = [fields] if isinstance(fields, str) else fields + reducers = [reducers] if isinstance(reducers, Reducer) else reducers + + ret = ["GROUPBY", str(len(fields)), *fields] + for reducer in reducers: + ret += ["REDUCE", reducer.NAME, str(len(reducer.args))] + ret.extend(reducer.args) + if reducer._alias is not None: + ret += ["AS", reducer._alias] + + self._aggregateplan.extend(ret) + return self + + def apply(self, **kwexpr) -> "AggregateRequest": + """ + Specify one or more projection expressions to add to each result + + ### Parameters + + - **kwexpr**: One or more key-value pairs for a projection. The key is + the alias for the projection, and the value is the projection + expression itself, for example `apply(square_root="sqrt(@foo)")` + """ + for alias, expr in kwexpr.items(): + ret = ["APPLY", expr] + if alias is not None: + ret += ["AS", alias] + self._aggregateplan.extend(ret) + + return self + + def limit(self, offset: int, num: int) -> "AggregateRequest": + """ + Sets the limit for the most recent group or query. + + If no group has been defined yet (via `group_by()`) then this sets + the limit for the initial pool of results from the query. Otherwise, + this limits the number of items operated on from the previous group. + + Setting a limit on the initial search results may be useful when + attempting to execute an aggregation on a sample of a large data set. + + ### Parameters + + - **offset**: Result offset from which to begin paging + - **num**: Number of results to return + + + Example of sorting the initial results: + + ``` + AggregateRequest("@sale_amount:[10000, inf]")\ + .limit(0, 10)\ + .group_by("@state", r.count()) + ``` + + Will only group by the states found in the first 10 results of the + query `@sale_amount:[10000, inf]`. On the other hand, + + ``` + AggregateRequest("@sale_amount:[10000, inf]")\ + .limit(0, 1000)\ + .group_by("@state", r.count()\ + .limit(0, 10) + ``` + + Will group all the results matching the query, but only return the + first 10 groups. + + If you only wish to return a *top-N* style query, consider using + `sort_by()` instead. + + """ + _limit = Limit(offset, num) + self._aggregateplan.extend(_limit.build_args()) + return self + + def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": + """ + Indicate how the results should be sorted. This can also be used for + *top-N* style queries + + ### Parameters + + - **fields**: The fields by which to sort. This can be either a single + field or a list of fields. If you wish to specify order, you can + use the `Asc` or `Desc` wrapper classes. + - **max**: Maximum number of results to return. This can be + used instead of `LIMIT` and is also faster. + + + Example of sorting by `foo` ascending and `bar` descending: + + ``` + sort_by(Asc("@foo"), Desc("@bar")) + ``` + + Return the top 10 customers: + + ``` + AggregateRequest()\ + .group_by("@customer", r.sum("@paid").alias(FIELDNAME))\ + .sort_by(Desc("@paid"), max=10) + ``` + """ + if isinstance(fields, (str, SortDirection)): + fields = [fields] + + fields_args = [] + for f in fields: + if isinstance(f, SortDirection): + fields_args += [f.field, f.DIRSTRING] + else: + fields_args += [f] + + ret = ["SORTBY", str(len(fields_args))] + ret.extend(fields_args) + max = kwargs.get("max", 0) + if max > 0: + ret += ["MAX", str(max)] + + self._aggregateplan.extend(ret) + return self + + def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest": + """ + Specify filter for post-query results using predicates relating to + values in the result set. + + ### Parameters + + - **fields**: Fields to group by. This can either be a single string, + or a list of strings. + """ + if isinstance(expressions, str): + expressions = [expressions] + + for expression in expressions: + self._aggregateplan.extend(["FILTER", expression]) + + return self + + def with_schema(self) -> "AggregateRequest": + """ + If set, the `schema` property will contain a list of `[field, type]` + entries in the result object. + """ + self._with_schema = True + return self + + def verbatim(self) -> "AggregateRequest": + self._verbatim = True + return self + + def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest": + args = ["WITHCURSOR"] + if count: + args += ["COUNT", str(count)] + if max_idle: + args += ["MAXIDLE", str(max_idle * 1000)] + self._cursor = args + return self + + def build_args(self) -> List[str]: + # @foo:bar ... + ret = [self._query] + + if self._with_schema: + ret.append("WITHSCHEMA") + + if self._verbatim: + ret.append("VERBATIM") + + if self._cursor: + ret += self._cursor + + if self._loadall: + ret.append("LOAD") + ret.append("*") + elif self._loadfields: + ret.append("LOAD") + ret.append(str(len(self._loadfields))) + ret.extend(self._loadfields) + + if self._dialect: + ret.extend(["DIALECT", self._dialect]) + + ret.extend(self._aggregateplan) + + return ret + + def dialect(self, dialect: int) -> "AggregateRequest": + """ + Add a dialect field to the aggregate command. + + - **dialect** - dialect version to execute the query under + """ + self._dialect = dialect + return self + + +class Cursor: + def __init__(self, cid: int) -> None: + self.cid = cid + self.max_idle = 0 + self.count = 0 + + def build_args(self): + args = [str(self.cid)] + if self.max_idle: + args += ["MAXIDLE", str(self.max_idle)] + if self.count: + args += ["COUNT", str(self.count)] + return args + + +class AggregateResult: + def __init__(self, rows, cursor: Cursor, schema) -> None: + self.rows = rows + self.cursor = cursor + self.schema = schema + + def __repr__(self) -> (str, str): + cid = self.cursor.cid if self.cursor else -1 + return ( + f"<{self.__class__.__name__} at 0x{id(self):x} " + f"Rows={len(self.rows)}, Cursor={cid}>" + ) diff --git a/.venv/Lib/site-packages/redis/commands/search/commands.py b/.venv/Lib/site-packages/redis/commands/search/commands.py new file mode 100644 index 00000000..2df2b5a7 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/commands.py @@ -0,0 +1,1117 @@ +import itertools +import time +from typing import Dict, List, Optional, Union + +from redis.client import Pipeline +from redis.utils import deprecated_function + +from ..helpers import get_protocol_version, parse_to_dict +from ._util import to_string +from .aggregation import AggregateRequest, AggregateResult, Cursor +from .document import Document +from .query import Query +from .result import Result +from .suggestion import SuggestionParser + +NUMERIC = "NUMERIC" + +CREATE_CMD = "FT.CREATE" +ALTER_CMD = "FT.ALTER" +SEARCH_CMD = "FT.SEARCH" +ADD_CMD = "FT.ADD" +ADDHASH_CMD = "FT.ADDHASH" +DROP_CMD = "FT.DROP" +DROPINDEX_CMD = "FT.DROPINDEX" +EXPLAIN_CMD = "FT.EXPLAIN" +EXPLAINCLI_CMD = "FT.EXPLAINCLI" +DEL_CMD = "FT.DEL" +AGGREGATE_CMD = "FT.AGGREGATE" +PROFILE_CMD = "FT.PROFILE" +CURSOR_CMD = "FT.CURSOR" +SPELLCHECK_CMD = "FT.SPELLCHECK" +DICT_ADD_CMD = "FT.DICTADD" +DICT_DEL_CMD = "FT.DICTDEL" +DICT_DUMP_CMD = "FT.DICTDUMP" +GET_CMD = "FT.GET" +MGET_CMD = "FT.MGET" +CONFIG_CMD = "FT.CONFIG" +TAGVALS_CMD = "FT.TAGVALS" +ALIAS_ADD_CMD = "FT.ALIASADD" +ALIAS_UPDATE_CMD = "FT.ALIASUPDATE" +ALIAS_DEL_CMD = "FT.ALIASDEL" +INFO_CMD = "FT.INFO" +SUGADD_COMMAND = "FT.SUGADD" +SUGDEL_COMMAND = "FT.SUGDEL" +SUGLEN_COMMAND = "FT.SUGLEN" +SUGGET_COMMAND = "FT.SUGGET" +SYNUPDATE_CMD = "FT.SYNUPDATE" +SYNDUMP_CMD = "FT.SYNDUMP" + +NOOFFSETS = "NOOFFSETS" +NOFIELDS = "NOFIELDS" +NOHL = "NOHL" +NOFREQS = "NOFREQS" +MAXTEXTFIELDS = "MAXTEXTFIELDS" +TEMPORARY = "TEMPORARY" +STOPWORDS = "STOPWORDS" +SKIPINITIALSCAN = "SKIPINITIALSCAN" +WITHSCORES = "WITHSCORES" +FUZZY = "FUZZY" +WITHPAYLOADS = "WITHPAYLOADS" + + +class SearchCommands: + """Search commands.""" + + def _parse_results(self, cmd, res, **kwargs): + if get_protocol_version(self.client) in ["3", 3]: + return res + else: + return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) + + def _parse_info(self, res, **kwargs): + it = map(to_string, res) + return dict(zip(it, it)) + + def _parse_search(self, res, **kwargs): + return Result( + res, + not kwargs["query"]._no_content, + duration=kwargs["duration"], + has_payload=kwargs["query"]._with_payloads, + with_scores=kwargs["query"]._with_scores, + ) + + def _parse_aggregate(self, res, **kwargs): + return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) + + def _parse_profile(self, res, **kwargs): + query = kwargs["query"] + if isinstance(query, AggregateRequest): + result = self._get_aggregate_result(res[0], query, query._cursor) + else: + result = Result( + res[0], + not query._no_content, + duration=kwargs["duration"], + has_payload=query._with_payloads, + with_scores=query._with_scores, + ) + + return result, parse_to_dict(res[1]) + + def _parse_spellcheck(self, res, **kwargs): + corrections = {} + if res == 0: + return corrections + + for _correction in res: + if isinstance(_correction, int) and _correction == 0: + continue + + if len(_correction) != 3: + continue + if not _correction[2]: + continue + if not _correction[2][0]: + continue + + # For spellcheck output + # 1) 1) "TERM" + # 2) "{term1}" + # 3) 1) 1) "{score1}" + # 2) "{suggestion1}" + # 2) 1) "{score2}" + # 2) "{suggestion2}" + # + # Following dictionary will be made + # corrections = { + # '{term1}': [ + # {'score': '{score1}', 'suggestion': '{suggestion1}'}, + # {'score': '{score2}', 'suggestion': '{suggestion2}'} + # ] + # } + corrections[_correction[1]] = [ + {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] + ] + + return corrections + + def _parse_config_get(self, res, **kwargs): + return {kvs[0]: kvs[1] for kvs in res} if res else {} + + def _parse_syndump(self, res, **kwargs): + return {res[i]: res[i + 1] for i in range(0, len(res), 2)} + + def batch_indexer(self, chunk_size=100): + """ + Create a new batch indexer from the client with a given chunk size + """ + return self.BatchIndexer(self, chunk_size=chunk_size) + + def create_index( + self, + fields, + no_term_offsets=False, + no_field_flags=False, + stopwords=None, + definition=None, + max_text_fields=False, + temporary=None, + no_highlight=False, + no_term_frequencies=False, + skip_initial_scan=False, + ): + """ + Create the search index. The index must not already exist. + + ### Parameters: + + - **fields**: a list of TextField or NumericField objects + - **no_term_offsets**: If true, we will not save term offsets in + the index + - **no_field_flags**: If true, we will not save field flags that + allow searching in specific fields + - **stopwords**: If not None, we create the index with this custom + stopword list. The list can be empty + - **max_text_fields**: If true, we will encode indexes as if there + were more than 32 text fields which allows you to add additional + fields (beyond 32). + - **temporary**: Create a lightweight temporary index which will + expire after the specified period of inactivity (in seconds). The + internal idle timer is reset whenever the index is searched or added to. + - **no_highlight**: If true, disabling highlighting support. + Also implied by no_term_offsets. + - **no_term_frequencies**: If true, we avoid saving the term frequencies + in the index. + - **skip_initial_scan**: If true, we do not scan and index. + + For more information see `FT.CREATE `_. + """ # noqa + + args = [CREATE_CMD, self.index_name] + if definition is not None: + args += definition.args + if max_text_fields: + args.append(MAXTEXTFIELDS) + if temporary is not None and isinstance(temporary, int): + args.append(TEMPORARY) + args.append(temporary) + if no_term_offsets: + args.append(NOOFFSETS) + if no_highlight: + args.append(NOHL) + if no_field_flags: + args.append(NOFIELDS) + if no_term_frequencies: + args.append(NOFREQS) + if skip_initial_scan: + args.append(SKIPINITIALSCAN) + if stopwords is not None and isinstance(stopwords, (list, tuple, set)): + args += [STOPWORDS, len(stopwords)] + if len(stopwords) > 0: + args += list(stopwords) + + args.append("SCHEMA") + try: + args += list(itertools.chain(*(f.redis_args() for f in fields))) + except TypeError: + args += fields.redis_args() + + return self.execute_command(*args) + + def alter_schema_add(self, fields: List[str]): + """ + Alter the existing search index by adding new fields. The index + must already exist. + + ### Parameters: + + - **fields**: a list of Field objects to add for the index + + For more information see `FT.ALTER `_. + """ # noqa + + args = [ALTER_CMD, self.index_name, "SCHEMA", "ADD"] + try: + args += list(itertools.chain(*(f.redis_args() for f in fields))) + except TypeError: + args += fields.redis_args() + + return self.execute_command(*args) + + def dropindex(self, delete_documents: bool = False): + """ + Drop the index if it exists. + Replaced `drop_index` in RediSearch 2.0. + Default behavior was changed to not delete the indexed documents. + + ### Parameters: + + - **delete_documents**: If `True`, all documents will be deleted. + + For more information see `FT.DROPINDEX `_. + """ # noqa + delete_str = "DD" if delete_documents else "" + return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str) + + def _add_document( + self, + doc_id, + conn=None, + nosave=False, + score=1.0, + payload=None, + replace=False, + partial=False, + language=None, + no_create=False, + **fields, + ): + """ + Internal add_document used for both batch and single doc indexing + """ + + if partial or no_create: + replace = True + + args = [ADD_CMD, self.index_name, doc_id, score] + if nosave: + args.append("NOSAVE") + if payload is not None: + args.append("PAYLOAD") + args.append(payload) + if replace: + args.append("REPLACE") + if partial: + args.append("PARTIAL") + if no_create: + args.append("NOCREATE") + if language: + args += ["LANGUAGE", language] + args.append("FIELDS") + args += list(itertools.chain(*fields.items())) + + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) + + def _add_document_hash( + self, doc_id, conn=None, score=1.0, language=None, replace=False + ): + """ + Internal add_document_hash used for both batch and single doc indexing + """ + + args = [ADDHASH_CMD, self.index_name, doc_id, score] + + if replace: + args.append("REPLACE") + + if language: + args += ["LANGUAGE", language] + + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) + + @deprecated_function( + version="2.0.0", reason="deprecated since redisearch 2.0, call hset instead" + ) + def add_document( + self, + doc_id: str, + nosave: bool = False, + score: float = 1.0, + payload: bool = None, + replace: bool = False, + partial: bool = False, + language: Optional[str] = None, + no_create: str = False, + **fields: List[str], + ): + """ + Add a single document to the index. + + ### Parameters + + - **doc_id**: the id of the saved document. + - **nosave**: if set to true, we just index the document, and don't + save a copy of it. This means that searches will just + return ids. + - **score**: the document ranking, between 0.0 and 1.0 + - **payload**: optional inner-index payload we can save for fast + i access in scoring functions + - **replace**: if True, and the document already is in the index, + we perform an update and reindex the document + - **partial**: if True, the fields specified will be added to the + existing document. + This has the added benefit that any fields specified + with `no_index` + will not be reindexed again. Implies `replace` + - **language**: Specify the language used for document tokenization. + - **no_create**: if True, the document is only updated and reindexed + if it already exists. + If the document does not exist, an error will be + returned. Implies `replace` + - **fields** kwargs dictionary of the document fields to be saved + and/or indexed. + NOTE: Geo points shoule be encoded as strings of "lon,lat" + """ # noqa + return self._add_document( + doc_id, + conn=None, + nosave=nosave, + score=score, + payload=payload, + replace=replace, + partial=partial, + language=language, + no_create=no_create, + **fields, + ) + + @deprecated_function( + version="2.0.0", reason="deprecated since redisearch 2.0, call hset instead" + ) + def add_document_hash(self, doc_id, score=1.0, language=None, replace=False): + """ + Add a hash document to the index. + + ### Parameters + + - **doc_id**: the document's id. This has to be an existing HASH key + in Redis that will hold the fields the index needs. + - **score**: the document ranking, between 0.0 and 1.0 + - **replace**: if True, and the document already is in the index, we + perform an update and reindex the document + - **language**: Specify the language used for document tokenization. + """ # noqa + return self._add_document_hash( + doc_id, conn=None, score=score, language=language, replace=replace + ) + + def delete_document(self, doc_id, conn=None, delete_actual_document=False): + """ + Delete a document from index + Returns 1 if the document was deleted, 0 if not + + ### Parameters + + - **delete_actual_document**: if set to True, RediSearch also delete + the actual document if it is in the index + """ # noqa + args = [DEL_CMD, self.index_name, doc_id] + if delete_actual_document: + args.append("DD") + + if conn is not None: + return conn.execute_command(*args) + + return self.execute_command(*args) + + def load_document(self, id): + """ + Load a single document by id + """ + fields = self.client.hgetall(id) + f2 = {to_string(k): to_string(v) for k, v in fields.items()} + fields = f2 + + try: + del fields["id"] + except KeyError: + pass + + return Document(id=id, **fields) + + def get(self, *ids): + """ + Returns the full contents of multiple documents. + + ### Parameters + + - **ids**: the ids of the saved documents. + + """ + + return self.execute_command(MGET_CMD, self.index_name, *ids) + + def info(self): + """ + Get info an stats about the the current index, including the number of + documents, memory consumption, etc + + For more information see `FT.INFO `_. + """ + + res = self.execute_command(INFO_CMD, self.index_name) + return self._parse_results(INFO_CMD, res) + + def get_params_args( + self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + ): + if query_params is None: + return [] + args = [] + if len(query_params) > 0: + args.append("params") + args.append(len(query_params) * 2) + for key, value in query_params.items(): + args.append(key) + args.append(value) + return args + + def _mk_query_args( + self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + ): + args = [self.index_name] + + if isinstance(query, str): + # convert the query from a text to a query object + query = Query(query) + if not isinstance(query, Query): + raise ValueError(f"Bad query type {type(query)}") + + args += query.get_args() + args += self.get_params_args(query_params) + + return args, query + + def search( + self, + query: Union[str, Query], + query_params: Union[Dict[str, Union[str, int, float, bytes]], None] = None, + ): + """ + Search the index for a given query, and return a result of documents + + ### Parameters + + - **query**: the search query. Either a text for simple queries with + default parameters, or a Query object for complex queries. + See RediSearch's documentation on query format + + For more information see `FT.SEARCH `_. + """ # noqa + args, query = self._mk_query_args(query, query_params=query_params) + st = time.time() + res = self.execute_command(SEARCH_CMD, *args) + + if isinstance(res, Pipeline): + return res + + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) + + def explain( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): + """Returns the execution plan for a complex query. + + For more information see `FT.EXPLAIN `_. + """ # noqa + args, query_text = self._mk_query_args(query, query_params=query_params) + return self.execute_command(EXPLAIN_CMD, *args) + + def explain_cli(self, query: Union[str, Query]): # noqa + raise NotImplementedError("EXPLAINCLI will not be implemented.") + + def aggregate( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): + """ + Issue an aggregation query. + + ### Parameters + + **query**: This can be either an `AggregateRequest`, or a `Cursor` + + An `AggregateResult` object is returned. You can access the rows from + its `rows` property, which will always yield the rows of the result. + + For more information see `FT.AGGREGATE `_. + """ # noqa + if isinstance(query, AggregateRequest): + has_cursor = bool(query._cursor) + cmd = [AGGREGATE_CMD, self.index_name] + query.build_args() + elif isinstance(query, Cursor): + has_cursor = True + cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() + else: + raise ValueError("Bad query", query) + cmd += self.get_params_args(query_params) + + raw = self.execute_command(*cmd) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) + + def _get_aggregate_result( + self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool + ): + if has_cursor: + if isinstance(query, Cursor): + query.cid = raw[1] + cursor = query + else: + cursor = Cursor(raw[1]) + raw = raw[0] + else: + cursor = None + + if isinstance(query, AggregateRequest) and query._with_schema: + schema = raw[0] + rows = raw[2:] + else: + schema = None + rows = raw[1:] + + return AggregateResult(rows, cursor, schema) + + def profile( + self, + query: Union[str, Query, AggregateRequest], + limited: bool = False, + query_params: Optional[Dict[str, Union[str, int, float]]] = None, + ): + """ + Performs a search or aggregate command and collects performance + information. + + ### Parameters + + **query**: This can be either an `AggregateRequest`, `Query` or string. + **limited**: If set to True, removes details of reader iterator. + **query_params**: Define one or more value parameters. + Each parameter has a name and a value. + + """ + st = time.time() + cmd = [PROFILE_CMD, self.index_name, ""] + if limited: + cmd.append("LIMITED") + cmd.append("QUERY") + + if isinstance(query, AggregateRequest): + cmd[2] = "AGGREGATE" + cmd += query.build_args() + elif isinstance(query, Query): + cmd[2] = "SEARCH" + cmd += query.get_args() + cmd += self.get_params_args(query_params) + else: + raise ValueError("Must provide AggregateRequest object or Query object.") + + res = self.execute_command(*cmd) + + return self._parse_results( + PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) + + def spellcheck(self, query, distance=None, include=None, exclude=None): + """ + Issue a spellcheck query + + ### Parameters + + **query**: search query. + **distance***: the maximal Levenshtein distance for spelling + suggestions (default: 1, max: 4). + **include**: specifies an inclusion custom dictionary. + **exclude**: specifies an exclusion custom dictionary. + + For more information see `FT.SPELLCHECK `_. + """ # noqa + cmd = [SPELLCHECK_CMD, self.index_name, query] + if distance: + cmd.extend(["DISTANCE", distance]) + + if include: + cmd.extend(["TERMS", "INCLUDE", include]) + + if exclude: + cmd.extend(["TERMS", "EXCLUDE", exclude]) + + res = self.execute_command(*cmd) + + return self._parse_results(SPELLCHECK_CMD, res) + + def dict_add(self, name: str, *terms: List[str]): + """Adds terms to a dictionary. + + ### Parameters + + - **name**: Dictionary name. + - **terms**: List of items for adding to the dictionary. + + For more information see `FT.DICTADD `_. + """ # noqa + cmd = [DICT_ADD_CMD, name] + cmd.extend(terms) + return self.execute_command(*cmd) + + def dict_del(self, name: str, *terms: List[str]): + """Deletes terms from a dictionary. + + ### Parameters + + - **name**: Dictionary name. + - **terms**: List of items for removing from the dictionary. + + For more information see `FT.DICTDEL `_. + """ # noqa + cmd = [DICT_DEL_CMD, name] + cmd.extend(terms) + return self.execute_command(*cmd) + + def dict_dump(self, name: str): + """Dumps all terms in the given dictionary. + + ### Parameters + + - **name**: Dictionary name. + + For more information see `FT.DICTDUMP `_. + """ # noqa + cmd = [DICT_DUMP_CMD, name] + return self.execute_command(*cmd) + + def config_set(self, option: str, value: str) -> bool: + """Set runtime configuration option. + + ### Parameters + + - **option**: the name of the configuration option. + - **value**: a value for the configuration option. + + For more information see `FT.CONFIG SET `_. + """ # noqa + cmd = [CONFIG_CMD, "SET", option, value] + raw = self.execute_command(*cmd) + return raw == "OK" + + def config_get(self, option: str) -> str: + """Get runtime configuration option value. + + ### Parameters + + - **option**: the name of the configuration option. + + For more information see `FT.CONFIG GET `_. + """ # noqa + cmd = [CONFIG_CMD, "GET", option] + res = self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) + + def tagvals(self, tagfield: str): + """ + Return a list of all possible tag values + + ### Parameters + + - **tagfield**: Tag field name + + For more information see `FT.TAGVALS `_. + """ # noqa + + return self.execute_command(TAGVALS_CMD, self.index_name, tagfield) + + def aliasadd(self, alias: str): + """ + Alias a search index - will fail if alias already exists + + ### Parameters + + - **alias**: Name of the alias to create + + For more information see `FT.ALIASADD `_. + """ # noqa + + return self.execute_command(ALIAS_ADD_CMD, alias, self.index_name) + + def aliasupdate(self, alias: str): + """ + Updates an alias - will fail if alias does not already exist + + ### Parameters + + - **alias**: Name of the alias to create + + For more information see `FT.ALIASUPDATE `_. + """ # noqa + + return self.execute_command(ALIAS_UPDATE_CMD, alias, self.index_name) + + def aliasdel(self, alias: str): + """ + Removes an alias to a search index + + ### Parameters + + - **alias**: Name of the alias to delete + + For more information see `FT.ALIASDEL `_. + """ # noqa + return self.execute_command(ALIAS_DEL_CMD, alias) + + def sugadd(self, key, *suggestions, **kwargs): + """ + Add suggestion terms to the AutoCompleter engine. Each suggestion has + a score and string. + If kwargs["increment"] is true and the terms are already in the + server's dictionary, we increment their scores. + + For more information see `FT.SUGADD `_. + """ # noqa + # If Transaction is not False it will MULTI/EXEC which will error + pipe = self.pipeline(transaction=False) + for sug in suggestions: + args = [SUGADD_COMMAND, key, sug.string, sug.score] + if kwargs.get("increment"): + args.append("INCR") + if sug.payload: + args.append("PAYLOAD") + args.append(sug.payload) + + pipe.execute_command(*args) + + return pipe.execute()[-1] + + def suglen(self, key: str) -> int: + """ + Return the number of entries in the AutoCompleter index. + + For more information see `FT.SUGLEN `_. + """ # noqa + return self.execute_command(SUGLEN_COMMAND, key) + + def sugdel(self, key: str, string: str) -> int: + """ + Delete a string from the AutoCompleter index. + Returns 1 if the string was found and deleted, 0 otherwise. + + For more information see `FT.SUGDEL `_. + """ # noqa + return self.execute_command(SUGDEL_COMMAND, key, string) + + def sugget( + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: + """ + Get a list of suggestions from the AutoCompleter, for a given prefix. + + Parameters: + + prefix : str + The prefix we are searching. **Must be valid ascii or utf-8** + fuzzy : bool + If set to true, the prefix search is done in fuzzy mode. + **NOTE**: Running fuzzy searches on short (<3 letters) prefixes + can be very + slow, and even scan the entire index. + with_scores : bool + If set to true, we also return the (refactored) score of + each suggestion. + This is normally not needed, and is NOT the original score + inserted into the index. + with_payloads : bool + Return suggestion payloads + num : int + The maximum number of results we return. Note that we might + return less. The algorithm trims irrelevant suggestions. + + Returns: + + list: + A list of Suggestion objects. If with_scores was False, the + score of all suggestions is 1. + + For more information see `FT.SUGGET `_. + """ # noqa + args = [SUGGET_COMMAND, key, prefix, "MAX", num] + if fuzzy: + args.append(FUZZY) + if with_scores: + args.append(WITHSCORES) + if with_payloads: + args.append(WITHPAYLOADS) + + res = self.execute_command(*args) + results = [] + if not res: + return results + + parser = SuggestionParser(with_scores, with_payloads, res) + return [s for s in parser] + + def synupdate(self, groupid: str, skipinitial: bool = False, *terms: List[str]): + """ + Updates a synonym group. + The command is used to create or update a synonym group with + additional terms. + Only documents which were indexed after the update will be affected. + + Parameters: + + groupid : + Synonym group id. + skipinitial : bool + If set to true, we do not scan and index. + terms : + The terms. + + For more information see `FT.SYNUPDATE `_. + """ # noqa + cmd = [SYNUPDATE_CMD, self.index_name, groupid] + if skipinitial: + cmd.extend(["SKIPINITIALSCAN"]) + cmd.extend(terms) + return self.execute_command(*cmd) + + def syndump(self): + """ + Dumps the contents of a synonym group. + + The command is used to dump the synonyms data structure. + Returns a list of synonym terms and their synonym group ids. + + For more information see `FT.SYNDUMP `_. + """ # noqa + res = self.execute_command(SYNDUMP_CMD, self.index_name) + return self._parse_results(SYNDUMP_CMD, res) + + +class AsyncSearchCommands(SearchCommands): + async def info(self): + """ + Get info an stats about the the current index, including the number of + documents, memory consumption, etc + + For more information see `FT.INFO `_. + """ + + res = await self.execute_command(INFO_CMD, self.index_name) + return self._parse_results(INFO_CMD, res) + + async def search( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): + """ + Search the index for a given query, and return a result of documents + + ### Parameters + + - **query**: the search query. Either a text for simple queries with + default parameters, or a Query object for complex queries. + See RediSearch's documentation on query format + + For more information see `FT.SEARCH `_. + """ # noqa + args, query = self._mk_query_args(query, query_params=query_params) + st = time.time() + res = await self.execute_command(SEARCH_CMD, *args) + + if isinstance(res, Pipeline): + return res + + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) + + async def aggregate( + self, + query: Union[str, Query], + query_params: Dict[str, Union[str, int, float]] = None, + ): + """ + Issue an aggregation query. + + ### Parameters + + **query**: This can be either an `AggregateRequest`, or a `Cursor` + + An `AggregateResult` object is returned. You can access the rows from + its `rows` property, which will always yield the rows of the result. + + For more information see `FT.AGGREGATE `_. + """ # noqa + if isinstance(query, AggregateRequest): + has_cursor = bool(query._cursor) + cmd = [AGGREGATE_CMD, self.index_name] + query.build_args() + elif isinstance(query, Cursor): + has_cursor = True + cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args() + else: + raise ValueError("Bad query", query) + cmd += self.get_params_args(query_params) + + raw = await self.execute_command(*cmd) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) + + async def spellcheck(self, query, distance=None, include=None, exclude=None): + """ + Issue a spellcheck query + + ### Parameters + + **query**: search query. + **distance***: the maximal Levenshtein distance for spelling + suggestions (default: 1, max: 4). + **include**: specifies an inclusion custom dictionary. + **exclude**: specifies an exclusion custom dictionary. + + For more information see `FT.SPELLCHECK `_. + """ # noqa + cmd = [SPELLCHECK_CMD, self.index_name, query] + if distance: + cmd.extend(["DISTANCE", distance]) + + if include: + cmd.extend(["TERMS", "INCLUDE", include]) + + if exclude: + cmd.extend(["TERMS", "EXCLUDE", exclude]) + + res = await self.execute_command(*cmd) + + return self._parse_results(SPELLCHECK_CMD, res) + + async def config_set(self, option: str, value: str) -> bool: + """Set runtime configuration option. + + ### Parameters + + - **option**: the name of the configuration option. + - **value**: a value for the configuration option. + + For more information see `FT.CONFIG SET `_. + """ # noqa + cmd = [CONFIG_CMD, "SET", option, value] + raw = await self.execute_command(*cmd) + return raw == "OK" + + async def config_get(self, option: str) -> str: + """Get runtime configuration option value. + + ### Parameters + + - **option**: the name of the configuration option. + + For more information see `FT.CONFIG GET `_. + """ # noqa + cmd = [CONFIG_CMD, "GET", option] + res = {} + res = await self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) + + async def load_document(self, id): + """ + Load a single document by id + """ + fields = await self.client.hgetall(id) + f2 = {to_string(k): to_string(v) for k, v in fields.items()} + fields = f2 + + try: + del fields["id"] + except KeyError: + pass + + return Document(id=id, **fields) + + async def sugadd(self, key, *suggestions, **kwargs): + """ + Add suggestion terms to the AutoCompleter engine. Each suggestion has + a score and string. + If kwargs["increment"] is true and the terms are already in the + server's dictionary, we increment their scores. + + For more information see `FT.SUGADD `_. + """ # noqa + # If Transaction is not False it will MULTI/EXEC which will error + pipe = self.pipeline(transaction=False) + for sug in suggestions: + args = [SUGADD_COMMAND, key, sug.string, sug.score] + if kwargs.get("increment"): + args.append("INCR") + if sug.payload: + args.append("PAYLOAD") + args.append(sug.payload) + + pipe.execute_command(*args) + + return (await pipe.execute())[-1] + + async def sugget( + self, + key: str, + prefix: str, + fuzzy: bool = False, + num: int = 10, + with_scores: bool = False, + with_payloads: bool = False, + ) -> List[SuggestionParser]: + """ + Get a list of suggestions from the AutoCompleter, for a given prefix. + + Parameters: + + prefix : str + The prefix we are searching. **Must be valid ascii or utf-8** + fuzzy : bool + If set to true, the prefix search is done in fuzzy mode. + **NOTE**: Running fuzzy searches on short (<3 letters) prefixes + can be very + slow, and even scan the entire index. + with_scores : bool + If set to true, we also return the (refactored) score of + each suggestion. + This is normally not needed, and is NOT the original score + inserted into the index. + with_payloads : bool + Return suggestion payloads + num : int + The maximum number of results we return. Note that we might + return less. The algorithm trims irrelevant suggestions. + + Returns: + + list: + A list of Suggestion objects. If with_scores was False, the + score of all suggestions is 1. + + For more information see `FT.SUGGET `_. + """ # noqa + args = [SUGGET_COMMAND, key, prefix, "MAX", num] + if fuzzy: + args.append(FUZZY) + if with_scores: + args.append(WITHSCORES) + if with_payloads: + args.append(WITHPAYLOADS) + + ret = await self.execute_command(*args) + results = [] + if not ret: + return results + + parser = SuggestionParser(with_scores, with_payloads, ret) + return [s for s in parser] diff --git a/.venv/Lib/site-packages/redis/commands/search/document.py b/.venv/Lib/site-packages/redis/commands/search/document.py new file mode 100644 index 00000000..47534ec2 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/document.py @@ -0,0 +1,17 @@ +class Document: + """ + Represents a single document in a result set + """ + + def __init__(self, id, payload=None, **fields): + self.id = id + self.payload = payload + for k, v in fields.items(): + setattr(self, k, v) + + def __repr__(self): + return f"Document {self.__dict__}" + + def __getitem__(self, item): + value = getattr(self, item) + return value diff --git a/.venv/Lib/site-packages/redis/commands/search/field.py b/.venv/Lib/site-packages/redis/commands/search/field.py new file mode 100644 index 00000000..76eb58c2 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/field.py @@ -0,0 +1,168 @@ +from typing import List + +from redis import DataError + + +class Field: + NUMERIC = "NUMERIC" + TEXT = "TEXT" + WEIGHT = "WEIGHT" + GEO = "GEO" + TAG = "TAG" + VECTOR = "VECTOR" + SORTABLE = "SORTABLE" + NOINDEX = "NOINDEX" + AS = "AS" + + def __init__( + self, + name: str, + args: List[str] = None, + sortable: bool = False, + no_index: bool = False, + as_name: str = None, + ): + if args is None: + args = [] + self.name = name + self.args = args + self.args_suffix = list() + self.as_name = as_name + + if sortable: + self.args_suffix.append(Field.SORTABLE) + if no_index: + self.args_suffix.append(Field.NOINDEX) + + if no_index and not sortable: + raise ValueError("Non-Sortable non-Indexable fields are ignored") + + def append_arg(self, value): + self.args.append(value) + + def redis_args(self): + args = [self.name] + if self.as_name: + args += [self.AS, self.as_name] + args += self.args + args += self.args_suffix + return args + + +class TextField(Field): + """ + TextField is used to define a text field in a schema definition + """ + + NOSTEM = "NOSTEM" + PHONETIC = "PHONETIC" + + def __init__( + self, + name: str, + weight: float = 1.0, + no_stem: bool = False, + phonetic_matcher: str = None, + withsuffixtrie: bool = False, + **kwargs, + ): + Field.__init__(self, name, args=[Field.TEXT, Field.WEIGHT, weight], **kwargs) + + if no_stem: + Field.append_arg(self, self.NOSTEM) + if phonetic_matcher and phonetic_matcher in [ + "dm:en", + "dm:fr", + "dm:pt", + "dm:es", + ]: + Field.append_arg(self, self.PHONETIC) + Field.append_arg(self, phonetic_matcher) + if withsuffixtrie: + Field.append_arg(self, "WITHSUFFIXTRIE") + + +class NumericField(Field): + """ + NumericField is used to define a numeric field in a schema definition + """ + + def __init__(self, name: str, **kwargs): + Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) + + +class GeoField(Field): + """ + GeoField is used to define a geo-indexing field in a schema definition + """ + + def __init__(self, name: str, **kwargs): + Field.__init__(self, name, args=[Field.GEO], **kwargs) + + +class TagField(Field): + """ + TagField is a tag-indexing field with simpler compression and tokenization. + See http://redisearch.io/Tags/ + """ + + SEPARATOR = "SEPARATOR" + CASESENSITIVE = "CASESENSITIVE" + + def __init__( + self, + name: str, + separator: str = ",", + case_sensitive: bool = False, + withsuffixtrie: bool = False, + **kwargs, + ): + args = [Field.TAG, self.SEPARATOR, separator] + if case_sensitive: + args.append(self.CASESENSITIVE) + if withsuffixtrie: + args.append("WITHSUFFIXTRIE") + + Field.__init__(self, name, args=args, **kwargs) + + +class VectorField(Field): + """ + Allows vector similarity queries against the value in this attribute. + See https://oss.redis.com/redisearch/Vectors/#vector_fields. + """ + + def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs): + """ + Create Vector Field. Notice that Vector cannot have sortable or no_index tag, + although it's also a Field. + + ``name`` is the name of the field. + + ``algorithm`` can be "FLAT" or "HNSW". + + ``attributes`` each algorithm can have specific attributes. Some of them + are mandatory and some of them are optional. See + https://oss.redis.com/redisearch/master/Vectors/#specific_creation_attributes_per_algorithm + for more information. + """ + sort = kwargs.get("sortable", False) + noindex = kwargs.get("no_index", False) + + if sort or noindex: + raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") + + if algorithm.upper() not in ["FLAT", "HNSW"]: + raise DataError( + "Realtime vector indexing supporting 2 Indexing Methods:" + "'FLAT' and 'HNSW'." + ) + + attr_li = [] + + for key, value in attributes.items(): + attr_li.extend([key, value]) + + Field.__init__( + self, name, args=[Field.VECTOR, algorithm, len(attr_li), *attr_li], **kwargs + ) diff --git a/.venv/Lib/site-packages/redis/commands/search/indexDefinition.py b/.venv/Lib/site-packages/redis/commands/search/indexDefinition.py new file mode 100644 index 00000000..a668e85b --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/indexDefinition.py @@ -0,0 +1,79 @@ +from enum import Enum + + +class IndexType(Enum): + """Enum of the currently supported index types.""" + + HASH = 1 + JSON = 2 + + +class IndexDefinition: + """IndexDefinition is used to define a index definition for automatic + indexing on Hash or Json update.""" + + def __init__( + self, + prefix=[], + filter=None, + language_field=None, + language=None, + score_field=None, + score=1.0, + payload_field=None, + index_type=None, + ): + self.args = [] + self._append_index_type(index_type) + self._append_prefix(prefix) + self._append_filter(filter) + self._append_language(language_field, language) + self._append_score(score_field, score) + self._append_payload(payload_field) + + def _append_index_type(self, index_type): + """Append `ON HASH` or `ON JSON` according to the enum.""" + if index_type is IndexType.HASH: + self.args.extend(["ON", "HASH"]) + elif index_type is IndexType.JSON: + self.args.extend(["ON", "JSON"]) + elif index_type is not None: + raise RuntimeError(f"index_type must be one of {list(IndexType)}") + + def _append_prefix(self, prefix): + """Append PREFIX.""" + if len(prefix) > 0: + self.args.append("PREFIX") + self.args.append(len(prefix)) + for p in prefix: + self.args.append(p) + + def _append_filter(self, filter): + """Append FILTER.""" + if filter is not None: + self.args.append("FILTER") + self.args.append(filter) + + def _append_language(self, language_field, language): + """Append LANGUAGE_FIELD and LANGUAGE.""" + if language_field is not None: + self.args.append("LANGUAGE_FIELD") + self.args.append(language_field) + if language is not None: + self.args.append("LANGUAGE") + self.args.append(language) + + def _append_score(self, score_field, score): + """Append SCORE_FIELD and SCORE.""" + if score_field is not None: + self.args.append("SCORE_FIELD") + self.args.append(score_field) + if score is not None: + self.args.append("SCORE") + self.args.append(score) + + def _append_payload(self, payload_field): + """Append PAYLOAD_FIELD.""" + if payload_field is not None: + self.args.append("PAYLOAD_FIELD") + self.args.append(payload_field) diff --git a/.venv/Lib/site-packages/redis/commands/search/query.py b/.venv/Lib/site-packages/redis/commands/search/query.py new file mode 100644 index 00000000..113ddf9d --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/query.py @@ -0,0 +1,362 @@ +from typing import List, Optional, Union + + +class Query: + """ + Query is used to build complex queries that have more parameters than just + the query string. The query string is set in the constructor, and other + options have setter functions. + + The setter functions return the query object, so they can be chained, + i.e. `Query("foo").verbatim().filter(...)` etc. + """ + + def __init__(self, query_string: str) -> None: + """ + Create a new query object. + The query string is set in the constructor, and other options have + setter functions. + """ + + self._query_string: str = query_string + self._offset: int = 0 + self._num: int = 10 + self._no_content: bool = False + self._no_stopwords: bool = False + self._fields: Optional[List[str]] = None + self._verbatim: bool = False + self._with_payloads: bool = False + self._with_scores: bool = False + self._scorer: Optional[str] = None + self._filters: List = list() + self._ids: Optional[List[str]] = None + self._slop: int = -1 + self._timeout: Optional[float] = None + self._in_order: bool = False + self._sortby: Optional[SortbyField] = None + self._return_fields: List = [] + self._summarize_fields: List = [] + self._highlight_fields: List = [] + self._language: Optional[str] = None + self._expander: Optional[str] = None + self._dialect: Optional[int] = None + + def query_string(self) -> str: + """Return the query string of this query only.""" + return self._query_string + + def limit_ids(self, *ids) -> "Query": + """Limit the results to a specific set of pre-known document + ids of any length.""" + self._ids = ids + return self + + def return_fields(self, *fields) -> "Query": + """Add fields to return fields.""" + self._return_fields += fields + return self + + def return_field(self, field: str, as_field: Optional[str] = None) -> "Query": + """Add field to return fields (Optional: add 'AS' name + to the field).""" + self._return_fields.append(field) + if as_field is not None: + self._return_fields += ("AS", as_field) + return self + + def _mk_field_list(self, fields: List[str]) -> List: + if not fields: + return [] + return [fields] if isinstance(fields, str) else list(fields) + + def summarize( + self, + fields: Optional[List] = None, + context_len: Optional[int] = None, + num_frags: Optional[int] = None, + sep: Optional[str] = None, + ) -> "Query": + """ + Return an abridged format of the field, containing only the segments of + the field which contain the matching term(s). + + If `fields` is specified, then only the mentioned fields are + summarized; otherwise all results are summarized. + + Server side defaults are used for each option (except `fields`) + if not specified + + - **fields** List of fields to summarize. All fields are summarized + if not specified + - **context_len** Amount of context to include with each fragment + - **num_frags** Number of fragments per document + - **sep** Separator string to separate fragments + """ + args = ["SUMMARIZE"] + fields = self._mk_field_list(fields) + if fields: + args += ["FIELDS", str(len(fields))] + fields + + if context_len is not None: + args += ["LEN", str(context_len)] + if num_frags is not None: + args += ["FRAGS", str(num_frags)] + if sep is not None: + args += ["SEPARATOR", sep] + + self._summarize_fields = args + return self + + def highlight( + self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None + ) -> None: + """ + Apply specified markup to matched term(s) within the returned field(s). + + - **fields** If specified then only those mentioned fields are + highlighted, otherwise all fields are highlighted + - **tags** A list of two strings to surround the match. + """ + args = ["HIGHLIGHT"] + fields = self._mk_field_list(fields) + if fields: + args += ["FIELDS", str(len(fields))] + fields + if tags: + args += ["TAGS"] + list(tags) + + self._highlight_fields = args + return self + + def language(self, language: str) -> "Query": + """ + Analyze the query as being in the specified language. + + :param language: The language (e.g. `chinese` or `english`) + """ + self._language = language + return self + + def slop(self, slop: int) -> "Query": + """Allow a maximum of N intervening non matched terms between + phrase terms (0 means exact phrase). + """ + self._slop = slop + return self + + def timeout(self, timeout: float) -> "Query": + """overrides the timeout parameter of the module""" + self._timeout = timeout + return self + + def in_order(self) -> "Query": + """ + Match only documents where the query terms appear in + the same order in the document. + i.e. for the query "hello world", we do not match "world hello" + """ + self._in_order = True + return self + + def scorer(self, scorer: str) -> "Query": + """ + Use a different scoring function to evaluate document relevance. + Default is `TFIDF`. + + :param scorer: The scoring function to use + (e.g. `TFIDF.DOCNORM` or `BM25`) + """ + self._scorer = scorer + return self + + def get_args(self) -> List[str]: + """Format the redis arguments for this query and return them.""" + args = [self._query_string] + args += self._get_args_tags() + args += self._summarize_fields + self._highlight_fields + args += ["LIMIT", self._offset, self._num] + return args + + def _get_args_tags(self) -> List[str]: + args = [] + if self._no_content: + args.append("NOCONTENT") + if self._fields: + args.append("INFIELDS") + args.append(len(self._fields)) + args += self._fields + if self._verbatim: + args.append("VERBATIM") + if self._no_stopwords: + args.append("NOSTOPWORDS") + if self._filters: + for flt in self._filters: + if not isinstance(flt, Filter): + raise AttributeError("Did not receive a Filter object.") + args += flt.args + if self._with_payloads: + args.append("WITHPAYLOADS") + if self._scorer: + args += ["SCORER", self._scorer] + if self._with_scores: + args.append("WITHSCORES") + if self._ids: + args.append("INKEYS") + args.append(len(self._ids)) + args += self._ids + if self._slop >= 0: + args += ["SLOP", self._slop] + if self._timeout is not None: + args += ["TIMEOUT", self._timeout] + if self._in_order: + args.append("INORDER") + if self._return_fields: + args.append("RETURN") + args.append(len(self._return_fields)) + args += self._return_fields + if self._sortby: + if not isinstance(self._sortby, SortbyField): + raise AttributeError("Did not receive a SortByField.") + args.append("SORTBY") + args += self._sortby.args + if self._language: + args += ["LANGUAGE", self._language] + if self._expander: + args += ["EXPANDER", self._expander] + if self._dialect: + args += ["DIALECT", self._dialect] + + return args + + def paging(self, offset: int, num: int) -> "Query": + """ + Set the paging for the query (defaults to 0..10). + + - **offset**: Paging offset for the results. Defaults to 0 + - **num**: How many results do we want + """ + self._offset = offset + self._num = num + return self + + def verbatim(self) -> "Query": + """Set the query to be verbatim, i.e. use no query expansion + or stemming. + """ + self._verbatim = True + return self + + def no_content(self) -> "Query": + """Set the query to only return ids and not the document content.""" + self._no_content = True + return self + + def no_stopwords(self) -> "Query": + """ + Prevent the query from being filtered for stopwords. + Only useful in very big queries that you are certain contain + no stopwords. + """ + self._no_stopwords = True + return self + + def with_payloads(self) -> "Query": + """Ask the engine to return document payloads.""" + self._with_payloads = True + return self + + def with_scores(self) -> "Query": + """Ask the engine to return document search scores.""" + self._with_scores = True + return self + + def limit_fields(self, *fields: List[str]) -> "Query": + """ + Limit the search to specific TEXT fields only. + + - **fields**: A list of strings, case sensitive field names + from the defined schema. + """ + self._fields = fields + return self + + def add_filter(self, flt: "Filter") -> "Query": + """ + Add a numeric or geo filter to the query. + **Currently only one of each filter is supported by the engine** + + - **flt**: A NumericFilter or GeoFilter object, used on a + corresponding field + """ + + self._filters.append(flt) + return self + + def sort_by(self, field: str, asc: bool = True) -> "Query": + """ + Add a sortby field to the query. + + - **field** - the name of the field to sort by + - **asc** - when `True`, sorting will be done in asceding order + """ + self._sortby = SortbyField(field, asc) + return self + + def expander(self, expander: str) -> "Query": + """ + Add a expander field to the query. + + - **expander** - the name of the expander + """ + self._expander = expander + return self + + def dialect(self, dialect: int) -> "Query": + """ + Add a dialect field to the query. + + - **dialect** - dialect version to execute the query under + """ + self._dialect = dialect + return self + + +class Filter: + def __init__(self, keyword: str, field: str, *args: List[str]) -> None: + self.args = [keyword, field] + list(args) + + +class NumericFilter(Filter): + INF = "+inf" + NEG_INF = "-inf" + + def __init__( + self, + field: str, + minval: Union[int, str], + maxval: Union[int, str], + minExclusive: bool = False, + maxExclusive: bool = False, + ) -> None: + args = [ + minval if not minExclusive else f"({minval}", + maxval if not maxExclusive else f"({maxval}", + ] + + Filter.__init__(self, "FILTER", field, *args) + + +class GeoFilter(Filter): + METERS = "m" + KILOMETERS = "km" + FEET = "ft" + MILES = "mi" + + def __init__( + self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS + ) -> None: + Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit) + + +class SortbyField: + def __init__(self, field: str, asc=True) -> None: + self.args = [field, "ASC" if asc else "DESC"] diff --git a/.venv/Lib/site-packages/redis/commands/search/querystring.py b/.venv/Lib/site-packages/redis/commands/search/querystring.py new file mode 100644 index 00000000..3ff13209 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/querystring.py @@ -0,0 +1,317 @@ +def tags(*t): + """ + Indicate that the values should be matched to a tag field + + ### Parameters + + - **t**: Tags to search for + """ + if not t: + raise ValueError("At least one tag must be specified") + return TagValue(*t) + + +def between(a, b, inclusive_min=True, inclusive_max=True): + """ + Indicate that value is a numeric range + """ + return RangeValue(a, b, inclusive_min=inclusive_min, inclusive_max=inclusive_max) + + +def equal(n): + """ + Match a numeric value + """ + return between(n, n) + + +def lt(n): + """ + Match any value less than n + """ + return between(None, n, inclusive_max=False) + + +def le(n): + """ + Match any value less or equal to n + """ + return between(None, n, inclusive_max=True) + + +def gt(n): + """ + Match any value greater than n + """ + return between(n, None, inclusive_min=False) + + +def ge(n): + """ + Match any value greater or equal to n + """ + return between(n, None, inclusive_min=True) + + +def geo(lat, lon, radius, unit="km"): + """ + Indicate that value is a geo region + """ + return GeoValue(lat, lon, radius, unit) + + +class Value: + @property + def combinable(self): + """ + Whether this type of value may be combined with other values + for the same field. This makes the filter potentially more efficient + """ + return False + + @staticmethod + def make_value(v): + """ + Convert an object to a value, if it is not a value already + """ + if isinstance(v, Value): + return v + return ScalarValue(v) + + def to_string(self): + raise NotImplementedError() + + def __str__(self): + return self.to_string() + + +class RangeValue(Value): + combinable = False + + def __init__(self, a, b, inclusive_min=False, inclusive_max=False): + if a is None: + a = "-inf" + if b is None: + b = "inf" + self.range = [str(a), str(b)] + self.inclusive_min = inclusive_min + self.inclusive_max = inclusive_max + + def to_string(self): + return "[{1}{0[0]} {2}{0[1]}]".format( + self.range, + "(" if not self.inclusive_min else "", + "(" if not self.inclusive_max else "", + ) + + +class ScalarValue(Value): + combinable = True + + def __init__(self, v): + self.v = str(v) + + def to_string(self): + return self.v + + +class TagValue(Value): + combinable = False + + def __init__(self, *tags): + self.tags = tags + + def to_string(self): + return "{" + " | ".join(str(t) for t in self.tags) + "}" + + +class GeoValue(Value): + def __init__(self, lon, lat, radius, unit="km"): + self.lon = lon + self.lat = lat + self.radius = radius + self.unit = unit + + def to_string(self): + return f"[{self.lon} {self.lat} {self.radius} {self.unit}]" + + +class Node: + def __init__(self, *children, **kwparams): + """ + Create a node + + ### Parameters + + - **children**: One or more sub-conditions. These can be additional + `intersect`, `disjunct`, `union`, `optional`, or any other `Node` + type. + + The semantics of multiple conditions are dependent on the type of + query. For an `intersection` node, this amounts to a logical AND, + for a `union` node, this amounts to a logical `OR`. + + - **kwparams**: key-value parameters. Each key is the name of a field, + and the value should be a field value. This can be one of the + following: + + - Simple string (for text field matches) + - value returned by one of the helper functions + - list of either a string or a value + + + ### Examples + + Field `num` should be between 1 and 10 + ``` + intersect(num=between(1, 10) + ``` + + Name can either be `bob` or `john` + + ``` + union(name=("bob", "john")) + ``` + + Don't select countries in Israel, Japan, or US + + ``` + disjunct_union(country=("il", "jp", "us")) + ``` + """ + + self.params = [] + + kvparams = {} + for k, v in kwparams.items(): + curvals = kvparams.setdefault(k, []) + if isinstance(v, (str, int, float)): + curvals.append(Value.make_value(v)) + elif isinstance(v, Value): + curvals.append(v) + else: + curvals.extend(Value.make_value(subv) for subv in v) + + self.params += [Node.to_node(p) for p in children] + + for k, v in kvparams.items(): + self.params.extend(self.join_fields(k, v)) + + def join_fields(self, key, vals): + if len(vals) == 1: + return [BaseNode(f"@{key}:{vals[0].to_string()}")] + if not vals[0].combinable: + return [BaseNode(f"@{key}:{v.to_string()}") for v in vals] + s = BaseNode(f"@{key}:({self.JOINSTR.join(v.to_string() for v in vals)})") + return [s] + + @classmethod + def to_node(cls, obj): # noqa + if isinstance(obj, Node): + return obj + return BaseNode(obj) + + @property + def JOINSTR(self): + raise NotImplementedError() + + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + pre, post = ("(", ")") if with_parens else ("", "") + return f"{pre}{self.JOINSTR.join(n.to_string() for n in self.params)}{post}" + + def _should_use_paren(self, optval): + if optval is not None: + return optval + return len(self.params) > 1 + + def __str__(self): + return self.to_string() + + +class BaseNode(Node): + def __init__(self, s): + super().__init__() + self.s = str(s) + + def to_string(self, with_parens=None): + return self.s + + +class IntersectNode(Node): + """ + Create an intersection node. All children need to be satisfied in order for + this node to evaluate as true + """ + + JOINSTR = " " + + +class UnionNode(Node): + """ + Create a union node. Any of the children need to be satisfied in order for + this node to evaluate as true + """ + + JOINSTR = "|" + + +class DisjunctNode(IntersectNode): + """ + Create a disjunct node. In order for this node to be true, all of its + children must evaluate to false + """ + + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + ret = super().to_string(with_parens=False) + if with_parens: + return "(-" + ret + ")" + else: + return "-" + ret + + +class DistjunctUnion(DisjunctNode): + """ + This node is true if *all* of its children are false. This is equivalent to + ``` + disjunct(union(...)) + ``` + """ + + JOINSTR = "|" + + +class OptionalNode(IntersectNode): + """ + Create an optional node. If this nodes evaluates to true, then the document + will be rated higher in score/rank. + """ + + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + ret = super().to_string(with_parens=False) + if with_parens: + return "(~" + ret + ")" + else: + return "~" + ret + + +def intersect(*args, **kwargs): + return IntersectNode(*args, **kwargs) + + +def union(*args, **kwargs): + return UnionNode(*args, **kwargs) + + +def disjunct(*args, **kwargs): + return DisjunctNode(*args, **kwargs) + + +def disjunct_union(*args, **kwargs): + return DistjunctUnion(*args, **kwargs) + + +def querystring(*args, **kwargs): + return intersect(*args, **kwargs).to_string() diff --git a/.venv/Lib/site-packages/redis/commands/search/reducers.py b/.venv/Lib/site-packages/redis/commands/search/reducers.py new file mode 100644 index 00000000..8b60f232 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/reducers.py @@ -0,0 +1,182 @@ +from typing import Union + +from .aggregation import Asc, Desc, Reducer, SortDirection + + +class FieldOnlyReducer(Reducer): + """See https://redis.io/docs/interact/search-and-query/search/aggregations/""" + + def __init__(self, field: str) -> None: + super().__init__(field) + self._field = field + + +class count(Reducer): + """ + Counts the number of results in the group + """ + + NAME = "COUNT" + + def __init__(self) -> None: + super().__init__() + + +class sum(FieldOnlyReducer): + """ + Calculates the sum of all the values in the given fields within the group + """ + + NAME = "SUM" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class min(FieldOnlyReducer): + """ + Calculates the smallest value in the given field within the group + """ + + NAME = "MIN" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class max(FieldOnlyReducer): + """ + Calculates the largest value in the given field within the group + """ + + NAME = "MAX" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class avg(FieldOnlyReducer): + """ + Calculates the mean value in the given field within the group + """ + + NAME = "AVG" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class tolist(FieldOnlyReducer): + """ + Returns all the matched properties in a list + """ + + NAME = "TOLIST" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class count_distinct(FieldOnlyReducer): + """ + Calculate the number of distinct values contained in all the results in + the group for the given field + """ + + NAME = "COUNT_DISTINCT" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class count_distinctish(FieldOnlyReducer): + """ + Calculate the number of distinct values contained in all the results in the + group for the given field. This uses a faster algorithm than + `count_distinct` but is less accurate + """ + + NAME = "COUNT_DISTINCTISH" + + +class quantile(Reducer): + """ + Return the value for the nth percentile within the range of values for the + field within the group. + """ + + NAME = "QUANTILE" + + def __init__(self, field: str, pct: float) -> None: + super().__init__(field, str(pct)) + self._field = field + + +class stddev(FieldOnlyReducer): + """ + Return the standard deviation for the values within the group + """ + + NAME = "STDDEV" + + def __init__(self, field: str) -> None: + super().__init__(field) + + +class first_value(Reducer): + """ + Selects the first value within the group according to sorting parameters + """ + + NAME = "FIRST_VALUE" + + def __init__(self, field: str, *byfields: Union[Asc, Desc]) -> None: + """ + Selects the first value of the given field within the group. + + ### Parameter + + - **field**: Source field used for the value + - **byfields**: How to sort the results. This can be either the + *class* of `aggregation.Asc` or `aggregation.Desc` in which + case the field `field` is also used as the sort input. + + `byfields` can also be one or more *instances* of `Asc` or `Desc` + indicating the sort order for these fields + """ + + fieldstrs = [] + if ( + len(byfields) == 1 + and isinstance(byfields[0], type) + and issubclass(byfields[0], SortDirection) + ): + byfields = [byfields[0](field)] + + for f in byfields: + fieldstrs += [f.field, f.DIRSTRING] + + args = [field] + if fieldstrs: + args += ["BY"] + fieldstrs + super().__init__(*args) + self._field = field + + +class random_sample(Reducer): + """ + Returns a random sample of items from the dataset, from the given property + """ + + NAME = "RANDOM_SAMPLE" + + def __init__(self, field: str, size: int) -> None: + """ + ### Parameter + + **field**: Field to sample from + **size**: Return this many items (can be less) + """ + args = [field, str(size)] + super().__init__(*args) + self._field = field diff --git a/.venv/Lib/site-packages/redis/commands/search/result.py b/.venv/Lib/site-packages/redis/commands/search/result.py new file mode 100644 index 00000000..5b19e6fa --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/result.py @@ -0,0 +1,73 @@ +from ._util import to_string +from .document import Document + + +class Result: + """ + Represents the result of a search query, and has an array of Document + objects + """ + + def __init__( + self, res, hascontent, duration=0, has_payload=False, with_scores=False + ): + """ + - **snippets**: An optional dictionary of the form + {field: snippet_size} for snippet formatting + """ + + self.total = res[0] + self.duration = duration + self.docs = [] + + step = 1 + if hascontent: + step = step + 1 + if has_payload: + step = step + 1 + if with_scores: + step = step + 1 + + offset = 2 if with_scores else 1 + + for i in range(1, len(res), step): + id = to_string(res[i]) + payload = to_string(res[i + offset]) if has_payload else None + # fields_offset = 2 if has_payload else 1 + fields_offset = offset + 1 if has_payload else offset + score = float(res[i + 1]) if with_scores else None + + fields = {} + if hascontent and res[i + fields_offset] is not None: + fields = ( + dict( + dict( + zip( + map(to_string, res[i + fields_offset][::2]), + map(to_string, res[i + fields_offset][1::2]), + ) + ) + ) + if hascontent + else {} + ) + try: + del fields["id"] + except KeyError: + pass + + try: + fields["json"] = fields["$"] + del fields["$"] + except KeyError: + pass + + doc = ( + Document(id, score=score, payload=payload, **fields) + if with_scores + else Document(id, payload=payload, **fields) + ) + self.docs.append(doc) + + def __repr__(self) -> str: + return f"Result{{{self.total} total, docs: {self.docs}}}" diff --git a/.venv/Lib/site-packages/redis/commands/search/suggestion.py b/.venv/Lib/site-packages/redis/commands/search/suggestion.py new file mode 100644 index 00000000..499c8d91 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/search/suggestion.py @@ -0,0 +1,55 @@ +from typing import Optional + +from ._util import to_string + + +class Suggestion: + """ + Represents a single suggestion being sent or returned from the + autocomplete server + """ + + def __init__( + self, string: str, score: float = 1.0, payload: Optional[str] = None + ) -> None: + self.string = to_string(string) + self.payload = to_string(payload) + self.score = score + + def __repr__(self) -> str: + return self.string + + +class SuggestionParser: + """ + Internal class used to parse results from the `SUGGET` command. + This needs to consume either 1, 2, or 3 values at a time from + the return value depending on what objects were requested + """ + + def __init__(self, with_scores: bool, with_payloads: bool, ret) -> None: + self.with_scores = with_scores + self.with_payloads = with_payloads + + if with_scores and with_payloads: + self.sugsize = 3 + self._scoreidx = 1 + self._payloadidx = 2 + elif with_scores: + self.sugsize = 2 + self._scoreidx = 1 + elif with_payloads: + self.sugsize = 2 + self._payloadidx = 1 + else: + self.sugsize = 1 + self._scoreidx = -1 + + self._sugs = ret + + def __iter__(self): + for i in range(0, len(self._sugs), self.sugsize): + ss = self._sugs[i] + score = float(self._sugs[i + self._scoreidx]) if self.with_scores else 1.0 + payload = self._sugs[i + self._payloadidx] if self.with_payloads else None + yield Suggestion(ss, score, payload) diff --git a/.venv/Lib/site-packages/redis/commands/sentinel.py b/.venv/Lib/site-packages/redis/commands/sentinel.py new file mode 100644 index 00000000..f7457579 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/sentinel.py @@ -0,0 +1,99 @@ +import warnings + + +class SentinelCommands: + """ + A class containing the commands specific to redis sentinel. This class is + to be used as a mixin. + """ + + def sentinel(self, *args): + """Redis Sentinel's SENTINEL command.""" + warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) + + def sentinel_get_master_addr_by_name(self, service_name): + """Returns a (host, port) pair for the given ``service_name``""" + return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) + + def sentinel_master(self, service_name): + """Returns a dictionary containing the specified masters state.""" + return self.execute_command("SENTINEL MASTER", service_name) + + def sentinel_masters(self): + """Returns a list of dictionaries containing each master's state.""" + return self.execute_command("SENTINEL MASTERS") + + def sentinel_monitor(self, name, ip, port, quorum): + """Add a new master to Sentinel to be monitored""" + return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) + + def sentinel_remove(self, name): + """Remove a master from Sentinel's monitoring""" + return self.execute_command("SENTINEL REMOVE", name) + + def sentinel_sentinels(self, service_name): + """Returns a list of sentinels for ``service_name``""" + return self.execute_command("SENTINEL SENTINELS", service_name) + + def sentinel_set(self, name, option, value): + """Set Sentinel monitoring parameters for a given master""" + return self.execute_command("SENTINEL SET", name, option, value) + + def sentinel_slaves(self, service_name): + """Returns a list of slaves for ``service_name``""" + return self.execute_command("SENTINEL SLAVES", service_name) + + def sentinel_reset(self, pattern): + """ + This command will reset all the masters with matching name. + The pattern argument is a glob-style pattern. + + The reset process clears any previous state in a master (including a + failover in progress), and removes every slave and sentinel already + discovered and associated with the master. + """ + return self.execute_command("SENTINEL RESET", pattern, once=True) + + def sentinel_failover(self, new_master_name): + """ + Force a failover as if the master was not reachable, and without + asking for agreement to other Sentinels (however a new version of the + configuration will be published so that the other Sentinels will + update their configurations). + """ + return self.execute_command("SENTINEL FAILOVER", new_master_name) + + def sentinel_ckquorum(self, new_master_name): + """ + Check if the current Sentinel configuration is able to reach the + quorum needed to failover a master, and the majority needed to + authorize the failover. + + This command should be used in monitoring systems to check if a + Sentinel deployment is ok. + """ + return self.execute_command("SENTINEL CKQUORUM", new_master_name, once=True) + + def sentinel_flushconfig(self): + """ + Force Sentinel to rewrite its configuration on disk, including the + current Sentinel state. + + Normally Sentinel rewrites the configuration every time something + changes in its state (in the context of the subset of the state which + is persisted on disk across restart). + However sometimes it is possible that the configuration file is lost + because of operation errors, disk failures, package upgrade scripts or + configuration managers. In those cases a way to to force Sentinel to + rewrite the configuration file is handy. + + This command works even if the previous configuration file is + completely missing. + """ + return self.execute_command("SENTINEL FLUSHCONFIG") + + +class AsyncSentinelCommands(SentinelCommands): + async def sentinel(self, *args) -> None: + """Redis Sentinel's SENTINEL command.""" + super().sentinel(*args) diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/__init__.py b/.venv/Lib/site-packages/redis/commands/timeseries/__init__.py new file mode 100644 index 00000000..4188b93d --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/timeseries/__init__.py @@ -0,0 +1,108 @@ +import redis +from redis._parsers.helpers import bool_ok + +from ..helpers import get_protocol_version, parse_to_list +from .commands import ( + ALTER_CMD, + CREATE_CMD, + CREATERULE_CMD, + DEL_CMD, + DELETERULE_CMD, + GET_CMD, + INFO_CMD, + MGET_CMD, + MRANGE_CMD, + MREVRANGE_CMD, + QUERYINDEX_CMD, + RANGE_CMD, + REVRANGE_CMD, + TimeSeriesCommands, +) +from .info import TSInfo +from .utils import parse_get, parse_m_get, parse_m_range, parse_range + + +class TimeSeries(TimeSeriesCommands): + """ + This class subclasses redis-py's `Redis` and implements RedisTimeSeries's + commands (prefixed with "ts"). + The client allows to interact with RedisTimeSeries and use all of it's + functionality. + """ + + def __init__(self, client=None, **kwargs): + """Create a new RedisTimeSeries client.""" + # Set the module commands' callbacks + self._MODULE_CALLBACKS = { + ALTER_CMD: bool_ok, + CREATE_CMD: bool_ok, + CREATERULE_CMD: bool_ok, + DELETERULE_CMD: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { + DEL_CMD: int, + GET_CMD: parse_get, + INFO_CMD: TSInfo, + MGET_CMD: parse_m_get, + MRANGE_CMD: parse_m_range, + MREVRANGE_CMD: parse_m_range, + RANGE_CMD: parse_range, + REVRANGE_CMD: parse_range, + QUERYINDEX_CMD: parse_to_list, + } + _RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) + + def pipeline(self, transaction=True, shard_hint=None): + """Creates a pipeline for the TimeSeries module, that can be used + for executing only TimeSeries commands and core commands. + + Usage example: + + r = redis.Redis() + pipe = r.ts().pipeline() + for i in range(100): + pipeline.add("with_pipeline", i, 1.1 * i) + pipeline.execute() + + """ + if isinstance(self.client, redis.RedisCluster): + p = ClusterPipeline( + nodes_manager=self.client.nodes_manager, + commands_parser=self.client.commands_parser, + startup_nodes=self.client.nodes_manager.startup_nodes, + result_callbacks=self.client.result_callbacks, + cluster_response_callbacks=self.client.cluster_response_callbacks, + cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, + read_from_replicas=self.client.read_from_replicas, + reinitialize_steps=self.client.reinitialize_steps, + lock=self.client._lock, + ) + + else: + p = Pipeline( + connection_pool=self.client.connection_pool, + response_callbacks=self._MODULE_CALLBACKS, + transaction=transaction, + shard_hint=shard_hint, + ) + return p + + +class ClusterPipeline(TimeSeriesCommands, redis.cluster.ClusterPipeline): + """Cluster pipeline for the module.""" + + +class Pipeline(TimeSeriesCommands, redis.client.Pipeline): + """Pipeline for the module.""" diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..b6d35ccb Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/commands.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/commands.cpython-311.pyc new file mode 100644 index 00000000..5f3887d3 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/info.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/info.cpython-311.pyc new file mode 100644 index 00000000..4c4f5350 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/info.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/utils.cpython-311.pyc b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/utils.cpython-311.pyc new file mode 100644 index 00000000..3aa23e86 Binary files /dev/null and b/.venv/Lib/site-packages/redis/commands/timeseries/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/commands.py b/.venv/Lib/site-packages/redis/commands/timeseries/commands.py new file mode 100644 index 00000000..13e3cdf4 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/timeseries/commands.py @@ -0,0 +1,896 @@ +from typing import Dict, List, Optional, Tuple, Union + +from redis.exceptions import DataError +from redis.typing import KeyT, Number + +ADD_CMD = "TS.ADD" +ALTER_CMD = "TS.ALTER" +CREATERULE_CMD = "TS.CREATERULE" +CREATE_CMD = "TS.CREATE" +DECRBY_CMD = "TS.DECRBY" +DELETERULE_CMD = "TS.DELETERULE" +DEL_CMD = "TS.DEL" +GET_CMD = "TS.GET" +INCRBY_CMD = "TS.INCRBY" +INFO_CMD = "TS.INFO" +MADD_CMD = "TS.MADD" +MGET_CMD = "TS.MGET" +MRANGE_CMD = "TS.MRANGE" +MREVRANGE_CMD = "TS.MREVRANGE" +QUERYINDEX_CMD = "TS.QUERYINDEX" +RANGE_CMD = "TS.RANGE" +REVRANGE_CMD = "TS.REVRANGE" + + +class TimeSeriesCommands: + """RedisTimeSeries Commands.""" + + def create( + self, + key: KeyT, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): + """ + Create a new time-series. + + Args: + + key: + time-series key + retention_msecs: + Maximum age for samples compared to highest reported timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. + + For more information: https://redis.io/commands/ts.create/ + """ # noqa + params = [key] + self._append_retention(params, retention_msecs) + self._append_uncompressed(params, uncompressed) + self._append_chunk_size(params, chunk_size) + self._append_duplicate_policy(params, CREATE_CMD, duplicate_policy) + self._append_labels(params, labels) + + return self.execute_command(CREATE_CMD, *params) + + def alter( + self, + key: KeyT, + retention_msecs: Optional[int] = None, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): + """ + Update the retention, chunk size, duplicate policy, and labels of an existing + time series. + + Args: + + key: + time-series key + retention_msecs: + Maximum retention period, compared to maximal existing timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. + + For more information: https://redis.io/commands/ts.alter/ + """ # noqa + params = [key] + self._append_retention(params, retention_msecs) + self._append_chunk_size(params, chunk_size) + self._append_duplicate_policy(params, ALTER_CMD, duplicate_policy) + self._append_labels(params, labels) + + return self.execute_command(ALTER_CMD, *params) + + def add( + self, + key: KeyT, + timestamp: Union[int, str], + value: Number, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + duplicate_policy: Optional[str] = None, + ): + """ + Append (or create and append) a new sample to a time series. + + Args: + + key: + time-series key + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + value: + Numeric data value of the sample + retention_msecs: + Maximum retention period, compared to maximal existing timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. + + For more information: https://redis.io/commands/ts.add/ + """ # noqa + params = [key, timestamp, value] + self._append_retention(params, retention_msecs) + self._append_uncompressed(params, uncompressed) + self._append_chunk_size(params, chunk_size) + self._append_duplicate_policy(params, ADD_CMD, duplicate_policy) + self._append_labels(params, labels) + + return self.execute_command(ADD_CMD, *params) + + def madd(self, ktv_tuples: List[Tuple[KeyT, Union[int, str], Number]]): + """ + Append (or create and append) a new `value` to series + `key` with `timestamp`. + Expects a list of `tuples` as (`key`,`timestamp`, `value`). + Return value is an array with timestamps of insertions. + + For more information: https://redis.io/commands/ts.madd/ + """ # noqa + params = [] + for ktv in ktv_tuples: + params.extend(ktv) + + return self.execute_command(MADD_CMD, *params) + + def incrby( + self, + key: KeyT, + value: Number, + timestamp: Optional[Union[int, str]] = None, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + ): + """ + Increment (or create an time-series and increment) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. + + Args: + + key: + time-series key + value: + Numeric data value of the sample + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + retention_msecs: + Maximum age for samples compared to last event time (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + + For more information: https://redis.io/commands/ts.incrby/ + """ # noqa + params = [key, value] + self._append_timestamp(params, timestamp) + self._append_retention(params, retention_msecs) + self._append_uncompressed(params, uncompressed) + self._append_chunk_size(params, chunk_size) + self._append_labels(params, labels) + + return self.execute_command(INCRBY_CMD, *params) + + def decrby( + self, + key: KeyT, + value: Number, + timestamp: Optional[Union[int, str]] = None, + retention_msecs: Optional[int] = None, + uncompressed: Optional[bool] = False, + labels: Optional[Dict[str, str]] = None, + chunk_size: Optional[int] = None, + ): + """ + Decrement (or create an time-series and decrement) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. + + Args: + + key: + time-series key + value: + Numeric data value of the sample + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + retention_msecs: + Maximum age for samples compared to last event time (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + + For more information: https://redis.io/commands/ts.decrby/ + """ # noqa + params = [key, value] + self._append_timestamp(params, timestamp) + self._append_retention(params, retention_msecs) + self._append_uncompressed(params, uncompressed) + self._append_chunk_size(params, chunk_size) + self._append_labels(params, labels) + + return self.execute_command(DECRBY_CMD, *params) + + def delete(self, key: KeyT, from_time: int, to_time: int): + """ + Delete all samples between two timestamps for a given time series. + + Args: + + key: + time-series key. + from_time: + Start timestamp for the range deletion. + to_time: + End timestamp for the range deletion. + + For more information: https://redis.io/commands/ts.del/ + """ # noqa + return self.execute_command(DEL_CMD, key, from_time, to_time) + + def createrule( + self, + source_key: KeyT, + dest_key: KeyT, + aggregation_type: str, + bucket_size_msec: int, + align_timestamp: Optional[int] = None, + ): + """ + Create a compaction rule from values added to `source_key` into `dest_key`. + + Args: + + source_key: + Key name for source time series + dest_key: + Key name for destination (compacted) time series + aggregation_type: + Aggregation type: One of the following: + [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, `std.p`, + `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Duration of each bucket, in milliseconds + align_timestamp: + Assure that there is a bucket that starts at exactly align_timestamp and + align all other buckets accordingly. + + For more information: https://redis.io/commands/ts.createrule/ + """ # noqa + params = [source_key, dest_key] + self._append_aggregation(params, aggregation_type, bucket_size_msec) + if align_timestamp is not None: + params.append(align_timestamp) + + return self.execute_command(CREATERULE_CMD, *params) + + def deleterule(self, source_key: KeyT, dest_key: KeyT): + """ + Delete a compaction rule from `source_key` to `dest_key`.. + + For more information: https://redis.io/commands/ts.deleterule/ + """ # noqa + return self.execute_command(DELETERULE_CMD, source_key, dest_key) + + def __range_params( + self, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int], + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + filter_by_ts: Optional[List[int]], + filter_by_min_value: Optional[int], + filter_by_max_value: Optional[int], + align: Optional[Union[int, str]], + latest: Optional[bool], + bucket_timestamp: Optional[str], + empty: Optional[bool], + ): + """Create TS.RANGE and TS.REVRANGE arguments.""" + params = [key, from_time, to_time] + self._append_latest(params, latest) + self._append_filer_by_ts(params, filter_by_ts) + self._append_filer_by_value(params, filter_by_min_value, filter_by_max_value) + self._append_count(params, count) + self._append_align(params, align) + self._append_aggregation(params, aggregation_type, bucket_size_msec) + self._append_bucket_timestamp(params, bucket_timestamp) + self._append_empty(params, empty) + + return params + + def range( + self, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, + ): + """ + Query a range in forward direction for a specific time-serie. + + Args: + + key: + Key name for timeseries. + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter by_min_value). + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.range/ + """ # noqa + params = self.__range_params( + key, + from_time, + to_time, + count, + aggregation_type, + bucket_size_msec, + filter_by_ts, + filter_by_min_value, + filter_by_max_value, + align, + latest, + bucket_timestamp, + empty, + ) + return self.execute_command(RANGE_CMD, *params) + + def revrange( + self, + key: KeyT, + from_time: Union[int, str], + to_time: Union[int, str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, + ): + """ + Query a range in reverse direction for a specific time-series. + + **Note**: This command is only available since RedisTimeSeries >= v1.4 + + Args: + + key: + Key name for timeseries. + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.revrange/ + """ # noqa + params = self.__range_params( + key, + from_time, + to_time, + count, + aggregation_type, + bucket_size_msec, + filter_by_ts, + filter_by_min_value, + filter_by_max_value, + align, + latest, + bucket_timestamp, + empty, + ) + return self.execute_command(REVRANGE_CMD, *params) + + def __mrange_params( + self, + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + count: Optional[int], + filters: List[str], + from_time: Union[int, str], + to_time: Union[int, str], + with_labels: Optional[bool], + filter_by_ts: Optional[List[int]], + filter_by_min_value: Optional[int], + filter_by_max_value: Optional[int], + groupby: Optional[str], + reduce: Optional[str], + select_labels: Optional[List[str]], + align: Optional[Union[int, str]], + latest: Optional[bool], + bucket_timestamp: Optional[str], + empty: Optional[bool], + ): + """Create TS.MRANGE and TS.MREVRANGE arguments.""" + params = [from_time, to_time] + self._append_latest(params, latest) + self._append_filer_by_ts(params, filter_by_ts) + self._append_filer_by_value(params, filter_by_min_value, filter_by_max_value) + self._append_with_labels(params, with_labels, select_labels) + self._append_count(params, count) + self._append_align(params, align) + self._append_aggregation(params, aggregation_type, bucket_size_msec) + self._append_bucket_timestamp(params, bucket_timestamp) + self._append_empty(params, empty) + params.extend(["FILTER"]) + params += filters + self._append_groupby_reduce(params, groupby, reduce) + return params + + def mrange( + self, + from_time: Union[int, str], + to_time: Union[int, str], + filters: List[str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + with_labels: Optional[bool] = False, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + groupby: Optional[str] = None, + reduce: Optional[str] = None, + select_labels: Optional[List[str]] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, + ): + """ + Query a range across multiple time-series by filters in forward direction. + + Args: + + from_time: + Start timestamp for the range query. `-` can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, `+` can be used to express the maximum possible timestamp. + filters: + filter to match the time-series labels. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + with_labels: + Include in the reply all label-value pairs representing metadata labels of the time series. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + groupby: + Grouping by fields the results (must mention also reduce). + reduce: + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrange/ + """ # noqa + params = self.__mrange_params( + aggregation_type, + bucket_size_msec, + count, + filters, + from_time, + to_time, + with_labels, + filter_by_ts, + filter_by_min_value, + filter_by_max_value, + groupby, + reduce, + select_labels, + align, + latest, + bucket_timestamp, + empty, + ) + + return self.execute_command(MRANGE_CMD, *params) + + def mrevrange( + self, + from_time: Union[int, str], + to_time: Union[int, str], + filters: List[str], + count: Optional[int] = None, + aggregation_type: Optional[str] = None, + bucket_size_msec: Optional[int] = 0, + with_labels: Optional[bool] = False, + filter_by_ts: Optional[List[int]] = None, + filter_by_min_value: Optional[int] = None, + filter_by_max_value: Optional[int] = None, + groupby: Optional[str] = None, + reduce: Optional[str] = None, + select_labels: Optional[List[str]] = None, + align: Optional[Union[int, str]] = None, + latest: Optional[bool] = False, + bucket_timestamp: Optional[str] = None, + empty: Optional[bool] = False, + ): + """ + Query a range across multiple time-series by filters in reverse direction. + + Args: + + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + filters: + Filter to match the time-series labels. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + with_labels: + Include in the reply all label-value pairs representing metadata labels of the time series. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + groupby: + Grouping by fields the results (must mention also reduce). + reduce: + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrevrange/ + """ # noqa + params = self.__mrange_params( + aggregation_type, + bucket_size_msec, + count, + filters, + from_time, + to_time, + with_labels, + filter_by_ts, + filter_by_min_value, + filter_by_max_value, + groupby, + reduce, + select_labels, + align, + latest, + bucket_timestamp, + empty, + ) + + return self.execute_command(MREVRANGE_CMD, *params) + + def get(self, key: KeyT, latest: Optional[bool] = False): + """# noqa + Get the last sample of `key`. + `latest` used when a time series is a compaction, reports the compacted + value of the latest (possibly partial) bucket + + For more information: https://redis.io/commands/ts.get/ + """ # noqa + params = [key] + self._append_latest(params, latest) + return self.execute_command(GET_CMD, *params) + + def mget( + self, + filters: List[str], + with_labels: Optional[bool] = False, + select_labels: Optional[List[str]] = None, + latest: Optional[bool] = False, + ): + """# noqa + Get the last samples matching the specific `filter`. + + Args: + + filters: + Filter to match the time-series labels. + with_labels: + Include in the reply all label-value pairs representing metadata + labels of the time series. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + + For more information: https://redis.io/commands/ts.mget/ + """ # noqa + params = [] + self._append_latest(params, latest) + self._append_with_labels(params, with_labels, select_labels) + params.extend(["FILTER"]) + params += filters + return self.execute_command(MGET_CMD, *params) + + def info(self, key: KeyT): + """# noqa + Get information of `key`. + + For more information: https://redis.io/commands/ts.info/ + """ # noqa + return self.execute_command(INFO_CMD, key) + + def queryindex(self, filters: List[str]): + """# noqa + Get all time series keys matching the `filter` list. + + For more information: https://redis.io/commands/ts.queryindex/ + """ # noq + return self.execute_command(QUERYINDEX_CMD, *filters) + + @staticmethod + def _append_uncompressed(params: List[str], uncompressed: Optional[bool]): + """Append UNCOMPRESSED tag to params.""" + if uncompressed: + params.extend(["UNCOMPRESSED"]) + + @staticmethod + def _append_with_labels( + params: List[str], + with_labels: Optional[bool], + select_labels: Optional[List[str]], + ): + """Append labels behavior to params.""" + if with_labels and select_labels: + raise DataError( + "with_labels and select_labels cannot be provided together." + ) + + if with_labels: + params.extend(["WITHLABELS"]) + if select_labels: + params.extend(["SELECTED_LABELS", *select_labels]) + + @staticmethod + def _append_groupby_reduce( + params: List[str], groupby: Optional[str], reduce: Optional[str] + ): + """Append GROUPBY REDUCE property to params.""" + if groupby is not None and reduce is not None: + params.extend(["GROUPBY", groupby, "REDUCE", reduce.upper()]) + + @staticmethod + def _append_retention(params: List[str], retention: Optional[int]): + """Append RETENTION property to params.""" + if retention is not None: + params.extend(["RETENTION", retention]) + + @staticmethod + def _append_labels(params: List[str], labels: Optional[List[str]]): + """Append LABELS property to params.""" + if labels: + params.append("LABELS") + for k, v in labels.items(): + params.extend([k, v]) + + @staticmethod + def _append_count(params: List[str], count: Optional[int]): + """Append COUNT property to params.""" + if count is not None: + params.extend(["COUNT", count]) + + @staticmethod + def _append_timestamp(params: List[str], timestamp: Optional[int]): + """Append TIMESTAMP property to params.""" + if timestamp is not None: + params.extend(["TIMESTAMP", timestamp]) + + @staticmethod + def _append_align(params: List[str], align: Optional[Union[int, str]]): + """Append ALIGN property to params.""" + if align is not None: + params.extend(["ALIGN", align]) + + @staticmethod + def _append_aggregation( + params: List[str], + aggregation_type: Optional[str], + bucket_size_msec: Optional[int], + ): + """Append AGGREGATION property to params.""" + if aggregation_type is not None: + params.extend(["AGGREGATION", aggregation_type, bucket_size_msec]) + + @staticmethod + def _append_chunk_size(params: List[str], chunk_size: Optional[int]): + """Append CHUNK_SIZE property to params.""" + if chunk_size is not None: + params.extend(["CHUNK_SIZE", chunk_size]) + + @staticmethod + def _append_duplicate_policy( + params: List[str], command: Optional[str], duplicate_policy: Optional[str] + ): + """Append DUPLICATE_POLICY property to params on CREATE + and ON_DUPLICATE on ADD. + """ + if duplicate_policy is not None: + if command == "TS.ADD": + params.extend(["ON_DUPLICATE", duplicate_policy]) + else: + params.extend(["DUPLICATE_POLICY", duplicate_policy]) + + @staticmethod + def _append_filer_by_ts(params: List[str], ts_list: Optional[List[int]]): + """Append FILTER_BY_TS property to params.""" + if ts_list is not None: + params.extend(["FILTER_BY_TS", *ts_list]) + + @staticmethod + def _append_filer_by_value( + params: List[str], min_value: Optional[int], max_value: Optional[int] + ): + """Append FILTER_BY_VALUE property to params.""" + if min_value is not None and max_value is not None: + params.extend(["FILTER_BY_VALUE", min_value, max_value]) + + @staticmethod + def _append_latest(params: List[str], latest: Optional[bool]): + """Append LATEST property to params.""" + if latest: + params.append("LATEST") + + @staticmethod + def _append_bucket_timestamp(params: List[str], bucket_timestamp: Optional[str]): + """Append BUCKET_TIMESTAMP property to params.""" + if bucket_timestamp is not None: + params.extend(["BUCKETTIMESTAMP", bucket_timestamp]) + + @staticmethod + def _append_empty(params: List[str], empty: Optional[bool]): + """Append EMPTY property to params.""" + if empty: + params.append("EMPTY") diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/info.py b/.venv/Lib/site-packages/redis/commands/timeseries/info.py new file mode 100644 index 00000000..3a384dc0 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/timeseries/info.py @@ -0,0 +1,91 @@ +from ..helpers import nativestr +from .utils import list_to_dict + + +class TSInfo: + """ + Hold information and statistics on the time-series. + Can be created using ``tsinfo`` command + https://oss.redis.com/redistimeseries/commands/#tsinfo. + """ + + rules = [] + labels = [] + sourceKey = None + chunk_count = None + memory_usage = None + total_samples = None + retention_msecs = None + last_time_stamp = None + first_time_stamp = None + + max_samples_per_chunk = None + chunk_size = None + duplicate_policy = None + + def __init__(self, args): + """ + Hold information and statistics on the time-series. + + The supported params that can be passed as args: + + rules: + A list of compaction rules of the time series. + sourceKey: + Key name for source time series in case the current series + is a target of a rule. + chunkCount: + Number of Memory Chunks used for the time series. + memoryUsage: + Total number of bytes allocated for the time series. + totalSamples: + Total number of samples in the time series. + labels: + A list of label-value pairs that represent the metadata + labels of the time series. + retentionTime: + Retention time, in milliseconds, for the time series. + lastTimestamp: + Last timestamp present in the time series. + firstTimestamp: + First timestamp present in the time series. + maxSamplesPerChunk: + Deprecated. + chunkSize: + Amount of memory, in bytes, allocated for data. + duplicatePolicy: + Policy that will define handling of duplicate samples. + + Can read more about on + https://oss.redis.com/redistimeseries/configuration/#duplicate_policy + """ + response = dict(zip(map(nativestr, args[::2]), args[1::2])) + self.rules = response.get("rules") + self.source_key = response.get("sourceKey") + self.chunk_count = response.get("chunkCount") + self.memory_usage = response.get("memoryUsage") + self.total_samples = response.get("totalSamples") + self.labels = list_to_dict(response.get("labels")) + self.retention_msecs = response.get("retentionTime") + self.last_timestamp = response.get("lastTimestamp") + self.first_timestamp = response.get("firstTimestamp") + if "maxSamplesPerChunk" in response: + self.max_samples_per_chunk = response["maxSamplesPerChunk"] + self.chunk_size = ( + self.max_samples_per_chunk * 16 + ) # backward compatible changes + if "chunkSize" in response: + self.chunk_size = response["chunkSize"] + if "duplicatePolicy" in response: + self.duplicate_policy = response["duplicatePolicy"] + if type(self.duplicate_policy) == bytes: + self.duplicate_policy = self.duplicate_policy.decode() + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/.venv/Lib/site-packages/redis/commands/timeseries/utils.py b/.venv/Lib/site-packages/redis/commands/timeseries/utils.py new file mode 100644 index 00000000..c49b0402 --- /dev/null +++ b/.venv/Lib/site-packages/redis/commands/timeseries/utils.py @@ -0,0 +1,44 @@ +from ..helpers import nativestr + + +def list_to_dict(aList): + return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} + + +def parse_range(response): + """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" + return [tuple((r[0], float(r[1]))) for r in response] + + +def parse_m_range(response): + """Parse multi range response. Used by TS.MRANGE and TS.MREVRANGE.""" + res = [] + for item in response: + res.append({nativestr(item[0]): [list_to_dict(item[1]), parse_range(item[2])]}) + return sorted(res, key=lambda d: list(d.keys())) + + +def parse_get(response): + """Parse get response. Used by TS.GET.""" + if not response: + return None + return int(response[0]), float(response[1]) + + +def parse_m_get(response): + """Parse multi get response. Used by TS.MGET.""" + res = [] + for item in response: + if not item[2]: + res.append({nativestr(item[0]): [list_to_dict(item[1]), None, None]}) + else: + res.append( + { + nativestr(item[0]): [ + list_to_dict(item[1]), + int(item[2][0]), + float(item[2][1]), + ] + } + ) + return sorted(res, key=lambda d: list(d.keys())) diff --git a/.venv/Lib/site-packages/redis/compat.py b/.venv/Lib/site-packages/redis/compat.py new file mode 100644 index 00000000..e4784934 --- /dev/null +++ b/.venv/Lib/site-packages/redis/compat.py @@ -0,0 +1,6 @@ +# flake8: noqa +try: + from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] +except ImportError: + from typing_extensions import Literal # lgtm [py/unused-import] + from typing_extensions import Protocol, TypedDict diff --git a/.venv/Lib/site-packages/redis/connection.py b/.venv/Lib/site-packages/redis/connection.py new file mode 100644 index 00000000..b39ba28f --- /dev/null +++ b/.venv/Lib/site-packages/redis/connection.py @@ -0,0 +1,1336 @@ +import copy +import os +import socket +import ssl +import sys +import threading +import weakref +from abc import abstractmethod +from itertools import chain +from queue import Empty, Full, LifoQueue +from time import time +from typing import Any, Callable, List, Optional, Type, Union +from urllib.parse import parse_qs, unquote, urlparse + +from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser +from .backoff import NoBackoff +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + ChildDeadlockedError, + ConnectionError, + DataError, + RedisError, + ResponseError, + TimeoutError, +) +from .retry import Retry +from .utils import ( + CRYPTOGRAPHY_AVAILABLE, + HIREDIS_AVAILABLE, + HIREDIS_PACK_AVAILABLE, + SSL_AVAILABLE, + get_lib_version, + str_if_bytes, +) + +if HIREDIS_AVAILABLE: + import hiredis + +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_EMPTY = b"" + +DEFAULT_RESP_VERSION = 2 + +SENTINEL = object() + +DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = _HiredisParser +else: + DefaultParser = _RESP2Parser + + +class HiredisRespSerializer: + def pack(self, *args: List): + """Pack a series of arguments into the Redis protocol""" + output = [] + + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + try: + output.append(hiredis.pack_command(args)) + except TypeError: + _, value, traceback = sys.exc_info() + raise DataError(value).with_traceback(traceback) + + return output + + +class PythonRespSerializer: + def __init__(self, buffer_cutoff, encode) -> None: + self._buffer_cutoff = buffer_cutoff + self.encode = encode + + def pack(self, *args): + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + +class AbstractConnection: + "Manages communication to and from a Redis server" + + def __init__( + self, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error=SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, + ): + """ + Initialize a new Connection. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) + self.pid = os.getpid() + self.db = db + self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version + self.credential_provider = credential_provider + self.password = password + self.username = username + self.socket_timeout = socket_timeout + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout + self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_error = [] + if retry_on_timeout: + # Add TimeoutError to the errors list to retry on + retry_on_error.append(TimeoutError) + self.retry_on_error = retry_on_error + if retry or retry_on_error: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(retry_on_error) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check = 0 + self.redis_connect_func = redis_connect_func + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self._sock = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + # p = DEFAULT_RESP_VERSION + self.protocol = p + self._command_packer = self._construct_command_packer(command_packer) + + def __repr__(self): + repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) + return f"{self.__class__.__name__}<{repr_args}>" + + @abstractmethod + def repr_pieces(self): + pass + + def __del__(self): + try: + self.disconnect() + except Exception: + pass + + def _construct_command_packer(self, packer): + if packer is not None: + return packer + elif HIREDIS_PACK_AVAILABLE: + return HiredisRespSerializer() + else: + return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) + + def _register_connect_callback(self, callback): + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) + + def _deregister_connect_callback(self, callback): + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass + + def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + def connect(self): + "Connects to the Redis server if not already connected" + if self._sock: + return + try: + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) + except socket.timeout: + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + + self._sock = sock + try: + if self.redis_connect_func is None: + # Use the default on_connect function + self.on_connect() + else: + # Use the passed function redis_connect_func + self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] + for ref in self._connect_callbacks: + callback = ref() + if callback: + callback(self) + + @abstractmethod + def _connect(self): + pass + + @abstractmethod + def _host_error(self): + pass + + @abstractmethod + def _error_message(self, exception): + pass + + def on_connect(self): + "Initialize the connection, authenticate and select a database" + self._parser.on_connect(self) + parser = self._parser + + auth_args = None + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() + + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + elif auth_args: + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + self.send_command("AUTH", auth_args[-1], check_health=False) + auth_response = self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol) + response = self.read_response() + if ( + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol + ): + raise ConnectionError("Invalid RESP version") + + # if a client_name is given, set it + if self.client_name: + self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + try: + # set the library name and version + if self.lib_name: + self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + self.read_response() + except ResponseError: + pass + + try: + if self.lib_version: + self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + self.read_response() + except ResponseError: + pass + + # if a database is specified, switch to it + if self.db: + self.send_command("SELECT", self.db) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + def disconnect(self, *args): + "Disconnects from the Redis server" + self._parser.on_disconnect() + + conn_sock = self._sock + self._sock = None + if conn_sock is None: + return + + if os.getpid() == self.pid: + try: + conn_sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + + try: + conn_sock.close() + except OSError: + pass + + def _send_ping(self): + """Send PING, expect PONG in return""" + self.send_command("PING", check_health=False) + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + def _ping_failed(self, error): + """Function to call when PING fails""" + self.disconnect() + + def check_health(self): + """Check the health of the connection with a PING/PONG""" + if self.health_check_interval and time() > self.next_health_check: + self.retry.call_with_retry(self._send_ping, self._ping_failed) + + def send_packed_command(self, command, check_health=True): + """Send an already packed command to the Redis server""" + if not self._sock: + self.connect() + # guard against health check recursion + if check_health: + self.check_health() + try: + if isinstance(command, str): + command = [command] + for item in command: + self._sock.sendall(item) + except socket.timeout: + self.disconnect() + raise TimeoutError("Timeout writing to socket") + except OSError as e: + self.disconnect() + if len(e.args) == 1: + errno, errmsg = "UNKNOWN", e.args[0] + else: + errno = e.args[0] + errmsg = e.args[1] + raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. + self.disconnect() + raise + + def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + self.send_packed_command( + self._command_packer.pack(*args), + check_health=kwargs.get("check_health", True), + ) + + def can_read(self, timeout=0): + """Poll the socket to see if there's data that can be read.""" + sock = self._sock + if not sock: + self.connect() + + host_error = self._host_error() + + try: + return self._parser.can_read(timeout) + except OSError as e: + self.disconnect() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") + + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + """Read the response from a previously sent command""" + + host_error = self._host_error() + + try: + if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = self._parser.read_response(disable_decoding=disable_decoding) + except socket.timeout: + if disconnect_on_error: + self.disconnect() + raise TimeoutError(f"Timeout reading from {host_error}") + except OSError as e: + if disconnect_on_error: + self.disconnect() + raise ConnectionError( + f"Error while reading from {host_error}" f" : {e.args}" + ) + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() + raise + + if self.health_check_interval: + self.next_health_check = time() + self.health_check_interval + + if isinstance(response, ResponseError): + try: + raise response + finally: + del response # avoid creating ref cycles + return response + + def pack_command(self, *args): + """Pack a series of arguments into the Redis protocol""" + return self._command_packer.pack(*args) + + def pack_commands(self, commands): + """Pack multiple commands into the Redis protocol""" + output = [] + pieces = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self._command_packer.pack(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + if pieces: + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connect(self): + "Create a TCP socket connection" + # we want to mimic what socket.create_connection does to support + # ipv4/ipv6, but we want to set options prior to calling + # socket.connect() + err = None + for res in socket.getaddrinfo( + self.host, self.port, self.socket_type, socket.SOCK_STREAM + ): + family, socktype, proto, canonname, socket_address = res + sock = None + try: + sock = socket.socket(family, socktype, proto) + # TCP_NODELAY + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.IPPROTO_TCP, k, v) + + # set the socket_connect_timeout before we connect + sock.settimeout(self.socket_connect_timeout) + + # connect + sock.connect(socket_address) + + # set the socket_timeout now that we're connected + sock.settimeout(self.socket_timeout) + return sock + + except OSError as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + raise err + raise OSError("socket.getaddrinfo returned an empty list") + + def _host_error(self): + return f"{self.host}:{self.port}" + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if len(exception.args) == 1: + try: + return f"Error connecting to {host_error}. \ + {exception.args[0]}." + except AttributeError: + return f"Connection Error: {exception.args[0]}" + else: + try: + return ( + f"Error {exception.args[0]} connecting to " + f"{host_error}. {exception.args[1]}." + ) + except AttributeError: + return f"Connection Error: {exception.args[0]}" + + +class SSLConnection(Connection): + """Manages SSL connections to and from the Redis server(s). + This class extends the Connection class, adding SSL functionality, and making + use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) + """ # noqa + + def __init__( + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_ca_path=None, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + **kwargs, + ): + """Constructor + + Args: + ssl_keyfile: Path to an ssl private key. Defaults to None. + ssl_certfile: Path to an ssl certificate. Defaults to None. + ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required". + ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. + ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. + ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. + ssl_ca_path: The path to a directory containing several CA certificates in PEM format. Defaults to None. + ssl_password: Password for unlocking an encrypted private key. Defaults to None. + + ssl_validate_ocsp: If set, perform a full ocsp validation (i.e not a stapled verification) + ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response + ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert + ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. + + Raises: + RedisError + """ # noqa + if not SSL_AVAILABLE: + raise RedisError("Python wasn't built with SSL support") + + self.keyfile = ssl_keyfile + self.certfile = ssl_certfile + if ssl_cert_reqs is None: + ssl_cert_reqs = ssl.CERT_NONE + elif isinstance(ssl_cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if ssl_cert_reqs not in CERT_REQS: + raise RedisError( + f"Invalid SSL Certificate Requirements Flag: {ssl_cert_reqs}" + ) + ssl_cert_reqs = CERT_REQS[ssl_cert_reqs] + self.cert_reqs = ssl_cert_reqs + self.ca_certs = ssl_ca_certs + self.ca_data = ssl_ca_data + self.ca_path = ssl_ca_path + self.check_hostname = ssl_check_hostname + self.certificate_password = ssl_password + self.ssl_validate_ocsp = ssl_validate_ocsp + self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled + self.ssl_ocsp_context = ssl_ocsp_context + self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert + super().__init__(**kwargs) + + def _connect(self): + "Wrap the socket with SSL support" + sock = super()._connect() + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile or self.keyfile: + context.load_cert_chain( + certfile=self.certfile, + keyfile=self.keyfile, + password=self.certificate_password, + ) + if ( + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None + ): + context.load_verify_locations( + cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data + ) + sslsock = context.wrap_socket(sock, server_hostname=self.host) + if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: + raise RedisError("cryptography is not installed.") + + if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: + raise RedisError( + "Either an OCSP staple or pure OCSP connection must be validated " + "- not both." + ) + + # validation for the stapled case + if self.ssl_validate_ocsp_stapled: + import OpenSSL + + from .ocsp import ocsp_staple_verifier + + # if a context is provided use it - otherwise, a basic context + if self.ssl_ocsp_context is None: + staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + staple_ctx.use_certificate_file(self.certfile) + staple_ctx.use_privatekey_file(self.keyfile) + else: + staple_ctx = self.ssl_ocsp_context + + staple_ctx.set_ocsp_client_callback( + ocsp_staple_verifier, self.ssl_ocsp_expected_cert + ) + + # need another socket + con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) + con.request_ocsp() + con.connect((self.host, self.port)) + con.do_handshake() + con.shutdown() + return sslsock + + # pure ocsp validation + if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: + from .ocsp import OCSPVerifier + + o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) + if o.is_valid(): + return sslsock + else: + raise ConnectionError("ocsp validation error") + return sslsock + + +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, path="", socket_timeout=None, **kwargs): + self.path = path + self.socket_timeout = socket_timeout + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("path", self.path), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connect(self): + "Create a Unix domain socket connection" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.socket_connect_timeout) + sock.connect(self.path) + sock.settimeout(self.socket_timeout) + return sock + + def _host_error(self): + return self.path + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + host_error = self._host_error() + if len(exception.args) == 1: + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{host_error}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value): + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS = { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "retry_on_error": list, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, +} + + +def parse_url(url): + if not ( + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") + ): + raise ValueError( + "Redis URL must specify one of the following " + "schemes (redis://, rediss://, unix://)" + ) + + url = urlparse(url) + kwargs = {} + + for name, value in parse_qs(url.query).items(): + if value and len(value) > 0: + value = unquote(value[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value + + if url.username: + kwargs["username"] = unquote(url.username) + if url.password: + kwargs["password"] = unquote(url.password) + + # We only support redis://, rediss:// and unix:// schemes. + if url.scheme == "unix": + if url.path: + kwargs["path"] = unquote(url.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + else: # implied: url.scheme in ("redis", "rediss"): + if url.hostname: + kwargs["host"] = unquote(url.hostname) + if url.port: + kwargs["port"] = int(url.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if url.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(url.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if url.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + + return kwargs + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use class:`.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[username@]/path/to/socket.sock?db=0[&password=password] + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + + if "connection_class" in kwargs: + url_options["connection_class"] = kwargs["connection_class"] + + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, + ): + max_connections = max_connections or 2**31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + # a lock to protect the critical section in _checkpid(). + # this lock is acquired when the process id changes, such as + # after a fork. during this time, multiple threads in the child + # process could attempt to acquire this lock. the first thread + # to acquire the lock will reset the data structures and lock + # object of this pool. subsequent threads acquiring this lock + # will notice the first thread already did the work and simply + # release the lock. + self._fork_lock = threading.Lock() + self.reset() + + def __repr__(self) -> (str, str): + return ( + f"{type(self).__name__}" + f"<{repr(self.connection_class(**self.connection_kwargs))}>" + ) + + def reset(self) -> None: + self._lock = threading.Lock() + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def _checkpid(self) -> None: + # _checkpid() attempts to keep ConnectionPool fork-safe on modern + # systems. this is called by all ConnectionPool methods that + # manipulate the pool's state such as get_connection() and release(). + # + # _checkpid() determines whether the process has forked by comparing + # the current process id to the process id saved on the ConnectionPool + # instance. if these values are the same, _checkpid() simply returns. + # + # when the process ids differ, _checkpid() assumes that the process + # has forked and that we're now running in the child process. the child + # process cannot use the parent's file descriptors (e.g., sockets). + # therefore, when _checkpid() sees the process id change, it calls + # reset() in order to reinitialize the child's ConnectionPool. this + # will cause the child to make all new connection objects. + # + # _checkpid() is protected by self._fork_lock to ensure that multiple + # threads in the child process do not call reset() multiple times. + # + # there is an extremely small chance this could fail in the following + # scenario: + # 1. process A calls _checkpid() for the first time and acquires + # self._fork_lock. + # 2. while holding self._fork_lock, process A forks (the fork() + # could happen in a different thread owned by process A) + # 3. process B (the forked child process) inherits the + # ConnectionPool's state from the parent. that state includes + # a locked _fork_lock. process B will not be notified when + # process A releases the _fork_lock and will thus never be + # able to acquire the _fork_lock. + # + # to mitigate this possible deadlock, _checkpid() will only wait 5 + # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in + # that time it is assumed that the child is deadlocked and a + # redis.ChildDeadlockedError error is raised. + if self.pid != os.getpid(): + acquired = self._fork_lock.acquire(timeout=5) + if not acquired: + raise ChildDeadlockedError + # reset() the instance for the new process if another thread + # hasn't already done so + try: + if self.pid != os.getpid(): + self.reset() + finally: + self._fork_lock.release() + + def get_connection(self, command_name: str, *keys, **options) -> "Connection": + "Get a connection from the pool" + self._checkpid() + with self._lock: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + # ensure this connection is connected to Redis + connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if connection.can_read(): + raise ConnectionError("Connection has data") + except (ConnectionError, OSError): + connection.disconnect() + connection.connect() + if connection.can_read(): + raise ConnectionError("Connection not ready") + except BaseException: + # release the connection back to the pool so that we don't + # leak it + self.release(connection) + raise + + return connection + + def get_encoder(self) -> Encoder: + "Return an encoder based on encoding settings" + kwargs = self.connection_kwargs + return Encoder( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), + ) + + def make_connection(self) -> "Connection": + "Create a new connection" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + def release(self, connection: "Connection") -> None: + "Releases the connection back to the pool" + self._checkpid() + with self._lock: + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + pass + + if self.owns_connection(connection): + self._available_connections.append(connection) + else: + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 + connection.disconnect() + return + + def owns_connection(self, connection: "Connection") -> int: + return connection.pid == self.pid + + def disconnect(self, inuse_connections: bool = True) -> None: + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other threads. Otherwise only disconnect + connections that are idle in the pool. + """ + self._checkpid() + with self._lock: + if inuse_connections: + connections = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + + for connection in connections: + connection.disconnect() + + def close(self) -> None: + """Close the pool, disconnecting all connections""" + self.disconnect() + + def set_retry(self, retry: "Retry") -> None: + self.connection_kwargs.update({"retry": retry}) + for conn in self._available_connections: + conn.retry = retry + for conn in self._in_use_connections: + conn.retry = retry + + +class BlockingConnectionPool(ConnectionPool): + """ + Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, + ): + self.queue_class = queue_class + self.timeout = timeout + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + + def reset(self): + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def make_connection(self): + "Make a fresh connection." + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + def get_connection(self, command_name, *keys, **options): + """ + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + try: + # ensure this connection is connected to Redis + connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if connection.can_read(): + raise ConnectionError("Connection has data") + except (ConnectionError, OSError): + connection.disconnect() + connection.connect() + if connection.can_read(): + raise ConnectionError("Connection not ready") + except BaseException: + # release the connection back to the pool so that we don't leak it + self.release(connection) + raise + + return connection + + def release(self, connection): + "Releases the connection back to the pool." + # Make sure we haven't changed process. + self._checkpid() + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + + def disconnect(self): + "Disconnects all connections in the pool." + self._checkpid() + for connection in self._connections: + connection.disconnect() diff --git a/.venv/Lib/site-packages/redis/crc.py b/.venv/Lib/site-packages/redis/crc.py new file mode 100644 index 00000000..e2612411 --- /dev/null +++ b/.venv/Lib/site-packages/redis/crc.py @@ -0,0 +1,23 @@ +from binascii import crc_hqx + +from redis.typing import EncodedT + +# Redis Cluster's key space is divided into 16384 slots. +# For more information see: https://github.com/redis/redis/issues/2576 +REDIS_CLUSTER_HASH_SLOTS = 16384 + +__all__ = ["key_slot", "REDIS_CLUSTER_HASH_SLOTS"] + + +def key_slot(key: EncodedT, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: + """Calculate key slot for a given key. + See Keys distribution model in https://redis.io/topics/cluster-spec + :param key - bytes + :param bucket - int + """ + start = key.find(b"{") + if start > -1: + end = key.find(b"}", start + 1) + if end > -1 and end != start + 1: + key = key[start + 1 : end] + return crc_hqx(key, 0) % bucket diff --git a/.venv/Lib/site-packages/redis/credentials.py b/.venv/Lib/site-packages/redis/credentials.py new file mode 100644 index 00000000..7ba26dcd --- /dev/null +++ b/.venv/Lib/site-packages/redis/credentials.py @@ -0,0 +1,26 @@ +from typing import Optional, Tuple, Union + + +class CredentialProvider: + """ + Credentials Provider. + """ + + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + raise NotImplementedError("get_credentials must be implemented") + + +class UsernamePasswordCredentialProvider(CredentialProvider): + """ + Simple implementation of CredentialProvider that just wraps static + username and password. + """ + + def __init__(self, username: Optional[str] = None, password: Optional[str] = None): + self.username = username or "" + self.password = password or "" + + def get_credentials(self): + if self.username: + return self.username, self.password + return (self.password,) diff --git a/.venv/Lib/site-packages/redis/exceptions.py b/.venv/Lib/site-packages/redis/exceptions.py new file mode 100644 index 00000000..7cf15a7d --- /dev/null +++ b/.venv/Lib/site-packages/redis/exceptions.py @@ -0,0 +1,218 @@ +"Core exceptions raised by the Redis client" + + +class RedisError(Exception): + pass + + +class ConnectionError(RedisError): + pass + + +class TimeoutError(RedisError): + pass + + +class AuthenticationError(ConnectionError): + pass + + +class AuthorizationError(ConnectionError): + pass + + +class BusyLoadingError(ConnectionError): + pass + + +class InvalidResponse(RedisError): + pass + + +class ResponseError(RedisError): + pass + + +class DataError(RedisError): + pass + + +class PubSubError(RedisError): + pass + + +class WatchError(RedisError): + pass + + +class NoScriptError(ResponseError): + pass + + +class OutOfMemoryError(ResponseError): + """ + Indicates the database is full. Can only occur when either: + * Redis maxmemory-policy=noeviction + * Redis maxmemory-policy=volatile* and there are no evictable keys + + For more information see `Memory optimization in Redis `_. # noqa + """ + + pass + + +class ExecAbortError(ResponseError): + pass + + +class ReadOnlyError(ResponseError): + pass + + +class NoPermissionError(ResponseError): + pass + + +class ModuleError(ResponseError): + pass + + +class LockError(RedisError, ValueError): + "Errors acquiring or releasing a lock" + # NOTE: For backwards compatibility, this class derives from ValueError. + # This was originally chosen to behave like threading.Lock. + pass + + +class LockNotOwnedError(LockError): + "Error trying to extend or release a lock that is (no longer) owned" + pass + + +class ChildDeadlockedError(Exception): + "Error indicating that a child process is deadlocked after a fork()" + pass + + +class AuthenticationWrongNumberOfArgsError(ResponseError): + """ + An error to indicate that the wrong number of args + were sent to the AUTH command + """ + + pass + + +class RedisClusterException(Exception): + """ + Base exception for the RedisCluster client + """ + + pass + + +class ClusterError(RedisError): + """ + Cluster errors occurred multiple times, resulting in an exhaustion of the + command execution TTL + """ + + pass + + +class ClusterDownError(ClusterError, ResponseError): + """ + Error indicated CLUSTERDOWN error received from cluster. + By default Redis Cluster nodes stop accepting queries if they detect there + is at least a hash slot uncovered (no available node is serving it). + This way if the cluster is partially down (for example a range of hash + slots are no longer covered) the entire cluster eventually becomes + unavailable. It automatically returns available as soon as all the slots + are covered again. + """ + + def __init__(self, resp): + self.args = (resp,) + self.message = resp + + +class AskError(ResponseError): + """ + Error indicated ASK error received from cluster. + When a slot is set as MIGRATING, the node will accept all queries that + pertain to this hash slot, but only if the key in question exists, + otherwise the query is forwarded using a -ASK redirection to the node that + is target of the migration. + + src node: MIGRATING to dst node + get > ASK error + ask dst node > ASKING command + dst node: IMPORTING from src node + asking command only affects next command + any op will be allowed after asking command + """ + + def __init__(self, resp): + """should only redirect to master node""" + self.args = (resp,) + self.message = resp + slot_id, new_node = resp.split(" ") + host, port = new_node.rsplit(":", 1) + self.slot_id = int(slot_id) + self.node_addr = self.host, self.port = host, int(port) + + +class TryAgainError(ResponseError): + """ + Error indicated TRYAGAIN error received from cluster. + Operations on keys that don't exist or are - during resharding - split + between the source and destination nodes, will generate a -TRYAGAIN error. + """ + + def __init__(self, *args, **kwargs): + pass + + +class ClusterCrossSlotError(ResponseError): + """ + Error indicated CROSSSLOT error received from cluster. + A CROSSSLOT error is generated when keys in a request don't hash to the + same slot. + """ + + message = "Keys in request don't hash to the same slot" + + +class MovedError(AskError): + """ + Error indicated MOVED error received from cluster. + A request sent to a node that doesn't serve this key will be replayed with + a MOVED error that points to the correct node. + """ + + pass + + +class MasterDownError(ClusterDownError): + """ + Error indicated MASTERDOWN error received from cluster. + Link with MASTER is down and replica-serve-stale-data is set to 'no'. + """ + + pass + + +class SlotNotCoveredError(RedisClusterException): + """ + This error only happens in the case where the connection pool will try to + fetch what node that is covered by a given slot. + + If this error is raised the client should drop the current node layout and + attempt to reconnect and refresh the node layout again + """ + + pass + + +class MaxConnectionsError(ConnectionError): + ... diff --git a/.venv/Lib/site-packages/redis/lock.py b/.venv/Lib/site-packages/redis/lock.py new file mode 100644 index 00000000..4cca102d --- /dev/null +++ b/.venv/Lib/site-packages/redis/lock.py @@ -0,0 +1,308 @@ +import threading +import time as mod_time +import uuid +from types import SimpleNamespace, TracebackType +from typing import Optional, Type + +from redis.exceptions import LockError, LockNotOwnedError +from redis.typing import Number + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis, + name: str, + timeout: Optional[Number] = None, + sleep: Number = 0.1, + blocking: bool = True, + blocking_timeout: Optional[Number] = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock in seconds. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep in seconds per loop + iteration when the lock is in blocking mode and another client is + currently holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self) -> None: + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + def __enter__(self) -> "Lock": + if self.acquire(): + return self + raise LockError("Unable to acquire lock within the time specified") + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.release() + + def acquire( + self, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + token: Optional[str] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + if sleep is None: + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = mod_time.monotonic() + blocking_timeout + while True: + if self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = mod_time.monotonic() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + mod_time.sleep(sleep) + + def do_acquire(self, token: str) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return self.redis.get(self.name) is not None + + def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> None: + """ + Releases the already acquired lock + """ + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + self.do_release(expected_token) + + def do_release(self, expected_token: str) -> None: + if not bool( + self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) + ): + raise LockNotOwnedError("Cannot release a lock that's no longer owned") + + def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, "1" if replace_ttl else "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's no longer owned") + return True + + def reacquire(self) -> bool: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") + return True diff --git a/.venv/Lib/site-packages/redis/ocsp.py b/.venv/Lib/site-packages/redis/ocsp.py new file mode 100644 index 00000000..b0420b47 --- /dev/null +++ b/.venv/Lib/site-packages/redis/ocsp.py @@ -0,0 +1,307 @@ +import base64 +import datetime +import ssl +from urllib.parse import urljoin, urlparse + +import cryptography.hazmat.primitives.hashes +import requests +from cryptography import hazmat, x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat import backends +from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA, EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from cryptography.hazmat.primitives.hashes import SHA1, Hash +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +from cryptography.x509 import ocsp +from redis.exceptions import AuthorizationError, ConnectionError + + +def _verify_response(issuer_cert, ocsp_response): + pubkey = issuer_cert.public_key() + try: + if isinstance(pubkey, RSAPublicKey): + pubkey.verify( + ocsp_response.signature, + ocsp_response.tbs_response_bytes, + PKCS1v15(), + ocsp_response.signature_hash_algorithm, + ) + elif isinstance(pubkey, DSAPublicKey): + pubkey.verify( + ocsp_response.signature, + ocsp_response.tbs_response_bytes, + ocsp_response.signature_hash_algorithm, + ) + elif isinstance(pubkey, EllipticCurvePublicKey): + pubkey.verify( + ocsp_response.signature, + ocsp_response.tbs_response_bytes, + ECDSA(ocsp_response.signature_hash_algorithm), + ) + else: + pubkey.verify(ocsp_response.signature, ocsp_response.tbs_response_bytes) + except InvalidSignature: + raise ConnectionError("failed to valid ocsp response") + + +def _check_certificate(issuer_cert, ocsp_bytes, validate=True): + """A wrapper the return the validity of a known ocsp certificate""" + + ocsp_response = ocsp.load_der_ocsp_response(ocsp_bytes) + + if ocsp_response.response_status == ocsp.OCSPResponseStatus.UNAUTHORIZED: + raise AuthorizationError("you are not authorized to view this ocsp certificate") + if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL: + if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD: + raise ConnectionError( + f'Received an {str(ocsp_response.certificate_status).split(".")[1]} ' + "ocsp certificate status" + ) + else: + raise ConnectionError( + "failed to retrieve a sucessful response from the ocsp responder" + ) + + if ocsp_response.this_update >= datetime.datetime.now(): + raise ConnectionError("ocsp certificate was issued in the future") + + if ( + ocsp_response.next_update + and ocsp_response.next_update < datetime.datetime.now() + ): + raise ConnectionError("ocsp certificate has invalid update - in the past") + + responder_name = ocsp_response.responder_name + issuer_hash = ocsp_response.issuer_key_hash + responder_hash = ocsp_response.responder_key_hash + + cert_to_validate = issuer_cert + if ( + responder_name is not None + and responder_name == issuer_cert.subject + or responder_hash == issuer_hash + ): + cert_to_validate = issuer_cert + else: + certs = ocsp_response.certificates + responder_certs = _get_certificates( + certs, issuer_cert, responder_name, responder_hash + ) + + try: + responder_cert = responder_certs[0] + except IndexError: + raise ConnectionError("no certificates found for the responder") + + ext = responder_cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage) + if ext is None or x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING not in ext.value: + raise ConnectionError("delegate not autorized for ocsp signing") + cert_to_validate = responder_cert + + if validate: + _verify_response(cert_to_validate, ocsp_response) + return True + + +def _get_certificates(certs, issuer_cert, responder_name, responder_hash): + if responder_name is None: + certificates = [ + c + for c in certs + if _get_pubkey_hash(c) == responder_hash and c.issuer == issuer_cert.subject + ] + else: + certificates = [ + c + for c in certs + if c.subject == responder_name and c.issuer == issuer_cert.subject + ] + + return certificates + + +def _get_pubkey_hash(certificate): + pubkey = certificate.public_key() + + # https://stackoverflow.com/a/46309453/600498 + if isinstance(pubkey, RSAPublicKey): + h = pubkey.public_bytes(Encoding.DER, PublicFormat.PKCS1) + elif isinstance(pubkey, EllipticCurvePublicKey): + h = pubkey.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) + else: + h = pubkey.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) + + sha1 = Hash(SHA1(), backend=backends.default_backend()) + sha1.update(h) + return sha1.finalize() + + +def ocsp_staple_verifier(con, ocsp_bytes, expected=None): + """An implemention of a function for set_ocsp_client_callback in PyOpenSSL. + + This function validates that the provide ocsp_bytes response is valid, + and matches the expected, stapled responses. + """ + if ocsp_bytes in [b"", None]: + raise ConnectionError("no ocsp response present") + + issuer_cert = None + peer_cert = con.get_peer_certificate().to_cryptography() + for c in con.get_peer_cert_chain(): + cert = c.to_cryptography() + if cert.subject == peer_cert.issuer: + issuer_cert = cert + break + + if issuer_cert is None: + raise ConnectionError("no matching issuer cert found in certificate chain") + + if expected is not None: + e = x509.load_pem_x509_certificate(expected) + if peer_cert != e: + raise ConnectionError("received and expected certificates do not match") + + return _check_certificate(issuer_cert, ocsp_bytes) + + +class OCSPVerifier: + """A class to verify ssl sockets for RFC6960/RFC6961. This can be used + when using direct validation of OCSP responses and certificate revocations. + + @see https://datatracker.ietf.org/doc/html/rfc6960 + @see https://datatracker.ietf.org/doc/html/rfc6961 + """ + + def __init__(self, sock, host, port, ca_certs=None): + self.SOCK = sock + self.HOST = host + self.PORT = port + self.CA_CERTS = ca_certs + + def _bin2ascii(self, der): + """Convert SSL certificates in a binary (DER) format to ASCII PEM.""" + + pem = ssl.DER_cert_to_PEM_cert(der) + cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend()) + return cert + + def components_from_socket(self): + """This function returns the certificate, primary issuer, and primary ocsp + server in the chain for a socket already wrapped with ssl. + """ + + # convert the binary certifcate to text + der = self.SOCK.getpeercert(True) + if der is False: + raise ConnectionError("no certificate found for ssl peer") + cert = self._bin2ascii(der) + return self._certificate_components(cert) + + def _certificate_components(self, cert): + """Given an SSL certificate, retract the useful components for + validating the certificate status with an OCSP server. + + Args: + cert ([bytes]): A PEM encoded ssl certificate + """ + + try: + aia = cert.extensions.get_extension_for_oid( + x509.oid.ExtensionOID.AUTHORITY_INFORMATION_ACCESS + ).value + except cryptography.x509.extensions.ExtensionNotFound: + raise ConnectionError("No AIA information present in ssl certificate") + + # fetch certificate issuers + issuers = [ + i + for i in aia + if i.access_method == x509.oid.AuthorityInformationAccessOID.CA_ISSUERS + ] + try: + issuer = issuers[0].access_location.value + except IndexError: + issuer = None + + # now, the series of ocsp server entries + ocsps = [ + i + for i in aia + if i.access_method == x509.oid.AuthorityInformationAccessOID.OCSP + ] + + try: + ocsp = ocsps[0].access_location.value + except IndexError: + raise ConnectionError("no ocsp servers in certificate") + + return cert, issuer, ocsp + + def components_from_direct_connection(self): + """Return the certificate, primary issuer, and primary ocsp server + from the host defined by the socket. This is useful in cases where + different certificates are occasionally presented. + """ + + pem = ssl.get_server_certificate((self.HOST, self.PORT), ca_certs=self.CA_CERTS) + cert = x509.load_pem_x509_certificate(pem.encode(), backends.default_backend()) + return self._certificate_components(cert) + + def build_certificate_url(self, server, cert, issuer_cert): + """Return the complete url to the ocsp""" + orb = ocsp.OCSPRequestBuilder() + + # add_certificate returns an initialized OCSPRequestBuilder + orb = orb.add_certificate( + cert, issuer_cert, cryptography.hazmat.primitives.hashes.SHA256() + ) + request = orb.build() + + path = base64.b64encode( + request.public_bytes(hazmat.primitives.serialization.Encoding.DER) + ) + url = urljoin(server, path.decode("ascii")) + return url + + def check_certificate(self, server, cert, issuer_url): + """Checks the validitity of an ocsp server for an issuer""" + + r = requests.get(issuer_url) + if not r.ok: + raise ConnectionError("failed to fetch issuer certificate") + der = r.content + issuer_cert = self._bin2ascii(der) + + ocsp_url = self.build_certificate_url(server, cert, issuer_cert) + + # HTTP 1.1 mandates the addition of the Host header in ocsp responses + header = { + "Host": urlparse(ocsp_url).netloc, + "Content-Type": "application/ocsp-request", + } + r = requests.get(ocsp_url, headers=header) + if not r.ok: + raise ConnectionError("failed to fetch ocsp certificate") + return _check_certificate(issuer_cert, r.content, True) + + def is_valid(self): + """Returns the validity of the certificate wrapping our socket. + This first retrieves for validate the certificate, issuer_url, + and ocsp_server for certificate validate. Then retrieves the + issuer certificate from the issuer_url, and finally checks + the validity of OCSP revocation status. + """ + + # validate the certificate + try: + cert, issuer_url, ocsp_server = self.components_from_socket() + if issuer_url is None: + raise ConnectionError("no issuers found in certificate chain") + return self.check_certificate(ocsp_server, cert, issuer_url) + except AuthorizationError: + cert, issuer_url, ocsp_server = self.components_from_direct_connection() + if issuer_url is None: + raise ConnectionError("no issuers found in certificate chain") + return self.check_certificate(ocsp_server, cert, issuer_url) diff --git a/.venv/Lib/site-packages/redis/py.typed b/.venv/Lib/site-packages/redis/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/redis/retry.py b/.venv/Lib/site-packages/redis/retry.py new file mode 100644 index 00000000..60644305 --- /dev/null +++ b/.venv/Lib/site-packages/redis/retry.py @@ -0,0 +1,54 @@ +import socket +from time import sleep + +from redis.exceptions import ConnectionError, TimeoutError + + +class Retry: + """Retry a specific number of times after a failure""" + + def __init__( + self, + backoff, + retries, + supported_errors=(ConnectionError, TimeoutError, socket.timeout), + ): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + `retries` can be negative to retry forever. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + def update_supported_errors(self, specified_errors: list): + """ + Updates the supported errors with the specified error types + """ + self._supported_errors = tuple( + set(self._supported_errors + tuple(specified_errors)) + ) + + def call_with_retry(self, do, fail): + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return do() + except self._supported_errors as error: + failures += 1 + fail(error) + if self._retries >= 0 and failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + sleep(backoff) diff --git a/.venv/Lib/site-packages/redis/sentinel.py b/.venv/Lib/site-packages/redis/sentinel.py new file mode 100644 index 00000000..41f308d1 --- /dev/null +++ b/.venv/Lib/site-packages/redis/sentinel.py @@ -0,0 +1,389 @@ +import random +import weakref +from typing import Optional + +from redis.client import Redis +from redis.commands import SentinelCommands +from redis.connection import Connection, ConnectionPool, SSLConnection +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{type(self).__name__}" + if self.host: + host_info = f",host={self.host},port={self.port}" + s = s % host_info + return s + + def connect_to(self, address): + self.host, self.port = address + super().connect() + if self.connection_pool.check_connection: + self.send_command("PING") + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + def _connect_retry(self): + if self._sock: + return # already connected + if self.connection_pool.is_master: + self.connect_to(self.connection_pool.get_master_address()) + else: + for slave in self.connection_pool.rotate_slaves(): + try: + return self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + def connect(self): + return self.retry.call_with_retry(self._connect_retry, lambda error: None) + + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error: Optional[bool] = False, + push_request: Optional[bool] = False, + ): + try: + return super().read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request, + ) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPoolProxy: + def __init__( + self, + connection_pool, + is_master, + check_connection, + service_name, + sentinel_manager, + ): + self.connection_pool_ref = weakref.ref(connection_pool) + self.is_master = is_master + self.check_connection = check_connection + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.reset() + + def reset(self): + self.master_address = None + self.slave_rr_counter = None + + def get_master_address(self): + master_address = self.sentinel_manager.discover_master(self.service_name) + if self.is_master and self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + connection_pool = self.connection_pool_ref() + if connection_pool is not None: + connection_pool.disconnect(inuse_connections=False) + return master_address + + def rotate_slaves(self): + slaves = self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + self.proxy = SentinelConnectionPoolProxy( + connection_pool=self, + is_master=self.is_master, + check_connection=self.check_connection, + service_name=service_name, + sentinel_manager=sentinel_manager, + ) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = self.proxy + self.service_name = service_name + self.sentinel_manager = sentinel_manager + + def __repr__(self): + role = "master" if self.is_master else "slave" + return f"{type(self).__name__}>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(hostname, port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + random.choice(self.sentinels).execute_command(*args, **kwargs) + else: + for sentinel in self.sentinels: + sentinel.execute_command(*args, **kwargs) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + "{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs) + ) + return f'{type(self).__name__}' + + def check_master_state(self, state, service_name): + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + def discover_master(self, service_name): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + collected_errors = list() + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = sentinel.sentinel_masters() + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") + + def filter_slaves(self, slaves): + "Remove slaves that are in an ODOWN or SDOWN state" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + def discover_slaves(self, service_name): + "Returns a list of alive slaves for service ``service_name``" + for sentinel in self.sentinels: + try: + slaves = sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name, + redis_class=Redis, + connection_pool_class=SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) + ) + + def slave_for( + self, + service_name, + redis_class=Redis, + connection_pool_class=SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) + ) diff --git a/.venv/Lib/site-packages/redis/typing.py b/.venv/Lib/site-packages/redis/typing.py new file mode 100644 index 00000000..56a1e99b --- /dev/null +++ b/.venv/Lib/site-packages/redis/typing.py @@ -0,0 +1,65 @@ +# from __future__ import annotations + +from datetime import datetime, timedelta +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Iterable, + Mapping, + Type, + TypeVar, + Union, +) + +from redis.compat import Protocol + +if TYPE_CHECKING: + from redis._parsers import Encoder + from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool + from redis.connection import ConnectionPool + + +Number = Union[int, float] +EncodedT = Union[bytes, memoryview] +DecodedT = Union[str, int, float] +EncodableT = Union[EncodedT, DecodedT] +AbsExpiryT = Union[int, datetime] +ExpiryT = Union[int, timedelta] +ZScoreBoundT = Union[float, str] # str allows for the [ or ( prefix +BitfieldOffsetT = Union[int, str] # str allows for #x syntax +_StringLikeT = Union[bytes, str, memoryview] +KeyT = _StringLikeT # Main redis key space +PatternT = _StringLikeT # Patterns matched against keys, fields etc +FieldT = EncodableT # Fields within hash tables, streams and geo commands +KeysT = Union[KeyT, Iterable[KeyT]] +ChannelT = _StringLikeT +GroupT = _StringLikeT # Consumer group +ConsumerT = _StringLikeT # Consumer name +StreamIdT = Union[int, _StringLikeT] +ScriptTextT = _StringLikeT +TimeoutSecT = Union[int, float, _StringLikeT] +# Mapping is not covariant in the key type, which prevents +# Mapping[_StringLikeT, X] from accepting arguments of type Dict[str, X]. Using +# a TypeVar instead of a Union allows mappings with any of the permitted types +# to be passed. Care is needed if there is more than one such mapping in a +# type signature because they will all be required to be the same key type. +AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) +AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) +AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) + +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + + +class CommandsProtocol(Protocol): + connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] + + def execute_command(self, *args, **options): + ... + + +class ClusterCommandsProtocol(CommandsProtocol, Protocol): + encoder: "Encoder" + + def execute_command(self, *args, **options) -> Union[Any, Awaitable]: + ... diff --git a/.venv/Lib/site-packages/redis/utils.py b/.venv/Lib/site-packages/redis/utils.py new file mode 100644 index 00000000..01fdfed7 --- /dev/null +++ b/.venv/Lib/site-packages/redis/utils.py @@ -0,0 +1,147 @@ +import logging +import sys +from contextlib import contextmanager +from functools import wraps +from typing import Any, Dict, Mapping, Union + +try: + import hiredis # noqa + + # Only support Hiredis >= 1.0: + HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") + HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") +except ImportError: + HIREDIS_AVAILABLE = False + HIREDIS_PACK_AVAILABLE = False + +try: + import ssl # noqa + + SSL_AVAILABLE = True +except ImportError: + SSL_AVAILABLE = False + +try: + import cryptography # noqa + + CRYPTOGRAPHY_AVAILABLE = True +except ImportError: + CRYPTOGRAPHY_AVAILABLE = False + +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.client import Redis + + return Redis.from_url(url, **kwargs) + + +@contextmanager +def pipeline(redis_obj): + p = redis_obj.pipeline() + yield p + p.execute() + + +def str_if_bytes(value: Union[str, bytes]) -> str: + return ( + value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value + ) + + +def safe_str(value): + return str(str_if_bytes(value)) + + +def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: + """ + Merge all provided dicts into 1 dict. + *dicts : `dict` + dictionaries to merge + """ + merged = {} + + for d in dicts: + merged.update(d) + + return merged + + +def list_keys_to_dict(key_list, callback): + return dict.fromkeys(key_list, callback) + + +def merge_result(command, res): + """ + Merge all items in `res` into a list. + + This command is used when sending a command to multiple nodes + and the result from each node should be merged into a single list. + + res : 'dict' + """ + result = set() + + for v in res.values(): + for value in v: + result.add(value) + + return list(result) + + +def warn_deprecated(name, reason="", version="", stacklevel=2): + import warnings + + msg = f"Call to deprecated {name}." + if reason: + msg += f" ({reason})" + if version: + msg += f" -- Deprecated since version {version}." + warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) + + +def deprecated_function(reason="", version="", name=None): + """ + Decorator to mark a function as deprecated. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + warn_deprecated(name or func.__name__, reason, version, stacklevel=3) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def _set_info_logger(): + """ + Set up a logger that log info logs to stdout. + (This is used by the default push response handler) + """ + if "push_response" not in logging.root.manager.loggerDict.keys(): + logger = logging.getLogger("push_response") + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) + + +def get_lib_version(): + try: + libver = metadata.version("redis") + except metadata.PackageNotFoundError: + libver = "99.99.99" + return libver diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/INSTALLER b/.venv/Lib/site-packages/six-1.16.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/LICENSE b/.venv/Lib/site-packages/six-1.16.0.dist-info/LICENSE new file mode 100644 index 00000000..de663311 --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/LICENSE @@ -0,0 +1,18 @@ +Copyright (c) 2010-2020 Benjamin Peterson + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/METADATA b/.venv/Lib/site-packages/six-1.16.0.dist-info/METADATA new file mode 100644 index 00000000..6d7525c2 --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/METADATA @@ -0,0 +1,49 @@ +Metadata-Version: 2.1 +Name: six +Version: 1.16.0 +Summary: Python 2 and 3 compatibility utilities +Home-page: https://github.com/benjaminp/six +Author: Benjamin Peterson +Author-email: benjamin@python.org +License: MIT +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 3 +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Utilities +Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.* + +.. image:: https://img.shields.io/pypi/v/six.svg + :target: https://pypi.org/project/six/ + :alt: six on PyPI + +.. image:: https://travis-ci.org/benjaminp/six.svg?branch=master + :target: https://travis-ci.org/benjaminp/six + :alt: six on TravisCI + +.. image:: https://readthedocs.org/projects/six/badge/?version=latest + :target: https://six.readthedocs.io/ + :alt: six's documentation on Read the Docs + +.. image:: https://img.shields.io/badge/license-MIT-green.svg + :target: https://github.com/benjaminp/six/blob/master/LICENSE + :alt: MIT License badge + +Six is a Python 2 and 3 compatibility library. It provides utility functions +for smoothing over the differences between the Python versions with the goal of +writing Python code that is compatible on both Python versions. See the +documentation for more information on what is provided. + +Six supports Python 2.7 and 3.3+. It is contained in only one Python +file, so it can be easily copied into your project. (The copyright and license +notice must be retained.) + +Online documentation is at https://six.readthedocs.io/. + +Bugs can be reported to https://github.com/benjaminp/six. The code can also +be found there. + + diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/RECORD b/.venv/Lib/site-packages/six-1.16.0.dist-info/RECORD new file mode 100644 index 00000000..28b18f29 --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/RECORD @@ -0,0 +1,8 @@ +__pycache__/six.cpython-311.pyc,, +six-1.16.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +six-1.16.0.dist-info/LICENSE,sha256=i7hQxWWqOJ_cFvOkaWWtI9gq3_YPI5P8J2K2MYXo5sk,1066 +six-1.16.0.dist-info/METADATA,sha256=VQcGIFCAEmfZcl77E5riPCN4v2TIsc_qtacnjxKHJoI,1795 +six-1.16.0.dist-info/RECORD,, +six-1.16.0.dist-info/WHEEL,sha256=Z-nyYpwrcSqxfdux5Mbn_DQ525iP7J2DG3JgGvOYyTQ,110 +six-1.16.0.dist-info/top_level.txt,sha256=_iVH_iYEtEXnD8nYGQYpYFUvkUW9sEO1GYbkeKSAais,4 +six.py,sha256=TOOfQi7nFGfMrIvtdr6wX4wyHH8M7aknmuLfo2cBBrM,34549 diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/WHEEL b/.venv/Lib/site-packages/six-1.16.0.dist-info/WHEEL new file mode 100644 index 00000000..01b8fc7d --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/WHEEL @@ -0,0 +1,6 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.36.2) +Root-Is-Purelib: true +Tag: py2-none-any +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/six-1.16.0.dist-info/top_level.txt b/.venv/Lib/site-packages/six-1.16.0.dist-info/top_level.txt new file mode 100644 index 00000000..ffe2fce4 --- /dev/null +++ b/.venv/Lib/site-packages/six-1.16.0.dist-info/top_level.txt @@ -0,0 +1 @@ +six diff --git a/.venv/Lib/site-packages/six.py b/.venv/Lib/site-packages/six.py new file mode 100644 index 00000000..4e15675d --- /dev/null +++ b/.venv/Lib/site-packages/six.py @@ -0,0 +1,998 @@ +# Copyright (c) 2010-2020 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson " +__version__ = "1.16.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform.startswith("java"): + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X + +if PY34: + from importlib.util import spec_from_loader +else: + spec_from_loader = None + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) # Invokes __set__. + try: + # This is a bit ugly, but it avoids running this again by + # removing this descriptor. + delattr(obj.__class__, self.name) + except AttributeError: + pass + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + def __getattr__(self, attr): + _module = self._resolve() + value = getattr(_module, attr) + setattr(self, attr, value) + return value + + +class _LazyModule(types.ModuleType): + + def __init__(self, name): + super(_LazyModule, self).__init__(name) + self.__doc__ = self.__class__.__doc__ + + def __dir__(self): + attrs = ["__doc__", "__name__"] + attrs += [attr.name for attr in self._moved_attributes] + return attrs + + # Subclasses should override this + _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + + """ + A meta path importer to import six.moves and its submodules. + + This class implements a PEP302 finder and loader. It should be compatible + with Python 2.5 and all existing versions of Python3 + """ + + def __init__(self, six_module_name): + self.name = six_module_name + self.known_modules = {} + + def _add_module(self, mod, *fullnames): + for fullname in fullnames: + self.known_modules[self.name + "." + fullname] = mod + + def _get_module(self, fullname): + return self.known_modules[self.name + "." + fullname] + + def find_module(self, fullname, path=None): + if fullname in self.known_modules: + return self + return None + + def find_spec(self, fullname, path, target=None): + if fullname in self.known_modules: + return spec_from_loader(fullname, self) + return None + + def __get_module(self, fullname): + try: + return self.known_modules[fullname] + except KeyError: + raise ImportError("This loader does not know module " + fullname) + + def load_module(self, fullname): + try: + # in case of a reload + return sys.modules[fullname] + except KeyError: + pass + mod = self.__get_module(fullname) + if isinstance(mod, MovedModule): + mod = mod._resolve() + else: + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + def is_package(self, fullname): + """ + Return true, if the named module is a package. + + We need this method to get correct spec objects with + Python 3.4 (see PEP451) + """ + return hasattr(self.__get_module(fullname), "__path__") + + def get_code(self, fullname): + """Return None + + Required, if is_package is implemented""" + self.__get_module(fullname) # eventually raises ImportError + return None + get_source = get_code # same as get_code + + def create_module(self, spec): + return self.load_module(spec.name) + + def exec_module(self, module): + pass + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + + """Lazy loading of moved objects""" + __path__ = [] # mark as package + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("intern", "__builtin__", "sys"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), + MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), + MovedAttribute("getoutput", "commands", "subprocess"), + MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("UserDict", "UserDict", "collections"), + MovedAttribute("UserList", "UserList", "collections"), + MovedAttribute("UserString", "UserString", "collections"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), + MovedModule("copyreg", "copy_reg"), + MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), + MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), + MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), + MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), + MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("_thread", "thread", "_thread"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), + MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), + MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), + MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": + _moved_attributes += [ + MovedModule("winreg", "_winreg"), + ] + +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) + if isinstance(attr, MovedModule): + _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ + MovedAttribute("ParseResult", "urlparse", "urllib.parse"), + MovedAttribute("SplitResult", "urlparse", "urllib.parse"), + MovedAttribute("parse_qs", "urlparse", "urllib.parse"), + MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), + MovedAttribute("urldefrag", "urlparse", "urllib.parse"), + MovedAttribute("urljoin", "urlparse", "urllib.parse"), + MovedAttribute("urlparse", "urlparse", "urllib.parse"), + MovedAttribute("urlsplit", "urlparse", "urllib.parse"), + MovedAttribute("urlunparse", "urlparse", "urllib.parse"), + MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), + MovedAttribute("quote", "urllib", "urllib.parse"), + MovedAttribute("quote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote", "urllib", "urllib.parse"), + MovedAttribute("unquote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), + MovedAttribute("urlencode", "urllib", "urllib.parse"), + MovedAttribute("splitquery", "urllib", "urllib.parse"), + MovedAttribute("splittag", "urllib", "urllib.parse"), + MovedAttribute("splituser", "urllib", "urllib.parse"), + MovedAttribute("splitvalue", "urllib", "urllib.parse"), + MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), + MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), + MovedAttribute("uses_params", "urlparse", "urllib.parse"), + MovedAttribute("uses_query", "urlparse", "urllib.parse"), + MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: + setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ + MovedAttribute("URLError", "urllib2", "urllib.error"), + MovedAttribute("HTTPError", "urllib2", "urllib.error"), + MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: + setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ + MovedAttribute("urlopen", "urllib2", "urllib.request"), + MovedAttribute("install_opener", "urllib2", "urllib.request"), + MovedAttribute("build_opener", "urllib2", "urllib.request"), + MovedAttribute("pathname2url", "urllib", "urllib.request"), + MovedAttribute("url2pathname", "urllib", "urllib.request"), + MovedAttribute("getproxies", "urllib", "urllib.request"), + MovedAttribute("Request", "urllib2", "urllib.request"), + MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), + MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), + MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), + MovedAttribute("BaseHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), + MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), + MovedAttribute("FileHandler", "urllib2", "urllib.request"), + MovedAttribute("FTPHandler", "urllib2", "urllib.request"), + MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), + MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), + MovedAttribute("urlretrieve", "urllib", "urllib.request"), + MovedAttribute("urlcleanup", "urllib", "urllib.request"), + MovedAttribute("URLopener", "urllib", "urllib.request"), + MovedAttribute("FancyURLopener", "urllib", "urllib.request"), + MovedAttribute("proxy_bypass", "urllib", "urllib.request"), + MovedAttribute("parse_http_list", "urllib2", "urllib.request"), + MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: + setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ + MovedAttribute("addbase", "urllib", "urllib.response"), + MovedAttribute("addclosehook", "urllib", "urllib.response"), + MovedAttribute("addinfo", "urllib", "urllib.response"), + MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: + setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ + MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: + setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + + """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package + parse = _importer._get_module("moves.urllib_parse") + error = _importer._get_module("moves.urllib_error") + request = _importer._get_module("moves.urllib_request") + response = _importer._get_module("moves.urllib_response") + robotparser = _importer._get_module("moves.urllib_robotparser") + + def __dir__(self): + return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), + "moves.urllib") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_closure = "__closure__" + _func_code = "__code__" + _func_defaults = "__defaults__" + _func_globals = "__globals__" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_closure = "func_closure" + _func_code = "func_code" + _func_defaults = "func_defaults" + _func_globals = "func_globals" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +try: + callable = callable +except NameError: + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: + def get_unbound_function(unbound): + return unbound + + create_bound_method = types.MethodType + + def create_unbound_method(func, cls): + return func + + Iterator = object +else: + def get_unbound_function(unbound): + return unbound.im_func + + def create_bound_method(func, obj): + return types.MethodType(func, obj, obj.__class__) + + def create_unbound_method(func, cls): + return types.MethodType(func, None, cls) + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: + def iterkeys(d, **kw): + return iter(d.keys(**kw)) + + def itervalues(d, **kw): + return iter(d.values(**kw)) + + def iteritems(d, **kw): + return iter(d.items(**kw)) + + def iterlists(d, **kw): + return iter(d.lists(**kw)) + + viewkeys = operator.methodcaller("keys") + + viewvalues = operator.methodcaller("values") + + viewitems = operator.methodcaller("items") +else: + def iterkeys(d, **kw): + return d.iterkeys(**kw) + + def itervalues(d, **kw): + return d.itervalues(**kw) + + def iteritems(d, **kw): + return d.iteritems(**kw) + + def iterlists(d, **kw): + return d.iterlists(**kw) + + viewkeys = operator.methodcaller("viewkeys") + + viewvalues = operator.methodcaller("viewvalues") + + viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, + "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, + "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: + def b(s): + return s.encode("latin-1") + + def u(s): + return s + unichr = chr + import struct + int2byte = struct.Struct(">B").pack + del struct + byte2int = operator.itemgetter(0) + indexbytes = operator.getitem + iterbytes = iter + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + del io + _assertCountEqual = "assertCountEqual" + if sys.version_info[1] <= 1: + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + _assertNotRegex = "assertNotRegexpMatches" + else: + _assertRaisesRegex = "assertRaisesRegex" + _assertRegex = "assertRegex" + _assertNotRegex = "assertNotRegex" +else: + def b(s): + return s + # Workaround for standalone backslash + + def u(s): + return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + unichr = unichr + int2byte = chr + + def byte2int(bs): + return ord(bs[0]) + + def indexbytes(buf, i): + return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) + import StringIO + StringIO = BytesIO = StringIO.StringIO + _assertCountEqual = "assertItemsEqual" + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + _assertNotRegex = "assertNotRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): + return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): + return getattr(self, _assertRegex)(*args, **kwargs) + + +def assertNotRegex(self, *args, **kwargs): + return getattr(self, _assertNotRegex)(*args, **kwargs) + + +if PY3: + exec_ = getattr(moves.builtins, "exec") + + def reraise(tp, value, tb=None): + try: + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None + tb = None + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + exec_("""def reraise(tp, value, tb=None): + try: + raise tp, value, tb + finally: + tb = None +""") + + +if sys.version_info[:2] > (3,): + exec_("""def raise_from(value, from_value): + try: + raise value from from_value + finally: + value = None +""") +else: + def raise_from(value, from_value): + raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: + def print_(*args, **kwargs): + """The new-style print function for Python 2.4 and 2.5.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + + def write(data): + if not isinstance(data, basestring): + data = str(data) + # If the file has an encoding, encode unicode with it. + if (isinstance(fp, file) and + isinstance(data, unicode) and + fp.encoding is not None): + errors = getattr(fp, "errors", None) + if errors is None: + errors = "strict" + data = data.encode(fp.encoding, errors) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) +if sys.version_info[:2] < (3, 3): + _print = print_ + + def print_(*args, **kwargs): + fp = kwargs.get("file", sys.stdout) + flush = kwargs.pop("flush", False) + _print(*args, **kwargs) + if flush and fp is not None: + fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): + # This does exactly the same what the :func:`py3:functools.update_wrapper` + # function does on Python versions after 3.2. It sets the ``__wrapped__`` + # attribute on ``wrapper`` object and it doesn't raise an error if any of + # the attributes mentioned in ``assigned`` and ``updated`` are missing on + # ``wrapped`` object. + def _update_wrapper(wrapper, wrapped, + assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + continue + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + wrapper.__wrapped__ = wrapped + return wrapper + _update_wrapper.__doc__ = functools.update_wrapper.__doc__ + + def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + return functools.partial(_update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) + wraps.__doc__ = functools.wraps.__doc__ + +else: + wraps = functools.wraps + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + if sys.version_info[:2] >= (3, 7): + # This version introduced PEP 560 that requires a bit + # of extra care (we mimic what is done by __build_class__). + resolved_bases = types.resolve_bases(bases) + if resolved_bases is not bases: + d['__orig_bases__'] = bases + else: + resolved_bases = bases + return meta(name, resolved_bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): + """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): + orig_vars = cls.__dict__.copy() + slots = orig_vars.get('__slots__') + if slots is not None: + if isinstance(slots, str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop('__dict__', None) + orig_vars.pop('__weakref__', None) + if hasattr(cls, '__qualname__'): + orig_vars['__qualname__'] = cls.__qualname__ + return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper + + +def ensure_binary(s, encoding='utf-8', errors='strict'): + """Coerce **s** to six.binary_type. + + For Python 2: + - `unicode` -> encoded to `str` + - `str` -> `str` + + For Python 3: + - `str` -> encoded to `bytes` + - `bytes` -> `bytes` + """ + if isinstance(s, binary_type): + return s + if isinstance(s, text_type): + return s.encode(encoding, errors) + raise TypeError("not expecting type '%s'" % type(s)) + + +def ensure_str(s, encoding='utf-8', errors='strict'): + """Coerce *s* to `str`. + + For Python 2: + - `unicode` -> encoded to `str` + - `str` -> `str` + + For Python 3: + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + # Optimization: Fast return for the common case. + if type(s) is str: + return s + if PY2 and isinstance(s, text_type): + return s.encode(encoding, errors) + elif PY3 and isinstance(s, binary_type): + return s.decode(encoding, errors) + elif not isinstance(s, (text_type, binary_type)): + raise TypeError("not expecting type '%s'" % type(s)) + return s + + +def ensure_text(s, encoding='utf-8', errors='strict'): + """Coerce *s* to six.text_type. + + For Python 2: + - `unicode` -> `unicode` + - `str` -> `unicode` + + For Python 3: + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, binary_type): + return s.decode(encoding, errors) + elif isinstance(s, text_type): + return s + else: + raise TypeError("not expecting type '%s'" % type(s)) + + +def python_2_unicode_compatible(klass): + """ + A class decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if PY2: + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = [] # required for PEP 302 and PEP 451 +__package__ = __name__ # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: + __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: + for i, importer in enumerate(sys.meta_path): + # Here's some real nastiness: Another "instance" of the six module might + # be floating around. Therefore, we can't use isinstance() to check for + # the six meta path importer, since the other six instance will have + # inserted an importer with different class. + if (type(importer).__name__ == "_SixMetaPathImporter" and + importer.name == __name__): + del sys.meta_path[i] + break + del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/INSTALLER b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/INSTALLER new file mode 100644 index 00000000..a1b589e3 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/LICENSE.md b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/LICENSE.md new file mode 100644 index 00000000..f121286c --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Paul Lamere + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/METADATA b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/METADATA new file mode 100644 index 00000000..beeb19c0 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/METADATA @@ -0,0 +1,100 @@ +Metadata-Version: 2.1 +Name: spotipy +Version: 2.23.0 +Summary: A light weight Python library for the Spotify Web API +Home-page: https://spotipy.readthedocs.org/ +Author: @plamere +Author-email: paul@echonest.com +License: MIT +Project-URL: Source, https://github.com/plamere/spotipy +Description-Content-Type: text/markdown +License-File: LICENSE.md +Requires-Dist: redis (>=3.5.3) +Requires-Dist: requests (>=2.25.0) +Requires-Dist: six (>=1.15.0) +Requires-Dist: urllib3 (>=1.26.0) +Requires-Dist: redis (<4.0.0) ; python_version < "3.4" +Provides-Extra: doc +Requires-Dist: Sphinx (>=1.5.2) ; extra == 'doc' +Provides-Extra: test +Requires-Dist: mock (==2.0.0) ; extra == 'test' + +# Spotipy + +##### A light weight Python library for the Spotify Web API + +![Tests](https://github.com/plamere/spotipy/workflows/Tests/badge.svg?branch=master) [![Documentation Status](https://readthedocs.org/projects/spotipy/badge/?version=latest)](https://spotipy.readthedocs.io/en/latest/?badge=latest) + +## Documentation + +Spotipy's full documentation is online at [Spotipy Documentation](http://spotipy.readthedocs.org/). + +## Installation + +```bash +pip install spotipy +``` + +alternatively, for Windows users + +```bash +py -m pip install spotipy +``` + +or upgrade + +```bash +pip install spotipy --upgrade +``` + +## Quick Start + +A full set of examples can be found in the [online documentation](http://spotipy.readthedocs.org/) and in the [Spotipy examples directory](https://github.com/plamere/spotipy/tree/master/examples). + +To get started, install spotipy and create an app on https://developers.spotify.com/. +Add your new ID and SECRET to your environment: + +### Without user authentication + +```python +import spotipy +from spotipy.oauth2 import SpotifyClientCredentials + +sp = spotipy.Spotify(auth_manager=SpotifyClientCredentials(client_id="YOUR_APP_CLIENT_ID", + client_secret="YOUR_APP_CLIENT_SECRET")) + +results = sp.search(q='weezer', limit=20) +for idx, track in enumerate(results['tracks']['items']): + print(idx, track['name']) +``` + +### With user authentication + +A redirect URI must be added to your application at [My Dashboard](https://developer.spotify.com/dashboard/applications) to access user authenticated features. + +```python +import spotipy +from spotipy.oauth2 import SpotifyOAuth + +sp = spotipy.Spotify(auth_manager=SpotifyOAuth(client_id="YOUR_APP_CLIENT_ID", + client_secret="YOUR_APP_CLIENT_SECRET", + redirect_uri="YOUR_APP_REDIRECT_URI", + scope="user-library-read")) + +results = sp.current_user_saved_tracks() +for idx, item in enumerate(results['items']): + track = item['track'] + print(idx, track['artists'][0]['name'], " – ", track['name']) +``` + +## Reporting Issues + +For common questions please check our [FAQ](FAQ.md). + +You can ask questions about Spotipy on +[Stack Overflow](http://stackoverflow.com/questions/ask). +Don’t forget to add the *Spotipy* tag, and any other relevant tags as well, before posting. + +If you have suggestions, bugs or other issues specific to this library, +file them [here](https://github.com/plamere/spotipy/issues). +Or just send a pull request. diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/RECORD b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/RECORD new file mode 100644 index 00000000..2802103a --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/RECORD @@ -0,0 +1,19 @@ +spotipy-2.23.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +spotipy-2.23.0.dist-info/LICENSE.md,sha256=tsEBFbMqRzu097t-Ecqy-G6uyDeOFEqFcrWtQuS3I_8,1068 +spotipy-2.23.0.dist-info/METADATA,sha256=I0-D4pcOm-EPF5NayigGceAVQ-Jrz4z47H4YatDSq8s,3253 +spotipy-2.23.0.dist-info/RECORD,, +spotipy-2.23.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +spotipy-2.23.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92 +spotipy-2.23.0.dist-info/top_level.txt,sha256=dSwjjk5rAZzv_g4EkBg7OkAUjW6k02wB8jjA29BrSa0,8 +spotipy/__init__.py,sha256=6DLF5dHdanFTLYpi1yrQKlC44FgLiTTa-hycEQ53ZFw,159 +spotipy/__pycache__/__init__.cpython-311.pyc,, +spotipy/__pycache__/cache_handler.cpython-311.pyc,, +spotipy/__pycache__/client.cpython-311.pyc,, +spotipy/__pycache__/exceptions.cpython-311.pyc,, +spotipy/__pycache__/oauth2.cpython-311.pyc,, +spotipy/__pycache__/util.cpython-311.pyc,, +spotipy/cache_handler.py,sha256=1PA6rG_MApoNr-KlDQU_WSB1CYEX5W-tKViQF6K30vs,6181 +spotipy/client.py,sha256=Eqsyi1o49YbxI_QU5kpZte52JA8HTX_m2AziIj4jPaI,75075 +spotipy/exceptions.py,sha256=s4lt7yTc8PeAiJGnShH7nu7HKXKGVnRouRaHT51Xo4w,568 +spotipy/oauth2.py,sha256=rXWNs77R5u8oLSpZbigFsMctuQlfayEiA5I-8_7D0kU,53619 +spotipy/util.py,sha256=Cg8U74tNWsq3Lz_VOttp_qYvnGiVLXr6f8IPeoKA9-w,3996 diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/REQUESTED b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/REQUESTED new file mode 100644 index 00000000..e69de29b diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/WHEEL b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/WHEEL new file mode 100644 index 00000000..1f37c02f --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.40.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/top_level.txt b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/top_level.txt new file mode 100644 index 00000000..9dc31a22 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy-2.23.0.dist-info/top_level.txt @@ -0,0 +1 @@ +spotipy diff --git a/.venv/Lib/site-packages/spotipy/__init__.py b/.venv/Lib/site-packages/spotipy/__init__.py new file mode 100644 index 00000000..7f3d8599 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/__init__.py @@ -0,0 +1,5 @@ +from .cache_handler import * # noqa +from .client import * # noqa +from .exceptions import * # noqa +from .oauth2 import * # noqa +from .util import * # noqa diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/__init__.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 00000000..6c8fe9b6 Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/cache_handler.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/cache_handler.cpython-311.pyc new file mode 100644 index 00000000..781e9c6f Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/cache_handler.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/client.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/client.cpython-311.pyc new file mode 100644 index 00000000..c8d67225 Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/client.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/exceptions.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/exceptions.cpython-311.pyc new file mode 100644 index 00000000..c545f93e Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/exceptions.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/oauth2.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/oauth2.cpython-311.pyc new file mode 100644 index 00000000..b66f13d5 Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/oauth2.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/__pycache__/util.cpython-311.pyc b/.venv/Lib/site-packages/spotipy/__pycache__/util.cpython-311.pyc new file mode 100644 index 00000000..e6dbeccf Binary files /dev/null and b/.venv/Lib/site-packages/spotipy/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/Lib/site-packages/spotipy/cache_handler.py b/.venv/Lib/site-packages/spotipy/cache_handler.py new file mode 100644 index 00000000..9a6d703b --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/cache_handler.py @@ -0,0 +1,210 @@ +__all__ = [ + 'CacheHandler', + 'CacheFileHandler', + 'DjangoSessionCacheHandler', + 'FlaskSessionCacheHandler', + 'MemoryCacheHandler', + 'RedisCacheHandler'] + +import errno +import json +import logging +import os +from spotipy.util import CLIENT_CREDS_ENV_VARS + +from redis import RedisError + +logger = logging.getLogger(__name__) + + +class CacheHandler(): + """ + An abstraction layer for handling the caching and retrieval of + authorization tokens. + + Custom extensions of this class must implement get_cached_token + and save_token_to_cache methods with the same input and output + structure as the CacheHandler class. + """ + + def get_cached_token(self): + """ + Get and return a token_info dictionary object. + """ + # return token_info + raise NotImplementedError() + + def save_token_to_cache(self, token_info): + """ + Save a token_info dictionary object to the cache and return None. + """ + raise NotImplementedError() + return None + + +class CacheFileHandler(CacheHandler): + """ + Handles reading and writing cached Spotify authorization tokens + as json files on disk. + """ + + def __init__(self, + cache_path=None, + username=None, + encoder_cls=None): + """ + Parameters: + * cache_path: May be supplied, will otherwise be generated + (takes precedence over `username`) + * username: May be supplied or set as environment variable + (will set `cache_path` to `.cache-{username}`) + * encoder_cls: May be supplied as a means of overwriting the + default serializer used for writing tokens to disk + """ + self.encoder_cls = encoder_cls + if cache_path: + self.cache_path = cache_path + else: + cache_path = ".cache" + username = (username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"])) + if username: + cache_path += "-" + str(username) + self.cache_path = cache_path + + def get_cached_token(self): + token_info = None + + try: + f = open(self.cache_path) + token_info_string = f.read() + f.close() + token_info = json.loads(token_info_string) + + except IOError as error: + if error.errno == errno.ENOENT: + logger.debug("cache does not exist at: %s", self.cache_path) + else: + logger.warning("Couldn't read cache at: %s", self.cache_path) + + return token_info + + def save_token_to_cache(self, token_info): + try: + f = open(self.cache_path, "w") + f.write(json.dumps(token_info, cls=self.encoder_cls)) + f.close() + except IOError: + logger.warning('Couldn\'t write token to cache at: %s', + self.cache_path) + + +class MemoryCacheHandler(CacheHandler): + """ + A cache handler that simply stores the token info in memory as an + instance attribute of this class. The token info will be lost when this + instance is freed. + """ + + def __init__(self, token_info=None): + """ + Parameters: + * token_info: The token info to store in memory. Can be None. + """ + self.token_info = token_info + + def get_cached_token(self): + return self.token_info + + def save_token_to_cache(self, token_info): + self.token_info = token_info + + +class DjangoSessionCacheHandler(CacheHandler): + """ + A cache handler that stores the token info in the session framework + provided by Django. + + Read more at https://docs.djangoproject.com/en/3.2/topics/http/sessions/ + """ + + def __init__(self, request): + """ + Parameters: + * request: HttpRequest object provided by Django for every + incoming request + """ + self.request = request + + def get_cached_token(self): + token_info = None + try: + token_info = self.request.session['token_info'] + except KeyError: + logger.debug("Token not found in the session") + + return token_info + + def save_token_to_cache(self, token_info): + try: + self.request.session['token_info'] = token_info + except Exception as e: + logger.warning("Error saving token to cache: " + str(e)) + + +class FlaskSessionCacheHandler(CacheHandler): + """ + A cache handler that stores the token info in the session framework + provided by flask. + """ + + def __init__(self, session): + self.session = session + + def get_cached_token(self): + token_info = None + try: + token_info = self.session["token_info"] + except KeyError: + logger.debug("Token not found in the session") + + return token_info + + def save_token_to_cache(self, token_info): + try: + self.session["token_info"] = token_info + except Exception as e: + logger.warning("Error saving token to cache: " + str(e)) + + +class RedisCacheHandler(CacheHandler): + """ + A cache handler that stores the token info in the Redis. + """ + + def __init__(self, redis, key=None): + """ + Parameters: + * redis: Redis object provided by redis-py library + (https://github.com/redis/redis-py) + * key: May be supplied, will otherwise be generated + (takes precedence over `token_info`) + """ + self.redis = redis + self.key = key if key else 'token_info' + + def get_cached_token(self): + token_info = None + try: + token_info = self.redis.get(self.key) + if token_info: + return json.loads(token_info) + except RedisError as e: + logger.warning('Error getting token from cache: ' + str(e)) + + return token_info + + def save_token_to_cache(self, token_info): + try: + self.redis.set(self.key, json.dumps(token_info)) + except RedisError as e: + logger.warning('Error saving token to cache: ' + str(e)) diff --git a/.venv/Lib/site-packages/spotipy/client.py b/.venv/Lib/site-packages/spotipy/client.py new file mode 100644 index 00000000..d3b918f0 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/client.py @@ -0,0 +1,2035 @@ +# -*- coding: utf-8 -*- + +""" A simple and thin Python library for the Spotify Web API """ + +__all__ = ["Spotify", "SpotifyException"] + +import json +import logging +import re +import warnings + +import requests +import six +import urllib3 + +from spotipy.exceptions import SpotifyException + +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class Spotify(object): + """ + Example usage:: + + import spotipy + + urn = 'spotify:artist:3jOstUTkEu2JkjvRdBA5Gu' + sp = spotipy.Spotify() + + artist = sp.artist(urn) + print(artist) + + user = sp.user('plamere') + print(user) + """ + max_retries = 3 + default_retry_codes = (429, 500, 502, 503, 504) + country_codes = [ + "AD", + "AR", + "AU", + "AT", + "BE", + "BO", + "BR", + "BG", + "CA", + "CL", + "CO", + "CR", + "CY", + "CZ", + "DK", + "DO", + "EC", + "SV", + "EE", + "FI", + "FR", + "DE", + "GR", + "GT", + "HN", + "HK", + "HU", + "IS", + "ID", + "IE", + "IT", + "JP", + "LV", + "LI", + "LT", + "LU", + "MY", + "MT", + "MX", + "MC", + "NL", + "NZ", + "NI", + "NO", + "PA", + "PY", + "PE", + "PH", + "PL", + "PT", + "SG", + "ES", + "SK", + "SE", + "CH", + "TW", + "TR", + "GB", + "US", + "UY"] + + # Spotify URI scheme defined in [1], and the ID format as base-62 in [2]. + # + # Unfortunately the IANA specification is out of date and doesn't include the new types + # show and episode. Additionally, for the user URI, it does not specify which characters + # are valid for usernames, so the assumption is alphanumeric which coincidentially are also + # the same ones base-62 uses. + # In limited manual exploration this seems to hold true, as newly accounts are assigned an + # identifier that looks like the base-62 of all other IDs, but some older accounts only have + # numbers and even older ones seemed to have been allowed to freely pick this name. + # + # [1] https://www.iana.org/assignments/uri-schemes/prov/spotify + # [2] https://developer.spotify.com/documentation/web-api/#spotify-uris-and-ids + _regex_spotify_uri = r'^spotify:(?:(?Ptrack|artist|album|playlist|show|episode):(?P[0-9A-Za-z]+)|user:(?P[0-9A-Za-z]+):playlist:(?P[0-9A-Za-z]+))$' # noqa: E501 + + # Spotify URLs are defined at [1]. The assumption is made that they are all + # pointing to open.spotify.com, so a regex is used to parse them as well, + # instead of a more complex URL parsing function. + # + # [1] https://developer.spotify.com/documentation/web-api/#spotify-uris-and-ids + _regex_spotify_url = r'^(http[s]?:\/\/)?open.spotify.com\/(?Ptrack|artist|album|playlist|show|episode|user)\/(?P[0-9A-Za-z]+)(\?.*)?$' # noqa: E501 + + _regex_base62 = r'^[0-9A-Za-z]+$' + + def __init__( + self, + auth=None, + requests_session=True, + client_credentials_manager=None, + oauth_manager=None, + auth_manager=None, + proxies=None, + requests_timeout=5, + status_forcelist=None, + retries=max_retries, + status_retries=max_retries, + backoff_factor=0.3, + language=None, + ): + """ + Creates a Spotify API client. + + :param auth: An access token (optional) + :param requests_session: + A Requests session object or a truthy value to create one. + A falsy value disables sessions. + It should generally be a good idea to keep sessions enabled + for performance reasons (connection pooling). + :param client_credentials_manager: + SpotifyClientCredentials object + :param oauth_manager: + SpotifyOAuth object + :param auth_manager: + SpotifyOauth, SpotifyClientCredentials, + or SpotifyImplicitGrant object + :param proxies: + Definition of proxies (optional). + See Requests doc https://2.python-requests.org/en/master/user/advanced/#proxies + :param requests_timeout: + Tell Requests to stop waiting for a response after a given + number of seconds + :param status_forcelist: + Tell requests what type of status codes retries should occur on + :param retries: + Total number of retries to allow + :param status_retries: + Number of times to retry on bad status codes + :param backoff_factor: + A backoff factor to apply between attempts after the second try + See urllib3 https://urllib3.readthedocs.io/en/latest/reference/urllib3.util.html + :param language: + The language parameter advertises what language the user prefers to see. + See ISO-639-1 language code: https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes + """ + self.prefix = "https://api.spotify.com/v1/" + self._auth = auth + self.client_credentials_manager = client_credentials_manager + self.oauth_manager = oauth_manager + self.auth_manager = auth_manager + self.proxies = proxies + self.requests_timeout = requests_timeout + self.status_forcelist = status_forcelist or self.default_retry_codes + self.backoff_factor = backoff_factor + self.retries = retries + self.status_retries = status_retries + self.language = language + + if isinstance(requests_session, requests.Session): + self._session = requests_session + else: + if requests_session: # Build a new session. + self._build_session() + else: # Use the Requests API module as a "session". + self._session = requests.api + + def set_auth(self, auth): + self._auth = auth + + @property + def auth_manager(self): + return self._auth_manager + + @auth_manager.setter + def auth_manager(self, auth_manager): + if auth_manager is not None: + self._auth_manager = auth_manager + else: + self._auth_manager = ( + self.client_credentials_manager or self.oauth_manager + ) + + def __del__(self): + """Make sure the connection (pool) gets closed""" + if isinstance(self._session, requests.Session): + self._session.close() + + def _build_session(self): + self._session = requests.Session() + retry = urllib3.Retry( + total=self.retries, + connect=None, + read=False, + allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE']), + status=self.status_retries, + backoff_factor=self.backoff_factor, + status_forcelist=self.status_forcelist) + + adapter = requests.adapters.HTTPAdapter(max_retries=retry) + self._session.mount('http://', adapter) + self._session.mount('https://', adapter) + + def _auth_headers(self): + if self._auth: + return {"Authorization": "Bearer {0}".format(self._auth)} + if not self.auth_manager: + return {} + try: + token = self.auth_manager.get_access_token(as_dict=False) + except TypeError: + token = self.auth_manager.get_access_token() + return {"Authorization": "Bearer {0}".format(token)} + + def _internal_call(self, method, url, payload, params): + args = dict(params=params) + if not url.startswith("http"): + url = self.prefix + url + headers = self._auth_headers() + + if "content_type" in args["params"]: + headers["Content-Type"] = args["params"]["content_type"] + del args["params"]["content_type"] + if payload: + args["data"] = payload + else: + headers["Content-Type"] = "application/json" + if payload: + args["data"] = json.dumps(payload) + + if self.language is not None: + headers["Accept-Language"] = self.language + + logger.debug('Sending %s to %s with Params: %s Headers: %s and Body: %r ', + method, url, args.get("params"), headers, args.get('data')) + + try: + response = self._session.request( + method, url, headers=headers, proxies=self.proxies, + timeout=self.requests_timeout, **args + ) + + response.raise_for_status() + results = response.json() + except requests.exceptions.HTTPError as http_error: + response = http_error.response + try: + json_response = response.json() + error = json_response.get("error", {}) + msg = error.get("message") + reason = error.get("reason") + except ValueError: + # if the response cannot be decoded into JSON (which raises a ValueError), + # then try to decode it into text + + # if we receive an empty string (which is falsy), then replace it with `None` + msg = response.text or None + reason = None + + logger.error( + 'HTTP Error for %s to %s with Params: %s returned %s due to %s', + method, url, args.get("params"), response.status_code, msg + ) + + raise SpotifyException( + response.status_code, + -1, + "%s:\n %s" % (response.url, msg), + reason=reason, + headers=response.headers, + ) + except requests.exceptions.RetryError as retry_error: + request = retry_error.request + logger.error('Max Retries reached') + try: + reason = retry_error.args[0].reason + except (IndexError, AttributeError): + reason = None + raise SpotifyException( + 429, + -1, + "%s:\n %s" % (request.path_url, "Max Retries"), + reason=reason + ) + except ValueError: + results = None + + logger.debug('RESULTS: %s', results) + return results + + def _get(self, url, args=None, payload=None, **kwargs): + if args: + kwargs.update(args) + + return self._internal_call("GET", url, payload, kwargs) + + def _post(self, url, args=None, payload=None, **kwargs): + if args: + kwargs.update(args) + return self._internal_call("POST", url, payload, kwargs) + + def _delete(self, url, args=None, payload=None, **kwargs): + if args: + kwargs.update(args) + return self._internal_call("DELETE", url, payload, kwargs) + + def _put(self, url, args=None, payload=None, **kwargs): + if args: + kwargs.update(args) + return self._internal_call("PUT", url, payload, kwargs) + + def next(self, result): + """ returns the next result given a paged result + + Parameters: + - result - a previously returned paged result + """ + if result["next"]: + return self._get(result["next"]) + else: + return None + + def previous(self, result): + """ returns the previous result given a paged result + + Parameters: + - result - a previously returned paged result + """ + if result["previous"]: + return self._get(result["previous"]) + else: + return None + + def track(self, track_id, market=None): + """ returns a single track given the track's ID, URI or URL + + Parameters: + - track_id - a spotify URI, URL or ID + - market - an ISO 3166-1 alpha-2 country code. + """ + + trid = self._get_id("track", track_id) + return self._get("tracks/" + trid, market=market) + + def tracks(self, tracks, market=None): + """ returns a list of tracks given a list of track IDs, URIs, or URLs + + Parameters: + - tracks - a list of spotify URIs, URLs or IDs. Maximum: 50 IDs. + - market - an ISO 3166-1 alpha-2 country code. + """ + + tlist = [self._get_id("track", t) for t in tracks] + return self._get("tracks/?ids=" + ",".join(tlist), market=market) + + def artist(self, artist_id): + """ returns a single artist given the artist's ID, URI or URL + + Parameters: + - artist_id - an artist ID, URI or URL + """ + + trid = self._get_id("artist", artist_id) + return self._get("artists/" + trid) + + def artists(self, artists): + """ returns a list of artists given the artist IDs, URIs, or URLs + + Parameters: + - artists - a list of artist IDs, URIs or URLs + """ + + tlist = [self._get_id("artist", a) for a in artists] + return self._get("artists/?ids=" + ",".join(tlist)) + + def artist_albums( + self, artist_id, album_type=None, country=None, limit=20, offset=0 + ): + """ Get Spotify catalog information about an artist's albums + + Parameters: + - artist_id - the artist ID, URI or URL + - album_type - 'album', 'single', 'appears_on', 'compilation' + - country - limit the response to one particular country. + - limit - the number of albums to return + - offset - the index of the first album to return + """ + + trid = self._get_id("artist", artist_id) + return self._get( + "artists/" + trid + "/albums", + album_type=album_type, + country=country, + limit=limit, + offset=offset, + ) + + def artist_top_tracks(self, artist_id, country="US"): + """ Get Spotify catalog information about an artist's top 10 tracks + by country. + + Parameters: + - artist_id - the artist ID, URI or URL + - country - limit the response to one particular country. + """ + + trid = self._get_id("artist", artist_id) + return self._get("artists/" + trid + "/top-tracks", country=country) + + def artist_related_artists(self, artist_id): + """ Get Spotify catalog information about artists similar to an + identified artist. Similarity is based on analysis of the + Spotify community's listening history. + + Parameters: + - artist_id - the artist ID, URI or URL + """ + trid = self._get_id("artist", artist_id) + return self._get("artists/" + trid + "/related-artists") + + def album(self, album_id, market=None): + """ returns a single album given the album's ID, URIs or URL + + Parameters: + - album_id - the album ID, URI or URL + - market - an ISO 3166-1 alpha-2 country code + """ + + trid = self._get_id("album", album_id) + if market is not None: + return self._get("albums/" + trid + '?market=' + market) + else: + return self._get("albums/" + trid) + + def album_tracks(self, album_id, limit=50, offset=0, market=None): + """ Get Spotify catalog information about an album's tracks + + Parameters: + - album_id - the album ID, URI or URL + - limit - the number of items to return + - offset - the index of the first item to return + - market - an ISO 3166-1 alpha-2 country code. + + """ + + trid = self._get_id("album", album_id) + return self._get( + "albums/" + trid + "/tracks/", limit=limit, offset=offset, market=market + ) + + def albums(self, albums, market=None): + """ returns a list of albums given the album IDs, URIs, or URLs + + Parameters: + - albums - a list of album IDs, URIs or URLs + - market - an ISO 3166-1 alpha-2 country code + """ + + tlist = [self._get_id("album", a) for a in albums] + if market is not None: + return self._get("albums/?ids=" + ",".join(tlist) + '&market=' + market) + else: + return self._get("albums/?ids=" + ",".join(tlist)) + + def show(self, show_id, market=None): + """ returns a single show given the show's ID, URIs or URL + + Parameters: + - show_id - the show ID, URI or URL + - market - an ISO 3166-1 alpha-2 country code. + The show must be available in the given market. + If user-based authorization is in use, the user's country + takes precedence. If neither market nor user country are + provided, the content is considered unavailable for the client. + """ + + trid = self._get_id("show", show_id) + return self._get("shows/" + trid, market=market) + + def shows(self, shows, market=None): + """ returns a list of shows given the show IDs, URIs, or URLs + + Parameters: + - shows - a list of show IDs, URIs or URLs + - market - an ISO 3166-1 alpha-2 country code. + Only shows available in the given market will be returned. + If user-based authorization is in use, the user's country + takes precedence. If neither market nor user country are + provided, the content is considered unavailable for the client. + """ + + tlist = [self._get_id("show", s) for s in shows] + return self._get("shows/?ids=" + ",".join(tlist), market=market) + + def show_episodes(self, show_id, limit=50, offset=0, market=None): + """ Get Spotify catalog information about a show's episodes + + Parameters: + - show_id - the show ID, URI or URL + - limit - the number of items to return + - offset - the index of the first item to return + - market - an ISO 3166-1 alpha-2 country code. + Only episodes available in the given market will be returned. + If user-based authorization is in use, the user's country + takes precedence. If neither market nor user country are + provided, the content is considered unavailable for the client. + """ + + trid = self._get_id("show", show_id) + return self._get( + "shows/" + trid + "/episodes/", limit=limit, offset=offset, market=market + ) + + def episode(self, episode_id, market=None): + """ returns a single episode given the episode's ID, URIs or URL + + Parameters: + - episode_id - the episode ID, URI or URL + - market - an ISO 3166-1 alpha-2 country code. + The episode must be available in the given market. + If user-based authorization is in use, the user's country + takes precedence. If neither market nor user country are + provided, the content is considered unavailable for the client. + """ + + trid = self._get_id("episode", episode_id) + return self._get("episodes/" + trid, market=market) + + def episodes(self, episodes, market=None): + """ returns a list of episodes given the episode IDs, URIs, or URLs + + Parameters: + - episodes - a list of episode IDs, URIs or URLs + - market - an ISO 3166-1 alpha-2 country code. + Only episodes available in the given market will be returned. + If user-based authorization is in use, the user's country + takes precedence. If neither market nor user country are + provided, the content is considered unavailable for the client. + """ + + tlist = [self._get_id("episode", e) for e in episodes] + return self._get("episodes/?ids=" + ",".join(tlist), market=market) + + def search(self, q, limit=10, offset=0, type="track", market=None): + """ searches for an item + + Parameters: + - q - the search query (see how to write a query in the + official documentation https://developer.spotify.com/documentation/web-api/reference/search/) # noqa + - limit - the number of items to return (min = 1, default = 10, max = 50). The limit is applied + within each type, not on the total response. + - offset - the index of the first item to return + - type - the types of items to return. One or more of 'artist', 'album', + 'track', 'playlist', 'show', and 'episode'. If multiple types are desired, + pass in a comma separated string; e.g., 'track,album,episode'. + - market - An ISO 3166-1 alpha-2 country code or the string + from_token. + """ + return self._get( + "search", q=q, limit=limit, offset=offset, type=type, market=market + ) + + def search_markets(self, q, limit=10, offset=0, type="track", markets=None, total=None): + """ (experimental) Searches multiple markets for an item + + Parameters: + - q - the search query (see how to write a query in the + official documentation https://developer.spotify.com/documentation/web-api/reference/search/) # noqa + - limit - the number of items to return (min = 1, default = 10, max = 50). If a search is to be done on multiple + markets, then this limit is applied to each market. (e.g. search US, CA, MX each with a limit of 10). + If multiple types are specified, this applies to each type. + - offset - the index of the first item to return + - type - the types of items to return. One or more of 'artist', 'album', + 'track', 'playlist', 'show', or 'episode'. If multiple types are desired, pass in a comma separated string. + - markets - A list of ISO 3166-1 alpha-2 country codes. Search all country markets by default. + - total - the total number of results to return across multiple markets and types. + """ + warnings.warn( + "Searching multiple markets is an experimental feature. " + "Please be aware that this method's inputs and outputs can change in the future.", + UserWarning, + ) + if not markets: + markets = self.country_codes + + if not (isinstance(markets, list) or isinstance(markets, tuple)): + markets = [] + + warnings.warn( + "Searching multiple markets is poorly performing.", + UserWarning, + ) + return self._search_multiple_markets(q, limit, offset, type, markets, total) + + def user(self, user): + """ Gets basic profile information about a Spotify User + + Parameters: + - user - the id of the usr + """ + return self._get("users/" + user) + + def current_user_playlists(self, limit=50, offset=0): + """ Get current user playlists without required getting his profile + Parameters: + - limit - the number of items to return + - offset - the index of the first item to return + """ + return self._get("me/playlists", limit=limit, offset=offset) + + def playlist(self, playlist_id, fields=None, market=None, additional_types=("track",)): + """ Gets playlist by id. + + Parameters: + - playlist - the id of the playlist + - fields - which fields to return + - market - An ISO 3166-1 alpha-2 country code or the + string from_token. + - additional_types - list of item types to return. + valid types are: track and episode + """ + plid = self._get_id("playlist", playlist_id) + return self._get( + "playlists/%s" % (plid), + fields=fields, + market=market, + additional_types=",".join(additional_types), + ) + + def playlist_tracks( + self, + playlist_id, + fields=None, + limit=100, + offset=0, + market=None, + additional_types=("track",) + ): + """ Get full details of the tracks of a playlist. + + Parameters: + - playlist_id - the playlist ID, URI or URL + - fields - which fields to return + - limit - the maximum number of tracks to return + - offset - the index of the first track to return + - market - an ISO 3166-1 alpha-2 country code. + - additional_types - list of item types to return. + valid types are: track and episode + """ + warnings.warn( + "You should use `playlist_items(playlist_id, ...," + "additional_types=('track',))` instead", + DeprecationWarning, + ) + return self.playlist_items(playlist_id, fields, limit, offset, + market, additional_types) + + def playlist_items( + self, + playlist_id, + fields=None, + limit=100, + offset=0, + market=None, + additional_types=("track", "episode") + ): + """ Get full details of the tracks and episodes of a playlist. + + Parameters: + - playlist_id - the playlist ID, URI or URL + - fields - which fields to return + - limit - the maximum number of tracks to return + - offset - the index of the first track to return + - market - an ISO 3166-1 alpha-2 country code. + - additional_types - list of item types to return. + valid types are: track and episode + """ + plid = self._get_id("playlist", playlist_id) + return self._get( + "playlists/%s/tracks" % (plid), + limit=limit, + offset=offset, + fields=fields, + market=market, + additional_types=",".join(additional_types) + ) + + def playlist_cover_image(self, playlist_id): + """ Get cover image of a playlist. + + Parameters: + - playlist_id - the playlist ID, URI or URL + """ + plid = self._get_id("playlist", playlist_id) + return self._get("playlists/%s/images" % (plid)) + + def playlist_upload_cover_image(self, playlist_id, image_b64): + """ Replace the image used to represent a specific playlist + + Parameters: + - playlist_id - the id of the playlist + - image_b64 - image data as a Base64 encoded JPEG image string + (maximum payload size is 256 KB) + """ + plid = self._get_id("playlist", playlist_id) + return self._put( + "playlists/{}/images".format(plid), + payload=image_b64, + content_type="image/jpeg", + ) + + def user_playlist(self, user, playlist_id=None, fields=None, market=None): + warnings.warn( + "You should use `playlist(playlist_id)` instead", + DeprecationWarning, + ) + + """ Gets a single playlist of a user + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - fields - which fields to return + """ + if playlist_id is None: + return self._get("users/%s/starred" % user) + return self.playlist(playlist_id, fields=fields, market=market) + + def user_playlist_tracks( + self, + user=None, + playlist_id=None, + fields=None, + limit=100, + offset=0, + market=None, + ): + warnings.warn( + "You should use `playlist_tracks(playlist_id)` instead", + DeprecationWarning, + ) + + """ Get full details of the tracks of a playlist owned by a user. + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - fields - which fields to return + - limit - the maximum number of tracks to return + - offset - the index of the first track to return + - market - an ISO 3166-1 alpha-2 country code. + """ + return self.playlist_tracks( + playlist_id, + limit=limit, + offset=offset, + fields=fields, + market=market, + ) + + def user_playlists(self, user, limit=50, offset=0): + """ Gets playlists of a user + + Parameters: + - user - the id of the usr + - limit - the number of items to return + - offset - the index of the first item to return + """ + return self._get( + "users/%s/playlists" % user, limit=limit, offset=offset + ) + + def user_playlist_create(self, user, name, public=True, collaborative=False, description=""): + """ Creates a playlist for a user + + Parameters: + - user - the id of the user + - name - the name of the playlist + - public - is the created playlist public + - collaborative - is the created playlist collaborative + - description - the description of the playlist + """ + data = { + "name": name, + "public": public, + "collaborative": collaborative, + "description": description + } + + return self._post("users/%s/playlists" % (user,), payload=data) + + def user_playlist_change_details( + self, + user, + playlist_id, + name=None, + public=None, + collaborative=None, + description=None, + ): + warnings.warn( + "You should use `playlist_change_details(playlist_id, ...)` instead", + DeprecationWarning, + ) + """ Changes a playlist's name and/or public/private state + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - name - optional name of the playlist + - public - optional is the playlist public + - collaborative - optional is the playlist collaborative + - description - optional description of the playlist + """ + + return self.playlist_change_details(playlist_id, name, public, + collaborative, description) + + def user_playlist_unfollow(self, user, playlist_id): + """ Unfollows (deletes) a playlist for a user + + Parameters: + - user - the id of the user + - name - the name of the playlist + """ + warnings.warn( + "You should use `current_user_unfollow_playlist(playlist_id)` instead", + DeprecationWarning, + ) + return self.current_user_unfollow_playlist(playlist_id) + + def user_playlist_add_tracks( + self, user, playlist_id, tracks, position=None + ): + warnings.warn( + "You should use `playlist_add_items(playlist_id, tracks)` instead", + DeprecationWarning, + ) + """ Adds tracks to a playlist + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - tracks - a list of track URIs, URLs or IDs + - position - the position to add the tracks + """ + tracks = [self._get_uri("track", tid) for tid in tracks] + return self.playlist_add_items(playlist_id, tracks, position) + + def user_playlist_add_episodes( + self, user, playlist_id, episodes, position=None + ): + warnings.warn( + "You should use `playlist_add_items(playlist_id, episodes)` instead", + DeprecationWarning, + ) + """ Adds episodes to a playlist + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - episodes - a list of track URIs, URLs or IDs + - position - the position to add the episodes + """ + episodes = [self._get_uri("episode", tid) for tid in episodes] + return self.playlist_add_items(playlist_id, episodes, position) + + def user_playlist_replace_tracks(self, user, playlist_id, tracks): + """ Replace all tracks in a playlist for a user + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - tracks - the list of track ids to add to the playlist + """ + warnings.warn( + "You should use `playlist_replace_items(playlist_id, tracks)` instead", + DeprecationWarning, + ) + return self.playlist_replace_items(playlist_id, tracks) + + def user_playlist_reorder_tracks( + self, + user, + playlist_id, + range_start, + insert_before, + range_length=1, + snapshot_id=None, + ): + """ Reorder tracks in a playlist from a user + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - range_start - the position of the first track to be reordered + - range_length - optional the number of tracks to be reordered + (default: 1) + - insert_before - the position where the tracks should be + inserted + - snapshot_id - optional playlist's snapshot ID + """ + warnings.warn( + "You should use `playlist_reorder_items(playlist_id, ...)` instead", + DeprecationWarning, + ) + return self.playlist_reorder_items(playlist_id, range_start, + insert_before, range_length, + snapshot_id) + + def user_playlist_remove_all_occurrences_of_tracks( + self, user, playlist_id, tracks, snapshot_id=None + ): + """ Removes all occurrences of the given tracks from the given playlist + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - tracks - the list of track ids to remove from the playlist + - snapshot_id - optional id of the playlist snapshot + + """ + warnings.warn( + "You should use `playlist_remove_all_occurrences_of_items" + "(playlist_id, tracks)` instead", + DeprecationWarning, + ) + return self.playlist_remove_all_occurrences_of_items(playlist_id, + tracks, + snapshot_id) + + def user_playlist_remove_specific_occurrences_of_tracks( + self, user, playlist_id, tracks, snapshot_id=None + ): + """ Removes all occurrences of the given tracks from the given playlist + + Parameters: + - user - the id of the user + - playlist_id - the id of the playlist + - tracks - an array of objects containing Spotify URIs of the + tracks to remove with their current positions in the + playlist. For example: + [ { "uri":"4iV5W9uYEdYUVa79Axb7Rh", "positions":[2] }, + { "uri":"1301WleyT98MSxVHPZCA6M", "positions":[7] } ] + - snapshot_id - optional id of the playlist snapshot + """ + warnings.warn( + "You should use `playlist_remove_specific_occurrences_of_items" + "(playlist_id, tracks)` instead", + DeprecationWarning, + ) + plid = self._get_id("playlist", playlist_id) + ftracks = [] + for tr in tracks: + ftracks.append( + { + "uri": self._get_uri("track", tr["uri"]), + "positions": tr["positions"], + } + ) + payload = {"tracks": ftracks} + if snapshot_id: + payload["snapshot_id"] = snapshot_id + return self._delete( + "users/%s/playlists/%s/tracks" % (user, plid), payload=payload + ) + + def user_playlist_follow_playlist(self, playlist_owner_id, playlist_id): + """ + Add the current authenticated user as a follower of a playlist. + + Parameters: + - playlist_owner_id - the user id of the playlist owner + - playlist_id - the id of the playlist + + """ + warnings.warn( + "You should use `current_user_follow_playlist(playlist_id)` instead", + DeprecationWarning, + ) + return self.current_user_follow_playlist(playlist_id) + + def user_playlist_is_following( + self, playlist_owner_id, playlist_id, user_ids + ): + """ + Check to see if the given users are following the given playlist + + Parameters: + - playlist_owner_id - the user id of the playlist owner + - playlist_id - the id of the playlist + - user_ids - the ids of the users that you want to check to see + if they follow the playlist. Maximum: 5 ids. + + """ + warnings.warn( + "You should use `playlist_is_following(playlist_id, user_ids)` instead", + DeprecationWarning, + ) + return self.playlist_is_following(playlist_id, user_ids) + + def playlist_change_details( + self, + playlist_id, + name=None, + public=None, + collaborative=None, + description=None, + ): + """ Changes a playlist's name and/or public/private state, + collaborative state, and/or description + + Parameters: + - playlist_id - the id of the playlist + - name - optional name of the playlist + - public - optional is the playlist public + - collaborative - optional is the playlist collaborative + - description - optional description of the playlist + """ + + data = {} + if isinstance(name, six.string_types): + data["name"] = name + if isinstance(public, bool): + data["public"] = public + if isinstance(collaborative, bool): + data["collaborative"] = collaborative + if isinstance(description, six.string_types): + data["description"] = description + return self._put( + "playlists/%s" % (self._get_id("playlist", playlist_id)), payload=data + ) + + def current_user_unfollow_playlist(self, playlist_id): + """ Unfollows (deletes) a playlist for the current authenticated + user + + Parameters: + - name - the name of the playlist + """ + return self._delete( + "playlists/%s/followers" % (playlist_id) + ) + + def playlist_add_items( + self, playlist_id, items, position=None + ): + """ Adds tracks/episodes to a playlist + + Parameters: + - playlist_id - the id of the playlist + - items - a list of track/episode URIs or URLs + - position - the position to add the tracks + """ + plid = self._get_id("playlist", playlist_id) + ftracks = [self._get_uri("track", tid) for tid in items] + return self._post( + "playlists/%s/tracks" % (plid), + payload=ftracks, + position=position, + ) + + def playlist_replace_items(self, playlist_id, items): + """ Replace all tracks/episodes in a playlist + + Parameters: + - playlist_id - the id of the playlist + - items - list of track/episode ids to comprise playlist + """ + plid = self._get_id("playlist", playlist_id) + ftracks = [self._get_uri("track", tid) for tid in items] + payload = {"uris": ftracks} + return self._put( + "playlists/%s/tracks" % (plid), payload=payload + ) + + def playlist_reorder_items( + self, + playlist_id, + range_start, + insert_before, + range_length=1, + snapshot_id=None, + ): + """ Reorder tracks in a playlist + + Parameters: + - playlist_id - the id of the playlist + - range_start - the position of the first track to be reordered + - range_length - optional the number of tracks to be reordered + (default: 1) + - insert_before - the position where the tracks should be + inserted + - snapshot_id - optional playlist's snapshot ID + """ + plid = self._get_id("playlist", playlist_id) + payload = { + "range_start": range_start, + "range_length": range_length, + "insert_before": insert_before, + } + if snapshot_id: + payload["snapshot_id"] = snapshot_id + return self._put( + "playlists/%s/tracks" % (plid), payload=payload + ) + + def playlist_remove_all_occurrences_of_items( + self, playlist_id, items, snapshot_id=None + ): + """ Removes all occurrences of the given tracks/episodes from the given playlist + + Parameters: + - playlist_id - the id of the playlist + - items - list of track/episode ids to remove from the playlist + - snapshot_id - optional id of the playlist snapshot + + """ + + plid = self._get_id("playlist", playlist_id) + ftracks = [self._get_uri("track", tid) for tid in items] + payload = {"tracks": [{"uri": track} for track in ftracks]} + if snapshot_id: + payload["snapshot_id"] = snapshot_id + return self._delete( + "playlists/%s/tracks" % (plid), payload=payload + ) + + def playlist_remove_specific_occurrences_of_items( + self, playlist_id, items, snapshot_id=None + ): + """ Removes all occurrences of the given tracks from the given playlist + + Parameters: + - playlist_id - the id of the playlist + - items - an array of objects containing Spotify URIs of the + tracks/episodes to remove with their current positions in + the playlist. For example: + [ { "uri":"4iV5W9uYEdYUVa79Axb7Rh", "positions":[2] }, + { "uri":"1301WleyT98MSxVHPZCA6M", "positions":[7] } ] + - snapshot_id - optional id of the playlist snapshot + """ + + plid = self._get_id("playlist", playlist_id) + ftracks = [] + for tr in items: + ftracks.append( + { + "uri": self._get_uri("track", tr["uri"]), + "positions": tr["positions"], + } + ) + payload = {"tracks": ftracks} + if snapshot_id: + payload["snapshot_id"] = snapshot_id + return self._delete( + "playlists/%s/tracks" % (plid), payload=payload + ) + + def current_user_follow_playlist(self, playlist_id): + """ + Add the current authenticated user as a follower of a playlist. + + Parameters: + - playlist_id - the id of the playlist + + """ + return self._put( + "playlists/{}/followers".format(playlist_id) + ) + + def playlist_is_following( + self, playlist_id, user_ids + ): + """ + Check to see if the given users are following the given playlist + + Parameters: + - playlist_id - the id of the playlist + - user_ids - the ids of the users that you want to check to see + if they follow the playlist. Maximum: 5 ids. + + """ + endpoint = "playlists/{}/followers/contains?ids={}" + return self._get( + endpoint.format(playlist_id, ",".join(user_ids)) + ) + + def me(self): + """ Get detailed profile information about the current user. + An alias for the 'current_user' method. + """ + return self._get("me/") + + def current_user(self): + """ Get detailed profile information about the current user. + An alias for the 'me' method. + """ + return self.me() + + def current_user_playing_track(self): + """ Get information about the current users currently playing track. + """ + return self._get("me/player/currently-playing") + + def current_user_saved_albums(self, limit=20, offset=0, market=None): + """ Gets a list of the albums saved in the current authorized user's + "Your Music" library + + Parameters: + - limit - the number of albums to return (MAX_LIMIT=50) + - offset - the index of the first album to return + - market - an ISO 3166-1 alpha-2 country code. + + """ + return self._get("me/albums", limit=limit, offset=offset, market=market) + + def current_user_saved_albums_add(self, albums=[]): + """ Add one or more albums to the current user's + "Your Music" library. + Parameters: + - albums - a list of album URIs, URLs or IDs + """ + + alist = [self._get_id("album", a) for a in albums] + return self._put("me/albums?ids=" + ",".join(alist)) + + def current_user_saved_albums_delete(self, albums=[]): + """ Remove one or more albums from the current user's + "Your Music" library. + + Parameters: + - albums - a list of album URIs, URLs or IDs + """ + alist = [self._get_id("album", a) for a in albums] + return self._delete("me/albums/?ids=" + ",".join(alist)) + + def current_user_saved_albums_contains(self, albums=[]): + """ Check if one or more albums is already saved in + the current Spotify user’s “Your Music” library. + + Parameters: + - albums - a list of album URIs, URLs or IDs + """ + alist = [self._get_id("album", a) for a in albums] + return self._get("me/albums/contains?ids=" + ",".join(alist)) + + def current_user_saved_tracks(self, limit=20, offset=0, market=None): + """ Gets a list of the tracks saved in the current authorized user's + "Your Music" library + + Parameters: + - limit - the number of tracks to return + - offset - the index of the first track to return + - market - an ISO 3166-1 alpha-2 country code + + """ + return self._get("me/tracks", limit=limit, offset=offset, market=market) + + def current_user_saved_tracks_add(self, tracks=None): + """ Add one or more tracks to the current user's + "Your Music" library. + + Parameters: + - tracks - a list of track URIs, URLs or IDs + """ + tlist = [] + if tracks is not None: + tlist = [self._get_id("track", t) for t in tracks] + return self._put("me/tracks/?ids=" + ",".join(tlist)) + + def current_user_saved_tracks_delete(self, tracks=None): + """ Remove one or more tracks from the current user's + "Your Music" library. + + Parameters: + - tracks - a list of track URIs, URLs or IDs + """ + tlist = [] + if tracks is not None: + tlist = [self._get_id("track", t) for t in tracks] + return self._delete("me/tracks/?ids=" + ",".join(tlist)) + + def current_user_saved_tracks_contains(self, tracks=None): + """ Check if one or more tracks is already saved in + the current Spotify user’s “Your Music” library. + + Parameters: + - tracks - a list of track URIs, URLs or IDs + """ + tlist = [] + if tracks is not None: + tlist = [self._get_id("track", t) for t in tracks] + return self._get("me/tracks/contains?ids=" + ",".join(tlist)) + + def current_user_saved_episodes(self, limit=20, offset=0, market=None): + """ Gets a list of the episodes saved in the current authorized user's + "Your Music" library + + Parameters: + - limit - the number of episodes to return + - offset - the index of the first episode to return + - market - an ISO 3166-1 alpha-2 country code + + """ + return self._get("me/episodes", limit=limit, offset=offset, market=market) + + def current_user_saved_episodes_add(self, episodes=None): + """ Add one or more episodes to the current user's + "Your Music" library. + + Parameters: + - episodes - a list of episode URIs, URLs or IDs + """ + elist = [] + if episodes is not None: + elist = [self._get_id("episode", e) for e in episodes] + return self._put("me/episodes/?ids=" + ",".join(elist)) + + def current_user_saved_episodes_delete(self, episodes=None): + """ Remove one or more episodes from the current user's + "Your Music" library. + + Parameters: + - episodes - a list of episode URIs, URLs or IDs + """ + elist = [] + if episodes is not None: + elist = [self._get_id("episode", e) for e in episodes] + return self._delete("me/episodes/?ids=" + ",".join(elist)) + + def current_user_saved_episodes_contains(self, episodes=None): + """ Check if one or more episodes is already saved in + the current Spotify user’s “Your Music” library. + + Parameters: + - episodes - a list of episode URIs, URLs or IDs + """ + elist = [] + if episodes is not None: + elist = [self._get_id("episode", e) for e in episodes] + return self._get("me/episodes/contains?ids=" + ",".join(elist)) + + def current_user_saved_shows(self, limit=20, offset=0, market=None): + """ Gets a list of the shows saved in the current authorized user's + "Your Music" library + + Parameters: + - limit - the number of shows to return + - offset - the index of the first show to return + - market - an ISO 3166-1 alpha-2 country code + + """ + return self._get("me/shows", limit=limit, offset=offset, market=market) + + def current_user_saved_shows_add(self, shows=[]): + """ Add one or more albums to the current user's + "Your Music" library. + Parameters: + - shows - a list of show URIs, URLs or IDs + """ + slist = [self._get_id("show", s) for s in shows] + return self._put("me/shows?ids=" + ",".join(slist)) + + def current_user_saved_shows_delete(self, shows=[]): + """ Remove one or more shows from the current user's + "Your Music" library. + + Parameters: + - shows - a list of show URIs, URLs or IDs + """ + slist = [self._get_id("show", s) for s in shows] + return self._delete("me/shows/?ids=" + ",".join(slist)) + + def current_user_saved_shows_contains(self, shows=[]): + """ Check if one or more shows is already saved in + the current Spotify user’s “Your Music” library. + + Parameters: + - shows - a list of show URIs, URLs or IDs + """ + slist = [self._get_id("show", s) for s in shows] + return self._get("me/shows/contains?ids=" + ",".join(slist)) + + def current_user_followed_artists(self, limit=20, after=None): + """ Gets a list of the artists followed by the current authorized user + + Parameters: + - limit - the number of artists to return + - after - the last artist ID retrieved from the previous + request + + """ + return self._get( + "me/following", type="artist", limit=limit, after=after + ) + + def current_user_following_artists(self, ids=None): + """ Check if the current user is following certain artists + + Returns list of booleans respective to ids + + Parameters: + - ids - a list of artist URIs, URLs or IDs + """ + idlist = [] + if ids is not None: + idlist = [self._get_id("artist", i) for i in ids] + return self._get( + "me/following/contains", ids=",".join(idlist), type="artist" + ) + + def current_user_following_users(self, ids=None): + """ Check if the current user is following certain users + + Returns list of booleans respective to ids + + Parameters: + - ids - a list of user URIs, URLs or IDs + """ + idlist = [] + if ids is not None: + idlist = [self._get_id("user", i) for i in ids] + return self._get( + "me/following/contains", ids=",".join(idlist), type="user" + ) + + def current_user_top_artists( + self, limit=20, offset=0, time_range="medium_term" + ): + """ Get the current user's top artists + + Parameters: + - limit - the number of entities to return + - offset - the index of the first entity to return + - time_range - Over what time frame are the affinities computed + Valid-values: short_term, medium_term, long_term + """ + return self._get( + "me/top/artists", time_range=time_range, limit=limit, offset=offset + ) + + def current_user_top_tracks( + self, limit=20, offset=0, time_range="medium_term" + ): + """ Get the current user's top tracks + + Parameters: + - limit - the number of entities to return + - offset - the index of the first entity to return + - time_range - Over what time frame are the affinities computed + Valid-values: short_term, medium_term, long_term + """ + return self._get( + "me/top/tracks", time_range=time_range, limit=limit, offset=offset + ) + + def current_user_recently_played(self, limit=50, after=None, before=None): + """ Get the current user's recently played tracks + + Parameters: + - limit - the number of entities to return + - after - unix timestamp in milliseconds. Returns all items + after (but not including) this cursor position. + Cannot be used if before is specified. + - before - unix timestamp in milliseconds. Returns all items + before (but not including) this cursor position. + Cannot be used if after is specified + """ + return self._get( + "me/player/recently-played", + limit=limit, + after=after, + before=before, + ) + + def user_follow_artists(self, ids=[]): + """ Follow one or more artists + Parameters: + - ids - a list of artist IDs + """ + return self._put("me/following?type=artist&ids=" + ",".join(ids)) + + def user_follow_users(self, ids=[]): + """ Follow one or more users + Parameters: + - ids - a list of user IDs + """ + return self._put("me/following?type=user&ids=" + ",".join(ids)) + + def user_unfollow_artists(self, ids=[]): + """ Unfollow one or more artists + Parameters: + - ids - a list of artist IDs + """ + return self._delete("me/following?type=artist&ids=" + ",".join(ids)) + + def user_unfollow_users(self, ids=[]): + """ Unfollow one or more users + Parameters: + - ids - a list of user IDs + """ + return self._delete("me/following?type=user&ids=" + ",".join(ids)) + + def featured_playlists( + self, locale=None, country=None, timestamp=None, limit=20, offset=0 + ): + """ Get a list of Spotify featured playlists + + Parameters: + - locale - The desired language, consisting of a lowercase ISO + 639-1 alpha-2 language code and an uppercase ISO 3166-1 alpha-2 + country code, joined by an underscore. + + - country - An ISO 3166-1 alpha-2 country code. + + - timestamp - A timestamp in ISO 8601 format: + yyyy-MM-ddTHH:mm:ss. Use this parameter to specify the user's + local time to get results tailored for that specific date and + time in the day + + - limit - The maximum number of items to return. Default: 20. + Minimum: 1. Maximum: 50 + + - offset - The index of the first item to return. Default: 0 + (the first object). Use with limit to get the next set of + items. + """ + return self._get( + "browse/featured-playlists", + locale=locale, + country=country, + timestamp=timestamp, + limit=limit, + offset=offset, + ) + + def new_releases(self, country=None, limit=20, offset=0): + """ Get a list of new album releases featured in Spotify + + Parameters: + - country - An ISO 3166-1 alpha-2 country code. + + - limit - The maximum number of items to return. Default: 20. + Minimum: 1. Maximum: 50 + + - offset - The index of the first item to return. Default: 0 + (the first object). Use with limit to get the next set of + items. + """ + return self._get( + "browse/new-releases", country=country, limit=limit, offset=offset + ) + + def category(self, category_id, country=None, locale=None): + """ Get info about a category + + Parameters: + - category_id - The Spotify category ID for the category. + + - country - An ISO 3166-1 alpha-2 country code. + - locale - The desired language, consisting of an ISO 639-1 alpha-2 + language code and an ISO 3166-1 alpha-2 country code, joined + by an underscore. + """ + return self._get( + "browse/categories/" + category_id, + country=country, + locale=locale, + ) + + def categories(self, country=None, locale=None, limit=20, offset=0): + """ Get a list of categories + + Parameters: + - country - An ISO 3166-1 alpha-2 country code. + - locale - The desired language, consisting of an ISO 639-1 alpha-2 + language code and an ISO 3166-1 alpha-2 country code, joined + by an underscore. + + - limit - The maximum number of items to return. Default: 20. + Minimum: 1. Maximum: 50 + + - offset - The index of the first item to return. Default: 0 + (the first object). Use with limit to get the next set of + items. + """ + return self._get( + "browse/categories", + country=country, + locale=locale, + limit=limit, + offset=offset, + ) + + def category_playlists( + self, category_id=None, country=None, limit=20, offset=0 + ): + """ Get a list of playlists for a specific Spotify category + + Parameters: + - category_id - The Spotify category ID for the category. + + - country - An ISO 3166-1 alpha-2 country code. + + - limit - The maximum number of items to return. Default: 20. + Minimum: 1. Maximum: 50 + + - offset - The index of the first item to return. Default: 0 + (the first object). Use with limit to get the next set of + items. + """ + return self._get( + "browse/categories/" + category_id + "/playlists", + country=country, + limit=limit, + offset=offset, + ) + + def recommendations( + self, + seed_artists=None, + seed_genres=None, + seed_tracks=None, + limit=20, + country=None, + **kwargs + ): + """ Get a list of recommended tracks for one to five seeds. + (at least one of `seed_artists`, `seed_tracks` and `seed_genres` + are needed) + + Parameters: + - seed_artists - a list of artist IDs, URIs or URLs + - seed_tracks - a list of track IDs, URIs or URLs + - seed_genres - a list of genre names. Available genres for + recommendations can be found by calling + recommendation_genre_seeds + + - country - An ISO 3166-1 alpha-2 country code. If provided, + all results will be playable in this country. + + - limit - The maximum number of items to return. Default: 20. + Minimum: 1. Maximum: 100 + + - min/max/target_ - For the tuneable track + attributes listed in the documentation, these values + provide filters and targeting on results. + """ + params = dict(limit=limit) + if seed_artists: + params["seed_artists"] = ",".join( + [self._get_id("artist", a) for a in seed_artists] + ) + if seed_genres: + params["seed_genres"] = ",".join(seed_genres) + if seed_tracks: + params["seed_tracks"] = ",".join( + [self._get_id("track", t) for t in seed_tracks] + ) + if country: + params["market"] = country + + for attribute in [ + "acousticness", + "danceability", + "duration_ms", + "energy", + "instrumentalness", + "key", + "liveness", + "loudness", + "mode", + "popularity", + "speechiness", + "tempo", + "time_signature", + "valence", + ]: + for prefix in ["min_", "max_", "target_"]: + param = prefix + attribute + if param in kwargs: + params[param] = kwargs[param] + return self._get("recommendations", **params) + + def recommendation_genre_seeds(self): + """ Get a list of genres available for the recommendations function. + """ + return self._get("recommendations/available-genre-seeds") + + def audio_analysis(self, track_id): + """ Get audio analysis for a track based upon its Spotify ID + Parameters: + - track_id - a track URI, URL or ID + """ + trid = self._get_id("track", track_id) + return self._get("audio-analysis/" + trid) + + def audio_features(self, tracks=[]): + """ Get audio features for one or multiple tracks based upon their Spotify IDs + Parameters: + - tracks - a list of track URIs, URLs or IDs, maximum: 100 ids + """ + if isinstance(tracks, str): + trackid = self._get_id("track", tracks) + results = self._get("audio-features/?ids=" + trackid) + else: + tlist = [self._get_id("track", t) for t in tracks] + results = self._get("audio-features/?ids=" + ",".join(tlist)) + # the response has changed, look for the new style first, and if + # its not there, fallback on the old style + if "audio_features" in results: + return results["audio_features"] + else: + return results + + def devices(self): + """ Get a list of user's available devices. + """ + return self._get("me/player/devices") + + def current_playback(self, market=None, additional_types=None): + """ Get information about user's current playback. + + Parameters: + - market - an ISO 3166-1 alpha-2 country code. + - additional_types - `episode` to get podcast track information + """ + return self._get("me/player", market=market, additional_types=additional_types) + + def currently_playing(self, market=None, additional_types=None): + """ Get user's currently playing track. + + Parameters: + - market - an ISO 3166-1 alpha-2 country code. + - additional_types - `episode` to get podcast track information + """ + return self._get("me/player/currently-playing", market=market, + additional_types=additional_types) + + def transfer_playback(self, device_id, force_play=True): + """ Transfer playback to another device. + Note that the API accepts a list of device ids, but only + actually supports one. + + Parameters: + - device_id - transfer playback to this device + - force_play - true: after transfer, play. false: + keep current state. + """ + data = {"device_ids": [device_id], "play": force_play} + return self._put("me/player", payload=data) + + def start_playback( + self, device_id=None, context_uri=None, uris=None, offset=None, position_ms=None + ): + """ Start or resume user's playback. + + Provide a `context_uri` to start playback of an album, + artist, or playlist. + + Provide a `uris` list to start playback of one or more + tracks. + + Provide `offset` as {"position": } or {"uri": ""} + to start playback at a particular offset. + + Parameters: + - device_id - device target for playback + - context_uri - spotify context uri to play + - uris - spotify track uris + - offset - offset into context by index or track + - position_ms - (optional) indicates from what position to start playback. + Must be a positive number. Passing in a position that is + greater than the length of the track will cause the player to + start playing the next song. + """ + if context_uri is not None and uris is not None: + logger.warning("Specify either context uri or uris, not both") + return + if uris is not None and not isinstance(uris, list): + logger.warning("URIs must be a list") + return + data = {} + if context_uri is not None: + data["context_uri"] = context_uri + if uris is not None: + data["uris"] = uris + if offset is not None: + data["offset"] = offset + if position_ms is not None: + data["position_ms"] = position_ms + return self._put( + self._append_device_id("me/player/play", device_id), payload=data + ) + + def pause_playback(self, device_id=None): + """ Pause user's playback. + + Parameters: + - device_id - device target for playback + """ + return self._put(self._append_device_id("me/player/pause", device_id)) + + def next_track(self, device_id=None): + """ Skip user's playback to next track. + + Parameters: + - device_id - device target for playback + """ + return self._post(self._append_device_id("me/player/next", device_id)) + + def previous_track(self, device_id=None): + """ Skip user's playback to previous track. + + Parameters: + - device_id - device target for playback + """ + return self._post( + self._append_device_id("me/player/previous", device_id) + ) + + def seek_track(self, position_ms, device_id=None): + """ Seek to position in current track. + + Parameters: + - position_ms - position in milliseconds to seek to + - device_id - device target for playback + """ + if not isinstance(position_ms, int): + logger.warning("Position_ms must be an integer") + return + return self._put( + self._append_device_id( + "me/player/seek?position_ms=%s" % position_ms, device_id + ) + ) + + def repeat(self, state, device_id=None): + """ Set repeat mode for playback. + + Parameters: + - state - `track`, `context`, or `off` + - device_id - device target for playback + """ + if state not in ["track", "context", "off"]: + logger.warning("Invalid state") + return + self._put( + self._append_device_id( + "me/player/repeat?state=%s" % state, device_id + ) + ) + + def volume(self, volume_percent, device_id=None): + """ Set playback volume. + + Parameters: + - volume_percent - volume between 0 and 100 + - device_id - device target for playback + """ + if not isinstance(volume_percent, int): + logger.warning("Volume must be an integer") + return + if volume_percent < 0 or volume_percent > 100: + logger.warning("Volume must be between 0 and 100, inclusive") + return + self._put( + self._append_device_id( + "me/player/volume?volume_percent=%s" % volume_percent, + device_id, + ) + ) + + def shuffle(self, state, device_id=None): + """ Toggle playback shuffling. + + Parameters: + - state - true or false + - device_id - device target for playback + """ + if not isinstance(state, bool): + logger.warning("state must be a boolean") + return + state = str(state).lower() + self._put( + self._append_device_id( + "me/player/shuffle?state=%s" % state, device_id + ) + ) + + def queue(self): + """ Gets the current user's queue """ + return self._get("me/player/queue") + + def add_to_queue(self, uri, device_id=None): + """ Adds a song to the end of a user's queue + + If device A is currently playing music and you try to add to the queue + and pass in the id for device B, you will get a + 'Player command failed: Restriction violated' error + I therefore recommend leaving device_id as None so that the active device is targeted + + :param uri: song uri, id, or url + :param device_id: + the id of a Spotify device. + If None, then the active device is used. + + """ + + uri = self._get_uri("track", uri) + + endpoint = "me/player/queue?uri=%s" % uri + + if device_id is not None: + endpoint += "&device_id=%s" % device_id + + return self._post(endpoint) + + def available_markets(self): + """ Get the list of markets where Spotify is available. + Returns a list of the countries in which Spotify is available, identified by their + ISO 3166-1 alpha-2 country code with additional country codes for special territories. + """ + return self._get("markets") + + def _append_device_id(self, path, device_id): + """ Append device ID to API path. + + Parameters: + - device_id - device id to append + """ + if device_id: + if "?" in path: + path += "&device_id=%s" % device_id + else: + path += "?device_id=%s" % device_id + return path + + def _get_id(self, type, id): + uri_match = re.search(Spotify._regex_spotify_uri, id) + if uri_match is not None: + uri_match_groups = uri_match.groupdict() + if uri_match_groups['type'] != type: + # TODO change to a ValueError in v3 + raise SpotifyException(400, -1, "Unexpected Spotify URI type.") + return uri_match_groups['id'] + + url_match = re.search(Spotify._regex_spotify_url, id) + if url_match is not None: + url_match_groups = url_match.groupdict() + if url_match_groups['type'] != type: + raise SpotifyException(400, -1, "Unexpected Spotify URL type.") + # TODO change to a ValueError in v3 + return url_match_groups['id'] + + # Raw identifiers might be passed, ensure they are also base-62 + if re.search(Spotify._regex_base62, id) is not None: + return id + + # TODO change to a ValueError in v3 + raise SpotifyException(400, -1, "Unsupported URL / URI.") + + def _get_uri(self, type, id): + if self._is_uri(id): + return id + else: + return "spotify:" + type + ":" + self._get_id(type, id) + + def _is_uri(self, uri): + return re.search(Spotify._regex_spotify_uri, uri) is not None + + def _search_multiple_markets(self, q, limit, offset, type, markets, total): + if total and limit > total: + limit = total + warnings.warn( + "limit was auto-adjusted to equal {} as it must not be higher than total".format( + total), + UserWarning, + ) + + results = defaultdict(dict) + item_types = [item_type + "s" for item_type in type.split(",")] + count = 0 + + for country in markets: + result = self._get( + "search", q=q, limit=limit, offset=offset, type=type, market=country + ) + for item_type in item_types: + results[country][item_type] = result[item_type] + + # Truncate the items list to the current limit + if len(results[country][item_type]['items']) > limit: + results[country][item_type]['items'] = \ + results[country][item_type]['items'][:limit] + + count += len(results[country][item_type]['items']) + if total and limit > total - count: + # when approaching `total` results, adjust `limit` to not request more + # items than needed + limit = total - count + + if total and count >= total: + return results + + return results diff --git a/.venv/Lib/site-packages/spotipy/exceptions.py b/.venv/Lib/site-packages/spotipy/exceptions.py new file mode 100644 index 00000000..df503f10 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/exceptions.py @@ -0,0 +1,16 @@ +class SpotifyException(Exception): + + def __init__(self, http_status, code, msg, reason=None, headers=None): + self.http_status = http_status + self.code = code + self.msg = msg + self.reason = reason + # `headers` is used to support `Retry-After` in the event of a + # 429 status code. + if headers is None: + headers = {} + self.headers = headers + + def __str__(self): + return 'http status: {0}, code:{1} - {2}, reason: {3}'.format( + self.http_status, self.code, self.msg, self.reason) diff --git a/.venv/Lib/site-packages/spotipy/oauth2.py b/.venv/Lib/site-packages/spotipy/oauth2.py new file mode 100644 index 00000000..125c87c9 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/oauth2.py @@ -0,0 +1,1308 @@ +# -*- coding: utf-8 -*- + +__all__ = [ + "SpotifyClientCredentials", + "SpotifyOAuth", + "SpotifyOauthError", + "SpotifyStateError", + "SpotifyImplicitGrant", + "SpotifyPKCE" +] + +import base64 +import logging +import os +import time +import warnings +import webbrowser + +import requests +# Workaround to support both python 2 & 3 +import six +import six.moves.urllib.parse as urllibparse +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +from six.moves.urllib_parse import parse_qsl, urlparse + +from spotipy.cache_handler import CacheFileHandler, CacheHandler +from spotipy.util import CLIENT_CREDS_ENV_VARS, get_host_port, normalize_scope + +logger = logging.getLogger(__name__) + + +class SpotifyOauthError(Exception): + """ Error during Auth Code or Implicit Grant flow """ + + def __init__(self, message, error=None, error_description=None, *args, **kwargs): + self.error = error + self.error_description = error_description + self.__dict__.update(kwargs) + super(SpotifyOauthError, self).__init__(message, *args, **kwargs) + + +class SpotifyStateError(SpotifyOauthError): + """ The state sent and state received were different """ + + def __init__(self, local_state=None, remote_state=None, message=None, + error=None, error_description=None, *args, **kwargs): + if not message: + message = ("Expected " + local_state + " but recieved " + + remote_state) + super(SpotifyOauthError, self).__init__(message, error, + error_description, *args, + **kwargs) + + +def _make_authorization_headers(client_id, client_secret): + auth_header = base64.b64encode( + six.text_type(client_id + ":" + client_secret).encode("ascii") + ) + return {"Authorization": "Basic %s" % auth_header.decode("ascii")} + + +def _ensure_value(value, env_key): + env_val = CLIENT_CREDS_ENV_VARS[env_key] + _val = value or os.getenv(env_val) + if _val is None: + msg = "No %s. Pass it or set a %s environment variable." % ( + env_key, + env_val, + ) + raise SpotifyOauthError(msg) + return _val + + +class SpotifyAuthBase(object): + def __init__(self, requests_session): + if isinstance(requests_session, requests.Session): + self._session = requests_session + else: + if requests_session: # Build a new session. + self._session = requests.Session() + else: # Use the Requests API module as a "session". + from requests import api + self._session = api + + def _normalize_scope(self, scope): + return normalize_scope(scope) + + @property + def client_id(self): + return self._client_id + + @client_id.setter + def client_id(self, val): + self._client_id = _ensure_value(val, "client_id") + + @property + def client_secret(self): + return self._client_secret + + @client_secret.setter + def client_secret(self, val): + self._client_secret = _ensure_value(val, "client_secret") + + @property + def redirect_uri(self): + return self._redirect_uri + + @redirect_uri.setter + def redirect_uri(self, val): + self._redirect_uri = _ensure_value(val, "redirect_uri") + + @staticmethod + def _get_user_input(prompt): + try: + return raw_input(prompt) + except NameError: + return input(prompt) + + @staticmethod + def is_token_expired(token_info): + now = int(time.time()) + return token_info["expires_at"] - now < 60 + + @staticmethod + def _is_scope_subset(needle_scope, haystack_scope): + needle_scope = set(needle_scope.split()) if needle_scope else set() + haystack_scope = ( + set(haystack_scope.split()) if haystack_scope else set() + ) + return needle_scope <= haystack_scope + + def _handle_oauth_error(self, http_error): + response = http_error.response + try: + error_payload = response.json() + error = error_payload.get('error') + error_description = error_payload.get('error_description') + except ValueError: + # if the response cannot be decoded into JSON (which raises a ValueError), + # then try to decode it into text + + # if we receive an empty string (which is falsy), then replace it with `None` + error = response.text or None + error_description = None + + raise SpotifyOauthError( + 'error: {0}, error_description: {1}'.format( + error, error_description + ), + error=error, + error_description=error_description + ) + + def __del__(self): + """Make sure the connection (pool) gets closed""" + if isinstance(self._session, requests.Session): + self._session.close() + + +class SpotifyClientCredentials(SpotifyAuthBase): + OAUTH_TOKEN_URL = "https://accounts.spotify.com/api/token" + + def __init__( + self, + client_id=None, + client_secret=None, + proxies=None, + requests_session=True, + requests_timeout=None, + cache_handler=None + ): + """ + Creates a Client Credentials Flow Manager. + + The Client Credentials flow is used in server-to-server authentication. + Only endpoints that do not access user information can be accessed. + This means that endpoints that require authorization scopes cannot be accessed. + The advantage, however, of this authorization flow is that it does not require any + user interaction + + You can either provide a client_id and client_secret to the + constructor or set SPOTIPY_CLIENT_ID and SPOTIPY_CLIENT_SECRET + environment variables + + Parameters: + * client_id: Must be supplied or set as environment variable + * client_secret: Must be supplied or set as environment variable + * proxies: Optional, proxy for the requests library to route through + * requests_session: A Requests session + * requests_timeout: Optional, tell Requests to stop waiting for a response after + a given number of seconds + * cache_handler: An instance of the `CacheHandler` class to handle + getting and saving cached authorization tokens. + Optional, will otherwise use `CacheFileHandler`. + (takes precedence over `cache_path` and `username`) + + """ + + super(SpotifyClientCredentials, self).__init__(requests_session) + + self.client_id = client_id + self.client_secret = client_secret + self.proxies = proxies + self.requests_timeout = requests_timeout + if cache_handler: + assert issubclass(cache_handler.__class__, CacheHandler), \ + "cache_handler must be a subclass of CacheHandler: " + str(type(cache_handler)) \ + + " != " + str(CacheHandler) + self.cache_handler = cache_handler + else: + self.cache_handler = CacheFileHandler() + + def get_access_token(self, as_dict=True, check_cache=True): + """ + If a valid access token is in memory, returns it + Else fetches a new token and returns it + + Parameters: + - as_dict - a boolean indicating if returning the access token + as a token_info dictionary, otherwise it will be returned + as a string. + """ + if as_dict: + warnings.warn( + "You're using 'as_dict = True'." + "get_access_token will return the token string directly in future " + "versions. Please adjust your code accordingly, or use " + "get_cached_token instead.", + DeprecationWarning, + stacklevel=2, + ) + + if check_cache: + token_info = self.cache_handler.get_cached_token() + if token_info and not self.is_token_expired(token_info): + return token_info if as_dict else token_info["access_token"] + + token_info = self._request_access_token() + token_info = self._add_custom_values_to_token_info(token_info) + self.cache_handler.save_token_to_cache(token_info) + return token_info if as_dict else token_info["access_token"] + + def _request_access_token(self): + """Gets client credentials access token """ + payload = {"grant_type": "client_credentials"} + + headers = _make_authorization_headers( + self.client_id, self.client_secret + ) + + logger.debug( + "sending POST request to %s with Headers: %s and Body: %r", + self.OAUTH_TOKEN_URL, headers, payload + ) + + try: + response = self._session.post( + self.OAUTH_TOKEN_URL, + data=payload, + headers=headers, + verify=True, + proxies=self.proxies, + timeout=self.requests_timeout, + ) + response.raise_for_status() + token_info = response.json() + return token_info + except requests.exceptions.HTTPError as http_error: + self._handle_oauth_error(http_error) + + def _add_custom_values_to_token_info(self, token_info): + """ + Store some values that aren't directly provided by a Web API + response. + """ + token_info["expires_at"] = int(time.time()) + token_info["expires_in"] + return token_info + + +class SpotifyOAuth(SpotifyAuthBase): + """ + Implements Authorization Code Flow for Spotify's OAuth implementation. + """ + OAUTH_AUTHORIZE_URL = "https://accounts.spotify.com/authorize" + OAUTH_TOKEN_URL = "https://accounts.spotify.com/api/token" + + def __init__( + self, + client_id=None, + client_secret=None, + redirect_uri=None, + state=None, + scope=None, + cache_path=None, + username=None, + proxies=None, + show_dialog=False, + requests_session=True, + requests_timeout=None, + open_browser=True, + cache_handler=None + ): + """ + Creates a SpotifyOAuth object + + Parameters: + * client_id: Must be supplied or set as environment variable + * client_secret: Must be supplied or set as environment variable + * redirect_uri: Must be supplied or set as environment variable + * state: Optional, no verification is performed + * scope: Optional, either a list of scopes or comma separated string of scopes. + e.g, "playlist-read-private,playlist-read-collaborative" + * cache_path: (deprecated) Optional, will otherwise be generated + (takes precedence over `username`) + * username: (deprecated) Optional or set as environment variable + (will set `cache_path` to `.cache-{username}`) + * proxies: Optional, proxy for the requests library to route through + * show_dialog: Optional, interpreted as boolean + * requests_session: A Requests session + * requests_timeout: Optional, tell Requests to stop waiting for a response after + a given number of seconds + * open_browser: Optional, whether or not the web browser should be opened to + authorize a user + * cache_handler: An instance of the `CacheHandler` class to handle + getting and saving cached authorization tokens. + Optional, will otherwise use `CacheFileHandler`. + (takes precedence over `cache_path` and `username`) + """ + + super(SpotifyOAuth, self).__init__(requests_session) + + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.state = state + self.scope = self._normalize_scope(scope) + if username or cache_path: + warnings.warn("Specifying cache_path or username as arguments to SpotifyOAuth " + + "will be deprecated. Instead, please create a CacheFileHandler " + + "instance with the desired cache_path and username and pass it " + + "to SpotifyOAuth as the cache_handler. For example:\n\n" + + "\tfrom spotipy.oauth2 import CacheFileHandler\n" + + "\thandler = CacheFileHandler(cache_path=cache_path, " + + "username=username)\n" + + "\tsp = spotipy.SpotifyOAuth(client_id, client_secret, " + + "redirect_uri," + + " cache_handler=handler)", + DeprecationWarning + ) + if cache_handler: + warnings.warn("A cache_handler has been specified along with a cache_path or " + + "username. The cache_path and username arguments will be ignored.") + if cache_handler: + assert issubclass(cache_handler.__class__, CacheHandler), \ + "cache_handler must be a subclass of CacheHandler: " + str(type(cache_handler)) \ + + " != " + str(CacheHandler) + self.cache_handler = cache_handler + else: + username = (username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"])) + self.cache_handler = CacheFileHandler( + username=username, + cache_path=cache_path + ) + self.proxies = proxies + self.requests_timeout = requests_timeout + self.show_dialog = show_dialog + self.open_browser = open_browser + + def validate_token(self, token_info): + if token_info is None: + return None + + # if scopes don't match, then bail + if "scope" not in token_info or not self._is_scope_subset( + self.scope, token_info["scope"] + ): + return None + + if self.is_token_expired(token_info): + token_info = self.refresh_access_token( + token_info["refresh_token"] + ) + + return token_info + + def get_authorize_url(self, state=None): + """ Gets the URL to use to authorize this app + """ + payload = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + } + if self.scope: + payload["scope"] = self.scope + if state is None: + state = self.state + if state is not None: + payload["state"] = state + if self.show_dialog: + payload["show_dialog"] = True + + urlparams = urllibparse.urlencode(payload) + + return "%s?%s" % (self.OAUTH_AUTHORIZE_URL, urlparams) + + def parse_response_code(self, url): + """ Parse the response code in the given response url + + Parameters: + - url - the response url + """ + _, code = self.parse_auth_response_url(url) + if code is None: + return url + else: + return code + + @staticmethod + def parse_auth_response_url(url): + query_s = urlparse(url).query + form = dict(parse_qsl(query_s)) + if "error" in form: + raise SpotifyOauthError("Received error from auth server: " + "{}".format(form["error"]), + error=form["error"]) + return tuple(form.get(param) for param in ["state", "code"]) + + def _make_authorization_headers(self): + return _make_authorization_headers(self.client_id, self.client_secret) + + def _open_auth_url(self): + auth_url = self.get_authorize_url() + try: + webbrowser.open(auth_url) + logger.info("Opened %s in your browser", auth_url) + except webbrowser.Error: + logger.error("Please navigate here: %s", auth_url) + + def _get_auth_response_interactive(self, open_browser=False): + if open_browser: + self._open_auth_url() + prompt = "Enter the URL you were redirected to: " + else: + url = self.get_authorize_url() + prompt = ( + "Go to the following URL: {}\n" + "Enter the URL you were redirected to: ".format(url) + ) + response = self._get_user_input(prompt) + state, code = SpotifyOAuth.parse_auth_response_url(response) + if self.state is not None and self.state != state: + raise SpotifyStateError(self.state, state) + return code + + def _get_auth_response_local_server(self, redirect_port): + server = start_local_http_server(redirect_port) + self._open_auth_url() + server.handle_request() + + if server.error is not None: + raise server.error + elif self.state is not None and server.state != self.state: + raise SpotifyStateError(self.state, server.state) + elif server.auth_code is not None: + return server.auth_code + else: + raise SpotifyOauthError("Server listening on localhost has not been accessed") + + def get_auth_response(self, open_browser=None): + logger.info('User authentication requires interaction with your ' + 'web browser. Once you enter your credentials and ' + 'give authorization, you will be redirected to ' + 'a url. Paste that url you were directed to to ' + 'complete the authorization.') + + redirect_info = urlparse(self.redirect_uri) + redirect_host, redirect_port = get_host_port(redirect_info.netloc) + + if open_browser is None: + open_browser = self.open_browser + + if ( + open_browser + and redirect_host in ("127.0.0.1", "localhost") + and redirect_info.scheme == "http" + ): + # Only start a local http server if a port is specified + if redirect_port: + return self._get_auth_response_local_server(redirect_port) + else: + logger.warning('Using `%s` as redirect URI without a port. ' + 'Specify a port (e.g. `%s:8080`) to allow ' + 'automatic retrieval of authentication code ' + 'instead of having to copy and paste ' + 'the URL your browser is redirected to.', + redirect_host, redirect_host) + + return self._get_auth_response_interactive(open_browser=open_browser) + + def get_authorization_code(self, response=None): + if response: + return self.parse_response_code(response) + return self.get_auth_response() + + def get_access_token(self, code=None, as_dict=True, check_cache=True): + """ Gets the access token for the app given the code + + Parameters: + - code - the response code + - as_dict - a boolean indicating if returning the access token + as a token_info dictionary, otherwise it will be returned + as a string. + """ + if as_dict: + warnings.warn( + "You're using 'as_dict = True'." + "get_access_token will return the token string directly in future " + "versions. Please adjust your code accordingly, or use " + "get_cached_token instead.", + DeprecationWarning, + stacklevel=2, + ) + if check_cache: + token_info = self.validate_token(self.cache_handler.get_cached_token()) + if token_info is not None: + if self.is_token_expired(token_info): + token_info = self.refresh_access_token( + token_info["refresh_token"] + ) + return token_info if as_dict else token_info["access_token"] + + payload = { + "redirect_uri": self.redirect_uri, + "code": code or self.get_auth_response(), + "grant_type": "authorization_code", + } + if self.scope: + payload["scope"] = self.scope + if self.state: + payload["state"] = self.state + + headers = self._make_authorization_headers() + + logger.debug( + "sending POST request to %s with Headers: %s and Body: %r", + self.OAUTH_TOKEN_URL, headers, payload + ) + + try: + response = self._session.post( + self.OAUTH_TOKEN_URL, + data=payload, + headers=headers, + verify=True, + proxies=self.proxies, + timeout=self.requests_timeout, + ) + response.raise_for_status() + token_info = response.json() + token_info = self._add_custom_values_to_token_info(token_info) + self.cache_handler.save_token_to_cache(token_info) + return token_info if as_dict else token_info["access_token"] + except requests.exceptions.HTTPError as http_error: + self._handle_oauth_error(http_error) + + def refresh_access_token(self, refresh_token): + payload = { + "refresh_token": refresh_token, + "grant_type": "refresh_token", + } + + headers = self._make_authorization_headers() + + logger.debug( + "sending POST request to %s with Headers: %s and Body: %r", + self.OAUTH_TOKEN_URL, headers, payload + ) + + try: + response = self._session.post( + self.OAUTH_TOKEN_URL, + data=payload, + headers=headers, + proxies=self.proxies, + timeout=self.requests_timeout, + ) + response.raise_for_status() + token_info = response.json() + token_info = self._add_custom_values_to_token_info(token_info) + if "refresh_token" not in token_info: + token_info["refresh_token"] = refresh_token + self.cache_handler.save_token_to_cache(token_info) + return token_info + except requests.exceptions.HTTPError as http_error: + self._handle_oauth_error(http_error) + + def _add_custom_values_to_token_info(self, token_info): + """ + Store some values that aren't directly provided by a Web API + response. + """ + token_info["expires_at"] = int(time.time()) + token_info["expires_in"] + token_info["scope"] = self.scope + return token_info + + def get_cached_token(self): + warnings.warn("Calling get_cached_token directly on the SpotifyOAuth object will be " + + "deprecated. Instead, please specify a CacheFileHandler instance as " + + "the cache_handler in SpotifyOAuth and use the CacheFileHandler's " + + "get_cached_token method. You can replace:\n\tsp.get_cached_token()" + + "\n\nWith:\n\tsp.validate_token(sp.cache_handler.get_cached_token())", + DeprecationWarning + ) + return self.validate_token(self.cache_handler.get_cached_token()) + + def _save_token_info(self, token_info): + warnings.warn("Calling _save_token_info directly on the SpotifyOAuth object will be " + + "deprecated. Instead, please specify a CacheFileHandler instance as " + + "the cache_handler in SpotifyOAuth and use the CacheFileHandler's " + + "save_token_to_cache method.", + DeprecationWarning + ) + self.cache_handler.save_token_to_cache(token_info) + return None + + +class SpotifyPKCE(SpotifyAuthBase): + """ Implements PKCE Authorization Flow for client apps + + This auth manager enables *user and non-user* endpoints with only + a client ID, redirect URI, and username. When the app requests + an access token for the first time, the user is prompted to + authorize the new client app. After authorizing the app, the client + app is then given both access and refresh tokens. This is the + preferred way of authorizing a mobile/desktop client. + + """ + + OAUTH_AUTHORIZE_URL = "https://accounts.spotify.com/authorize" + OAUTH_TOKEN_URL = "https://accounts.spotify.com/api/token" + + def __init__(self, + client_id=None, + redirect_uri=None, + state=None, + scope=None, + cache_path=None, + username=None, + proxies=None, + requests_timeout=None, + requests_session=True, + open_browser=True, + cache_handler=None): + """ + Creates Auth Manager with the PKCE Auth flow. + + Parameters: + * client_id: Must be supplied or set as environment variable + * redirect_uri: Must be supplied or set as environment variable + * state: Optional, no verification is performed + * scope: Optional, either a list of scopes or comma separated string of scopes. + e.g, "playlist-read-private,playlist-read-collaborative" + * cache_path: (deprecated) Optional, will otherwise be generated + (takes precedence over `username`) + * username: (deprecated) Optional or set as environment variable + (will set `cache_path` to `.cache-{username}`) + * proxies: Optional, proxy for the requests library to route through + * requests_timeout: Optional, tell Requests to stop waiting for a response after + a given number of seconds + * requests_session: A Requests session + * open_browser: Optional, whether or not the web browser should be opened to + authorize a user + * cache_handler: An instance of the `CacheHandler` class to handle + getting and saving cached authorization tokens. + Optional, will otherwise use `CacheFileHandler`. + (takes precedence over `cache_path` and `username`) + """ + + super(SpotifyPKCE, self).__init__(requests_session) + self.client_id = client_id + self.redirect_uri = redirect_uri + self.state = state + self.scope = self._normalize_scope(scope) + if username or cache_path: + warnings.warn("Specifying cache_path or username as arguments to SpotifyPKCE " + + "will be deprecated. Instead, please create a CacheFileHandler " + + "instance with the desired cache_path and username and pass it " + + "to SpotifyPKCE as the cache_handler. For example:\n\n" + + "\tfrom spotipy.oauth2 import CacheFileHandler\n" + + "\thandler = CacheFileHandler(cache_path=cache_path, " + + "username=username)\n" + + "\tsp = spotipy.SpotifyImplicitGrant(client_id, client_secret, " + + "redirect_uri, cache_handler=handler)", + DeprecationWarning + ) + if cache_handler: + warnings.warn("A cache_handler has been specified along with a cache_path or " + + "username. The cache_path and username arguments will be ignored.") + if cache_handler: + assert issubclass(type(cache_handler), CacheHandler), \ + "type(cache_handler): " + str(type(cache_handler)) + " != " + str(CacheHandler) + self.cache_handler = cache_handler + else: + username = (username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"])) + self.cache_handler = CacheFileHandler( + username=username, + cache_path=cache_path + ) + self.proxies = proxies + self.requests_timeout = requests_timeout + + self._code_challenge_method = "S256" # Spotify requires SHA256 + self.code_verifier = None + self.code_challenge = None + self.authorization_code = None + self.open_browser = open_browser + + def _get_code_verifier(self): + """ Spotify PCKE code verifier - See step 1 of the reference guide below + Reference: + https://developer.spotify.com/documentation/general/guides/authorization-guide/#authorization-code-flow-with-proof-key-for-code-exchange-pkce + """ + # Range (33,96) is used to select between 44-128 base64 characters for the + # next operation. The range looks weird because base64 is 6 bytes + import random + length = random.randint(33, 96) + + # The seeded length generates between a 44 and 128 base64 characters encoded string + try: + import secrets + verifier = secrets.token_urlsafe(length) + except ImportError: # For python 3.5 support + import base64 + import os + rand_bytes = os.urandom(length) + verifier = base64.urlsafe_b64encode(rand_bytes).decode('utf-8').replace('=', '') + return verifier + + def _get_code_challenge(self): + """ Spotify PCKE code challenge - See step 1 of the reference guide below + Reference: + https://developer.spotify.com/documentation/general/guides/authorization-guide/#authorization-code-flow-with-proof-key-for-code-exchange-pkce + """ + import base64 + import hashlib + code_challenge_digest = hashlib.sha256(self.code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode('utf-8') + return code_challenge.replace('=', '') + + def get_authorize_url(self, state=None): + """ Gets the URL to use to authorize this app """ + if not self.code_challenge: + self.get_pkce_handshake_parameters() + payload = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "code_challenge_method": self._code_challenge_method, + "code_challenge": self.code_challenge + } + if self.scope: + payload["scope"] = self.scope + if state is None: + state = self.state + if state is not None: + payload["state"] = state + urlparams = urllibparse.urlencode(payload) + return "%s?%s" % (self.OAUTH_AUTHORIZE_URL, urlparams) + + def _open_auth_url(self, state=None): + auth_url = self.get_authorize_url(state) + try: + webbrowser.open(auth_url) + logger.info("Opened %s in your browser", auth_url) + except webbrowser.Error: + logger.error("Please navigate here: %s", auth_url) + + def _get_auth_response(self, open_browser=None): + logger.info('User authentication requires interaction with your ' + 'web browser. Once you enter your credentials and ' + 'give authorization, you will be redirected to ' + 'a url. Paste that url you were directed to to ' + 'complete the authorization.') + + redirect_info = urlparse(self.redirect_uri) + redirect_host, redirect_port = get_host_port(redirect_info.netloc) + + if open_browser is None: + open_browser = self.open_browser + + if ( + open_browser + and redirect_host in ("127.0.0.1", "localhost") + and redirect_info.scheme == "http" + ): + # Only start a local http server if a port is specified + if redirect_port: + return self._get_auth_response_local_server(redirect_port) + else: + logger.warning('Using `%s` as redirect URI without a port. ' + 'Specify a port (e.g. `%s:8080`) to allow ' + 'automatic retrieval of authentication code ' + 'instead of having to copy and paste ' + 'the URL your browser is redirected to.', + redirect_host, redirect_host) + return self._get_auth_response_interactive(open_browser=open_browser) + + def _get_auth_response_local_server(self, redirect_port): + server = start_local_http_server(redirect_port) + self._open_auth_url() + server.handle_request() + + if self.state is not None and server.state != self.state: + raise SpotifyStateError(self.state, server.state) + + if server.auth_code is not None: + return server.auth_code + elif server.error is not None: + raise SpotifyOauthError("Received error from OAuth server: {}".format(server.error)) + else: + raise SpotifyOauthError("Server listening on localhost has not been accessed") + + def _get_auth_response_interactive(self, open_browser=False): + if open_browser or self.open_browser: + self._open_auth_url() + prompt = "Enter the URL you were redirected to: " + else: + url = self.get_authorize_url() + prompt = ( + "Go to the following URL: {}\n" + "Enter the URL you were redirected to: ".format(url) + ) + response = self._get_user_input(prompt) + state, code = self.parse_auth_response_url(response) + if self.state is not None and self.state != state: + raise SpotifyStateError(self.state, state) + return code + + def get_authorization_code(self, response=None): + if response: + return self.parse_response_code(response) + return self._get_auth_response() + + def validate_token(self, token_info): + if token_info is None: + return None + + # if scopes don't match, then bail + if "scope" not in token_info or not self._is_scope_subset( + self.scope, token_info["scope"] + ): + return None + + if self.is_token_expired(token_info): + token_info = self.refresh_access_token( + token_info["refresh_token"] + ) + + return token_info + + def _add_custom_values_to_token_info(self, token_info): + """ + Store some values that aren't directly provided by a Web API + response. + """ + token_info["expires_at"] = int(time.time()) + token_info["expires_in"] + return token_info + + def get_pkce_handshake_parameters(self): + self.code_verifier = self._get_code_verifier() + self.code_challenge = self._get_code_challenge() + + def get_access_token(self, code=None, check_cache=True): + """ Gets the access token for the app + + If the code is not given and no cached token is used, an + authentication window will be shown to the user to get a new + code. + + Parameters: + - code - the response code from authentication + - check_cache - if true, checks for a locally stored token + before requesting a new token + """ + + if check_cache: + token_info = self.validate_token(self.cache_handler.get_cached_token()) + if token_info is not None: + if self.is_token_expired(token_info): + token_info = self.refresh_access_token( + token_info["refresh_token"] + ) + return token_info["access_token"] + + if self.code_verifier is None or self.code_challenge is None: + self.get_pkce_handshake_parameters() + + payload = { + "client_id": self.client_id, + "grant_type": "authorization_code", + "code": code or self.get_authorization_code(), + "redirect_uri": self.redirect_uri, + "code_verifier": self.code_verifier + } + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + logger.debug( + "sending POST request to %s with Headers: %s and Body: %r", + self.OAUTH_TOKEN_URL, headers, payload + ) + + try: + response = self._session.post( + self.OAUTH_TOKEN_URL, + data=payload, + headers=headers, + verify=True, + proxies=self.proxies, + timeout=self.requests_timeout, + ) + response.raise_for_status() + token_info = response.json() + token_info = self._add_custom_values_to_token_info(token_info) + self.cache_handler.save_token_to_cache(token_info) + return token_info["access_token"] + except requests.exceptions.HTTPError as http_error: + self._handle_oauth_error(http_error) + + def refresh_access_token(self, refresh_token): + payload = { + "refresh_token": refresh_token, + "grant_type": "refresh_token", + "client_id": self.client_id, + } + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + logger.debug( + "sending POST request to %s with Headers: %s and Body: %r", + self.OAUTH_TOKEN_URL, headers, payload + ) + + try: + response = self._session.post( + self.OAUTH_TOKEN_URL, + data=payload, + headers=headers, + proxies=self.proxies, + timeout=self.requests_timeout, + ) + response.raise_for_status() + token_info = response.json() + token_info = self._add_custom_values_to_token_info(token_info) + if "refresh_token" not in token_info: + token_info["refresh_token"] = refresh_token + self.cache_handler.save_token_to_cache(token_info) + return token_info + except requests.exceptions.HTTPError as http_error: + self._handle_oauth_error(http_error) + + def parse_response_code(self, url): + """ Parse the response code in the given response url + + Parameters: + - url - the response url + """ + _, code = self.parse_auth_response_url(url) + if code is None: + return url + else: + return code + + @staticmethod + def parse_auth_response_url(url): + return SpotifyOAuth.parse_auth_response_url(url) + + def get_cached_token(self): + warnings.warn("Calling get_cached_token directly on the SpotifyPKCE object will be " + + "deprecated. Instead, please specify a CacheFileHandler instance as " + + "the cache_handler in SpotifyOAuth and use the CacheFileHandler's " + + "get_cached_token method. You can replace:\n\tsp.get_cached_token()" + + "\n\nWith:\n\tsp.validate_token(sp.cache_handler.get_cached_token())", + DeprecationWarning + ) + return self.validate_token(self.cache_handler.get_cached_token()) + + def _save_token_info(self, token_info): + warnings.warn("Calling _save_token_info directly on the SpotifyOAuth object will be " + + "deprecated. Instead, please specify a CacheFileHandler instance as " + + "the cache_handler in SpotifyOAuth and use the CacheFileHandler's " + + "save_token_to_cache method.", + DeprecationWarning + ) + self.cache_handler.save_token_to_cache(token_info) + return None + + +class SpotifyImplicitGrant(SpotifyAuthBase): + """ Implements Implicit Grant Flow for client apps + + This auth manager enables *user and non-user* endpoints with only + a client secret, redirect uri, and username. The user will need to + copy and paste a URI from the browser every hour. + + Security Warning + ----------------- + The OAuth standard no longer recommends the Implicit Grant Flow for + client-side code. Spotify has implemented the OAuth-suggested PKCE + extension that removes the need for a client secret in the + Authentication Code flow. Use the SpotifyPKCE auth manager instead + of SpotifyImplicitGrant. + + SpotifyPKCE contains all of the functionality of + SpotifyImplicitGrant, plus automatic response retrieval and + refreshable tokens. Only a few replacements need to be made: + + * get_auth_response()['access_token'] -> + get_access_token(get_authorization_code()) + * get_auth_response() -> + get_access_token(get_authorization_code()); get_cached_token() + * parse_response_token(url)['access_token'] -> + get_access_token(parse_response_code(url)) + * parse_response_token(url) -> + get_access_token(parse_response_code(url)); get_cached_token() + + The security concern in the Implicit Grant flow is that the token is + returned in the URL and can be intercepted through the browser. A + request with an authorization code and proof of origin could not be + easily intercepted without a compromised network. + """ + OAUTH_AUTHORIZE_URL = "https://accounts.spotify.com/authorize" + + def __init__(self, + client_id=None, + redirect_uri=None, + state=None, + scope=None, + cache_path=None, + username=None, + show_dialog=False, + cache_handler=None): + """ Creates Auth Manager using the Implicit Grant flow + + **See help(SpotifyImplicitGrant) for full Security Warning** + + Parameters + ---------- + * client_id: Must be supplied or set as environment variable + * redirect_uri: Must be supplied or set as environment variable + * state: May be supplied, no verification is performed + * scope: Optional, either a list of scopes or comma separated string of scopes. + e.g, "playlist-read-private,playlist-read-collaborative" + * cache_handler: An instance of the `CacheHandler` class to handle + getting and saving cached authorization tokens. + May be supplied, will otherwise use `CacheFileHandler`. + (takes precedence over `cache_path` and `username`) + * cache_path: (deprecated) May be supplied, will otherwise be generated + (takes precedence over `username`) + * username: (deprecated) May be supplied or set as environment variable + (will set `cache_path` to `.cache-{username}`) + * show_dialog: Interpreted as boolean + """ + logger.warning("The OAuth standard no longer recommends the Implicit " + "Grant Flow for client-side code. Use the SpotifyPKCE " + "auth manager instead of SpotifyImplicitGrant. For " + "more details and a guide to switching, see " + "help(SpotifyImplicitGrant).") + + self.client_id = client_id + self.redirect_uri = redirect_uri + self.state = state + if username or cache_path: + warnings.warn("Specifying cache_path or username as arguments to " + + "SpotifyImplicitGrant will be deprecated. Instead, please create " + + "a CacheFileHandler instance with the desired cache_path and " + + "username and pass it to SpotifyImplicitGrant as the " + + "cache_handler. For example:\n\n" + + "\tfrom spotipy.oauth2 import CacheFileHandler\n" + + "\thandler = CacheFileHandler(cache_path=cache_path, " + + "username=username)\n" + + "\tsp = spotipy.SpotifyImplicitGrant(client_id, client_secret, " + + "redirect_uri, cache_handler=handler)", + DeprecationWarning + ) + if cache_handler: + warnings.warn("A cache_handler has been specified along with a cache_path or " + + "username. The cache_path and username arguments will be ignored.") + if cache_handler: + assert issubclass(type(cache_handler), CacheHandler), \ + "type(cache_handler): " + str(type(cache_handler)) + " != " + str(CacheHandler) + self.cache_handler = cache_handler + else: + username = (username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"])) + self.cache_handler = CacheFileHandler( + username=username, + cache_path=cache_path + ) + self.scope = self._normalize_scope(scope) + self.show_dialog = show_dialog + self._session = None # As to not break inherited __del__ + + def validate_token(self, token_info): + if token_info is None: + return None + + # if scopes don't match, then bail + if "scope" not in token_info or not self._is_scope_subset( + self.scope, token_info["scope"] + ): + return None + + if self.is_token_expired(token_info): + return None + + return token_info + + def get_access_token(self, + state=None, + response=None, + check_cache=True): + """ Gets Auth Token from cache (preferred) or user interaction + + Parameters + ---------- + * state: May be given, overrides (without changing) self.state + * response: URI with token, can break expiration checks + * check_cache: Interpreted as boolean + """ + if check_cache: + token_info = self.validate_token(self.cache_handler.get_cached_token()) + if not (token_info is None or self.is_token_expired(token_info)): + return token_info["access_token"] + + if response: + token_info = self.parse_response_token(response) + else: + token_info = self.get_auth_response(state) + token_info = self._add_custom_values_to_token_info(token_info) + self.cache_handler.save_token_to_cache(token_info) + + return token_info["access_token"] + + def get_authorize_url(self, state=None): + """ Gets the URL to use to authorize this app """ + payload = { + "client_id": self.client_id, + "response_type": "token", + "redirect_uri": self.redirect_uri, + } + if self.scope: + payload["scope"] = self.scope + if state is None: + state = self.state + if state is not None: + payload["state"] = state + if self.show_dialog: + payload["show_dialog"] = True + + urlparams = urllibparse.urlencode(payload) + + return "%s?%s" % (self.OAUTH_AUTHORIZE_URL, urlparams) + + def parse_response_token(self, url, state=None): + """ Parse the response code in the given response url """ + remote_state, token, t_type, exp_in = self.parse_auth_response_url(url) + if state is None: + state = self.state + if state is not None and remote_state != state: + raise SpotifyStateError(state, remote_state) + return {"access_token": token, "token_type": t_type, + "expires_in": exp_in, "state": state} + + @staticmethod + def parse_auth_response_url(url): + url_components = urlparse(url) + fragment_s = url_components.fragment + query_s = url_components.query + form = dict(i.split('=') for i + in (fragment_s or query_s or url).split('&')) + if "error" in form: + raise SpotifyOauthError("Received error from auth server: " + "{}".format(form["error"]), + state=form["state"]) + if "expires_in" in form: + form["expires_in"] = int(form["expires_in"]) + return tuple(form.get(param) for param in ["state", "access_token", + "token_type", "expires_in"]) + + def _open_auth_url(self, state=None): + auth_url = self.get_authorize_url(state) + try: + webbrowser.open(auth_url) + logger.info("Opened %s in your browser", auth_url) + except webbrowser.Error: + logger.error("Please navigate here: %s", auth_url) + + def get_auth_response(self, state=None): + """ Gets a new auth **token** with user interaction """ + logger.info('User authentication requires interaction with your ' + 'web browser. Once you enter your credentials and ' + 'give authorization, you will be redirected to ' + 'a url. Paste that url you were directed to to ' + 'complete the authorization.') + + redirect_info = urlparse(self.redirect_uri) + redirect_host, redirect_port = get_host_port(redirect_info.netloc) + # Implicit Grant tokens are returned in a hash fragment + # which is only available to the browser. Therefore, interactive + # URL retrieval is required. + if (redirect_host in ("127.0.0.1", "localhost") + and redirect_info.scheme == "http" and redirect_port): + logger.warning('Using a local redirect URI with a ' + 'port, likely expecting automatic ' + 'retrieval. Due to technical limitations, ' + 'the authentication token cannot be ' + 'automatically retrieved and must be ' + 'copied and pasted.') + + self._open_auth_url(state) + logger.info('Paste that url you were directed to in order to ' + 'complete the authorization') + response = SpotifyImplicitGrant._get_user_input("Enter the URL you " + "were redirected to: ") + return self.parse_response_token(response, state) + + def _add_custom_values_to_token_info(self, token_info): + """ + Store some values that aren't directly provided by a Web API + response. + """ + token_info["expires_at"] = int(time.time()) + token_info["expires_in"] + token_info["scope"] = self.scope + return token_info + + def get_cached_token(self): + warnings.warn("Calling get_cached_token directly on the SpotifyImplicitGrant " + + "object will be deprecated. Instead, please specify a " + + "CacheFileHandler instance as the cache_handler in SpotifyOAuth " + + "and use the CacheFileHandler's get_cached_token method. " + + "You can replace:\n\tsp.get_cached_token()" + + "\n\nWith:\n\tsp.validate_token(sp.cache_handler.get_cached_token())", + DeprecationWarning + ) + return self.validate_token(self.cache_handler.get_cached_token()) + + def _save_token_info(self, token_info): + warnings.warn("Calling _save_token_info directly on the SpotifyImplicitGrant " + + "object will be deprecated. Instead, please specify a " + + "CacheFileHandler instance as the cache_handler in SpotifyOAuth " + + "and use the CacheFileHandler's save_token_to_cache method.", + DeprecationWarning + ) + self.cache_handler.save_token_to_cache(token_info) + return None + + +class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.server.auth_code = self.server.error = None + try: + state, auth_code = SpotifyOAuth.parse_auth_response_url(self.path) + self.server.state = state + self.server.auth_code = auth_code + except SpotifyOauthError as error: + self.server.error = error + + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + + if self.server.auth_code: + status = "successful" + elif self.server.error: + status = "failed ({})".format(self.server.error) + else: + self._write("

Invalid request

") + return + + self._write(""" + + +

Authentication status: {}

+This window can be closed. + + + +""".format(status)) + + def _write(self, text): + return self.wfile.write(text.encode("utf-8")) + + def log_message(self, format, *args): + return + + +def start_local_http_server(port, handler=RequestHandler): + server = HTTPServer(("127.0.0.1", port), handler) + server.allow_reuse_address = True + server.auth_code = None + server.auth_token_form = None + server.error = None + return server diff --git a/.venv/Lib/site-packages/spotipy/util.py b/.venv/Lib/site-packages/spotipy/util.py new file mode 100644 index 00000000..b949a618 --- /dev/null +++ b/.venv/Lib/site-packages/spotipy/util.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +""" Shows a user's playlists (need to be authenticated via oauth) """ + +__all__ = ["CLIENT_CREDS_ENV_VARS", "prompt_for_user_token"] + +import logging +import os +import warnings + +import spotipy + +LOGGER = logging.getLogger(__name__) + +CLIENT_CREDS_ENV_VARS = { + "client_id": "SPOTIPY_CLIENT_ID", + "client_secret": "SPOTIPY_CLIENT_SECRET", + "client_username": "SPOTIPY_CLIENT_USERNAME", + "redirect_uri": "SPOTIPY_REDIRECT_URI", +} + + +def prompt_for_user_token( + username=None, + scope=None, + client_id=None, + client_secret=None, + redirect_uri=None, + cache_path=None, + oauth_manager=None, + show_dialog=False +): + warnings.warn( + "'prompt_for_user_token' is deprecated." + "Use the following instead: " + " auth_manager=SpotifyOAuth(scope=scope)" + " spotipy.Spotify(auth_manager=auth_manager)", + DeprecationWarning + ) + """ prompts the user to login if necessary and returns + the user token suitable for use with the spotipy.Spotify + constructor + + Parameters: + + - username - the Spotify username (optional) + - scope - the desired scope of the request (optional) + - client_id - the client id of your app (required) + - client_secret - the client secret of your app (required) + - redirect_uri - the redirect URI of your app (required) + - cache_path - path to location to save tokens (optional) + - oauth_manager - Oauth manager object (optional) + - show_dialog - If true, a login prompt always shows (optional, defaults to False) + + """ + if not oauth_manager: + if not client_id: + client_id = os.getenv("SPOTIPY_CLIENT_ID") + + if not client_secret: + client_secret = os.getenv("SPOTIPY_CLIENT_SECRET") + + if not redirect_uri: + redirect_uri = os.getenv("SPOTIPY_REDIRECT_URI") + + if not client_id: + LOGGER.warning( + """ + You need to set your Spotify API credentials. + You can do this by setting environment variables like so: + + export SPOTIPY_CLIENT_ID='your-spotify-client-id' + export SPOTIPY_CLIENT_SECRET='your-spotify-client-secret' + export SPOTIPY_REDIRECT_URI='your-app-redirect-url' + + Get your credentials at + https://developer.spotify.com/my-applications + """ + ) + raise spotipy.SpotifyException(550, -1, "no credentials set") + + sp_oauth = oauth_manager or spotipy.SpotifyOAuth( + client_id, + client_secret, + redirect_uri, + scope=scope, + cache_path=cache_path, + username=username, + show_dialog=show_dialog + ) + + # try to get a valid token for this user, from the cache, + # if not in the cache, then create a new (this will send + # the user to a web page where they can authorize this app) + + token_info = sp_oauth.validate_token(sp_oauth.cache_handler.get_cached_token()) + + if not token_info: + code = sp_oauth.get_auth_response() + token = sp_oauth.get_access_token(code, as_dict=False) + else: + return token_info["access_token"] + + # Auth'ed API request + if token: + return token + else: + return None + + +def get_host_port(netloc): + if ":" in netloc: + host, port = netloc.split(":", 1) + port = int(port) + else: + host = netloc + port = None + + return host, port + + +def normalize_scope(scope): + if scope: + if isinstance(scope, str): + scopes = scope.split(',') + elif isinstance(scope, list) or isinstance(scope, tuple): + scopes = scope + else: + raise Exception( + "Unsupported scope value, please either provide a list of scopes, " + "or a string of scopes separated by commas" + ) + return " ".join(sorted(scopes)) + else: + return None diff --git a/.venv/Scripts/dotenv.exe b/.venv/Scripts/dotenv.exe new file mode 100644 index 00000000..9f330ce0 Binary files /dev/null and b/.venv/Scripts/dotenv.exe differ diff --git a/__pycache__/webapp.cpython-310.pyc b/__pycache__/webapp.cpython-310.pyc new file mode 100644 index 00000000..0d98b847 Binary files /dev/null and b/__pycache__/webapp.cpython-310.pyc differ diff --git a/__pycache__/webapp.cpython-311.pyc b/__pycache__/webapp.cpython-311.pyc index 333d01c8..e16d30f5 100644 Binary files a/__pycache__/webapp.cpython-311.pyc and b/__pycache__/webapp.cpython-311.pyc differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..068cdab6 Binary files /dev/null and b/requirements.txt differ diff --git a/static/logo.png b/static/logo.png new file mode 100644 index 00000000..c19556ae Binary files /dev/null and b/static/logo.png differ diff --git a/tailwind.config.js b/tailwind.config.js new file mode 100644 index 00000000..80af1a8e --- /dev/null +++ b/tailwind.config.js @@ -0,0 +1,9 @@ +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: ["./templates/*.{html,js}"], + theme: { + extend: {}, + }, + plugins: [], +} + diff --git a/templates/about.html b/templates/about.html new file mode 100644 index 00000000..9e05cdc8 --- /dev/null +++ b/templates/about.html @@ -0,0 +1,112 @@ + + + + + + About Rememberescence + + + + + + + + +
+

About Rememberescence

+

Rememberescence is a platform that allows you to discover and share your memories with others. By answering a few simple questions, we can generate a unique story that captures your childhood memories. Sign up now to get started!

+

Made with love by Yihoi Jung, Triya Augustine, Luana Madeira, and Arianne Ghislaine Rull

+ Sign Up +
+ + \ No newline at end of file diff --git a/templates/home copy.html b/templates/home copy.html new file mode 100644 index 00000000..298e70f6 --- /dev/null +++ b/templates/home copy.html @@ -0,0 +1,146 @@ + + + + + + Flask Cohere App + + + + +
+ {% if session %} +

Welcome {{session.userinfo.name}}!

+

Logout

+ + + {% else %} +
+

Welcome to Rememberescence

+

Login to get started!

+
+ {% endif %} + + +

Flask Cohere App

+ +
+ + +
+ + +
+ + +
+ + +
+ + +
+ +
+ +
+ + + + + diff --git a/templates/home.html b/templates/home.html index 72e6d0d6..298e70f6 100644 --- a/templates/home.html +++ b/templates/home.html @@ -4,47 +4,143 @@ Flask Cohere App + - - -

Flask Cohere App

- -
- -
- -
- -
- -
+ + +
+ {% if session %} +

Welcome {{session.userinfo.name}}!

+

Logout

+ + + {% else %} +
+

Welcome to Rememberescence

+

Login to get started!

+
+ {% endif %} + + +

Flask Cohere App

+ +
+ + +
+ + +
+ + +
+ + +
+ + +
+ +
+ +
- \ No newline at end of file + diff --git a/templates/input.css b/templates/input.css new file mode 100644 index 00000000..bd6213e1 --- /dev/null +++ b/templates/input.css @@ -0,0 +1,3 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; \ No newline at end of file diff --git a/templates/login copy 2.html b/templates/login copy 2.html new file mode 100644 index 00000000..e13fb869 --- /dev/null +++ b/templates/login copy 2.html @@ -0,0 +1,160 @@ + + + + + + + + Flask Cohere App + + + + + + + + +
+

Welcome to Rememberescence

+

Discover and share your memories with others.

+ Get Started +
+ +
+

About Rememberescence

+

Rememberescence is a platform that allows you to discover and share your memories with others. By answering a few simple questions, we can generate a unique story that captures your childhood memories. Sign up now to get started!

+ Sign Up +
+ + \ No newline at end of file diff --git a/templates/login copy.html b/templates/login copy.html new file mode 100644 index 00000000..dd88f1dc --- /dev/null +++ b/templates/login copy.html @@ -0,0 +1,292 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Flask Cohere App + + + + + + + + +
+

Welcome to Rememberescence

+

Discover and share your memories with others.

+ Get Started +
+ +
+

About Rememberescence

+

Rememberescence is a platform that allows you to discover and share your memories with others. By answering a few simple questions, we can generate a unique story that captures your childhood memories. Sign up now to get started!

+ Sign Up +
+ + \ No newline at end of file diff --git a/templates/login.html b/templates/login.html new file mode 100644 index 00000000..8dce671a --- /dev/null +++ b/templates/login.html @@ -0,0 +1,162 @@ + + + + + + + Flask Cohere App + + + + + + + + +
+

Welcome to Rememberescence

+

Discover and share your memories with others.

+ Get Started +
+ + +
+

+

About Rememberescence

+

Rememberescence is a platform that allows you to discover and share your memories with others. By answering a few simple questions, we can generate a unique story that captures your childhood memories. Sign up now to get started!

+ Sign Up +
+ + + \ No newline at end of file diff --git a/templates/logo.png b/templates/logo.png new file mode 100644 index 00000000..10dac816 Binary files /dev/null and b/templates/logo.png differ diff --git a/templates/output.css b/templates/output.css new file mode 100644 index 00000000..bc406f35 --- /dev/null +++ b/templates/output.css @@ -0,0 +1,717 @@ +/* +! tailwindcss v3.4.1 | MIT License | https://tailwindcss.com +*/ + +/* +1. Prevent padding and border from affecting element width. (https://github.com/mozdevs/cssremedy/issues/4) +2. Allow adding a border to an element by just adding a border-width. (https://github.com/tailwindcss/tailwindcss/pull/116) +*/ + +*, +::before, +::after { + box-sizing: border-box; + /* 1 */ + border-width: 0; + /* 2 */ + border-style: solid; + /* 2 */ + border-color: #e5e7eb; + /* 2 */ +} + +::before, +::after { + --tw-content: ''; +} + +/* +1. Use a consistent sensible line-height in all browsers. +2. Prevent adjustments of font size after orientation changes in iOS. +3. Use a more readable tab size. +4. Use the user's configured `sans` font-family by default. +5. Use the user's configured `sans` font-feature-settings by default. +6. Use the user's configured `sans` font-variation-settings by default. +7. Disable tap highlights on iOS +*/ + +html, +:host { + line-height: 1.5; + /* 1 */ + -webkit-text-size-adjust: 100%; + /* 2 */ + -moz-tab-size: 4; + /* 3 */ + -o-tab-size: 4; + tab-size: 4; + /* 3 */ + font-family: ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; + /* 4 */ + font-feature-settings: normal; + /* 5 */ + font-variation-settings: normal; + /* 6 */ + -webkit-tap-highlight-color: transparent; + /* 7 */ +} + +/* +1. Remove the margin in all browsers. +2. Inherit line-height from `html` so users can set them as a class directly on the `html` element. +*/ + +body { + margin: 0; + /* 1 */ + line-height: inherit; + /* 2 */ +} + +/* +1. Add the correct height in Firefox. +2. Correct the inheritance of border color in Firefox. (https://bugzilla.mozilla.org/show_bug.cgi?id=190655) +3. Ensure horizontal rules are visible by default. +*/ + +hr { + height: 0; + /* 1 */ + color: inherit; + /* 2 */ + border-top-width: 1px; + /* 3 */ +} + +/* +Add the correct text decoration in Chrome, Edge, and Safari. +*/ + +abbr:where([title]) { + -webkit-text-decoration: underline dotted; + text-decoration: underline dotted; +} + +/* +Remove the default font size and weight for headings. +*/ + +h1, +h2, +h3, +h4, +h5, +h6 { + font-size: inherit; + font-weight: inherit; +} + +/* +Reset links to optimize for opt-in styling instead of opt-out. +*/ + +a { + color: inherit; + text-decoration: inherit; +} + +/* +Add the correct font weight in Edge and Safari. +*/ + +b, +strong { + font-weight: bolder; +} + +/* +1. Use the user's configured `mono` font-family by default. +2. Use the user's configured `mono` font-feature-settings by default. +3. Use the user's configured `mono` font-variation-settings by default. +4. Correct the odd `em` font sizing in all browsers. +*/ + +code, +kbd, +samp, +pre { + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; + /* 1 */ + font-feature-settings: normal; + /* 2 */ + font-variation-settings: normal; + /* 3 */ + font-size: 1em; + /* 4 */ +} + +/* +Add the correct font size in all browsers. +*/ + +small { + font-size: 80%; +} + +/* +Prevent `sub` and `sup` elements from affecting the line height in all browsers. +*/ + +sub, +sup { + font-size: 75%; + line-height: 0; + position: relative; + vertical-align: baseline; +} + +sub { + bottom: -0.25em; +} + +sup { + top: -0.5em; +} + +/* +1. Remove text indentation from table contents in Chrome and Safari. (https://bugs.chromium.org/p/chromium/issues/detail?id=999088, https://bugs.webkit.org/show_bug.cgi?id=201297) +2. Correct table border color inheritance in all Chrome and Safari. (https://bugs.chromium.org/p/chromium/issues/detail?id=935729, https://bugs.webkit.org/show_bug.cgi?id=195016) +3. Remove gaps between table borders by default. +*/ + +table { + text-indent: 0; + /* 1 */ + border-color: inherit; + /* 2 */ + border-collapse: collapse; + /* 3 */ +} + +/* +1. Change the font styles in all browsers. +2. Remove the margin in Firefox and Safari. +3. Remove default padding in all browsers. +*/ + +button, +input, +optgroup, +select, +textarea { + font-family: inherit; + /* 1 */ + font-feature-settings: inherit; + /* 1 */ + font-variation-settings: inherit; + /* 1 */ + font-size: 100%; + /* 1 */ + font-weight: inherit; + /* 1 */ + line-height: inherit; + /* 1 */ + color: inherit; + /* 1 */ + margin: 0; + /* 2 */ + padding: 0; + /* 3 */ +} + +/* +Remove the inheritance of text transform in Edge and Firefox. +*/ + +button, +select { + text-transform: none; +} + +/* +1. Correct the inability to style clickable types in iOS and Safari. +2. Remove default button styles. +*/ + +button, +[type='button'], +[type='reset'], +[type='submit'] { + -webkit-appearance: button; + /* 1 */ + background-color: transparent; + /* 2 */ + background-image: none; + /* 2 */ +} + +/* +Use the modern Firefox focus style for all focusable elements. +*/ + +:-moz-focusring { + outline: auto; +} + +/* +Remove the additional `:invalid` styles in Firefox. (https://github.com/mozilla/gecko-dev/blob/2f9eacd9d3d995c937b4251a5557d95d494c9be1/layout/style/res/forms.css#L728-L737) +*/ + +:-moz-ui-invalid { + box-shadow: none; +} + +/* +Add the correct vertical alignment in Chrome and Firefox. +*/ + +progress { + vertical-align: baseline; +} + +/* +Correct the cursor style of increment and decrement buttons in Safari. +*/ + +::-webkit-inner-spin-button, +::-webkit-outer-spin-button { + height: auto; +} + +/* +1. Correct the odd appearance in Chrome and Safari. +2. Correct the outline style in Safari. +*/ + +[type='search'] { + -webkit-appearance: textfield; + /* 1 */ + outline-offset: -2px; + /* 2 */ +} + +/* +Remove the inner padding in Chrome and Safari on macOS. +*/ + +::-webkit-search-decoration { + -webkit-appearance: none; +} + +/* +1. Correct the inability to style clickable types in iOS and Safari. +2. Change font properties to `inherit` in Safari. +*/ + +::-webkit-file-upload-button { + -webkit-appearance: button; + /* 1 */ + font: inherit; + /* 2 */ +} + +/* +Add the correct display in Chrome and Safari. +*/ + +summary { + display: list-item; +} + +/* +Removes the default spacing and border for appropriate elements. +*/ + +blockquote, +dl, +dd, +h1, +h2, +h3, +h4, +h5, +h6, +hr, +figure, +p, +pre { + margin: 0; +} + +fieldset { + margin: 0; + padding: 0; +} + +legend { + padding: 0; +} + +ol, +ul, +menu { + list-style: none; + margin: 0; + padding: 0; +} + +/* +Reset default styling for dialogs. +*/ + +dialog { + padding: 0; +} + +/* +Prevent resizing textareas horizontally by default. +*/ + +textarea { + resize: vertical; +} + +/* +1. Reset the default placeholder opacity in Firefox. (https://github.com/tailwindlabs/tailwindcss/issues/3300) +2. Set the default placeholder color to the user's configured gray 400 color. +*/ + +input::-moz-placeholder, textarea::-moz-placeholder { + opacity: 1; + /* 1 */ + color: #9ca3af; + /* 2 */ +} + +input::placeholder, +textarea::placeholder { + opacity: 1; + /* 1 */ + color: #9ca3af; + /* 2 */ +} + +/* +Set the default cursor for buttons. +*/ + +button, +[role="button"] { + cursor: pointer; +} + +/* +Make sure disabled buttons don't get the pointer cursor. +*/ + +:disabled { + cursor: default; +} + +/* +1. Make replaced elements `display: block` by default. (https://github.com/mozdevs/cssremedy/issues/14) +2. Add `vertical-align: middle` to align replaced elements more sensibly by default. (https://github.com/jensimmons/cssremedy/issues/14#issuecomment-634934210) + This can trigger a poorly considered lint error in some tools but is included by design. +*/ + +img, +svg, +video, +canvas, +audio, +iframe, +embed, +object { + display: block; + /* 1 */ + vertical-align: middle; + /* 2 */ +} + +/* +Constrain images and videos to the parent width and preserve their intrinsic aspect ratio. (https://github.com/mozdevs/cssremedy/issues/14) +*/ + +img, +video { + max-width: 100%; + height: auto; +} + +/* Make elements with the HTML hidden attribute stay hidden by default */ + +[hidden] { + display: none; +} + +*, ::before, ::after { + --tw-border-spacing-x: 0; + --tw-border-spacing-y: 0; + --tw-translate-x: 0; + --tw-translate-y: 0; + --tw-rotate: 0; + --tw-skew-x: 0; + --tw-skew-y: 0; + --tw-scale-x: 1; + --tw-scale-y: 1; + --tw-pan-x: ; + --tw-pan-y: ; + --tw-pinch-zoom: ; + --tw-scroll-snap-strictness: proximity; + --tw-gradient-from-position: ; + --tw-gradient-via-position: ; + --tw-gradient-to-position: ; + --tw-ordinal: ; + --tw-slashed-zero: ; + --tw-numeric-figure: ; + --tw-numeric-spacing: ; + --tw-numeric-fraction: ; + --tw-ring-inset: ; + --tw-ring-offset-width: 0px; + --tw-ring-offset-color: #fff; + --tw-ring-color: rgb(59 130 246 / 0.5); + --tw-ring-offset-shadow: 0 0 #0000; + --tw-ring-shadow: 0 0 #0000; + --tw-shadow: 0 0 #0000; + --tw-shadow-colored: 0 0 #0000; + --tw-blur: ; + --tw-brightness: ; + --tw-contrast: ; + --tw-grayscale: ; + --tw-hue-rotate: ; + --tw-invert: ; + --tw-saturate: ; + --tw-sepia: ; + --tw-drop-shadow: ; + --tw-backdrop-blur: ; + --tw-backdrop-brightness: ; + --tw-backdrop-contrast: ; + --tw-backdrop-grayscale: ; + --tw-backdrop-hue-rotate: ; + --tw-backdrop-invert: ; + --tw-backdrop-opacity: ; + --tw-backdrop-saturate: ; + --tw-backdrop-sepia: ; +} + +::backdrop { + --tw-border-spacing-x: 0; + --tw-border-spacing-y: 0; + --tw-translate-x: 0; + --tw-translate-y: 0; + --tw-rotate: 0; + --tw-skew-x: 0; + --tw-skew-y: 0; + --tw-scale-x: 1; + --tw-scale-y: 1; + --tw-pan-x: ; + --tw-pan-y: ; + --tw-pinch-zoom: ; + --tw-scroll-snap-strictness: proximity; + --tw-gradient-from-position: ; + --tw-gradient-via-position: ; + --tw-gradient-to-position: ; + --tw-ordinal: ; + --tw-slashed-zero: ; + --tw-numeric-figure: ; + --tw-numeric-spacing: ; + --tw-numeric-fraction: ; + --tw-ring-inset: ; + --tw-ring-offset-width: 0px; + --tw-ring-offset-color: #fff; + --tw-ring-color: rgb(59 130 246 / 0.5); + --tw-ring-offset-shadow: 0 0 #0000; + --tw-ring-shadow: 0 0 #0000; + --tw-shadow: 0 0 #0000; + --tw-shadow-colored: 0 0 #0000; + --tw-blur: ; + --tw-brightness: ; + --tw-contrast: ; + --tw-grayscale: ; + --tw-hue-rotate: ; + --tw-invert: ; + --tw-saturate: ; + --tw-sepia: ; + --tw-drop-shadow: ; + --tw-backdrop-blur: ; + --tw-backdrop-brightness: ; + --tw-backdrop-contrast: ; + --tw-backdrop-grayscale: ; + --tw-backdrop-hue-rotate: ; + --tw-backdrop-invert: ; + --tw-backdrop-opacity: ; + --tw-backdrop-saturate: ; + --tw-backdrop-sepia: ; +} + +.container { + width: 100%; +} + +@media (min-width: 640px) { + .container { + max-width: 640px; + } +} + +@media (min-width: 768px) { + .container { + max-width: 768px; + } +} + +@media (min-width: 1024px) { + .container { + max-width: 1024px; + } +} + +@media (min-width: 1280px) { + .container { + max-width: 1280px; + } +} + +@media (min-width: 1536px) { + .container { + max-width: 1536px; + } +} + +.mx-auto { + margin-left: auto; + margin-right: auto; +} + +.mb-4 { + margin-bottom: 1rem; +} + +.mt-4 { + margin-top: 1rem; +} + +.mt-8 { + margin-top: 2rem; +} + +.block { + display: block; +} + +.flex { + display: flex; +} + +.min-h-screen { + min-height: 100vh; +} + +.w-full { + width: 100%; +} + +.grow { + flex-grow: 1; +} + +.items-center { + align-items: center; +} + +.justify-center { + justify-content: center; +} + +.space-y-4 > :not([hidden]) ~ :not([hidden]) { + --tw-space-y-reverse: 0; + margin-top: calc(1rem * calc(1 - var(--tw-space-y-reverse))); + margin-bottom: calc(1rem * var(--tw-space-y-reverse)); +} + +.rounded { + border-radius: 0.25rem; +} + +.bg-blue-500 { + --tw-bg-opacity: 1; + background-color: rgb(59 130 246 / var(--tw-bg-opacity)); +} + +.bg-white { + --tw-bg-opacity: 1; + background-color: rgb(255 255 255 / var(--tw-bg-opacity)); +} + +.bg-gradient-to-r { + background-image: linear-gradient(to right, var(--tw-gradient-stops)); +} + +.from-pink-500 { + --tw-gradient-from: #ec4899 var(--tw-gradient-from-position); + --tw-gradient-to: rgb(236 72 153 / 0) var(--tw-gradient-to-position); + --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to); +} + +.via-purple-500 { + --tw-gradient-to: rgb(168 85 247 / 0) var(--tw-gradient-to-position); + --tw-gradient-stops: var(--tw-gradient-from), #a855f7 var(--tw-gradient-via-position), var(--tw-gradient-to); +} + +.to-indigo-500 { + --tw-gradient-to: #6366f1 var(--tw-gradient-to-position); +} + +.px-4 { + padding-left: 1rem; + padding-right: 1rem; +} + +.py-2 { + padding-top: 0.5rem; + padding-bottom: 0.5rem; +} + +.text-2xl { + font-size: 1.5rem; + line-height: 2rem; +} + +.text-3xl { + font-size: 1.875rem; + line-height: 2.25rem; +} + +.text-4xl { + font-size: 2.25rem; + line-height: 2.5rem; +} + +.font-bold { + font-weight: 700; +} + +.text-black { + --tw-text-opacity: 1; + color: rgb(0 0 0 / var(--tw-text-opacity)); +} + +.text-blue-300 { + --tw-text-opacity: 1; + color: rgb(147 197 253 / var(--tw-text-opacity)); +} + +.text-white { + --tw-text-opacity: 1; + color: rgb(255 255 255 / var(--tw-text-opacity)); +} + +.hover\:bg-blue-400:hover { + --tw-bg-opacity: 1; + background-color: rgb(96 165 250 / var(--tw-bg-opacity)); +} + +.hover\:text-blue-100:hover { + --tw-text-opacity: 1; + color: rgb(219 234 254 / var(--tw-text-opacity)); +} \ No newline at end of file diff --git a/templates/testerLog.html b/templates/testerLog.html new file mode 100644 index 00000000..dbafd630 --- /dev/null +++ b/templates/testerLog.html @@ -0,0 +1,76 @@ + + + + + + Flask Cohere App - Login + + + + + +
+

Login

+
+ + + + + +
+ +

Don't have an account? Sign up

+
+ + + + \ No newline at end of file diff --git a/webapp.py b/webapp.py index 3c71ec40..549fc766 100644 --- a/webapp.py +++ b/webapp.py @@ -1,6 +1,15 @@ from flask import Flask, render_template, request, jsonify import cohere import os +import spotipy +from spotipy.oauth2 import SpotifyClientCredentials +import json +from os import environ as env +from urllib.parse import quote_plus, urlencode + +from authlib.integrations.flask_client import OAuth +from dotenv import find_dotenv, load_dotenv +from flask import Flask, redirect, render_template, session, url_for app = Flask(__name__) @@ -9,9 +18,64 @@ print(f"API Key: {cohere_api_key}") co = cohere.Client(cohere_api_key) +# Initialize Spotipy client +client_credentials_manager = SpotifyClientCredentials(client_id='55be60be52eb4b658763757653391641', + client_secret='8ce1ad629229430ab7d2d7278ebdb9bd') +sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager) + + +ENV_FILE = find_dotenv() +if ENV_FILE: + load_dotenv(ENV_FILE) + +app = Flask(__name__) +app.secret_key = env.get("APP_SECRET_KEY") + +oauth = OAuth(app) + +oauth.register( + "auth0", + client_id=env.get("AUTH0_CLIENT_ID"), + client_secret=env.get("AUTH0_CLIENT_SECRET"), + client_kwargs={ + "scope": "openid profile email", + }, + server_metadata_url=f'https://{env.get("AUTH0_DOMAIN")}/.well-known/openid-configuration' +) + +@app.route("/login") +def login(): + return oauth.auth0.authorize_redirect( + redirect_uri=url_for("callback", _external=True) + ) + +@app.route("/callback", methods=["GET", "POST"]) +def callback(): + token = oauth.auth0.authorize_access_token() + session["user"] = token + return redirect("/") + +@app.route("/logout") +def logout(): + session.clear() + return redirect( + "https://" + env.get("AUTH0_DOMAIN") + + "/v2/logout?" + + urlencode( + { + "returnTo": url_for("home", _external=True), + "client_id": env.get("AUTH0_CLIENT_ID"), + }, + quote_via=quote_plus, + ) + ) + @app.route("/") def home(): - return render_template("home.html") + if session.get('user'): + return render_template("home.html", session=session.get('user'), pretty=json.dumps(session.get('user'), indent=4)) + else: + return render_template("login.html") @app.route("/generate", methods=["POST"]) def generate_content(): @@ -20,19 +84,40 @@ def generate_content(): # Get user inputs from the JSON data user_interests = data.get("interests") user_location = data.get("location") - #user_interests = "legos, beyblade, star wars" - #user_location = "Canada" + user_age = data.get("age") + user_friends = data.get("friends") + + print(user_friends) # Use Cohere to generate content based on user inputs + parameters = f'Friends:{user_friends}, Interests: {user_interests}, Location: {user_location}' generation_response = co.generate( - prompt=f'Please write a short story for nostalgia based on these parameters: {user_interests} {user_location}' + prompt=f'Please write a nostalgic short story about the users childhood based on these parameters: {parameters} addressing the user as "you"' ) + # User's age & location used for childhood Spotify + + # Access the 'text' attribute to get the generated content generated_content = generation_response[0].text # Return the JSON response with the generated content return jsonify({"generated_content": generated_content}) +@app.route("/refine", methods=["POST"]) +def refine_story(): + data = request.json + + # Get the previously generated content from the JSON data + generated_content = data.get("generated_content") + + # Use Cohere to generate a new question based on the generated content + new_prompt = f'Based on this short story, come up with 1 short question to ask the user to get another parameter to put back into cohere: {generated_content}' + new_question_response = co.generate(prompt=new_prompt) + new_question = new_question_response[0].text + + # Return the new question as JSON response + return jsonify({"new_question": new_question}) + if __name__ == '__main__': app.run(debug=True) \ No newline at end of file