diff --git a/identity/django.py b/identity/django.py index ae943ba..b8ccacd 100644 --- a/identity/django.py +++ b/identity/django.py @@ -34,7 +34,30 @@ class Auth(WebFrameworkAuth): your project's ``urlpatterns`` list in ``your_project/urls.py``. """ - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + post_logout_view: Optional[callable] = None, + **kwargs, + ): + """Initialize the Auth class for a Django web application. + + :param callable post_logout_view: + Optional. + If not provided, the user will be redirected to the root URL of the app. + + If provided, it shall be the view (which is a function) + that will be redirected to, after the user has logged out. + For example, you will typically use this parameter like this:: + + from . import public_views # This module shall NOT import settings.AUTH + auth = Auth( + ..., + post_logout_view=public_views.my_post_logout_view, + ) + + where ``my_post_logout_view`` is a Django view function. + """ super(Auth, self).__init__(*args, **kwargs) route, redirect_view = _parse_redirect_uri(self._redirect_uri) self.urlpattern = path(route, include([ @@ -46,6 +69,7 @@ def __init__(self, *args, **kwargs): self.auth_response, ), ])) + self._post_logout_view = post_logout_view def login( self, @@ -109,8 +133,9 @@ def logout(self, request): So you can use ``{% url "identity.django.logout" %}`` to get the url from inside a template. """ - return redirect( - self._build_auth(request.session).log_out(request.build_absolute_uri("/"))) + return redirect(self._build_auth(request.session).log_out(request.build_absolute_uri( + reverse(self._post_logout_view) if self._post_logout_view else "/" + ))) def login_required( # Named after Django's login_required self, diff --git a/identity/flask.py b/identity/flask.py index 222eec2..a5f7585 100644 --- a/identity/flask.py +++ b/identity/flask.py @@ -13,8 +13,14 @@ class Auth(PalletAuth): _Session = Session _redirect = redirect - def __init__(self, app: Optional[Flask], *args, **kwargs): - """Create an identity helper for a web application. + def __init__( + self, + app: Optional[Flask], + *args, + post_logout_view: Optional[callable] = None, + **kwargs, + ): + """Initialize the Auth class for a Flask web application. :param Flask app: It can be a Flask app instance, or ``None``. @@ -56,10 +62,17 @@ def build_app(): app = build_app() + :param callable post_logout_view: + Optional. + If not provided, the user will be redirected to the root URL of the app. + If provided, it shall be the view (which is a function) + that will be redirected to, after the user has logged out. + It also passes extra parameters to :class:`identity.web.WebFrameworkAuth`. """ self._request = request # Not available during class definition self._session = session # Not available during class definition + self._post_logout_view = post_logout_view super(Auth, self).__init__(app, *args, **kwargs) def _render_auth_error( # type: ignore[override] @@ -153,3 +166,8 @@ def call_an_api(*, context): """ return super(Auth, self).login_required(function, scopes=scopes) + def logout(self): + return super(Auth, self).logout(url_for( + self._post_logout_view.__name__, _external=True, + ) if self._post_logout_view else None) + diff --git a/identity/pallet.py b/identity/pallet.py index 5a4d065..f435810 100644 --- a/identity/pallet.py +++ b/identity/pallet.py @@ -62,9 +62,10 @@ def __getattribute__(self, name): "@auth.login_required() or auth.logout() etc.") return super(PalletAuth, self).__getattribute__(name) - def logout(self): + def logout(self, post_logout_redirect_uri: Optional[str] = None): return self.__class__._redirect( # self._redirect(...) won't work - self._auth.log_out(self._request.url_root)) + self._auth.log_out(post_logout_redirect_uri or self._request.url_root) + ) def login_required( # Named after Django's login_required self, diff --git a/identity/quart.py b/identity/quart.py index eca7b6d..45ef978 100644 --- a/identity/quart.py +++ b/identity/quart.py @@ -13,7 +13,13 @@ class Auth(PalletAuth): _Session = Session _redirect = redirect - def __init__(self, app: Optional[Quart], *args, **kwargs): + def __init__( + self, + app: Optional[Quart], + *args, + post_logout_view: Optional[callable] = None, + **kwargs, + ): """Create an identity helper for a web application. :param Quart app: @@ -56,10 +62,17 @@ def build_app(): app = build_app() + :param callable post_logout_view: + Optional. + If not provided, the user will be redirected to the root URL of the app. + If provided, it shall be the view (which is a function) + that will be redirected to, after the user has logged out. + It also passes extra parameters to :class:`identity.web.WebFrameworkAuth`. """ self._request = request # Not available during class definition self._session = session # Not available during class definition + self._post_logout_view = post_logout_view super(Auth, self).__init__(app, *args, **kwargs) async def _render_auth_error(self, *, error, error_description=None): @@ -152,3 +165,8 @@ async def call_api(*, context): """ return super(Auth, self).login_required(function, scopes=scopes) + def logout(self): + return super(Auth, self).logout(url_for( + self._post_logout_view.__name__, _external=True, + ) if self._post_logout_view else None) + diff --git a/identity/web.py b/identity/web.py index 5d5566f..48ffc9a 100644 --- a/identity/web.py +++ b/identity/web.py @@ -298,13 +298,13 @@ def _get_oidc_config(self): "%s not found from OIDC config: %s", self._END_SESSION_ENDPOINT, conf) return conf - def log_out(self, homepage): + def log_out(self, post_logout_redirect_uri: str) -> str: # The vocabulary is "log out" (rather than "sign out") in the specs # https://openid.net/specs/openid-connect-frontchannel-1_0.html """Logs out the user from current app. - :param str homepage: - The page to be redirected to, after the log-out. + :param str post_logout_redirect_uri: + The absolute uri of the page to be redirected to, after the log-out. In Flask, you can pass in ``url_for("index", _external=True)``. :return: @@ -318,11 +318,11 @@ def log_out(self, homepage): # but its default (i.e. v1.0) endpoint will sign out the (only?) account endpoint = self._get_oidc_config().get(self._END_SESSION_ENDPOINT) if endpoint: - return f"{endpoint}?post_logout_redirect_uri={homepage}" + return f"{endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}" except requests.exceptions.RequestException: logger.exception("Failed to get OIDC config") - logger.warning("No end_session_endpoint found. Fallback to %s", homepage) - return homepage + logger.warning("No end_session_endpoint found. Fallback to %s", post_logout_redirect_uri) + return post_logout_redirect_uri # Fall back to this def get_token_for_client(self, scopes): """Get access token for the current app, with specified scopes. diff --git a/tests/test_django.py b/tests/test_django.py index 6028e32..0fcf348 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -2,9 +2,14 @@ from unittest import mock import pytest +from django.conf import settings from identity.django import _parse_redirect_uri, Auth +urlpatterns = [] # This is required for Django to recognize the URL patterns +settings.configure( + ROOT_URLCONF='test_django', # Set the root URL configuration +) def test_parse_redirect_uri(): with pytest.raises(ValueError): @@ -35,3 +40,24 @@ def test_the_installed_package_contains_builtin_templates(): templates_found.add(t) assert templates_needed == templates_found +def test_logout(): + request = mock.MagicMock( + build_absolute_uri=lambda relative_uri: f"http://localhost{relative_uri}" + ) + with mock.patch('identity.web.requests.get', new=mock.MagicMock( + return_value=mock.MagicMock( + json=mock.MagicMock(return_value={ + "end_session_endpoint": "https://example.com/end_session", + }), + status_code=200, + ) + )): + response = Auth("client_id").logout(request) + assert response.status_code == 302 + assert response.url == "https://example.com/end_session?post_logout_redirect_uri=http://localhost/" + + auth = Auth("client_id", post_logout_view=lambda r: "You have logged out") + with mock.patch('identity.django.reverse', return_value="/post_logout"): + response = auth.logout(request) + assert response.status_code == 302 + assert response.url == "https://example.com/end_session?post_logout_redirect_uri=http://localhost/post_logout" diff --git a/tests/test_flask.py b/tests/test_flask.py index a63e113..6761448 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -18,29 +18,44 @@ def app(): # https://flask.palletsprojects.com/en/3.0.x/testing/ yield app shutil.rmtree("flask_session") # clean up -@pytest.fixture() -def auth(app): +def build_auth(app, post_logout_view=None): return Auth( app, client_id="fake", redirect_uri="http://localhost:5000/redirect", # To use auth code flow oidc_authority="https://example.com/foo", + post_logout_view=post_logout_view, ) -def test_logout(app, auth): +@pytest.mark.parametrize("customize_post_logout,expected_post_logout_uri", [ + (False, "http://localhost/app_root/"), + (True, "http://localhost/app_root/my_post_logout_page"), +]) +def test_logout(app, customize_post_logout, expected_post_logout_uri): + + @app.route("/my_post_logout_page") + def post_logout_view(): + return "You have logged out" + + auth = build_auth( + app, + post_logout_view=post_logout_view if customize_post_logout else None, + ) with patch.object(auth._auth, "_get_oidc_config", new=Mock(return_value={ "end_session_endpoint": "https://example.com/end_session", })): with app.test_request_context("/", method="GET"): - homepage = "http://localhost/app_root" - assert homepage in auth.logout().get_data(as_text=True), ( - "The homepage should be in the logout URL. There was a bug in 0.9.0.") + assert ( + f"?post_logout_redirect_uri={expected_post_logout_uri}" + in auth.logout().get_data(as_text=True) + ), "The post-login uri should be in the logout page" @patch("msal.authority.tenant_discovery", new=Mock(return_value={ "authorization_endpoint": "https://example.com/placeholder", "token_endpoint": "https://example.com/placeholder", })) -def test_login(app, auth): +def test_login(app): + auth = build_auth(app) @app.route("/path") @auth.login_required