Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions identity/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,30 @@ class Auth(WebFrameworkAuth):
your project's ``urlpatterns`` list in ``your_project/urls.py``.
"""

def __init__(self, *args, **kwargs):
def __init__(
self,
*args,
post_logout_view: Optional[callable] = None,
**kwargs,
):
"""Initialize the Auth class for a Django web application.

:param callable post_logout_view:
Optional.
If not provided, the user will be redirected to the root URL of the app.

If provided, it shall be the view (which is a function)
that will be redirected to, after the user has logged out.
For example, you will typically use this parameter like this::

from . import public_views # This module shall NOT import settings.AUTH
auth = Auth(
...,
post_logout_view=public_views.my_post_logout_view,
)

where ``my_post_logout_view`` is a Django view function.
"""
super(Auth, self).__init__(*args, **kwargs)
route, redirect_view = _parse_redirect_uri(self._redirect_uri)
self.urlpattern = path(route, include([
Expand All @@ -46,6 +69,7 @@ def __init__(self, *args, **kwargs):
self.auth_response,
),
]))
self._post_logout_view = post_logout_view

def login(
self,
Expand Down Expand Up @@ -109,8 +133,9 @@ def logout(self, request):
So you can use ``{% url "identity.django.logout" %}`` to get the url
from inside a template.
"""
return redirect(
self._build_auth(request.session).log_out(request.build_absolute_uri("/")))
return redirect(self._build_auth(request.session).log_out(request.build_absolute_uri(
reverse(self._post_logout_view) if self._post_logout_view else "/"
)))

def login_required( # Named after Django's login_required
self,
Expand Down
22 changes: 20 additions & 2 deletions identity/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ class Auth(PalletAuth):
_Session = Session
_redirect = redirect

def __init__(self, app: Optional[Flask], *args, **kwargs):
"""Create an identity helper for a web application.
def __init__(
self,
app: Optional[Flask],
*args,
post_logout_view: Optional[callable] = None,
**kwargs,
):
"""Initialize the Auth class for a Flask web application.

:param Flask app:
It can be a Flask app instance, or ``None``.
Expand Down Expand Up @@ -56,10 +62,17 @@ def build_app():

app = build_app()

:param callable post_logout_view:
Optional.
If not provided, the user will be redirected to the root URL of the app.
If provided, it shall be the view (which is a function)
that will be redirected to, after the user has logged out.

It also passes extra parameters to :class:`identity.web.WebFrameworkAuth`.
"""
self._request = request # Not available during class definition
self._session = session # Not available during class definition
self._post_logout_view = post_logout_view
super(Auth, self).__init__(app, *args, **kwargs)

def _render_auth_error( # type: ignore[override]
Expand Down Expand Up @@ -153,3 +166,8 @@ def call_an_api(*, context):
"""
return super(Auth, self).login_required(function, scopes=scopes)

def logout(self):
return super(Auth, self).logout(url_for(
self._post_logout_view.__name__, _external=True,
) if self._post_logout_view else None)

5 changes: 3 additions & 2 deletions identity/pallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def __getattribute__(self, name):
"@auth.login_required() or auth.logout() etc.")
return super(PalletAuth, self).__getattribute__(name)

def logout(self):
def logout(self, post_logout_redirect_uri: Optional[str] = None):
return self.__class__._redirect( # self._redirect(...) won't work
self._auth.log_out(self._request.url_root))
self._auth.log_out(post_logout_redirect_uri or self._request.url_root)
)

