From f9005144333b48eb140c4ffed81eff460d5b1d84 Mon Sep 17 00:00:00 2001 From: Hao Liu <44379968+TheRealHaoLiu@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:46:05 -0400 Subject: [PATCH 1/3] Add AnsibleBaseCsrfViewMiddleware and update SessionAuthentication to use get_setting - Add AnsibleBaseCsrfViewMiddleware that reads CSRF_TRUSTED_ORIGINS using ansible_base.lib.utils.settings.get_setting instead of directly from Django settings - Override all three cached properties (csrf_trusted_origins_hosts, allowed_origins_exact, allowed_origin_subdomains) to use get_setting for dynamic configuration - Add AnsibleBaseCSRFCheck class that inherits from AnsibleBaseCsrfViewMiddleware - Modify SessionAuthentication.enforce_csrf to use AnsibleBaseCSRFCheck instead of Django's default CSRFCheck - Add comprehensive tests for both middleware and session authentication CSRF functionality - Enables CSRF_TRUSTED_ORIGINS to be dynamically loaded from various sources via ANSIBLE_BASE_SETTINGS_FUNCTION while maintaining backward compatibility --- ansible_base/authentication/middleware.py | 48 ++++++ ansible_base/authentication/session.py | 41 ++++- .../tests/authentication/test_middleware.py | 146 +++++++++++++++++- 3 files changed, 233 insertions(+), 2 deletions(-) diff --git a/ansible_base/authentication/middleware.py b/ansible_base/authentication/middleware.py index 33d011b1e..dad1e21e7 100644 --- a/ansible_base/authentication/middleware.py +++ b/ansible_base/authentication/middleware.py @@ -1,11 +1,16 @@ import logging +from collections import defaultdict +from urllib.parse import urlsplit from django.contrib.auth import BACKEND_SESSION_KEY from django.core.exceptions import ImproperlyConfigured +from django.middleware.csrf import CsrfViewMiddleware from django.utils.deprecation import MiddlewareMixin +from django.utils.functional import cached_property from social_django.middleware import SocialAuthExceptionMiddleware from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins +from ansible_base.lib.utils.settings import get_setting logger = logging.getLogger('ansible_base.authentication.middleware') @@ -51,3 +56,46 @@ def get_redirect_uri(self, request, exception): backend_name = getattr(backend, "name", "unknown-backend") logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}") return error_url + + +class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware): + """ + CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using + ansible_base.lib.utils.settings.get_setting instead of directly from + Django settings. + + This allows the setting to be dynamically loaded from various sources + as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting. + + Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS + to use get_setting instead. + """ + + @cached_property + def csrf_trusted_origins_hosts(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + return [urlsplit(origin).netloc.lstrip("*") for origin in csrf_trusted_origins] + + @cached_property + def allowed_origins_exact(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + return {origin for origin in csrf_trusted_origins if "*" not in origin} + + @cached_property + def allowed_origin_subdomains(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + A mapping of allowed schemes to list of allowed netlocs, where all + subdomains of the netloc are allowed. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + allowed_origin_subdomains = defaultdict(list) + for parsed in (urlsplit(origin) for origin in csrf_trusted_origins if "*" in origin): + allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*")) + return allowed_origin_subdomains diff --git a/ansible_base/authentication/session.py b/ansible_base/authentication/session.py index a9eb63f79..b867cc211 100644 --- a/ansible_base/authentication/session.py +++ b/ansible_base/authentication/session.py @@ -1,10 +1,49 @@ -from rest_framework import authentication +from rest_framework import authentication, exceptions + +from ansible_base.authentication.middleware import ( + AnsibleBaseCsrfViewMiddleware, +) + + +class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware): + """ + Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware + instead of Django's CsrfViewMiddleware for CSRF validation. + + This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting + instead of directly from Django settings. + """ + + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason class SessionAuthentication(authentication.SessionAuthentication): """ This class allows us to fail with a 401 if the user is not authenticated. + + Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's + default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read + dynamically using get_setting. """ def authenticate_header(self, request): return "Session" + + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication using + AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware. + """ + + def dummy_get_response(request): # pragma: no cover + return None + + check = AnsibleBaseCSRFCheck(dummy_get_response) + # populates request.META['CSRF_COOKIE'], which is used in process_view() + check.process_request(request) + reason = check.process_view(request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) diff --git a/test_app/tests/authentication/test_middleware.py b/test_app/tests/authentication/test_middleware.py index 627bee931..932c99d06 100644 --- a/test_app/tests/authentication/test_middleware.py +++ b/test_app/tests/authentication/test_middleware.py @@ -1,7 +1,16 @@ +from unittest.mock import patch + from django.conf import settings from social_core.exceptions import AuthException -from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware +from ansible_base.authentication.middleware import ( + AnsibleBaseCsrfViewMiddleware, + SocialExceptionHandlerMiddleware, +) +from ansible_base.authentication.session import ( + AnsibleBaseCSRFCheck, + SessionAuthentication, +) def test_social_exception_handler_mw(): @@ -21,3 +30,138 @@ def __init__(self): mw = SocialExceptionHandlerMiddleware(None) url = mw.get_redirect_uri(Request(), AuthException("test")) assert url == "/?auth_failed" + + +def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts(): + """Test that csrf_trusted_origins_hosts uses get_setting.""" + test_origins = ['https://example.com', 'https://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.csrf_trusted_origins_hosts + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should strip * from netloc + assert result == ['example.com', '.test.com'] + + +def test_ansible_base_csrf_view_middleware_allowed_origins_exact(): + """Test that allowed_origins_exact uses get_setting.""" + test_origins = ['https://example.com', 'https://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.allowed_origins_exact + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should only include origins without * + assert result == {'https://example.com'} + + +def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains(): + """Test that allowed_origin_subdomains uses get_setting.""" + test_origins = ['https://*.example.com', 'http://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.allowed_origin_subdomains + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should group by scheme and strip * + expected = {'https': ['.example.com'], 'http': ['.test.com']} + assert dict(result) == expected + + +def test_ansible_base_csrf_view_middleware_default_value(): + """Test that middleware returns empty/default values when setting is empty.""" + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = [] + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + + # Test all three properties + assert middleware.csrf_trusted_origins_hosts == [] + assert middleware.allowed_origins_exact == set() + assert dict(middleware.allowed_origin_subdomains) == {} + + # get_setting should be called three times (once for each property) + assert mock_get_setting.call_count == 3 + + +def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware(): + """Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware.""" + csrf_check = AnsibleBaseCSRFCheck(lambda request: None) + assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware) + + +def test_ansible_base_csrf_check_reject_method(): + """Test that AnsibleBaseCSRFCheck._reject returns the reason.""" + csrf_check = AnsibleBaseCSRFCheck(lambda request: None) + reason = "Test CSRF failure reason" + result = csrf_check._reject(None, reason) + assert result == reason + + +def test_session_authentication_uses_ansible_base_csrf_check(): + """Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation.""" + from unittest.mock import Mock + + # Create a mock request with an authenticated user + mock_request = Mock() + mock_request._request = Mock() + mock_request._request.user = Mock() + mock_request._request.user.is_active = True + + # Mock the AnsibleBaseCSRFCheck to track its usage + with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: + mock_csrf_check = Mock() + mock_csrf_check.process_request.return_value = None + mock_csrf_check.process_view.return_value = None # No CSRF error + mock_csrf_check_class.return_value = mock_csrf_check + + # Create SessionAuthentication instance and call enforce_csrf + session_auth = SessionAuthentication() + session_auth.enforce_csrf(mock_request) + + # Verify AnsibleBaseCSRFCheck was instantiated + mock_csrf_check_class.assert_called_once() + + # Verify process_request and process_view were called + mock_csrf_check.process_request.assert_called_once_with(mock_request) + mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {}) + + +def test_session_authentication_csrf_failure_raises_permission_denied(): + """Test that SessionAuthentication raises PermissionDenied when CSRF fails.""" + from unittest.mock import Mock + + from rest_framework.exceptions import PermissionDenied + + # Create a mock request with an authenticated user + mock_request = Mock() + mock_request._request = Mock() + mock_request._request.user = Mock() + mock_request._request.user.is_active = True + + # Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason + with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: + mock_csrf_check = Mock() + mock_csrf_check.process_request.return_value = None + mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error + mock_csrf_check_class.return_value = mock_csrf_check + + # Create SessionAuthentication instance and call enforce_csrf + session_auth = SessionAuthentication() + + # Should raise PermissionDenied with the CSRF failure reason + try: + session_auth.enforce_csrf(mock_request) + assert False, "Expected PermissionDenied to be raised" + except PermissionDenied as e: + assert "CSRF Failed: CSRF token missing" in str(e) From b3d6883cc6e3ddd6b5169e940e9fbe4cd9856c7a Mon Sep 17 00:00:00 2001 From: Brennan Paciorek Date: Mon, 7 Jul 2025 18:20:13 -0400 Subject: [PATCH 2/3] AAP-47499 attempt a different approach to this problem --- ansible_base/authentication/middleware.py | 48 ------ ansible_base/authentication/session.py | 40 ++--- .../tests/authentication/test_middleware.py | 142 ------------------ 3 files changed, 11 insertions(+), 219 deletions(-) diff --git a/ansible_base/authentication/middleware.py b/ansible_base/authentication/middleware.py index dad1e21e7..33d011b1e 100644 --- a/ansible_base/authentication/middleware.py +++ b/ansible_base/authentication/middleware.py @@ -1,16 +1,11 @@ import logging -from collections import defaultdict -from urllib.parse import urlsplit from django.contrib.auth import BACKEND_SESSION_KEY from django.core.exceptions import ImproperlyConfigured -from django.middleware.csrf import CsrfViewMiddleware from django.utils.deprecation import MiddlewareMixin -from django.utils.functional import cached_property from social_django.middleware import SocialAuthExceptionMiddleware from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins -from ansible_base.lib.utils.settings import get_setting logger = logging.getLogger('ansible_base.authentication.middleware') @@ -56,46 +51,3 @@ def get_redirect_uri(self, request, exception): backend_name = getattr(backend, "name", "unknown-backend") logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}") return error_url - - -class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware): - """ - CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using - ansible_base.lib.utils.settings.get_setting instead of directly from - Django settings. - - This allows the setting to be dynamically loaded from various sources - as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting. - - Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS - to use get_setting instead. - """ - - @cached_property - def csrf_trusted_origins_hosts(self): - """ - Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. - """ - csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) - return [urlsplit(origin).netloc.lstrip("*") for origin in csrf_trusted_origins] - - @cached_property - def allowed_origins_exact(self): - """ - Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. - """ - csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) - return {origin for origin in csrf_trusted_origins if "*" not in origin} - - @cached_property - def allowed_origin_subdomains(self): - """ - Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. - A mapping of allowed schemes to list of allowed netlocs, where all - subdomains of the netloc are allowed. - """ - csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) - allowed_origin_subdomains = defaultdict(list) - for parsed in (urlsplit(origin) for origin in csrf_trusted_origins if "*" in origin): - allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*")) - return allowed_origin_subdomains diff --git a/ansible_base/authentication/session.py b/ansible_base/authentication/session.py index b867cc211..06d311c25 100644 --- a/ansible_base/authentication/session.py +++ b/ansible_base/authentication/session.py @@ -1,22 +1,7 @@ -from rest_framework import authentication, exceptions +from django.conf import settings +from rest_framework import authentication -from ansible_base.authentication.middleware import ( - AnsibleBaseCsrfViewMiddleware, -) - - -class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware): - """ - Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware - instead of Django's CsrfViewMiddleware for CSRF validation. - - This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting - instead of directly from Django settings. - """ - - def _reject(self, request, reason): - # Return the failure reason instead of an HttpResponse - return reason +from ansible_base.lib.utils.settings import get_setting class SessionAuthentication(authentication.SessionAuthentication): @@ -36,14 +21,11 @@ def enforce_csrf(self, request): Enforce CSRF validation for session based authentication using AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware. """ - - def dummy_get_response(request): # pragma: no cover - return None - - check = AnsibleBaseCSRFCheck(dummy_get_response) - # populates request.META['CSRF_COOKIE'], which is used in process_view() - check.process_request(request) - reason = check.process_view(request, None, (), {}) - if reason: - # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + csrf_trusted_origins = settings.CSRF_TRUSTED_ORIGINS + try: + # Temporarily patch the setting + settings.CSRF_TRUSTED_ORIGINS = get_setting("CSRF_TRUSTED_ORIGINS", csrf_trusted_origins) + return super().enforce_csrf(request) + finally: + # Revert setting after this is done + settings.CSRF_TRUSTED_ORIGINS = csrf_trusted_origins diff --git a/test_app/tests/authentication/test_middleware.py b/test_app/tests/authentication/test_middleware.py index 932c99d06..744afc3e4 100644 --- a/test_app/tests/authentication/test_middleware.py +++ b/test_app/tests/authentication/test_middleware.py @@ -1,16 +1,9 @@ -from unittest.mock import patch - from django.conf import settings from social_core.exceptions import AuthException from ansible_base.authentication.middleware import ( - AnsibleBaseCsrfViewMiddleware, SocialExceptionHandlerMiddleware, ) -from ansible_base.authentication.session import ( - AnsibleBaseCSRFCheck, - SessionAuthentication, -) def test_social_exception_handler_mw(): @@ -30,138 +23,3 @@ def __init__(self): mw = SocialExceptionHandlerMiddleware(None) url = mw.get_redirect_uri(Request(), AuthException("test")) assert url == "/?auth_failed" - - -def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts(): - """Test that csrf_trusted_origins_hosts uses get_setting.""" - test_origins = ['https://example.com', 'https://*.test.com'] - - with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: - mock_get_setting.return_value = test_origins - - middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) - result = middleware.csrf_trusted_origins_hosts - - mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) - # Should strip * from netloc - assert result == ['example.com', '.test.com'] - - -def test_ansible_base_csrf_view_middleware_allowed_origins_exact(): - """Test that allowed_origins_exact uses get_setting.""" - test_origins = ['https://example.com', 'https://*.test.com'] - - with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: - mock_get_setting.return_value = test_origins - - middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) - result = middleware.allowed_origins_exact - - mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) - # Should only include origins without * - assert result == {'https://example.com'} - - -def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains(): - """Test that allowed_origin_subdomains uses get_setting.""" - test_origins = ['https://*.example.com', 'http://*.test.com'] - - with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: - mock_get_setting.return_value = test_origins - - middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) - result = middleware.allowed_origin_subdomains - - mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) - # Should group by scheme and strip * - expected = {'https': ['.example.com'], 'http': ['.test.com']} - assert dict(result) == expected - - -def test_ansible_base_csrf_view_middleware_default_value(): - """Test that middleware returns empty/default values when setting is empty.""" - with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: - mock_get_setting.return_value = [] - - middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) - - # Test all three properties - assert middleware.csrf_trusted_origins_hosts == [] - assert middleware.allowed_origins_exact == set() - assert dict(middleware.allowed_origin_subdomains) == {} - - # get_setting should be called three times (once for each property) - assert mock_get_setting.call_count == 3 - - -def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware(): - """Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware.""" - csrf_check = AnsibleBaseCSRFCheck(lambda request: None) - assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware) - - -def test_ansible_base_csrf_check_reject_method(): - """Test that AnsibleBaseCSRFCheck._reject returns the reason.""" - csrf_check = AnsibleBaseCSRFCheck(lambda request: None) - reason = "Test CSRF failure reason" - result = csrf_check._reject(None, reason) - assert result == reason - - -def test_session_authentication_uses_ansible_base_csrf_check(): - """Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation.""" - from unittest.mock import Mock - - # Create a mock request with an authenticated user - mock_request = Mock() - mock_request._request = Mock() - mock_request._request.user = Mock() - mock_request._request.user.is_active = True - - # Mock the AnsibleBaseCSRFCheck to track its usage - with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: - mock_csrf_check = Mock() - mock_csrf_check.process_request.return_value = None - mock_csrf_check.process_view.return_value = None # No CSRF error - mock_csrf_check_class.return_value = mock_csrf_check - - # Create SessionAuthentication instance and call enforce_csrf - session_auth = SessionAuthentication() - session_auth.enforce_csrf(mock_request) - - # Verify AnsibleBaseCSRFCheck was instantiated - mock_csrf_check_class.assert_called_once() - - # Verify process_request and process_view were called - mock_csrf_check.process_request.assert_called_once_with(mock_request) - mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {}) - - -def test_session_authentication_csrf_failure_raises_permission_denied(): - """Test that SessionAuthentication raises PermissionDenied when CSRF fails.""" - from unittest.mock import Mock - - from rest_framework.exceptions import PermissionDenied - - # Create a mock request with an authenticated user - mock_request = Mock() - mock_request._request = Mock() - mock_request._request.user = Mock() - mock_request._request.user.is_active = True - - # Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason - with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: - mock_csrf_check = Mock() - mock_csrf_check.process_request.return_value = None - mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error - mock_csrf_check_class.return_value = mock_csrf_check - - # Create SessionAuthentication instance and call enforce_csrf - session_auth = SessionAuthentication() - - # Should raise PermissionDenied with the CSRF failure reason - try: - session_auth.enforce_csrf(mock_request) - assert False, "Expected PermissionDenied to be raised" - except PermissionDenied as e: - assert "CSRF Failed: CSRF token missing" in str(e) From f1ade289e549d00f5a4ab3836945a010069846cc Mon Sep 17 00:00:00 2001 From: Brennan Paciorek Date: Tue, 8 Jul 2025 10:58:15 -0400 Subject: [PATCH 3/3] fix comments, move setings patch to decorator --- ansible_base/authentication/session.py | 18 +++++------------- ansible_base/lib/utils/settings.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/ansible_base/authentication/session.py b/ansible_base/authentication/session.py index 06d311c25..24ffd89a9 100644 --- a/ansible_base/authentication/session.py +++ b/ansible_base/authentication/session.py @@ -1,31 +1,23 @@ -from django.conf import settings from rest_framework import authentication -from ansible_base.lib.utils.settings import get_setting +from ansible_base.lib.utils.settings import replace_trusted_origins class SessionAuthentication(authentication.SessionAuthentication): """ This class allows us to fail with a 401 if the user is not authenticated. - Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's - default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read - dynamically using get_setting. + Allows CSRF_TRUSTED_ORIGINS to be read dynamically using get_setting. + Reverting the value of CSRF_TRUSTED_ORIGINS afterwards. """ def authenticate_header(self, request): return "Session" + @replace_trusted_origins def enforce_csrf(self, request): """ Enforce CSRF validation for session based authentication using AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware. """ - csrf_trusted_origins = settings.CSRF_TRUSTED_ORIGINS - try: - # Temporarily patch the setting - settings.CSRF_TRUSTED_ORIGINS = get_setting("CSRF_TRUSTED_ORIGINS", csrf_trusted_origins) - return super().enforce_csrf(request) - finally: - # Revert setting after this is done - settings.CSRF_TRUSTED_ORIGINS = csrf_trusted_origins + return super().enforce_csrf(request) diff --git a/ansible_base/lib/utils/settings.py b/ansible_base/lib/utils/settings.py index e71bd96ee..b79d8a93e 100644 --- a/ansible_base/lib/utils/settings.py +++ b/ansible_base/lib/utils/settings.py @@ -49,6 +49,24 @@ def get_function_from_setting(setting_name: str) -> Any: return None +def replace_trusted_origins(func): + """Decorator for patching the CSRF_TRUSTED_ORIGINS django setting using the potentially different value in get_setting for the duration of a + function call + """ + + def override_setting(*args, **kwargs): + csrf_trusted_origins = settings.CSRF_TRUSTED_ORIGINS + try: + # Temporarily patch the setting + settings.CSRF_TRUSTED_ORIGINS = get_setting("CSRF_TRUSTED_ORIGINS", csrf_trusted_origins) + return func(*args, **kwargs) + finally: + # Revert setting after this is done + settings.CSRF_TRUSTED_ORIGINS = csrf_trusted_origins + + return override_setting + + def get_from_import(module_name, attr): "Thin wrapper around importlib.import_module, mostly exists so that we can safely mock this in tests" module = importlib.import_module(module_name, package=attr)