diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 065686f43a..207ba232fe 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -813,6 +813,14 @@ def __init__( self.__pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 + self.error_times: collections.deque[tuple[float, int]] = collections.deque() + + async def get_error_rate(self) -> float: + async with self.lock: + # Require at least 10 samples to compute an error rate. + if len(self.error_times) < 10: + return 0.0 + return sum(t[1] for t in self.error_times) / (len(self.error_times) or 1) async def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. @@ -1182,7 +1190,16 @@ async def checkout( serverPort=self.address[1], ) - conn = await self._get_conn(checkout_started_time, handler=handler) + try: + conn = await self._get_conn(checkout_started_time, handler=handler) + except Exception: + self.error_times.append((time.monotonic(), 1)) + raise + + # clear old info from error rate >10 seconds old + watermark = time.monotonic() - 10 + while self.error_times and self.error_times[0][0] < watermark: + self.error_times.popleft() duration = time.monotonic() - checkout_started_time if self.enabled_for_cmap: @@ -1202,8 +1219,10 @@ async def checkout( async with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + self.error_times.append((time.monotonic(), 0)) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: + self.error_times.append((time.monotonic(), 1)) # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the # connection and it is responsible for checking the connection diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 1e91bbe79b..25b2595500 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -391,10 +391,17 @@ async def _select_server( if len(servers) == 1: return servers[0] server1, server2 = random.sample(servers, 2) - if server1.pool.operation_count <= server2.pool.operation_count: + error_rate1 = await server1.pool.get_error_rate() + error_rate2 = await server2.pool.get_error_rate() + if error_rate1 < error_rate2 and (random.random() < (error_rate2 - error_rate1)): # noqa: S311 return server1 - else: + elif error_rate1 > error_rate2 and (random.random() < (error_rate1 - error_rate2)): # noqa: S311 return server2 + else: + if server1.pool.operation_count <= server2.pool.operation_count: + return server1 + else: + return server2 async def select_server( self, diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index d0c517f186..5fc64f92de 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -811,6 +811,14 @@ def __init__( self.__pinned_sockets: set[Connection] = set() self.ncursors = 0 self.ntxns = 0 + self.error_times: collections.deque[tuple[float, int]] = collections.deque() + + def get_error_rate(self) -> float: + with self.lock: + # Require at least 10 samples to compute an error rate. + if len(self.error_times) < 10: + return 0.0 + return sum(t[1] for t in self.error_times) / (len(self.error_times) or 1) def ready(self) -> None: # Take the lock to avoid the race condition described in PYTHON-2699. @@ -1178,7 +1186,16 @@ def checkout( serverPort=self.address[1], ) - conn = self._get_conn(checkout_started_time, handler=handler) + try: + conn = self._get_conn(checkout_started_time, handler=handler) + except Exception: + self.error_times.append((time.monotonic(), 1)) + raise + + # clear old info from error rate >10 seconds old + watermark = time.monotonic() - 10 + while self.error_times and self.error_times[0][0] < watermark: + self.error_times.popleft() duration = time.monotonic() - checkout_started_time if self.enabled_for_cmap: @@ -1198,8 +1215,10 @@ def checkout( with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + self.error_times.append((time.monotonic(), 0)) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: + self.error_times.append((time.monotonic(), 1)) # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the # connection and it is responsible for checking the connection diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 0f6592dfc0..75367b659d 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -391,10 +391,17 @@ def _select_server( if len(servers) == 1: return servers[0] server1, server2 = random.sample(servers, 2) - if server1.pool.operation_count <= server2.pool.operation_count: + error_rate1 = server1.pool.get_error_rate() + error_rate2 = server2.pool.get_error_rate() + if error_rate1 < error_rate2 and (random.random() < (error_rate2 - error_rate1)): # noqa: S311 return server1 - else: + elif error_rate1 > error_rate2 and (random.random() < (error_rate1 - error_rate2)): # noqa: S311 return server2 + else: + if server1.pool.operation_count <= server2.pool.operation_count: + return server1 + else: + return server2 def select_server( self, diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index fcf2cce0e0..3b3b43fd32 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -17,8 +17,10 @@ import asyncio import os +import random import threading from pathlib import Path +from string import ascii_lowercase from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner from test.utils import flaky @@ -26,6 +28,7 @@ from test.utils_shared import ( CMAPListener, OvertCommandListener, + delay, wait_until, ) from test.utils_spec_runner import SpecTestCreator @@ -106,17 +109,25 @@ def __init__(self, collection, iterations): self.collection = collection self.iterations = iterations self.passed = False + self.n_overload_errors = 0 def run(self): + from pymongo.errors import PyMongoError + for _ in range(self.iterations): - self.collection.find_one({}) + try: + self.collection.find_one({"$where": delay(0.025)}) + except PyMongoError as exc: + if not exc.has_error_label("SystemOverloadedError"): + raise + self.n_overload_errors += 1 self.passed = True class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_TASKS = 10 + N_TASKS = 20 tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] for task in tasks: task.start() @@ -126,7 +137,7 @@ def frequencies(self, client, listener, n_finds=10): self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_TASKS) + self.assertGreaterEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} @@ -134,6 +145,9 @@ def frequencies(self, client, listener, n_finds=10): freqs[event.connection_id] += 1 for address in freqs: freqs[address] = freqs[address] / float(len(events)) + freqs["overload_errors"] = sum(task.n_overload_errors for task in tasks) + freqs["operations"] = sum(task.iterations for task in tasks) + freqs["error_rate"] = freqs["overload_errors"] / float(freqs["operations"]) return freqs @client_context.require_failCommand_appName @@ -175,6 +189,52 @@ def test_load_balancing(self): freqs = self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + @client_context.require_failCommand_appName + @client_context.require_multiple_mongoses + def test_load_balancing_overload(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = self.rs_client( + client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # enable rate limiter + client.test.test.insert_many( + [{"str": "".join(random.choices(ascii_lowercase, k=512))} for _ in range(10)] + ) + listener.reset() + + # Mock rate limiter errors on only one mongos. + rejection_rate = 0.75 + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"activationProbability": rejection_rate}, + "data": { + "failCommands": ["find"], + "errorCode": 462, + # Intentionally omit "RetryableError" to avoid retry behavior from impacting this test. + "errorLabels": ["SystemOverloadedError"], + "appName": "loadBalancingTest", + }, + } + nodes = client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = next(iter(nodes)) + listener.reset() + with self.fail_point(delay_finds): + freqs = self.frequencies(client, listener, n_finds=200) + print(f"\nOverloaded server: {delayed_server}") + print(freqs) + self.assertAlmostEqual(freqs[delayed_server], 1 - rejection_rate, delta=0.15) + if __name__ == "__main__": unittest.main()