def login_required( # Named after Django's login_required
self,
Expand Down
20 changes: 19 additions & 1 deletion identity/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ class Auth(PalletAuth):
_Session = Session
_redirect = redirect

def __init__(self, app: Optional[Quart], *args, **kwargs):
def __init__(
self,
app: Optional[Quart],
*args,
post_logout_view: Optional[callable] = None,
**kwargs,
):
"""Create an identity helper for a web application.

:param Quart app:
Expand Down Expand Up @@ -56,10 +62,17 @@ def build_app():

app = build_app()

:param callable post_logout_view:
Optional.
If not provided, the user will be redirected to the root URL of the app.
If provided, it shall be the view (which is a function)
that will be redirected to, after the user has logged out.

It also passes extra parameters to :class:`identity.web.WebFrameworkAuth`.
"""
self._request = request # Not available during class definition
self._session = session # Not available during class definition
self._post_logout_view = post_logout_view
super(Auth, self).__init__(app, *args, **kwargs)

async def _render_auth_error(self, *, error, error_description=None):
Expand Down Expand Up @@ -152,3 +165,8 @@ async def call_api(*, context):
"""
return super(Auth, self).login_required(function, scopes=scopes)

def logout(self):
return super(Auth, self).logout(url_for(
self._post_logout_view.__name__, _external=True,
) if self._post_logout_view else None)

12 changes: 6 additions & 6 deletions identity/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ def _get_oidc_config(self):
"%s not found from OIDC config: %s", self._END_SESSION_ENDPOINT, conf)
return conf

def log_out(self, homepage):
def log_out(self, post_logout_redirect_uri: str) -> str:
# The vocabulary is "log out" (rather than "sign out") in the specs
# https://openid.net/specs/openid-connect-frontchannel-1_0.html
"""Logs out the user from current app.

:param str homepage:
The page to be redirected to, after the log-out.
:param str post_logout_redirect_uri:
The absolute uri of the page to be redirected to, after the log-out.
In Flask, you can pass in ``url_for("index", _external=True)``.

:return:
Expand All @@ -318,11 +318,11 @@ def log_out(self, homepage):
# but its default (i.e. v1.0) endpoint will sign out the (only?) account
endpoint = self._get_oidc_config().get(self._END_SESSION_ENDPOINT)
if endpoint:
return f"{endpoint}?post_logout_redirect_uri={homepage}"
return f"{endpoint}?post_logout_redirect_uri={post_logout_redirect_uri}"
except requests.exceptions.RequestException:
logger.exception("Failed to get OIDC config")
logger.warning("No end_session_endpoint found. Fallback to %s", homepage)
return homepage
logger.warning("No end_session_endpoint found. Fallback to %s", post_logout_redirect_uri)
return post_logout_redirect_uri # Fall back to this

def get_token_for_client(self, scopes):
"""Get access token for the current app, with specified scopes.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from unittest import mock

import pytest
from django.conf import settings

from identity.django import _parse_redirect_uri, Auth

urlpatterns = [] # This is required for Django to recognize the URL patterns
settings.configure(
ROOT_URLCONF='test_django', # Set the root URL configuration
)

def test_parse_redirect_uri():
with pytest.raises(ValueError):
Expand Down Expand Up @@ -35,3 +40,24 @@ def test_the_installed_package_contains_builtin_templates():
templates_found.add(t)
assert templates_needed == templates_found

def test_logout():
request = mock.MagicMock(
build_absolute_uri=lambda relative_uri: f"http://localhost{relative_uri}"
)
with mock.patch('identity.web.requests.get', new=mock.MagicMock(
return_value=mock.MagicMock(
json=mock.MagicMock(return_value={
"end_session_endpoint": "https://example.com/end_session",
}),
status_code=200,
)
)):
response = Auth("client_id").logout(request)
assert response.status_code == 302
assert response.url == "https://example.com/end_session?post_logout_redirect_uri=http://localhost/"

auth = Auth("client_id", post_logout_view=lambda r: "You have logged out")
with mock.patch('identity.django.reverse', return_value="/post_logout"):
response = auth.logout(request)
assert response.status_code == 302
assert response.url == "https://example.com/end_session?post_logout_redirect_uri=http://localhost/post_logout"
29 changes: 22 additions & 7 deletions tests/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,44 @@ def app(): # https://flask.palletsprojects.com/en/3.0.x/testing/
yield app
shutil.rmtree("flask_session") # clean up

@pytest.fixture()
def auth(app):
def build_auth(app, post_logout_view=None):
return Auth(
app,
client_id="fake",
redirect_uri="http://localhost:5000/redirect", # To use auth code flow
oidc_authority="https://example.com/foo",
post_logout_view=post_logout_view,
)

def test_logout(app, auth):
@pytest.mark.parametrize("customize_post_logout,expected_post_logout_uri", [
(False, "http://localhost/app_root/"),
(True, "http://localhost/app_root/my_post_logout_page"),
])
def test_logout(app, customize_post_logout, expected_post_logout_uri):

@app.route("/my_post_logout_page")
def post_logout_view():
return "You have logged out"

auth = build_auth(
app,
post_logout_view=post_logout_view if customize_post_logout else None,
)
with patch.object(auth._auth, "_get_oidc_config", new=Mock(return_value={
"end_session_endpoint": "https://example.com/end_session",
})):
with app.test_request_context("/", method="GET"):
homepage = "http://localhost/app_root"
assert homepage in auth.logout().get_data(as_text=True), (
"The homepage should be in the logout URL. There was a bug in 0.9.0.")
assert (
f"?post_logout_redirect_uri={expected_post_logout_uri}</a>"
in auth.logout().get_data(as_text=True)
), "The post-login uri should be in the logout page"

@patch("msal.authority.tenant_discovery", new=Mock(return_value={
"authorization_endpoint": "https://example.com/placeholder",
"token_endpoint": "https://example.com/placeholder",
}))
def test_login(app, auth):
def test_login(app):
auth = build_auth(app)

@app.route("/path")
@auth.login_required
Expand Down