Skip to content
16 changes: 11 additions & 5 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,7 @@ def _get_metadata_ip_root(use_mtls: bool):

# Timeout in seconds to wait for the GCE metadata server when detecting the
# GCE environment.
try:
_METADATA_DEFAULT_TIMEOUT = int(os.getenv("GCE_METADATA_TIMEOUT", 3))
except ValueError: # pragma: NO COVER
_METADATA_DEFAULT_TIMEOUT = 3

_METADATA_PING_DEFAULT_TIMEOUT = 3
# Detect GCE Residency
_GOOGLE = "Google"
_GCE_PRODUCT_NAME_FILE = "/sys/class/dmi/id/product_name"
Expand Down Expand Up @@ -191,6 +187,16 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
# could lead to false negatives in the event that we are on GCE, but
# the metadata resolution was particularly slow. The latter case is
# "unlikely".

if timeout is None:
try:
timeout = float(os.getenv(
environment_vars.GCE_METADATA_TIMEOUT,
str(_METADATA_PING_DEFAULT_TIMEOUT)))
except ValueError:
timeout = _METADATA_PING_DEFAULT_TIMEOUT

retries = 0
headers = _METADATA_HEADERS.copy()
headers[metrics.API_CLIENT_HEADER] = metrics.mds_ping()

Expand Down
3 changes: 3 additions & 0 deletions google/auth/environment_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
Used to distinguish between GAE gen1 and GAE gen2+.
"""

GCE_METADATA_TIMEOUT = "GCE_METADATA_TIMEOUT"
"""Environment variable for setting timeouts in seconds for metadata queries."""

# AWS environment variables used with AWS workload identity pools to retrieve
# AWS security credentials and the AWS region needed to create a serialized
# signed requests to the AWS STS GetCalledIdentity API that can be exchanged
Expand Down
45 changes: 41 additions & 4 deletions tests/compute_engine/test__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,46 @@ def test_ping_success(mock_metrics_header_value):

request.assert_called_once_with(
method="GET",
url="http://169.254.169.254",
headers=MDS_PING_REQUEST_HEADER,
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADER,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT,
)

@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
def test_ping_success_with_gce_metadata_timeout(mock_metrics_header_value):
request = make_request("", headers=_metadata._METADATA_HEADERS)
gce_metadata_timeout = .5
os.environ[
environment_vars.GCE_METADATA_TIMEOUT] = str(gce_metadata_timeout)

try:
assert _metadata.ping(request)
finally:
del os.environ[environment_vars.GCE_METADATA_TIMEOUT]

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADER,
timeout=gce_metadata_timeout,
)

@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
def test_ping_success_with_invalid_gce_metadata_timeout(mock_metrics_header_value):
request = make_request("", headers=_metadata._METADATA_HEADERS)
os.environ[
environment_vars.GCE_METADATA_TIMEOUT] = "Not a valid float value!"

try:
assert _metadata.ping(request)
finally:
del os.environ[environment_vars.GCE_METADATA_TIMEOUT]

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_IP_ROOT,
headers=_metadata._METADATA_HEADERS,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT, # Fallback value.
)


Expand Down Expand Up @@ -183,7 +220,7 @@ def test_ping_success_custom_root(mock_metrics_header_value):
method="GET",
url="http://" + fake_ip,
headers=MDS_PING_REQUEST_HEADER,
timeout=_metadata._METADATA_DEFAULT_TIMEOUT,
timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT,
)


Expand Down
Loading