1+ from unittest .mock import patch
2+
13from django .conf import settings
24from social_core .exceptions import AuthException
35
4- from ansible_base .authentication .middleware import SocialExceptionHandlerMiddleware
6+ from ansible_base .authentication .middleware import (
7+ AnsibleBaseCsrfViewMiddleware ,
8+ SocialExceptionHandlerMiddleware ,
9+ )
10+ from ansible_base .authentication .session import (
11+ AnsibleBaseCSRFCheck ,
12+ SessionAuthentication ,
13+ )
514
615
716def test_social_exception_handler_mw ():
@@ -21,3 +30,137 @@ def __init__(self):
2130 mw = SocialExceptionHandlerMiddleware (None )
2231 url = mw .get_redirect_uri (Request (), AuthException ("test" ))
2332 assert url == "/?auth_failed"
33+
34+
35+ def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts ():
36+ """Test that csrf_trusted_origins_hosts uses get_setting."""
37+ test_origins = ['https://example.com' , 'https://*.test.com' ]
38+
39+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
40+ mock_get_setting .return_value = test_origins
41+
42+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
43+ result = middleware .csrf_trusted_origins_hosts
44+
45+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
46+ # Should strip * from netloc
47+ assert result == ['example.com' , '.test.com' ]
48+
49+
50+ def test_ansible_base_csrf_view_middleware_allowed_origins_exact ():
51+ """Test that allowed_origins_exact uses get_setting."""
52+ test_origins = ['https://example.com' , 'https://*.test.com' ]
53+
54+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
55+ mock_get_setting .return_value = test_origins
56+
57+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
58+ result = middleware .allowed_origins_exact
59+
60+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
61+ # Should only include origins without *
62+ assert result == {'https://example.com' }
63+
64+
65+ def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains ():
66+ """Test that allowed_origin_subdomains uses get_setting."""
67+ test_origins = ['https://*.example.com' , 'http://*.test.com' ]
68+
69+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
70+ mock_get_setting .return_value = test_origins
71+
72+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
73+ result = middleware .allowed_origin_subdomains
74+
75+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
76+ # Should group by scheme and strip *
77+ expected = {'https' : ['.example.com' ], 'http' : ['.test.com' ]}
78+ assert dict (result ) == expected
79+
80+
81+ def test_ansible_base_csrf_view_middleware_default_value ():
82+ """Test that middleware returns empty/default values when setting is empty."""
83+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
84+ mock_get_setting .return_value = []
85+
86+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
87+
88+ # Test all three properties
89+ assert middleware .csrf_trusted_origins_hosts == []
90+ assert middleware .allowed_origins_exact == set ()
91+ assert dict (middleware .allowed_origin_subdomains ) == {}
92+
93+ # get_setting should be called three times (once for each property)
94+ assert mock_get_setting .call_count == 3
95+
96+
97+ def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware ():
98+ """Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware."""
99+ csrf_check = AnsibleBaseCSRFCheck (lambda request : None )
100+ assert isinstance (csrf_check , AnsibleBaseCsrfViewMiddleware )
101+
102+
103+ def test_ansible_base_csrf_check_reject_method ():
104+ """Test that AnsibleBaseCSRFCheck._reject returns the reason."""
105+ csrf_check = AnsibleBaseCSRFCheck (lambda request : None )
106+ reason = "Test CSRF failure reason"
107+ result = csrf_check ._reject (None , reason )
108+ assert result == reason
109+
110+
111+ def test_session_authentication_uses_ansible_base_csrf_check ():
112+ """Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation."""
113+ from unittest .mock import MagicMock , Mock
114+
115+ # Create a mock request with an authenticated user
116+ mock_request = Mock ()
117+ mock_request ._request = Mock ()
118+ mock_request ._request .user = Mock ()
119+ mock_request ._request .user .is_active = True
120+
121+ # Mock the AnsibleBaseCSRFCheck to track its usage
122+ with patch ('ansible_base.authentication.session.AnsibleBaseCSRFCheck' ) as mock_csrf_check_class :
123+ mock_csrf_check = Mock ()
124+ mock_csrf_check .process_request .return_value = None
125+ mock_csrf_check .process_view .return_value = None # No CSRF error
126+ mock_csrf_check_class .return_value = mock_csrf_check
127+
128+ # Create SessionAuthentication instance and call enforce_csrf
129+ session_auth = SessionAuthentication ()
130+ session_auth .enforce_csrf (mock_request )
131+
132+ # Verify AnsibleBaseCSRFCheck was instantiated
133+ mock_csrf_check_class .assert_called_once ()
134+
135+ # Verify process_request and process_view were called
136+ mock_csrf_check .process_request .assert_called_once_with (mock_request )
137+ mock_csrf_check .process_view .assert_called_once_with (mock_request , None , (), {})
138+
139+
140+ def test_session_authentication_csrf_failure_raises_permission_denied ():
141+ """Test that SessionAuthentication raises PermissionDenied when CSRF fails."""
142+ from rest_framework .exceptions import PermissionDenied
143+ from unittest .mock import Mock
144+
145+ # Create a mock request with an authenticated user
146+ mock_request = Mock ()
147+ mock_request ._request = Mock ()
148+ mock_request ._request .user = Mock ()
149+ mock_request ._request .user .is_active = True
150+
151+ # Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason
152+ with patch ('ansible_base.authentication.session.AnsibleBaseCSRFCheck' ) as mock_csrf_check_class :
153+ mock_csrf_check = Mock ()
154+ mock_csrf_check .process_request .return_value = None
155+ mock_csrf_check .process_view .return_value = "CSRF token missing" # CSRF error
156+ mock_csrf_check_class .return_value = mock_csrf_check
157+
158+ # Create SessionAuthentication instance and call enforce_csrf
159+ session_auth = SessionAuthentication ()
160+
161+ # Should raise PermissionDenied with the CSRF failure reason
162+ try :
163+ session_auth .enforce_csrf (mock_request )
164+ assert False , "Expected PermissionDenied to be raised"
165+ except PermissionDenied as e :
166+ assert "CSRF Failed: CSRF token missing" in str (e )
0 commit comments