diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 629463ac..cc8bf5dc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -49,32 +49,32 @@ jobs: - name: Set up uv uses: astral-sh/setup-uv@v6 with: - cache-dependency-glob: | - setup.py cache-suffix: ${{ matrix.python-version }} enable-cache: true version: "latest" - + - name: Setup env + run: uv sync - name: Invoke tests run: | - + # Propagate build matrix information. ./devtools/setup_ci.sh # Bootstrap environment. source bootstrap.sh - + + # Run linter. + uv run ruff check . + + # Run type testing + uv run mypy + # Report about the test matrix slot. echo "Invoking tests with CrateDB ${CRATEDB_VERSION}" - - # Run linter. - poe lint + uv run coverage run -m pytest - # Run tests. - coverage run bin/test -vvv - # Set the stage for uploading the coverage report. - coverage xml + uv run coverage xml # https://github.com/codecov/codecov-action - name: Upload coverage results to Codecov diff --git a/DEVELOP.rst b/DEVELOP.rst index 2f39ede0..1e2fb962 100644 --- a/DEVELOP.rst +++ b/DEVELOP.rst @@ -25,47 +25,41 @@ Running tests ============= All tests will be invoked using the Python interpreter that was used when -creating the Python virtualenv. The test runner is `zope.testrunner`_. +creating the Python virtualenv. The test runner is `pytest`. -Some examples are outlined below. In order to learn about more details, -see, for example, `useful command-line options for zope-testrunner`_. Run all tests:: - - poe test + uv run pytest Run specific tests:: # Select modules. - bin/test -t test_cursor - bin/test -t client - bin/test -t testing + uv run pytest -k test_cursor.py + uv run pytest -k client # Select doctests. - bin/test -t http.rst + uv run pytest --doctest-glob="*.rst" + uv run pytest --doctest-glob="connect.rst" Ignore specific test directories:: - - bin/test --ignore_dir=testing + uv run pytest -k 'not testing' The ``LayerTest`` test cases have quite some overhead. Omitting them will save a few cycles (~70 seconds runtime):: - bin/test -t '!LayerTest' + uv run pytest -k 'not testing' Invoke all tests without integration tests (~10 seconds runtime):: - bin/test --layer '!crate.testing.layer.crate' --test '!LayerTest' + uv run pytest -k 'not testing' Yet ~60 test cases, but only ~1 second runtime:: - bin/test --layer '!crate.testing.layer.crate' --test '!LayerTest' \ - -t '!test_client_threaded' -t '!test_no_retry_on_read_timeout' \ - -t '!test_wait_for_http' -t '!test_table_clustered_by' + uv run pytest -k 'not testing and not test_wait_for_http and not test_client_multithreaded and not test_keep_alive and not test_no_retry_on_read_timeout' To inspect the whole list of test cases, run:: - bin/test --list-tests + uv run pytest --collect-only The CI setup on GitHub Actions (GHA) provides a full test matrix covering relevant Python versions. You can invoke the software tests against a specific @@ -84,15 +78,15 @@ Formatting and linting code To use Ruff for code formatting, according to the standards configured in ``pyproject.toml``, use:: - poe format + uv run ruff format To lint the code base using Ruff and mypy, use:: - poe lint + uv run ruff check && uv run mypy Linting and software testing, all together now:: - poe check + uv run pytest && uv run ruff check && uv run mypy Renew certificates @@ -169,8 +163,6 @@ nothing special you need to do to get the live docs to update. .. _Sphinx: http://sphinx-doc.org/ .. _tests/assets/pki/*.pem: https://github.com/crate/crate-python/tree/main/tests/assets/pki .. _twine: https://pypi.python.org/pypi/twine -.. _useful command-line options for zope-testrunner: https://pypi.org/project/zope.testrunner/#some-useful-command-line-options-to-get-you-started .. _uv: https://docs.astral.sh/uv/ .. _UV_PYTHON: https://docs.astral.sh/uv/configuration/environment/#uv_python .. _versions hosted on ReadTheDocs: https://readthedocs.org/projects/crate-python/versions/ -.. _zope.testrunner: https://pypi.org/project/zope.testrunner/ diff --git a/bin/test b/bin/test deleted file mode 100755 index 749ec64b..00000000 --- a/bin/test +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env python -import os -import sys -import zope.testrunner - -join = os.path.join -base = os.path.dirname(os.path.abspath(os.path.realpath(__file__))) -base = os.path.dirname(base) - - -sys.argv[0] = os.path.abspath(sys.argv[0]) - -if __name__ == '__main__': - zope.testrunner.run([ - '-vvvv', '--auto-color', - '--path', join(base, 'tests'), - ]) diff --git a/bootstrap.sh b/bootstrap.sh index 93795ad7..8f30f932 100644 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -103,7 +103,6 @@ function main() { ensure_virtualenv activate_virtualenv before_setup - setup_package run_buildout deactivate_uv finalize diff --git a/pyproject.toml b/pyproject.toml index 08b0d321..19a92729 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,55 @@ +[build-system] +requires = ["hatchling >= 1.26"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/crate"] + +[tool.hatch.version] +path = "src/crate/client/__init__.py" + +[project] +name = "crate-python" +dynamic = ["version"] +description = "CrateDB Python Client" +authors = [{ name = "Crate.io", email = "office@crate.io" }] +requires-python = ">=3.10" +readme = "README.rst" +license = { file = "LICENSE"} +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Database", +] +dependencies = [ + "orjson>=3.11.3", + "urllib3>=2.5.0", +] + +[dependency-groups] +dev = [ + "certifi>=2025.10.5", + "coverage>=7.11.0", + "mypy>=1.18.2", + "pytest>=8.4.2", + "pytz>=2025.2", + "ruff>=0.14.2", + "setuptools>=80.9.0", + "stopit>=1.1.2", + "verlib2>=0.3.1", +] + + [tool.mypy] mypy_path = "src" packages = [ @@ -18,65 +70,67 @@ non_interactive = true line-length = 80 extend-exclude = [ - "/example_*", + "/example_*", ] lint.select = [ - # Builtins - "A", - # Bugbear - "B", - # comprehensions - "C4", - # Pycodestyle - "E", - # eradicate - "ERA", - # Pyflakes - "F", - # isort - "I", - # pandas-vet - "PD", - # return - "RET", - # Bandit - "S", - # print - "T20", - "W", - # flake8-2020 - "YTT", + # Builtins + "A", + # Bugbear + "B", + # comprehensions + "C4", + # Pycodestyle + "E", + # eradicate + "ERA", + # Pyflakes + "F", + # isort + "I", + # pandas-vet + "PD", + # return + "RET", + # Bandit + "S", + # print + "T20", + "W", + # flake8-2020 + "YTT", ] lint.extend-ignore = [ - # Unnecessary variable assignment before `return` statement - "RET504", - # Unnecessary `elif` after `return` statement - "RET505", + # Unnecessary variable assignment before `return` statement + "RET504", + # Unnecessary `elif` after `return` statement + "RET505", ] lint.per-file-ignores."example_*" = [ - "ERA001", # Found commented-out code - "T201", # Allow `print` + "ERA001", # Found commented-out code + "T201", # Allow `print` ] lint.per-file-ignores."devtools/*" = [ - "T201", # Allow `print` + "T201", # Allow `print` ] lint.per-file-ignores."examples/*" = [ - "ERA001", # Found commented-out code - "T201", # Allow `print` + "ERA001", # Found commented-out code + "T201", # Allow `print` ] lint.per-file-ignores."tests/*" = [ - "S106", # Possible hardcoded password assigned to argument: "password" - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S101", # Asserts. + "S105", # Possible hardcoded password assigned to: "password" + "S106", # Possible hardcoded password assigned to argument: "password" + "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] lint.per-file-ignores."src/crate/client/{connection.py,http.py}" = [ - "A004", # Import `ConnectionError` is shadowing a Python builtin - "A005", # Import `ConnectionError` is shadowing a Python builtin + "A004", # Import `ConnectionError` is shadowing a Python builtin + "A005", # Import `ConnectionError` is shadowing a Python builtin ] lint.per-file-ignores."tests/client/test_http.py" = [ - "A004", # Import `ConnectionError` is shadowing a Python builtin + "A004", # Import `ConnectionError` is shadowing a Python builtin ] diff --git a/requirements.txt b/requirements.txt index 8935d351..885e5c9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ setuptools<80.3 urllib3<2.4 -zc.buildout==3.3 -zope.interface==6.4.post2 +zc.buildout==3.3 \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 85a28d82..00000000 --- a/setup.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8; -*- -# -# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor -# license agreements. See the NOTICE file distributed with this work for -# additional information regarding copyright ownership. Crate licenses -# this file to you under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. You may -# obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# However, if you have executed another commercial license agreement -# with Crate these terms will supersede the license and you may use the -# software solely pursuant to the terms of the relevant commercial agreement. - -import os -import re - -from setuptools import find_namespace_packages, setup - - -def read(path): - with open(os.path.join(os.path.dirname(__file__), path)) as f: - return f.read() - - -long_description = read("README.rst") -versionf_content = read("src/crate/client/__init__.py") -version_rex = r'^__version__ = [\'"]([^\'"]*)[\'"]$' -m = re.search(version_rex, versionf_content, re.M) -if m: - version = m.group(1) -else: - raise RuntimeError("Unable to find version string") - -setup( - name="crate", - version=version, - url="https://github.com/crate/crate-python", - author="Crate.io", - author_email="office@crate.io", - description="CrateDB Python Client", - long_description=long_description, - long_description_content_type="text/x-rst", - platforms=["any"], - license="Apache License 2.0", - keywords="cratedb db api dbapi database sql http rdbms olap", - packages=find_namespace_packages("src"), - package_dir={"": "src"}, - install_requires=[ - "orjson<4", - "urllib3", - "verlib2", - ], - extras_require={ - "doc": [ - "crate-docs-theme>=0.26.5", - "sphinx>=3.5,<9", - ], - "test": [ - 'backports.zoneinfo<1; python_version<"3.9"', - "certifi", - "createcoverage>=1,<2", - "mypy<1.18", - "poethepoet<1", - "ruff<0.14", - "stopit>=1.1.2,<2", - "pytz", - "zc.customdoctests>=1.0.1,<2", - "zope.testing>=4,<6", - "zope.testrunner>=5,<8", - ], - }, - python_requires=">=3.6", - package_data={"": ["*.txt"]}, - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Database", - ], -) diff --git a/src/crate/client/connection.py b/src/crate/client/connection.py index b0a2a15b..0638a018 100644 --- a/src/crate/client/connection.py +++ b/src/crate/client/connection.py @@ -208,7 +208,7 @@ def _lowest_server_version(self): return lowest or Version("0.0.0") def __repr__(self): - return "".format(repr(self.client)) + return f"<{self.__class__.__qualname__} {self.client!r}>" def __enter__(self): return self diff --git a/src/crate/client/cursor.py b/src/crate/client/cursor.py index 2a82d502..587f1491 100644 --- a/src/crate/client/cursor.py +++ b/src/crate/client/cursor.py @@ -236,7 +236,10 @@ def _convert_rows(self): # Process result rows with conversion. for row in self._result["rows"]: - yield [convert(value) for convert, value in zip(converters, row)] + yield [ + convert(value) + for convert, value in zip(converters, row, strict=False) + ] @property def time_zone(self): diff --git a/src/crate/client/exceptions.py b/src/crate/client/exceptions.py index 3833eecc..5e99126b 100644 --- a/src/crate/client/exceptions.py +++ b/src/crate/client/exceptions.py @@ -86,7 +86,7 @@ def __init__(self, table, digest): self.digest = digest def __str__(self): - return "{table}/{digest}".format(table=self.table, digest=self.digest) + return f"{self.__class__.__qualname__}('{self.table}/{self.digest})'" class DigestNotFoundException(BlobException): diff --git a/src/crate/client/http.py b/src/crate/client/http.py index a1251d34..b1d51f02 100644 --- a/src/crate/client/http.py +++ b/src/crate/client/http.py @@ -326,7 +326,10 @@ def _pool_kw_args( return kw -def _remove_certs_for_non_https(server, kwargs): +def _remove_certs_for_non_https(server: str, kwargs: dict) -> dict: + """ + Removes certificates for http requests. + """ if server.lower().startswith("https"): return kwargs used_ssl_args = SSL_ONLY_ARGS & set(kwargs.keys()) diff --git a/tests/client/test_connection.py b/tests/client/test_connection.py index 0cc5e1ef..90b121f2 100644 --- a/tests/client/test_connection.py +++ b/tests/client/test_connection.py @@ -1,107 +1,169 @@ import datetime -from unittest import TestCase +from unittest.mock import MagicMock, patch +import pytest from urllib3 import Timeout from crate.client import connect from crate.client.connection import Connection +from crate.client.exceptions import ProgrammingError from crate.client.http import Client from .settings import crate_host -class ConnectionTest(TestCase): - def test_connection_mock(self): - """ - For testing purposes it is often useful to replace the client used for - communication with the CrateDB server with a stub or mock. - - This can be done by passing an object of the Client class when calling - the `connect` method. - """ - - class MyConnectionClient: - active_servers = ["localhost:4200"] - - def __init__(self): - pass - - def server_infos(self, server): - return ("localhost:4200", "my server", "0.42.0") - - connection = connect([crate_host], client=MyConnectionClient()) - self.assertIsInstance(connection, Connection) - self.assertEqual( - connection.client.server_infos("foo"), - ("localhost:4200", "my server", "0.42.0"), - ) - - def test_lowest_server_version(self): - infos = [ - (None, None, "0.42.3"), - (None, None, "0.41.8"), - (None, None, "not a version"), - ] - - client = Client(servers="localhost:4200 localhost:4201 localhost:4202") - client.server_infos = lambda server: infos.pop() - connection = connect(client=client) - self.assertEqual((0, 41, 8), connection.lowest_server_version.version) - connection.close() - - def test_invalid_server_version(self): - client = Client(servers="localhost:4200") - client.server_infos = lambda server: (None, None, "No version") - connection = connect(client=client) - self.assertEqual((0, 0, 0), connection.lowest_server_version.version) - connection.close() - - def test_context_manager(self): - with connect("localhost:4200") as conn: +def test_lowest_server_version(): + """ + Verify the lowest server version is correctly set. + """ + servers = "localhost:4200 localhost:4201 localhost:4202 localhost:4207" + infos = [ + (None, None, "1.0.3"), + (None, None, "5.5.2"), + (None, None, "6.0.0"), + (None, None, "not a version"), + ] + + client = Client(servers=servers) + client.server_infos = lambda server: infos.pop() + connection = connect(client=client) + assert (1, 0, 3) == connection.lowest_server_version.version + + +def test_connection_closes_access(): + """ + Verify that a connection closes on exit and that it also closes + the client. + """ + with patch( + "crate.client.connection.Client", spec=Client, return_value=MagicMock() + ) as client: + conn = connect() + conn.close() + + assert conn._closed + client.assert_called_once() + + # Should raise an exception if + # we try to access a cursor now. + with pytest.raises(ProgrammingError): + conn.cursor() + + with pytest.raises(ProgrammingError): + conn.commit() + + +def test_connection_closes_context_manager(): + """Verify that the context manager of the client closes the connection""" + with patch.object(connect, "close", autospec=True) as close_fn: + with connect(): pass - self.assertEqual(conn._closed, True) - - def test_with_timezone(self): - """ - The cursor can return timezone-aware `datetime` objects when requested. - - When switching the time zone at runtime on the connection object, only - new cursor objects will inherit the new time zone. - """ - - tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") - connection = connect("localhost:4200", time_zone=tz_mst) - cursor = connection.cursor() - self.assertEqual(cursor.time_zone.tzname(None), "MST") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200) - ) - - connection.time_zone = datetime.timezone.utc - cursor = connection.cursor() - self.assertEqual(cursor.time_zone.tzname(None), "UTC") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(0) - ) - - def test_timeout_float(self): - """ - Verify setting the timeout value as a scalar (float) works. - """ - with connect("localhost:4200", timeout=2.42) as conn: - self.assertEqual(conn.client._pool_kw["timeout"], 2.42) - - def test_timeout_string(self): - """ - Verify setting the timeout value as a scalar (string) works. - """ - with connect("localhost:4200", timeout="2.42") as conn: - self.assertEqual(conn.client._pool_kw["timeout"], 2.42) - - def test_timeout_object(self): - """ - Verify setting the timeout value as a Timeout object works. - """ - timeout = Timeout(connect=2.42, read=0.01) - with connect("localhost:4200", timeout=timeout) as conn: - self.assertEqual(conn.client._pool_kw["timeout"], timeout) + close_fn.assert_called_once() + + +def test_invalid_server_version(): + """ + Verify that when no correct version is set, + the default (0, 0, 0) is returned. + """ + client = Client(servers="localhost:4200") + client.server_infos = lambda server: (None, None, "No version") + connection = connect(client=client) + assert (0, 0, 0) == connection.lowest_server_version.version + + +def test_context_manager(): + """ + Verify the context manager implementation of `Connection`. + """ + close_method = "crate.client.http.Client.close" + with patch(close_method, return_value=MagicMock()) as close_func: + with connect("localhost:4200") as conn: + assert not conn._closed + + assert conn._closed + # Checks that the close method of the client + # is called when the connection is closed. + close_func.assert_called_once() + + +def test_connection_mock(): + """ + Verify that a custom client can be passed. + + + For testing purposes, it is often useful to replace the client used for + communication with the CrateDB server with a stub or mock. + + This can be done by passing an object of the Client class when calling + the `connect` method. + """ + + mock = MagicMock(spec=Client) + mock.server_infos.return_value = "localhost:4200", "my server", "0.42.0" + connection = connect(crate_host, client=mock) + + assert isinstance(connection, Connection) + assert connection.client.server_infos("foo") == ( + "localhost:4200", + "my server", + "0.42.0", + ) + + +def test_default_repr(): + """ + Verify default repr dunder method. + """ + conn = connect() + assert repr(conn) == ">" + + +def test_with_timezone(): + """ + Verify the logic of passing timezone objects to the client. + + The cursor can return timezone-aware `datetime` objects when requested. + + When switching the time zone at runtime on the connection object, only + new cursor objects will inherit the new time zone. + + These tests are complementary to timezone `test_cursor` + """ + + tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") + connection = connect("localhost:4200", time_zone=tz_mst) + cursor = connection.cursor() + + assert cursor.time_zone.tzname(None) == "MST" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(seconds=25200) + + connection.time_zone = datetime.timezone.utc + cursor = connection.cursor() + assert cursor.time_zone.tzname(None) == "UTC" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(0) + + +def test_timeout_float(): + """ + Verify setting the timeout value as a scalar (float) works. + """ + with connect("localhost:4200", timeout=2.42) as conn: + assert conn.client._pool_kw["timeout"] == 2.42 + + +def test_timeout_string(): + """ + Verify setting the timeout value as a scalar (string) works. + """ + with connect("localhost:4200", timeout="2.42") as conn: + assert conn.client._pool_kw["timeout"] == 2.42 + + +def test_timeout_object(): + """ + Verify setting the timeout value as a Timeout object works. + """ + timeout = Timeout(connect=2.42, read=0.01) + with connect("localhost:4200", timeout=timeout) as conn: + assert conn.client._pool_kw["timeout"] == timeout diff --git a/tests/client/test_cursor.py b/tests/client/test_cursor.py index 7f1a9f2f..411e29b1 100644 --- a/tests/client/test_cursor.py +++ b/tests/client/test_cursor.py @@ -21,8 +21,11 @@ import datetime from ipaddress import IPv4Address -from unittest import TestCase -from unittest.mock import MagicMock +from unittest import mock + +import pytest + +from crate.client.exceptions import ProgrammingError try: import zoneinfo @@ -33,416 +36,444 @@ from crate.client import connect from crate.client.converter import DataType, DefaultTypeConverter -from crate.client.http import Client -from crate.testing.util import ClientMocked - - -class CursorTest(TestCase): - @staticmethod - def get_mocked_connection(): - client = MagicMock(spec=Client) - return connect(client=client) - - def test_create_with_timezone_as_datetime_object(self): - """ - The cursor can return timezone-aware `datetime` objects when requested. - Switching the time zone at runtime on the cursor object is possible. - Here: Use a `datetime.timezone` instance. - """ - - connection = self.get_mocked_connection() - - tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") - cursor = connection.cursor(time_zone=tz_mst) - - self.assertEqual(cursor.time_zone.tzname(None), "MST") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=25200) - ) - - cursor.time_zone = datetime.timezone.utc - self.assertEqual(cursor.time_zone.tzname(None), "UTC") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(0) - ) - - def test_create_with_timezone_as_pytz_object(self): - """ - The cursor can return timezone-aware `datetime` objects when requested. - Here: Use a `pytz.timezone` instance. - """ - connection = self.get_mocked_connection() - cursor = connection.cursor(time_zone=pytz.timezone("Australia/Sydney")) - self.assertEqual(cursor.time_zone.tzname(None), "Australia/Sydney") - - # Apparently, when using `pytz`, the timezone object does not return - # an offset. Nevertheless, it works, as demonstrated per doctest in - # `cursor.txt`. - self.assertEqual(cursor.time_zone.utcoffset(None), None) - - def test_create_with_timezone_as_zoneinfo_object(self): - """ - The cursor can return timezone-aware `datetime` objects when requested. - Here: Use a `zoneinfo.ZoneInfo` instance. - """ - connection = self.get_mocked_connection() - cursor = connection.cursor( - time_zone=zoneinfo.ZoneInfo("Australia/Sydney") - ) - self.assertEqual(cursor.time_zone.key, "Australia/Sydney") - - def test_create_with_timezone_as_utc_offset_success(self): - """ - The cursor can return timezone-aware `datetime` objects when requested. - Here: Use a UTC offset in string format. - """ - connection = self.get_mocked_connection() - cursor = connection.cursor(time_zone="+0530") - self.assertEqual(cursor.time_zone.tzname(None), "+0530") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800) - ) - - connection = self.get_mocked_connection() - cursor = connection.cursor(time_zone="-1145") - self.assertEqual(cursor.time_zone.tzname(None), "-1145") - self.assertEqual( - cursor.time_zone.utcoffset(None), - datetime.timedelta(days=-1, seconds=44100), - ) - - def test_create_with_timezone_as_utc_offset_failure(self): - """ - Verify the cursor trips when trying to use invalid UTC offset strings. - """ - connection = self.get_mocked_connection() - with self.assertRaises(ValueError) as ex: - connection.cursor(time_zone="foobar") - self.assertEqual( - str(ex.exception), - "Time zone 'foobar' is given in invalid UTC offset format", - ) - - connection = self.get_mocked_connection() - with self.assertRaises(ValueError) as ex: - connection.cursor(time_zone="+abcd") - self.assertEqual( - str(ex.exception), - "Time zone '+abcd' is given in invalid UTC offset format: " - "invalid literal for int() with base 10: '+ab'", - ) - - def test_create_with_timezone_connection_cursor_precedence(self): - """ - Verify that the time zone specified on the cursor object instance - takes precedence over the one specified on the connection instance. - """ - client = MagicMock(spec=Client) - connection = connect( - client=client, time_zone=pytz.timezone("Australia/Sydney") - ) - cursor = connection.cursor(time_zone="+0530") - self.assertEqual(cursor.time_zone.tzname(None), "+0530") - self.assertEqual( - cursor.time_zone.utcoffset(None), datetime.timedelta(seconds=19800) - ) - - def test_execute_with_args(self): - client = MagicMock(spec=Client) - conn = connect(client=client) - c = conn.cursor() - statement = "select * from locations where position = ?" - c.execute(statement, 1) - client.sql.assert_called_once_with(statement, 1, None) - conn.close() - - def test_execute_with_bulk_args(self): - client = MagicMock(spec=Client) - conn = connect(client=client) - c = conn.cursor() - statement = "select * from locations where position = ?" - c.execute(statement, bulk_parameters=[[1]]) - client.sql.assert_called_once_with(statement, None, [[1]]) - conn.close() - - def test_execute_with_converter(self): - client = ClientMocked() - conn = connect(client=client) - - # Use the set of data type converters from `DefaultTypeConverter` - # and add another custom converter. - converter = DefaultTypeConverter( - { - DataType.BIT: lambda value: value is not None - and int(value[2:-1], 2) - or None - } - ) - - # Create a `Cursor` object with converter. - c = conn.cursor(converter=converter) - - # Make up a response using CrateDB data types `TEXT`, `IP`, - # `TIMESTAMP`, `BIT`. - conn.client.set_next_response( - { - "col_types": [4, 5, 11, 25], - "cols": ["name", "address", "timestamp", "bitmask"], - "rows": [ - ["foo", "10.10.10.1", 1658167836758, "B'0110'"], - [None, None, None, None], - ], - "rowcount": 1, - "duration": 123, - } - ) - - c.execute("") - result = c.fetchall() - self.assertEqual( - result, - [ - [ - "foo", - IPv4Address("10.10.10.1"), - datetime.datetime( - 2022, - 7, - 18, - 18, - 10, - 36, - 758000, - tzinfo=datetime.timezone.utc, - ), - 6, - ], - [None, None, None, None], - ], - ) - - conn.close() - - def test_execute_with_converter_and_invalid_data_type(self): - client = ClientMocked() - conn = connect(client=client) - converter = DefaultTypeConverter() - # Create a `Cursor` object with converter. - c = conn.cursor(converter=converter) - # Make up a response using CrateDB data types `TEXT`, `IP`, - # `TIMESTAMP`, `BIT`. - conn.client.set_next_response( - { - "col_types": [999], - "cols": ["foo"], - "rows": [ - ["n/a"], - ], - "rowcount": 1, - "duration": 123, - } +def test_cursor_fetch(mocked_connection): + """Verify fetchone/fetchmany behaviour""" + cursor = mocked_connection.cursor() + response = { + "col_types": [4, 5], + "cols": ["name", "address"], + "rows": [["foo", "10.10.10.1"], ["bar", "10.10.10.2"]], + "rowcount": 2, + "duration": 123, + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): + cursor.execute("") + assert cursor.fetchone() == ["foo", "10.10.10.1"] + assert cursor.fetchmany() == [ + ["bar", "10.10.10.2"], + ] + + +def test_cursor_executemany(mocked_connection): + """ + Verify executemany. + """ + response = { + "col_types": [], + "cols": [], + "duration": 123, + "results": [{"rowcount": 1, "rowcount:": 1}], + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): + cursor = mocked_connection.cursor() + result = cursor.executemany("some sql", ()) + + assert isinstance(result, list) + assert response["results"] == result + + +def test_create_with_timezone_as_datetime_object(mocked_connection): + """ + The cursor can return timezone-aware `datetime` objects when requested. + Switching the time zone at runtime on the cursor object is possible. + Here: Use a `datetime.timezone` instance. + """ + tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") + cursor = mocked_connection.cursor(time_zone=tz_mst) + + assert cursor.time_zone.tzname(None) == "MST" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(seconds=25200) + + cursor.time_zone = datetime.timezone.utc + + assert cursor.time_zone.tzname(None) == "UTC" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(0) + + +def test_create_with_timezone_as_pytz_object(mocked_connection): + """ + The cursor can return timezone-aware `datetime` objects when requested. + Here: Use a `pytz.timezone` instance. + """ + + cursor = mocked_connection.cursor( + time_zone=pytz.timezone("Australia/Sydney") + ) + assert cursor.time_zone.tzname(None) == "Australia/Sydney" + + # Apparently, when using `pytz`, the timezone object does not return + # an offset. Nevertheless, it works, as demonstrated per doctest in + # `cursor.txt`. + assert cursor.time_zone.utcoffset(None) is None + + +def test_create_with_timezone_as_zoneinfo_object(mocked_connection): + """ + The cursor can return timezone-aware `datetime` objects when requested. + Here: Use a `zoneinfo.ZoneInfo` instance. + """ + cursor = mocked_connection.cursor( + time_zone=zoneinfo.ZoneInfo("Australia/Sydney") + ) + assert cursor.time_zone.key == "Australia/Sydney" + + +def test_create_with_timezone_as_utc_offset_success(mocked_connection): + """ + Verify the cursor can return timezone-aware `datetime` objects when + requested. + + Here: Use a UTC offset in string format. + """ + + cursor = mocked_connection.cursor(time_zone="+0530") + assert cursor.time_zone.tzname(None) == "+0530" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(seconds=19800) + + cursor = mocked_connection.cursor(time_zone="-1145") + assert cursor.time_zone.tzname(None) == "-1145" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta( + days=-1, seconds=44100 + ) + + +def test_create_with_timezone_as_utc_offset_failure(mocked_connection): + """ + Verify the cursor trips when trying to use invalid UTC offset strings. + """ + + with pytest.raises(ValueError) as err: + mocked_connection.cursor(time_zone="foobar") + assert err == "Time zone 'foobar' is given in invalid UTC offset format" + + with pytest.raises(ValueError) as err: + mocked_connection.cursor(time_zone="+abcd") + assert ( + err + == "Time zone '+abcd' is given in invalid UTC offset format: " + + "invalid literal for int() with base 10: '+ab'" ) - c.execute("") - with self.assertRaises(ValueError) as ex: - c.fetchone() - self.assertEqual(ex.exception.args, ("999 is not a valid DataType",)) - - def test_execute_array_with_converter(self): - client = ClientMocked() - conn = connect(client=client) - converter = DefaultTypeConverter() - cursor = conn.cursor(converter=converter) - - conn.client.set_next_response( - { - "col_types": [4, [100, 5]], - "cols": ["name", "address"], - "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], - "rowcount": 1, - "duration": 123, - } - ) +def test_create_with_timezone_connection_cursor_precedence(mocked_connection): + """ + Verify that the time zone specified on the cursor object instance + takes precedence over the one specified on the connection instance. + """ + connection = connect( + client=mocked_connection.client, + time_zone=pytz.timezone("Australia/Sydney"), + ) + cursor = connection.cursor(time_zone="+0530") + assert cursor.time_zone.tzname(None) == "+0530" + assert cursor.time_zone.utcoffset(None) == datetime.timedelta(seconds=19800) + + +def test_execute_with_args(mocked_connection): + """ + Verify that `cursor.execute` is called with the right parameters. + """ + cursor = mocked_connection.cursor() + statement = "select * from locations where position = ?" + cursor.execute(statement, 1) + mocked_connection.client.sql.assert_called_once_with(statement, 1, None) + + +def test_execute_with_bulk_args(mocked_connection): + """ + Verify that `cursor.execute` is called with the right parameters + when passing `bulk_parameters`. + """ + cursor = mocked_connection.cursor() + statement = "select * from locations where position = ?" + cursor.execute(statement, bulk_parameters=[[1]]) + mocked_connection.client.sql.assert_called_once_with(statement, None, [[1]]) + + +def test_execute_custom_converter(mocked_connection): + """ + Verify that a custom converter is correctly applied when passed to a cursor. + """ + # Extends the DefaultTypeConverter + converter = DefaultTypeConverter( + { + DataType.BIT: lambda value: value is not None + and int(value[2:-1], 2) + or None + } + ) + cursor = mocked_connection.cursor(converter=converter) + response = { + "col_types": [4, 5, 11, 25], + "cols": ["name", "address", "timestamp", "bitmask"], + "rows": [ + ["foo", "10.10.10.1", 1658167836758, "B'0110'"], + [None, None, None, None], + ], + "rowcount": 1, + "duration": 123, + } + + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): cursor.execute("") - result = cursor.fetchone() - self.assertEqual( - result, + result = cursor.fetchall() + + assert result == [ [ "foo", - [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], + IPv4Address("10.10.10.1"), + datetime.datetime( + 2022, + 7, + 18, + 18, + 10, + 36, + 758000, + tzinfo=datetime.timezone.utc, + ), + 6, ], - ) - - def test_execute_array_with_converter_and_invalid_collection_type(self): - client = ClientMocked() - conn = connect(client=client) - converter = DefaultTypeConverter() - cursor = conn.cursor(converter=converter) - - # Converting collections only works for `ARRAY`s. (ID=100). - # When using `DOUBLE` (ID=6), it should croak. - conn.client.set_next_response( - { - "col_types": [4, [6, 5]], - "cols": ["name", "address"], - "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], - "rowcount": 1, - "duration": 123, - } - ) - + [None, None, None, None], + ] + + +def test_execute_with_converter_and_invalid_data_type(mocked_connection): + converter = DefaultTypeConverter() + + # Create a `Cursor` object with converter. + cursor = mocked_connection.cursor(converter=converter) + + response = { + "col_types": [999], + "cols": ["foo"], + "rows": [ + ["n/a"], + ], + "rowcount": 1, + "duration": 123, + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): cursor.execute("") - - with self.assertRaises(ValueError) as ex: + with pytest.raises(ValueError) as e: cursor.fetchone() - self.assertEqual( - ex.exception.args, - ("Data type 6 is not implemented as collection type",), - ) - - def test_execute_nested_array_with_converter(self): - client = ClientMocked() - conn = connect(client=client) - converter = DefaultTypeConverter() - cursor = conn.cursor(converter=converter) - - conn.client.set_next_response( - { - "col_types": [4, [100, [100, 5]]], - "cols": ["name", "address_buckets"], - "rows": [ - [ - "foo", - [ - ["10.10.10.1", "10.10.10.2"], - ["10.10.10.3"], - [], - None, - ], - ] - ], - "rowcount": 1, - "duration": 123, - } - ) - + assert e.exception.args == "999 is not a valid DataType" + + +def test_execute_array_with_converter(mocked_connection): + converter = DefaultTypeConverter() + cursor = mocked_connection.cursor(converter=converter) + response = { + "col_types": [4, [100, 5]], + "cols": ["name", "address"], + "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], + "rowcount": 1, + "duration": 123, + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): cursor.execute("") result = cursor.fetchone() - self.assertEqual( - result, + + assert result == [ + "foo", + [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], + ] + + +def test_execute_array_with_converter_invalid(mocked_connection): + converter = DefaultTypeConverter() + cursor = mocked_connection.cursor(converter=converter) + response = { + "col_types": [4, [6, 5]], + "cols": ["name", "address"], + "rows": [["foo", ["10.10.10.1", "10.10.10.2"]]], + "rowcount": 1, + "duration": 123, + } + # Converting collections only works for `ARRAY`s. (ID=100). + # When using `DOUBLE` (ID=6), it should raise an Exception. + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): + cursor.execute("") + with pytest.raises(ValueError) as e: + cursor.fetchone() + assert e.exception.args == ( + "Data type 6 is not implemented as collection type" + ) + + +def test_execute_nested_array_with_converter(mocked_connection): + converter = DefaultTypeConverter() + cursor = mocked_connection.cursor(converter=converter) + response = { + "col_types": [4, [100, [100, 5]]], + "cols": ["name", "address_buckets"], + "rows": [ [ "foo", [ - [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], - [IPv4Address("10.10.10.3")], + ["10.10.10.1", "10.10.10.2"], + ["10.10.10.3"], [], None, ], + ] + ], + "rowcount": 1, + "duration": 123, + } + + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): + cursor.execute("") + result = cursor.fetchone() + assert result == [ + "foo", + [ + [IPv4Address("10.10.10.1"), IPv4Address("10.10.10.2")], + [IPv4Address("10.10.10.3")], + [], + None, ], - ) - - def test_executemany_with_converter(self): - client = ClientMocked() - conn = connect(client=client) - converter = DefaultTypeConverter() - cursor = conn.cursor(converter=converter) - - conn.client.set_next_response( - { - "col_types": [4, 5], - "cols": ["name", "address"], - "rows": [["foo", "10.10.10.1"]], - "rowcount": 1, - "duration": 123, - } - ) - + ] + + +def test_executemany_with_converter(mocked_connection): + converter = DefaultTypeConverter() + cursor = mocked_connection.cursor(converter=converter) + response = { + "col_types": [4, 5], + "cols": ["name", "address"], + "rows": [["foo", "10.10.10.1"]], + "rowcount": 1, + "duration": 123, + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): cursor.executemany("", []) result = cursor.fetchall() # ``executemany()`` is not intended to be used with statements # returning result sets. The result will always be empty. - self.assertEqual(result, []) - - def test_execute_with_timezone(self): - client = ClientMocked() - conn = connect(client=client) - - # Create a `Cursor` object with `time_zone`. - tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") - c = conn.cursor(time_zone=tz_mst) - - # Make up a response using CrateDB data type `TIMESTAMP`. - conn.client.set_next_response( - { - "col_types": [4, 11], - "cols": ["name", "timestamp"], - "rows": [ - ["foo", 1658167836758], - [None, None], - ], - } - ) - + assert result == [] + + +def test_execute_with_timezone(mocked_connection): + # Create a `Cursor` object with `time_zone`. + tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") + cursor = mocked_connection.cursor(time_zone=tz_mst) + + # Make up a response using CrateDB data type `TIMESTAMP`. + response = { + "col_types": [4, 11], + "cols": ["name", "timestamp"], + "rows": [ + ["foo", 1658167836758], + [None, None], + ], + } + with mock.patch.object( + mocked_connection.client, "sql", return_value=response + ): # Run execution and verify the returned `datetime` object is # timezone-aware, using the designated timezone object. - c.execute("") - result = c.fetchall() - self.assertEqual( - result, + cursor.execute("") + result = cursor.fetchall() + assert result == [ [ - [ - "foo", - datetime.datetime( - 2022, - 7, - 19, - 1, - 10, - 36, - 758000, - tzinfo=datetime.timezone( - datetime.timedelta(seconds=25200), "MST" - ), + "foo", + datetime.datetime( + 2022, + 7, + 19, + 1, + 10, + 36, + 758000, + tzinfo=datetime.timezone( + datetime.timedelta(seconds=25200), "MST" ), - ], - [ - None, - None, - ], + ), ], - ) - self.assertEqual(result[0][1].tzname(), "MST") + [ + None, + None, + ], + ] + + assert result[0][1].tzname() == "MST" # Change timezone and verify the returned `datetime` object is using it. - c.time_zone = datetime.timezone.utc - c.execute("") - result = c.fetchall() - self.assertEqual( - result, + cursor.time_zone = datetime.timezone.utc + cursor.execute("") + result = cursor.fetchall() + assert result == [ [ - [ - "foo", - datetime.datetime( - 2022, - 7, - 18, - 18, - 10, - 36, - 758000, - tzinfo=datetime.timezone.utc, - ), - ], - [ - None, - None, - ], + "foo", + datetime.datetime( + 2022, + 7, + 18, + 18, + 10, + 36, + 758000, + tzinfo=datetime.timezone.utc, + ), ], - ) - self.assertEqual(result[0][1].tzname(), "UTC") + [ + None, + None, + ], + ] + + assert result[0][1].tzname() == "UTC" + + +def test_cursor_close(mocked_connection): + """ + Verify that a cursor is not closed if not specifically closed. + """ + + cursor = mocked_connection.cursor() + cursor.execute("") + assert cursor._closed is False + + cursor.close() + + assert cursor._closed is True + assert not cursor._result + assert cursor.duration == -1 + + with pytest.raises(ProgrammingError, match="Connection closed"): + mocked_connection.close() + cursor.execute("") + + +def test_cursor_closes_access(mocked_connection): + """ + Verify that a cursor cannot be used once it is closed. + """ + + cursor = mocked_connection.cursor() + cursor.execute("") + + cursor.close() - conn.close() + with pytest.raises(ProgrammingError): + cursor.execute("s") diff --git a/tests/client/test_exceptions.py b/tests/client/test_exceptions.py index cb91e1a9..c2ed0976 100644 --- a/tests/client/test_exceptions.py +++ b/tests/client/test_exceptions.py @@ -1,13 +1,17 @@ -import unittest - from crate.client import Error +from crate.client.exceptions import BlobException + + +def test_error_with_msg(): + err = Error("foo") + assert str(err) == "foo" + +def test_error_with_error_trace(): + err = Error("foo", error_trace="### TRACE ###") + assert str(err), "foo\n### TRACE ###" -class ErrorTestCase(unittest.TestCase): - def test_error_with_msg(self): - err = Error("foo") - self.assertEqual(str(err), "foo") - def test_error_with_error_trace(self): - err = Error("foo", error_trace="### TRACE ###") - self.assertEqual(str(err), "foo\n### TRACE ###") +def test_blob_exception(): + err = BlobException(table="sometable", digest="somedigest") + assert str(err) == "BlobException('sometable/somedigest)'" diff --git a/tests/client/test_http.py b/tests/client/test_http.py index c4c0609e..946cdc00 100644 --- a/tests/client/test_http.py +++ b/tests/client/test_http.py @@ -19,29 +19,23 @@ # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -import datetime as dt import json -import multiprocessing import os import queue import random import socket -import sys import time -import traceback -import uuid from base64 import b64decode -from decimal import Decimal -from http.server import BaseHTTPRequestHandler, HTTPServer -from multiprocessing.context import ForkProcess +from http.server import BaseHTTPRequestHandler from threading import Event, Thread -from unittest import TestCase from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import certifi +import pytest import urllib3.exceptions +from crate.client.connection import connect from crate.client.exceptions import ( ConnectionError, IntegrityError, @@ -51,57 +45,18 @@ Client, _get_socket_opts, _remove_certs_for_non_https, - json_dumps, ) +from tests.conftest import REQUEST_PATH, fake_response -REQUEST = "crate.client.http.Server.request" -CA_CERT_PATH = certifi.where() +mocked_request = MagicMock(spec=urllib3.response.HTTPResponse) -def fake_request(response=None): - def request(*args, **kwargs): - if isinstance(response, list): - resp = response.pop(0) - response.append(resp) - return resp - elif response: - return response - else: - return MagicMock(spec=urllib3.response.HTTPResponse) - - return request - - -def fake_response(status, reason=None, content_type="application/json"): - m = MagicMock(spec=urllib3.response.HTTPResponse) - m.status = status - m.reason = reason or "" - m.headers = {"content-type": content_type} - return m - - -def fake_redirect(location): +def fake_redirect(location: str) -> MagicMock: m = fake_response(307) m.get_redirect_location.return_value = location return m -def bad_bulk_response(): - r = fake_response(400, "Bad Request") - r.data = json.dumps( - { - "results": [ - {"rowcount": 1}, - {"error_message": "an error occured"}, - {"error_message": "another error"}, - {"error_message": ""}, - {"error_message": None}, - ] - } - ).encode() - return r - - def duplicate_key_exception(): r = fake_response(409, "Conflict") r.data = json.dumps( @@ -116,270 +71,227 @@ def duplicate_key_exception(): return r -def fail_sometimes(*args, **kwargs): - if random.randint(1, 100) % 10 == 0: +def fail_sometimes(*args, **kwargs) -> MagicMock: + """ + Function that fails with a 50% chance. It either returns a successful mocked + response or raises an urllib3 exception. + """ + if random.randint(1, 10) % 2: raise urllib3.exceptions.MaxRetryError(None, "/_sql", "") return fake_response(200) -class HttpClientTest(TestCase): - @patch( - REQUEST, - fake_request( - [ - fake_response(200), - fake_response(104, "Connection reset by peer"), - fake_response(503, "Service Unavailable"), - ] - ), +def test_connection_reset_exception(): + """ + Verify that a HTTP 503 status code response raises an exception. + """ + + expected_exception_msg = ( + "No more Servers available, exception" + " from last server: Service Unavailable" ) - def test_connection_reset_exception(self): + with patch( + REQUEST_PATH, + side_effect=[ + fake_response(200), + fake_response(104, "Connection reset by peer"), + fake_response(503, "Service Unavailable"), + ], + ): client = Client(servers="localhost:4200") - client.sql("select 1") - client.sql("select 2") - self.assertEqual( - ["http://localhost:4200"], list(client._active_servers) - ) - try: - client.sql("select 3") - except ProgrammingError: - self.assertEqual([], list(client._active_servers)) - else: - self.assertTrue(False) - finally: - client.close() + client.sql("select 1") # 200 response + client.sql("select 2") # 104 response + assert list(client._active_servers) == ["http://localhost:4200"] - def test_no_connection_exception(self): - client = Client(servers="localhost:9999") - self.assertRaises(ConnectionError, client.sql, "select foo") - client.close() + with pytest.raises(ProgrammingError, match=expected_exception_msg): + client.sql("select 3") # 503 response + assert not client._active_servers - @patch(REQUEST) - def test_http_error_is_re_raised(self, request): - request.side_effect = Exception - client = Client() - self.assertRaises(ProgrammingError, client.sql, "select foo") - client.close() +def test_no_connection_exception(): + """ + Verify that when no connection can be made to the server, + a `ConnectionError` is raised. + """ + client = Client(servers="localhost:9999") + with pytest.raises(ConnectionError): + client.sql("") - @patch(REQUEST) - def test_programming_error_contains_http_error_response_content( - self, request - ): - request.side_effect = Exception("this shouldn't be raised") +def test_http_error_is_re_raised(): + """ + Verify that when calling `REQUEST` if any error occurs, + a `ProgrammingError` exception is raised _from_ that exception. + """ + client = Client() + + exception_msg = "some exception did happen" + with patch(REQUEST_PATH, side_effect=Exception(exception_msg)): + with pytest.raises(ProgrammingError, match=exception_msg): + client.sql("select foo") + + +def test_programming_error_contains_http_error_response_content(): + """ + Verify that when calling `REQUEST` if any error occurs, + the raised `ProgrammingError` exception + contains the error message from the original error. + """ + expected_msg = "this message should appear" + with patch(REQUEST_PATH, side_effect=Exception(expected_msg)): client = Client() - try: + with pytest.raises(ProgrammingError, match=expected_msg): client.sql("select 1") - except ProgrammingError as e: - self.assertEqual("this shouldn't be raised", e.message) - else: - self.assertTrue(False) - finally: - client.close() - - @patch( - REQUEST, - fake_request( - [fake_response(200), fake_response(503, "Service Unavailable")] - ), - ) - def test_server_error_50x(self): - client = Client(servers="localhost:4200 localhost:4201") - client.sql("select 1") - client.sql("select 2") - try: - client.sql("select 3") - except ProgrammingError as e: - self.assertEqual( - "No more Servers available, " - + "exception from last server: Service Unavailable", - e.message, - ) - self.assertEqual([], list(client._active_servers)) - else: - self.assertTrue(False) - finally: - client.close() - def test_connect(self): - client = Client(servers="localhost:4200 localhost:4201") - self.assertEqual( - client._active_servers, - ["http://localhost:4200", "http://localhost:4201"], - ) - client.close() +def test_connect(): + """ + Verify the correctness of `server` parameter when `Client` is instantiated. + """ + client = Client(servers="localhost:4200 localhost:4201") + assert client._active_servers == [ + "http://localhost:4200", + "http://localhost:4201", + ] + + # By default, it's http://127.0.0.1:4200 + client = Client(servers=None) + assert client._active_servers == ["http://127.0.0.1:4200"] + + with pytest.raises(TypeError, match="expected string or bytes"): + Client(servers=[123, "127.0.0.1:4201", False]) + + +def test_redirect_handling(): + """ + Verify that when a redirect happens, that redirect uri + gets added to the server pool. + """ + with patch( + REQUEST_PATH, return_value=fake_redirect("http://localhost:4201") + ): client = Client(servers="localhost:4200") - self.assertEqual(client._active_servers, ["http://localhost:4200"]) - client.close() - - client = Client(servers=["localhost:4200"]) - self.assertEqual(client._active_servers, ["http://localhost:4200"]) - client.close() - - client = Client(servers=["localhost:4200", "127.0.0.1:4201"]) - self.assertEqual( - client._active_servers, - ["http://localhost:4200", "http://127.0.0.1:4201"], - ) - client.close() - - @patch(REQUEST, fake_request(fake_redirect("http://localhost:4201"))) - def test_redirect_handling(self): - client = Client(servers="localhost:4200") - try: - client.blob_get("blobs", "fake_digest") - except ProgrammingError: + + # Don't try to print the exception or use `match`, otherwise + # the recursion will not be short-circuited and it will hang. + with pytest.raises(ProgrammingError): # 4201 gets added to serverpool but isn't available # that's why we run into an infinite recursion # exception message is: maximum recursion depth exceeded - pass - self.assertEqual( - ["http://localhost:4200", "http://localhost:4201"], - sorted(client.server_pool.keys()), - ) - # the new non-https server must not contain any SSL only arguments - # regression test for github issue #179/#180 - self.assertEqual( - {"socket_options": _get_socket_opts(keepalive=True)}, - client.server_pool["http://localhost:4201"].pool.conn_kw, - ) - client.close() - - @patch(REQUEST) - def test_server_infos(self, request): - request.side_effect = urllib3.exceptions.MaxRetryError( - None, "/", "this shouldn't be raised" - ) - client = Client(servers="localhost:4200 localhost:4201") - self.assertRaises( - ConnectionError, client.server_infos, "http://localhost:4200" - ) - client.close() + client.blob_get("blobs", "fake_digest") - @patch(REQUEST, fake_request(fake_response(503))) - def test_server_infos_503(self): - client = Client(servers="localhost:4200") - self.assertRaises( - ConnectionError, client.server_infos, "http://localhost:4200" - ) - client.close() + assert sorted(client.server_pool.keys()) == [ + "http://localhost:4200", + "http://localhost:4201", + ] - @patch( - REQUEST, fake_request(fake_response(401, "Unauthorized", "text/html")) - ) - def test_server_infos_401(self): - client = Client(servers="localhost:4200") - try: - client.server_infos("http://localhost:4200") - except ProgrammingError as e: - self.assertEqual("401 Client Error: Unauthorized", e.message) - else: - self.assertTrue(False, msg="Exception should have been raised") - finally: - client.close() + # the new non-https server must not contain any SSL only arguments + # regression test for: + # - https://github.com/crate/crate-python/issues/179 + # - https://github.com/crate/crate-python/issues/180 - @patch(REQUEST, fake_request(bad_bulk_response())) - def test_bad_bulk_400(self): - client = Client(servers="localhost:4200") - try: - client.sql( - "Insert into users (name) values(?)", - bulk_parameters=[["douglas"], ["monthy"]], - ) - except ProgrammingError as e: - self.assertEqual("an error occured\nanother error", e.message) - else: - self.assertTrue(False, msg="Exception should have been raised") - finally: - client.close() + assert client.server_pool["http://localhost:4201"].pool.conn_kw == { + "socket_options": _get_socket_opts(keepalive=True) + } - @patch(REQUEST, autospec=True) - def test_decimal_serialization(self, request): - client = Client(servers="localhost:4200") - request.return_value = fake_response(200) - dec = Decimal(0.12) - client.sql("insert into users (float_col) values (?)", (dec,)) +def test_server_infos(): + """ + Verify that when a `MaxRetryError` is raised, a `ConnectionError` is raised. + """ + error = urllib3.exceptions.MaxRetryError(None, "/") + with patch(REQUEST_PATH, side_effect=error): + client = Client(servers="localhost:4200 localhost:4201") + with pytest.raises(ConnectionError): + client.server_infos("http://localhost:4200") - data = json.loads(request.call_args[1]["data"]) - self.assertEqual(data["args"], [str(dec)]) - client.close() - @patch(REQUEST, autospec=True) - def test_datetime_is_converted_to_ts(self, request): +def test_server_infos_401(): + """ + Verify that when a 401 status code is returned, a `ProgrammingError` + is raised. + """ + response = fake_response(401, "Unauthorized", "text/html") + with patch(REQUEST_PATH, return_value=response): client = Client(servers="localhost:4200") - request.return_value = fake_response(200) + with pytest.raises( + ProgrammingError, match="401 Client Error: Unauthorized" + ): + client.server_infos("http://localhost:4200") + - datetime = dt.datetime(2015, 2, 28, 7, 31, 40) - client.sql("insert into users (dt) values (?)", (datetime,)) +def test_bad_bulk_400(): + """ + Verify that a 400 response when doing a bulk request raises + a `ProgrammingException` with the error message of the response object's + key `error_message`, several error messages can be returned by the database. + """ + response = fake_response(400, "Bad Request") + response.data = json.dumps( + { + "results": [ + {"rowcount": 1}, + {"error_message": "an error occurred"}, + {"error_message": "another error"}, + {"error_message": ""}, + {"error_message": None}, + ] + } + ).encode() - # convert string to dict - # because the order of the keys isn't deterministic - data = json.loads(request.call_args[1]["data"]) - self.assertEqual(data["args"], [1425108700000]) - client.close() + client = Client(servers="localhost:4200") + with patch(REQUEST_PATH, return_value=response): + with pytest.raises( + ProgrammingError, match="an error occurred\nanother error" + ): + client.sql( + "Insert into users (name) values(?)", + bulk_parameters=[["douglas"], ["monthy"]], + ) - @patch(REQUEST, autospec=True) - def test_date_is_converted_to_ts(self, request): - client = Client(servers="localhost:4200") - request.return_value = fake_response(200) - - day = dt.date(2016, 4, 21) - client.sql("insert into users (dt) values (?)", (day,)) - data = json.loads(request.call_args[1]["data"]) - self.assertEqual(data["args"], [1461196800000]) - client.close() - - def test_socket_options_contain_keepalive(self): - server = "http://localhost:4200" - client = Client(servers=server) - conn_kw = client.server_pool[server].pool.conn_kw - self.assertIn( - (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), - conn_kw["socket_options"], - ) - client.close() - - @patch(REQUEST, autospec=True) - def test_uuid_serialization(self, request): - client = Client(servers="localhost:4200") - request.return_value = fake_response(200) - uid = uuid.uuid4() - client.sql("insert into my_table (str_col) values (?)", (uid,)) +def test_socket_options_contain_keepalive(): + """ + Verify that KEEPALIVE options are present at `socket_options` + """ + server = "http://localhost:4200" + client = Client(servers=server) + conn_kw = client.server_pool[server].pool.conn_kw + assert (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) in conn_kw[ + "socket_options" + ] - data = json.loads(request.call_args[1]["data"]) - self.assertEqual(data["args"], [str(uid)]) - client.close() - @patch(REQUEST, fake_request(duplicate_key_exception())) - def test_duplicate_key_error(self): - """ - Verify that an `IntegrityError` is raised on duplicate key errors, - instead of the more general `ProgrammingError`. - """ +def test_duplicate_key_error(): + """ + Verify that an `IntegrityError` is raised on duplicate key errors, + instead of the more general `ProgrammingError`. + """ + expected_error_msg = ( + r"DuplicateKeyException\[A document with " + r"the same primary key exists already\]" + ) + with patch(REQUEST_PATH, return_value=duplicate_key_exception()): client = Client(servers="localhost:4200") - with self.assertRaises(IntegrityError) as cm: + with pytest.raises(IntegrityError, match=expected_error_msg): client.sql("INSERT INTO testdrive (foo) VALUES (42)") - self.assertEqual( - cm.exception.message, - "DuplicateKeyException[A document with the " - "same primary key exists already]", - ) -@patch(REQUEST, fail_sometimes) -class ThreadSafeHttpClientTest(TestCase): +@patch(REQUEST_PATH, fail_sometimes) +def test_client_multithreaded(): """ - Using a pool of 5 Threads to emit commands to the multiple servers through - one Client-instance + Verify client multithreading using a pool of 5 Threads to emit commands to + the multiple servers through one Client-instance. - check if number of servers in _inactive_servers and _active_servers always - equals the number of servers initially given. - """ + Checks if the number of servers in _inactive_servers and _active_servers + always equals the number of servers initially given. + + Note: + This test is probabilistic and does not ensure that the + client is indeed thread-safe in all cases, it can only show that it + withstands this scenario. + """ servers = [ "127.0.0.1:44209", "127.0.0.2:44209", @@ -387,177 +299,178 @@ class ThreadSafeHttpClientTest(TestCase): ] num_threads = 5 num_commands = 1000 - thread_timeout = 5.0 # seconds + thread_timeout = 10.0 # seconds - def __init__(self, *args, **kwargs): - self.event = Event() - self.err_queue = queue.Queue() - super(ThreadSafeHttpClientTest, self).__init__(*args, **kwargs) + gate = Event() + error_queue = queue.Queue() - def setUp(self): - self.client = Client(self.servers) - self.client.retry_interval = 0.2 # faster retry + client = Client(servers) + client.retry_interval = 0.2 # faster retry - def tearDown(self): - self.client.close() - - def _run(self): - self.event.wait() # wait for the others - expected_num_servers = len(self.servers) - for _ in range(self.num_commands): + def worker(): + """ + Worker that sends many requests, if the `num_server` is not the + expected value at some point, an assertion will be added to the shared + error queue. + """ + gate.wait() # wait for the others + expected_num_servers = len(servers) + for _ in range(num_commands): try: - self.client.sql("select name from sys.cluster") + client.sql("select name from sys.cluster") except ConnectionError: + # Sometimes it will fail. pass try: - with self.client._lock: - num_servers = len(self.client._active_servers) + len( - self.client._inactive_servers + with client._lock: + num_servers = len(client._active_servers) + len( + client._inactive_servers ) - self.assertEqual( - expected_num_servers, - num_servers, - "expected %d but got %d" - % (expected_num_servers, num_servers), - ) - except AssertionError: - self.err_queue.put(sys.exc_info()) - - def test_client_threaded(self): - """ - Testing if lists of servers is handled correctly when client is used - from multiple threads with some requests failing. + assert num_servers == expected_num_servers, ( + f"expected {expected_num_servers} but got {num_servers}" + ) + except AssertionError as e: + error_queue.put(e) - **ATTENTION:** this test is probabilistic and does not ensure that the - client is indeed thread-safe in all cases, it can only show that it - withstands this scenario. - """ - threads = [ - Thread(target=self._run, name=str(x)) - for x in range(self.num_threads) - ] - for thread in threads: - thread.start() - - self.event.set() - for t in threads: - t.join(self.thread_timeout) - - if not self.err_queue.empty(): - self.assertTrue( - False, - "".join( - traceback.format_exception(*self.err_queue.get(block=False)) - ), - ) + threads = [Thread(target=worker, name=str(i)) for i in range(num_threads)] + + for thread in threads: + thread.start() + + gate.set() + + for t in threads: + t.join(timeout=thread_timeout) + + # If any thread is still alive after the timeout, consider it a failure. + alive = [t.name for t in threads if t.is_alive()] + if alive: + pytest.fail(f"Threads did not finish within {thread_timeout}s: {alive}") + if not error_queue.empty(): + # If an error happened, consider it a failure as well. + first_error_trace = error_queue.get(block=False) + pytest.fail(first_error_trace) -class ClientAddressRequestHandler(BaseHTTPRequestHandler): + +def test_params(): + """ + Verify client parameters translate correctly to query parameters. """ - http handler for use with HTTPServer + client = Client(["127.0.0.1:4200"], error_trace=True) + parsed = urlparse(client.path) + params = parse_qs(parsed.query) + + assert params["error_trace"] == ["true"] + assert params["types"] == ["true"] + + client = Client(["127.0.0.1:4200"]) + parsed = urlparse(client.path) + params = parse_qs(parsed.query) + + # Default is False + assert "error_trace" not in params + assert params["types"] == ["true"] - returns client host and port in crate-conform-responses + assert "/_sql?" in client.path + + +def test_client_ca(): + """ + Verify that if env variable `REQUESTS_CA_BUNDLE` is set, certs are + loaded into the pool. """ + with patch.dict(os.environ, {"REQUEST_PATH": certifi.where()}, clear=True): + client = Client("http://127.0.0.1:4200") + assert "ca_certs" in client._pool_kw - protocol_version = "HTTP/1.1" - def do_GET(self): - content_length = self.headers.get("content-length") - if content_length: - self.rfile.read(int(content_length)) - response = json.dumps( - { - "cols": ["host", "port"], - "rows": [self.client_address[0], self.client_address[1]], - "rowCount": 1, - } - ) - self.send_response(200) - self.send_header("Content-Length", len(response)) - self.send_header("Content-Type", "application/json; charset=UTF-8") - self.end_headers() - self.wfile.write(response.encode("UTF-8")) +def test_remove_certs_for_non_https(): + """ + Verify that `_remove_certs_for_non_https` correctly removes ca_certs. + """ + d = _remove_certs_for_non_https("https", {"ca_certs": 1}) + assert "ca_certs" in d - do_POST = do_PUT = do_DELETE = do_HEAD = do_GET + kwargs = {"ca_certs": 1, "foobar": 2, "cert_file": 3} + d = _remove_certs_for_non_https("http", kwargs) + assert "ca_certs" not in d + assert "cert_file" not in d + assert "foobar" in d -class KeepAliveClientTest(TestCase): - server_address = ("127.0.0.1", 65535) +def test_keep_alive(serve_http): + """ + Verify that when launching several requests, the connection is kept + alive and successfully terminates. - def __init__(self, *args, **kwargs): - super(KeepAliveClientTest, self).__init__(*args, **kwargs) - self.server_process = ForkProcess(target=self._run_server) + This uses a real http sever that mocks CrateDB-like responses. + """ - def setUp(self): - super(KeepAliveClientTest, self).setUp() - self.client = Client(["%s:%d" % self.server_address]) - self.server_process.start() - time.sleep(0.10) + class ClientAddressRequestHandler(BaseHTTPRequestHandler): + """ + http handler for use with HTTPServer - def tearDown(self): - self.server_process.terminate() - self.client.close() - super(KeepAliveClientTest, self).tearDown() + returns client host and port in crate-conform-responses + """ - def _run_server(self): - self.server = HTTPServer( - self.server_address, ClientAddressRequestHandler - ) - self.server.handle_request() + protocol_version = "HTTP/1.1" - def test_client_keepalive(self): - for _ in range(10): - result = self.client.sql("select * from fake") + def do_GET(self): + content_length = self.headers.get("content-length") + if content_length: + self.rfile.read(int(content_length)) - another_result = self.client.sql("select again from fake") - self.assertEqual(result, another_result) + response = json.dumps( + { + "cols": ["host", "port"], + "rows": [self.client_address[0], self.client_address[1]], + "rowCount": 1, + } + ) + self.send_response(200) + self.send_header("Content-Length", str(len(response))) + self.send_header("Content-Type", "application/json; charset=UTF-8") + self.end_headers() + self.wfile.write(response.encode("UTF-8")) -class ParamsTest(TestCase): - def test_params(self): - client = Client(["127.0.0.1:4200"], error_trace=True) - parsed = urlparse(client.path) - params = parse_qs(parsed.query) - self.assertEqual(params["error_trace"], ["true"]) - client.close() + do_POST = do_GET - def test_no_params(self): - client = Client() - self.assertEqual(client.path, "/_sql?types=true") - client.close() + with serve_http(ClientAddressRequestHandler) as (_, url): + with connect(url) as conn: + client = conn.client + for _ in range(25): + result = client.sql("select * from fake") + another_result = client.sql("select again from fake") + assert result == another_result -class RequestsCaBundleTest(TestCase): - def test_open_client(self): - os.environ["REQUESTS_CA_BUNDLE"] = CA_CERT_PATH - try: - Client("http://127.0.0.1:4200") - except ProgrammingError: - self.fail("HTTP not working with REQUESTS_CA_BUNDLE") - finally: - os.unsetenv("REQUESTS_CA_BUNDLE") - os.environ["REQUESTS_CA_BUNDLE"] = "" - def test_remove_certs_for_non_https(self): - d = _remove_certs_for_non_https("https", {"ca_certs": 1}) - self.assertIn("ca_certs", d) +def test_no_retry_on_read_timeout(serve_http): + timeout = 1 - kwargs = {"ca_certs": 1, "foobar": 2, "cert_file": 3} - d = _remove_certs_for_non_https("http", kwargs) - self.assertNotIn("ca_certs", d) - self.assertNotIn("cert_file", d) - self.assertIn("foobar", d) + class TimeoutRequestHandler(BaseHTTPRequestHandler): + """ + HTTP handler for use with TestingHTTPServer + updates the shared counter and waits so that the client times out + """ + def do_POST(self): + self.server.SHARED["count"] += 1 + time.sleep(timeout + 0.1) -class TimeoutRequestHandler(BaseHTTPRequestHandler): - """ - HTTP handler for use with TestingHTTPServer - updates the shared counter and waits so that the client times out - """ + def do_GET(self): + pass - def do_POST(self): - self.server.SHARED["count"] += 1 - time.sleep(5) + # Start the http server. + with serve_http(TimeoutRequestHandler) as (server, url): + # Connect to the server. + with connect(url, timeout=timeout) as conn: + # We expect it to raise a `ConnectionError` + with pytest.raises(ConnectionError, match="Read timed out"): + conn.client.sql("select * from fake") + assert server.SHARED.get("count") == 1 class SharedStateRequestHandler(BaseHTTPRequestHandler): @@ -594,140 +507,53 @@ def do_POST(self): self.end_headers() self.wfile.write(response.encode("utf-8")) + def do_GET(self): + pass + -class TestingHTTPServer(HTTPServer): +def test_default_schema(serve_http): """ - http server providing a shared dict + Verify that the schema is correctly sent. """ + test_schema = "some_schema" + with serve_http(SharedStateRequestHandler) as (server, url): + with connect(url, schema=test_schema) as conn: + conn.client.sql("select 1;") + assert server.SHARED.get("schema") == test_schema - manager = multiprocessing.Manager() - SHARED = manager.dict() - SHARED["count"] = 0 - SHARED["usernameFromXUser"] = None - SHARED["username"] = None - SHARED["password"] = None - SHARED["schema"] = None - - @classmethod - def run_server(cls, server_address, request_handler_cls): - cls(server_address, request_handler_cls).serve_forever() - -class TestingHttpServerTestCase(TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.assertIsNotNone(self.request_handler) - self.server_address = ("127.0.0.1", random.randint(65000, 65535)) - self.server_process = ForkProcess( - target=TestingHTTPServer.run_server, - args=(self.server_address, self.request_handler), - ) - - def setUp(self): - self.server_process.start() - self.wait_for_server() - - def wait_for_server(self): - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(self.server_address) - except Exception: - time.sleep(0.25) - else: - break - - def tearDown(self): - self.server_process.terminate() - - def clientWithKwargs(self, **kwargs): - return Client(["%s:%d" % self.server_address], timeout=1, **kwargs) - - -class RetryOnTimeoutServerTest(TestingHttpServerTestCase): - request_handler = TimeoutRequestHandler - - def setUp(self): - super().setUp() - self.client = self.clientWithKwargs() - - def tearDown(self): - super().tearDown() - self.client.close() - - def test_no_retry_on_read_timeout(self): - try: - self.client.sql("select * from fake") - except ConnectionError as e: - self.assertIn( - "Read timed out", - e.message, - msg="Error message must contain: Read timed out", - ) - self.assertEqual(TestingHTTPServer.SHARED["count"], 1) - - -class TestDefaultSchemaHeader(TestingHttpServerTestCase): - request_handler = SharedStateRequestHandler - - def setUp(self): - super().setUp() - self.client = self.clientWithKwargs(schema="my_custom_schema") - - def tearDown(self): - self.client.close() - super().tearDown() - - def test_default_schema(self): - self.client.sql("SELECT 1") - self.assertEqual(TestingHTTPServer.SHARED["schema"], "my_custom_schema") - - -class TestUsernameSentAsHeader(TestingHttpServerTestCase): - request_handler = SharedStateRequestHandler - - def setUp(self): - super().setUp() - self.clientWithoutUsername = self.clientWithKwargs() - self.clientWithUsername = self.clientWithKwargs(username="testDBUser") - self.clientWithUsernameAndPassword = self.clientWithKwargs( - username="testDBUser", password="test:password" - ) - - def tearDown(self): - self.clientWithoutUsername.close() - self.clientWithUsername.close() - self.clientWithUsernameAndPassword.close() - super().tearDown() - - def test_username(self): - self.clientWithoutUsername.sql("select * from fake") - self.assertEqual(TestingHTTPServer.SHARED["usernameFromXUser"], None) - self.assertEqual(TestingHTTPServer.SHARED["username"], None) - self.assertEqual(TestingHTTPServer.SHARED["password"], None) - - self.clientWithUsername.sql("select * from fake") - self.assertEqual( - TestingHTTPServer.SHARED["usernameFromXUser"], "testDBUser" - ) - self.assertEqual(TestingHTTPServer.SHARED["username"], "testDBUser") - self.assertEqual(TestingHTTPServer.SHARED["password"], None) - - self.clientWithUsernameAndPassword.sql("select * from fake") - self.assertEqual( - TestingHTTPServer.SHARED["usernameFromXUser"], "testDBUser" - ) - self.assertEqual(TestingHTTPServer.SHARED["username"], "testDBUser") - self.assertEqual(TestingHTTPServer.SHARED["password"], "test:password") - - -class TestCrateJsonEncoder(TestCase): - def test_naive_datetime(self): - data = dt.datetime.fromisoformat("2023-06-26T09:24:00.123") - result = json_dumps(data) - self.assertEqual(result, b"1687771440123") - - def test_aware_datetime(self): - data = dt.datetime.fromisoformat("2023-06-26T09:24:00.123+02:00") - result = json_dumps(data) - self.assertEqual(result, b"1687764240123") +def test_credentials(serve_http): + """ + Verify credentials are correctly set in the connection and client. + """ + with serve_http(SharedStateRequestHandler) as (server, url): + # Nothing default + with connect(url) as conn: + assert not conn.client.username + assert not conn.client.password + + conn.client.sql("select 1;") + assert not server.SHARED["usernameFromXUser"] + assert not server.SHARED["username"] + assert not server.SHARED["password"] + + # Just the username + username = "some_username" + with connect(url, username=username) as conn: + assert conn.client.username == username + assert not conn.client.password + + conn.client.sql("select 2;") + assert server.SHARED["usernameFromXUser"] == username + assert server.SHARED["username"] == username + assert not server.SHARED["password"] + + # Both username and password + password = "some_password" + with connect(url, username=username, password=password) as conn: + assert conn.client.username == username + assert conn.client.password == password + conn.client.sql("select 3;") + assert server.SHARED["usernameFromXUser"] == username + assert server.SHARED["username"] == username + assert server.SHARED["password"] == password diff --git a/tests/client/test_serialization.py b/tests/client/test_serialization.py new file mode 100644 index 00000000..b022d1bb --- /dev/null +++ b/tests/client/test_serialization.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +""" +Tests for serializing data, typically python objects + into CrateDB-sql compatible structures. +""" + +import datetime +import datetime as dt +import uuid +from decimal import Decimal +from unittest.mock import MagicMock, patch + +from crate.client.http import Client, json_dumps +from tests.conftest import REQUEST_PATH, fake_response + + +def test_data_is_serialized(): + """ + Verify that when a request is issued, `json_dumps` is called with + the right parameters and that a requests gets the output from json_dumps, + this verifies the entire serialization call chain, so in the following + tests we can just test `json_dumps` and ignore + `Client` altogether. + """ + mock = MagicMock(spec=bytes) + + with patch("crate.client.http.json_dumps", return_value=mock) as f: + with patch(REQUEST_PATH, return_value=fake_response(200)) as request: + client = Client(servers="localhost:4200") + client.sql( + "insert into t (a, b) values (?, ?)", + ( + datetime.datetime( + 2025, + 10, + 23, + 11, + ), + "ss", + ), + ) + + # Verify json_dumps is called with the right parameters. + f.assert_called_once_with( + { + "stmt": "insert into t (a, b) values (?, ?)", + "args": (datetime.datetime(2025, 10, 23, 11, 0), "ss"), + } + ) + + # Verify that the output of json_dumps is used as + # call argument for a request. + assert request.call_args[1]["data"] is mock + + +def test_naive_datetime_serialization(): + """ + Verify that a `datetime.datetime` can be serialized. + """ + data = dt.datetime(2015, 2, 28, 7, 31, 40) + result = json_dumps(data) + assert isinstance(result, bytes) + assert result == b"1425108700000" + + +def test_aware_datetime_serialization(): + """ + Verify that a `datetime` that is tz aware type can be serialized. + """ + data = dt.datetime.fromisoformat("2023-06-26T09:24:00.123+02:00") + result = json_dumps(data) + assert isinstance(result, bytes) + assert result == b"1687764240123" + + +def test_decimal_serialization(): + """ + Verify that a `Decimal` type can be serialized. + """ + + data = Decimal(0.12) + expected = b'"0.11999999999999999555910790149937383830547332763671875"' + result = json_dumps(data) + assert isinstance(result, bytes) + + # Question: Is this deterministic in every Python release? + assert result == expected + + +def test_date_serialization(): + """ + Verify that a `datetime.date` can be serialized. + """ + data = dt.date(2016, 4, 21) + result = json_dumps(data) + assert result == b"1461196800000" + + +def test_uuid_serialization(): + """ + Verify that a `uuid.UUID` can be serialized. + + We do not care about specific uuid versions, just the object that is + re-used across all versions of the uuid module. + """ + uuid_int = 50583033507982468033520929066863110751 + data = uuid.UUID(bytes=uuid_int.to_bytes(16, byteorder="big"), version=4) + result = json_dumps(data) + assert result == b'"260df019-a183-431f-ad46-115ccdf12a5f"' diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..66626354 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import multiprocessing +import socket +import threading +from contextlib import contextmanager +from http.server import BaseHTTPRequestHandler, HTTPServer +from unittest.mock import MagicMock + +import pytest +import urllib3 + +import crate + +REQUEST_PATH = "crate.client.http.Server.request" + + +def fake_response( + status: int, reason: str = None, content_type: str = "application/json" +) -> MagicMock: + """ + Returns a mocked `urllib3.response.HTTPResponse` HTTP response. + """ + m = MagicMock(spec=urllib3.response.HTTPResponse) + m.status = status + m.reason = reason or "" + m.headers = {"content-type": content_type} + return m + + +@pytest.fixture +def mocked_connection(): + """ + Returns a crate `Connection` with a mocked `Client` + + Example: + def test_conn(mocked_connection): + cursor = mocked_connection.cursor() + statement = "select * from locations where position = ?" + cursor.execute(statement, 1) + mocked_connection.client.sql.called_with(statement, 1, None) + """ + yield crate.client.connect(client=MagicMock(spec=crate.client.http.Client)) + + +@pytest.fixture +def serve_http(): + """ + Returns a context manager that start an http server running + in another thread that returns CrateDB successful responses. + + It accepts an optional parameter, the handler class, it has to be an + instance of `BaseHTTPRequestHandler` + + The port will be an unused random port. + + Example: + def test_http(serve_http): + with serve_http() as url: + urllib3.urlopen(url) + + See `test_http.test_keep_alive` for more advance example. + """ + + @contextmanager + def _serve(handler_cls=BaseHTTPRequestHandler): + assert issubclass(handler_cls, BaseHTTPRequestHandler) # noqa: S101 + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + host, port = sock.getsockname() + sock.close() + + manager = multiprocessing.Manager() + SHARED = manager.dict() + SHARED["count"] = 0 + SHARED["usernameFromXUser"] = None + SHARED["username"] = None + SHARED["password"] = None + SHARED["schema"] = None + + server = HTTPServer((host, port), handler_cls) + + server.SHARED = SHARED + + thread = threading.Thread(target=server.serve_forever, daemon=False) + thread.start() + try: + yield server, f"http://{host}:{port}" + + finally: + server.shutdown() + thread.join() + + return _serve