Skip to content

Commit 110e535

Browse files
committed
address feedback
1 parent cc457da commit 110e535

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

google/auth/compute_engine/_metadata.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,20 @@
5555
def _validate_gce_mds_configured_environment():
5656
"""Validates the GCE metadata server environment configuration for mTLS.
5757
58-
mTLS is only supported when connecting to the default metadata host.
58+
mTLS is only supported when connecting to the default metadata server hosts.
5959
If we are in strict mode (which requires mTLS), ensure that the metadata host
60-
has not been overridden (which means mTLS will fail).
60+
has not been overridden to a custom value (which means mTLS will fail).
6161
6262
Raises:
6363
google.auth.exceptions.MutualTLSChannelError: if the environment
6464
configuration is invalid for mTLS.
6565
"""
6666
mode = _mtls._parse_mds_mode()
6767
if mode == _mtls.MdsMtlsMode.STRICT:
68-
if _GCE_METADATA_HOST != _GCE_DEFAULT_HOST:
69-
# mTLS is only supported when connecting to the default metadata host.
70-
# Raise an exception if we are in strict mode (which requires mTLS)
71-
# but the metadata host has been overridden. (which means mTLS will fail)
68+
# mTLS is only supported when connecting to the default metadata host.
69+
# Raise an exception if we are in strict mode (which requires mTLS)
70+
# but the metadata host has been overridden to a custom MDS. (which means mTLS will fail)
71+
if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS:
7272
raise exceptions.MutualTLSChannelError(
7373
"Mutual TLS is required, but the metadata host has been overridden. "
7474
"mTLS is only supported when connecting to the default metadata host."
@@ -143,7 +143,7 @@ def detect_gce_residency_linux():
143143
return content.startswith(_GOOGLE)
144144

145145

146-
def _prepare_request_for_mds(request, use_mtls=False):
146+
def _prepare_request_for_mds(request, use_mtls=False) -> None:
147147
"""Prepares a request for the metadata server.
148148
149149
This will check if mTLS should be used and mount the mTLS adapter if needed.
@@ -158,15 +158,16 @@ def _prepare_request_for_mds(request, use_mtls=False):
158158
If mTLS is enabled, the request will have the mTLS adapter mounted.
159159
Otherwise, the original request will be returned unchanged.
160160
"""
161-
if not use_mtls:
162-
return request
161+
# Only modify the request if mTLS is enabled.
162+
if use_mtls:
163+
# Ensure the request has a session to mount the adapter to.
164+
if not request.session:
165+
request.session = requests.Session()
163166

164-
adapter = _mtls.MdsMtlsAdapter()
165-
if not request.session:
166-
request.session = requests.Session()
167-
for host in _GCE_DEFAULT_MDS_HOSTS:
168-
request.session.mount(f"https://{host}/", adapter)
169-
return request
167+
adapter = _mtls.MdsMtlsAdapter()
168+
# Mount the adapter for all default GCE metadata hosts.
169+
for host in _GCE_DEFAULT_MDS_HOSTS:
170+
request.session.mount(f"https://{host}/", adapter)
170171

171172

172173
def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
@@ -183,7 +184,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3):
183184
bool: True if the metadata server is reachable, False otherwise.
184185
"""
185186
use_mtls = _mtls.should_use_mds_mtls()
186-
request = _prepare_request_for_mds(request, use_mtls=use_mtls)
187+
_prepare_request_for_mds(request, use_mtls=use_mtls)
187188
# NOTE: The explicit ``timeout`` is a workaround. The underlying
188189
# issue is that resolving an unknown host on some networks will take
189190
# 20-30 seconds; making this timeout short fixes the issue, but
@@ -270,14 +271,14 @@ def get(
270271
use_mtls = _mtls.should_use_mds_mtls()
271272
# Prepare the request object for mTLS if needed.
272273
# This will create a new request object with the mTLS session.
273-
request = _prepare_request_for_mds(request, use_mtls=use_mtls)
274+
_prepare_request_for_mds(request, use_mtls=use_mtls)
274275

275276
if root is None:
276277
root = _get_metadata_root(use_mtls)
277278

278279
# mTLS is only supported when connecting to the default metadata host.
279280
# If we are in strict mode (which requires mTLS), ensure that the metadata host
280-
# has not been overridden (which means mTLS will fail).
281+
# has not been overridden to a non-default host value (which means mTLS will fail).
281282
_validate_gce_mds_configured_environment()
282283

283284
base_url = urljoin(root, path)

tests/compute_engine/test__metadata.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,8 @@ def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter):
722722

723723
def test__prepare_request_for_mds_no_mtls():
724724
request = mock.Mock()
725-
new_request = _metadata._prepare_request_for_mds(request, use_mtls=False)
726-
assert new_request is request
725+
_metadata._prepare_request_for_mds(request, use_mtls=False)
726+
request.session.mount.assert_not_called()
727727

728728

729729
@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE)
@@ -776,9 +776,11 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter):
776776
"mds_mode, metadata_host, expect_exception",
777777
[
778778
(_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False),
779+
(_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False),
779780
(_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True),
780781
(_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False),
781782
(_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False),
783+
(_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False),
782784
],
783785
)
784786
@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode")
@@ -801,11 +803,10 @@ def test_validate_gce_mds_configured_environment(
801803
def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter):
802804
mock_session = mock.create_autospec(requests.Session)
803805
request = google_auth_requests.Request(mock_session)
804-
new_request = _metadata._prepare_request_for_mds(request, use_mtls=True)
806+
_metadata._prepare_request_for_mds(request, use_mtls=True)
805807

806808
mock_mds_mtls_adapter.assert_called_once()
807809
assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS)
808-
assert new_request is request
809810

810811

811812
@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter")
@@ -815,11 +816,8 @@ def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter):
815816
request.session = None
816817

817818
with mock.patch("requests.Session") as mock_session_class:
818-
new_request = _metadata._prepare_request_for_mds(request, use_mtls=True)
819+
_metadata._prepare_request_for_mds(request, use_mtls=True)
819820

820821
mock_session_class.assert_called_once()
821822
mock_mds_mtls_adapter.assert_called_once()
822-
assert new_request.session.mount.call_count == len(
823-
_metadata._GCE_DEFAULT_MDS_HOSTS
824-
)
825-
assert new_request is request
823+
assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS)

0 commit comments

Comments
 (0)