Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,14 @@ def __init__(
self.__pinned_sockets: set[AsyncConnection] = set()
self.ncursors = 0
self.ntxns = 0
self.error_times: collections.deque[tuple[float, int]] = collections.deque()

async def get_error_rate(self) -> float:
async with self.lock:
# Require at least 10 samples to compute an error rate.
if len(self.error_times) < 10:
return 0.0
return sum(t[1] for t in self.error_times) / (len(self.error_times) or 1)

async def ready(self) -> None:
# Take the lock to avoid the race condition described in PYTHON-2699.
Expand Down Expand Up @@ -1182,7 +1190,16 @@ async def checkout(
serverPort=self.address[1],
)

conn = await self._get_conn(checkout_started_time, handler=handler)
try:
conn = await self._get_conn(checkout_started_time, handler=handler)
except Exception:
self.error_times.append((time.monotonic(), 1))
raise

# clear old info from error rate >10 seconds old
Copy link

@baileympearson baileympearson Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely don't need to figure this out now, because we're not certain we'll move forward with this approach.

But because this might be relevant to Iris' DSI workflow: in scenarios where one node is significantly more overloaded than the other, pruning stale error measurements after connection checkout will result in a slower recovery than necessary because server selection will continue to avoid the overloaded server, even after it has started to recover. The higher the error rate on one node, the longer it will take on average to reach the error rate measurement pruning logic.

And as the error rate approaches 100% on one node and the other stays healthy, the likelihood of ever selecting this server decreases and approaches 0 (if the error rate ever hit 100%, we'd never select a different server and have no chance of ever selecting this server again).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this logic would need to be moved into server selection to avoid that pitfall. Also there's likely a more efficient algorithm to track the error rate. One way would be having 10 buckets which track the errors and total operations for each of the last 10 seconds. Then rate can be a simple sum of <=10 buckets rather than an unbounded list.

Something roughly like this (ignoring the phasing out old data):

self.error_stats = {}
...
current_second = int(time.monotonic())
bucket = self.error_stats.setdefault(current_second, {"errors": 0, "requests": 0})

if error:
    bucket["errors"] += 1
bucket["requests"] += 1

Then:

    async def get_error_rate(self) -> float:
        current_second = int(time.monotonic())
        errors = 0
        requests = 0
        async with self.lock:
            for sec in self.error_stats:
                if sec < current_second - 10:
                    continue
                bucket = self.error_stats[sec]
                errors += bucket["errors"]
                requests += bucket["requests"]
        # Require at least 10 samples to compute an error rate.
        if requests < 10:
            return 0.0
        return float(errors) / requests

watermark = time.monotonic() - 10
while self.error_times and self.error_times[0][0] < watermark:
self.error_times.popleft()

duration = time.monotonic() - checkout_started_time
if self.enabled_for_cmap:
Expand All @@ -1202,8 +1219,10 @@ async def checkout(
async with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
self.error_times.append((time.monotonic(), 0))
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.error_times.append((time.monotonic(), 1))
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
# connection and it is responsible for checking the connection
Expand Down
11 changes: 9 additions & 2 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,17 @@ async def _select_server(
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
if server1.pool.operation_count <= server2.pool.operation_count:
error_rate1 = await server1.pool.get_error_rate()
error_rate2 = await server2.pool.get_error_rate()
if error_rate1 < error_rate2 and (random.random() < (error_rate2 - error_rate1)): # noqa: S311

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of random.random() < (error_rate2 - error_rate1)?

Copy link
Owner Author

@ShaneHarvey ShaneHarvey Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The random choice scaling with the difference attempts to more evenly balance the load. Imagine a scenario where there are only 2 servers A and B. A's error rate is 0% and B's is 1%. A naive approach is to always pick the server with the lower error rate, but that would mean we'd never choose server B. That's a problem since it means as soon as any errors occur on one server, all requests to it will be rerouted.

Introducing randomness here means we instead bias only 1% of requests away from server B which is clearly better than 100% of requests.

Another example is A's error rate is 25% and B's is 75%. We know we want to bias some operations to server A since it has a higher chance of success but again not all of the requests. Scaling with the difference in error rate seemed to perform well in these scenarios (eg we observe fewer errors overall).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to consider it: when the error rate for 2 servers is around the same we want to continue using operationCount based selection. As the difference in error rates increases we want to route more and more requests away from that server.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I should have read the PR description, sorry 😅

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot resolve this comment for some reason. (permissions?)

feel free to resolve.

return server1
else:
elif error_rate1 > error_rate2 and (random.random() < (error_rate1 - error_rate2)): # noqa: S311
return server2
else:
if server1.pool.operation_count <= server2.pool.operation_count:
return server1
else:
return server2

async def select_server(
self,
Expand Down
21 changes: 20 additions & 1 deletion pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,14 @@ def __init__(
self.__pinned_sockets: set[Connection] = set()
self.ncursors = 0
self.ntxns = 0
self.error_times: collections.deque[tuple[float, int]] = collections.deque()

def get_error_rate(self) -> float:
with self.lock:
# Require at least 10 samples to compute an error rate.
if len(self.error_times) < 10:
return 0.0
return sum(t[1] for t in self.error_times) / (len(self.error_times) or 1)

def ready(self) -> None:
# Take the lock to avoid the race condition described in PYTHON-2699.
Expand Down Expand Up @@ -1178,7 +1186,16 @@ def checkout(
serverPort=self.address[1],
)

conn = self._get_conn(checkout_started_time, handler=handler)
try:
conn = self._get_conn(checkout_started_time, handler=handler)
except Exception:
self.error_times.append((time.monotonic(), 1))
raise

# clear old info from error rate >10 seconds old
watermark = time.monotonic() - 10
while self.error_times and self.error_times[0][0] < watermark:
self.error_times.popleft()

duration = time.monotonic() - checkout_started_time
if self.enabled_for_cmap:
Expand All @@ -1198,8 +1215,10 @@ def checkout(
with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
self.error_times.append((time.monotonic(), 0))
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.error_times.append((time.monotonic(), 1))
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
# connection and it is responsible for checking the connection
Expand Down
11 changes: 9 additions & 2 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,17 @@ def _select_server(
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
if server1.pool.operation_count <= server2.pool.operation_count:
error_rate1 = server1.pool.get_error_rate()
error_rate2 = server2.pool.get_error_rate()
if error_rate1 < error_rate2 and (random.random() < (error_rate2 - error_rate1)): # noqa: S311
return server1
else:
elif error_rate1 > error_rate2 and (random.random() < (error_rate1 - error_rate2)): # noqa: S311
return server2
else:
if server1.pool.operation_count <= server2.pool.operation_count:
return server1
else:
return server2

def select_server(
self,
Expand Down
66 changes: 63 additions & 3 deletions test/test_server_selection_in_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

import asyncio
import os
import random
import threading
from pathlib import Path
from string import ascii_lowercase
from test import IntegrationTest, client_context, unittest
from test.helpers import ConcurrentRunner
from test.utils import flaky
from test.utils_selection_tests import create_topology
from test.utils_shared import (
CMAPListener,
OvertCommandListener,
delay,
wait_until,
)
from test.utils_spec_runner import SpecTestCreator
Expand Down Expand Up @@ -106,17 +109,25 @@ def __init__(self, collection, iterations):
self.collection = collection
self.iterations = iterations
self.passed = False
self.n_overload_errors = 0

def run(self):
from pymongo.errors import PyMongoError

for _ in range(self.iterations):
self.collection.find_one({})
try:
self.collection.find_one({"$where": delay(0.025)})
except PyMongoError as exc:
if not exc.has_error_label("SystemOverloadedError"):
raise
self.n_overload_errors += 1
self.passed = True


class TestProse(IntegrationTest):
def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_TASKS = 10
N_TASKS = 20
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
task.start()
Expand All @@ -126,14 +137,17 @@ def frequencies(self, client, listener, n_finds=10):
self.assertTrue(task.passed)

events = listener.started_events
self.assertEqual(len(events), n_finds * N_TASKS)
self.assertGreaterEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}
for event in events:
freqs[event.connection_id] += 1
for address in freqs:
freqs[address] = freqs[address] / float(len(events))
freqs["overload_errors"] = sum(task.n_overload_errors for task in tasks)
freqs["operations"] = sum(task.iterations for task in tasks)
freqs["error_rate"] = freqs["overload_errors"] / float(freqs["operations"])
return freqs

@client_context.require_failCommand_appName
Expand Down Expand Up @@ -175,6 +189,52 @@ def test_load_balancing(self):
freqs = self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)

@client_context.require_failCommand_appName
@client_context.require_multiple_mongoses
def test_load_balancing_overload(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = self.rs_client(
client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
# Wait for both pools to be populated.
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
# enable rate limiter
client.test.test.insert_many(
[{"str": "".join(random.choices(ascii_lowercase, k=512))} for _ in range(10)]
)
listener.reset()

# Mock rate limiter errors on only one mongos.
rejection_rate = 0.75
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"activationProbability": rejection_rate},
"data": {
"failCommands": ["find"],
"errorCode": 462,
# Intentionally omit "RetryableError" to avoid retry behavior from impacting this test.
"errorLabels": ["SystemOverloadedError"],
"appName": "loadBalancingTest",
},
}
nodes = client_context.client.nodes
self.assertEqual(len(nodes), 1)
delayed_server = next(iter(nodes))
listener.reset()
with self.fail_point(delay_finds):
freqs = self.frequencies(client, listener, n_finds=200)
print(f"\nOverloaded server: {delayed_server}")
print(freqs)
self.assertAlmostEqual(freqs[delayed_server], 1 - rejection_rate, delta=0.15)


if __name__ == "__main__":
unittest.main()