From 23314ab10891cf5203199f6cc2a2f8d6fc2ccf2f Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Thu, 7 Sep 2023 18:16:55 -0400 Subject: [PATCH 1/7] account for overrides when generating new token --- flask_dance/consumer/storage/sqla.py | 15 +++++++++++++++ tests/consumer/storage/test_sqla.py | 7 +++++++ 2 files changed, 22 insertions(+) diff --git a/flask_dance/consumer/storage/sqla.py b/flask_dance/consumer/storage/sqla.py index 10cf307..91e9e66 100644 --- a/flask_dance/consumer/storage/sqla.py +++ b/flask_dance/consumer/storage/sqla.py @@ -214,6 +214,7 @@ def set(self, blueprint, token, user=None, user_id=None): has_user = hasattr(self.model, "user") if has_user and u: existing_query = existing_query.filter_by(user=u) + # queue up delete query -- won't be run until commit() existing_query.delete() # create a new model for this token @@ -222,6 +223,20 @@ def set(self, blueprint, token, user=None, user_id=None): kwargs["user_id"] = uid if has_user and u: kwargs["user"] = u + + existing = existing_query.first() + + # if the oauth model is overridden, make sure to copy the columns + column_names = [ + col.name + for col in self.model.__table__.columns + if not col.nullable + and not col.primary_key + and col.name not in kwargs.keys() + ] + for name in column_names: + kwargs[name] = getattr(existing, name) + self.session.add(self.model(**kwargs)) # commit to delete and add simultaneously self.session.commit() diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index c129a97..bc34110 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -692,6 +692,7 @@ def done(): def test_sqla_overwrite_token(app, db, blueprint, request): class OAuth(OAuthConsumerMixin, db.Model): + pass blueprint.storage = SQLAlchemyStorage(OAuth, db.session) @@ -708,6 +709,8 @@ def done(): existing = OAuth( provider="test-service", token={"access_token": "something", "token_type": "bearer", "scope": ["blah"]}, + provider_user_id="some-hash", + provider_user_login = "user.name" ) db.session.add(existing) db.session.commit() @@ -745,6 +748,10 @@ def done(): } +# def test_sqla_overwrite_token_override_model(app, db, blueprint, request): +# class OAuth(OAuthConsumerMixin, db.Model): + + def test_sqla_cache(app, db, blueprint, request): cache = Cache(app) From b02859e228d9a8849e1e72c083f354e05617667b Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Tue, 12 Sep 2023 10:31:22 -0400 Subject: [PATCH 2/7] Update SQLA to account for overriden cols when refreshing the token --- flask_dance/consumer/storage/sqla.py | 27 +++++++++++++++------------ tests/consumer/storage/test_sqla.py | 6 +++--- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/flask_dance/consumer/storage/sqla.py b/flask_dance/consumer/storage/sqla.py index 91e9e66..42e3e83 100644 --- a/flask_dance/consumer/storage/sqla.py +++ b/flask_dance/consumer/storage/sqla.py @@ -215,6 +215,8 @@ def set(self, blueprint, token, user=None, user_id=None): if has_user and u: existing_query = existing_query.filter_by(user=u) + # grab the existing model before we delete so that we can copy overriden columns + existing = existing_query.first() # queue up delete query -- won't be run until commit() existing_query.delete() # create a new model for this token @@ -224,18 +226,19 @@ def set(self, blueprint, token, user=None, user_id=None): if has_user and u: kwargs["user"] = u - existing = existing_query.first() - - # if the oauth model is overridden, make sure to copy the columns - column_names = [ - col.name - for col in self.model.__table__.columns - if not col.nullable - and not col.primary_key - and col.name not in kwargs.keys() - ] - for name in column_names: - kwargs[name] = getattr(existing, name) + if existing: + EXCLUDE_COLS = ["created_at"] + EXCLUDE_COLS.extend(kwargs.keys()) + # if the oauth model is overridden, make sure to copy the columns + column_names = [ + col.name + for col in self.model.__table__.columns + if not col.nullable + and not col.primary_key + and col.name not in EXCLUDE_COLS + ] + for name in column_names: + kwargs[name] = getattr(existing, name) self.session.add(self.model(**kwargs)) # commit to delete and add simultaneously diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index bc34110..2076a0d 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -692,8 +692,8 @@ def done(): def test_sqla_overwrite_token(app, db, blueprint, request): class OAuth(OAuthConsumerMixin, db.Model): - - pass + provider_user_id=db.Column(db.String, nullable=False) + provider_user_login=db.Column(db.String, nullable=False) blueprint.storage = SQLAlchemyStorage(OAuth, db.session) @@ -733,7 +733,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 # check that the database record was overwritten authorizations = OAuth.query.all() From f73de59ef94f39a054a75b6ed340538e78a3a6a7 Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Tue, 12 Sep 2023 10:39:25 -0400 Subject: [PATCH 3/7] clean up --- CHANGELOG.rst | 2 +- tests/consumer/storage/test_sqla.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1fd8519..5715436 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,7 +3,7 @@ Changelog `unreleased`_ ------------- -nothing yet +* Added support for auto refreshing tokens if the OAuth model is overriden as in the multi-provider example `7.0.0`_ (2023-05-10) --------------------- diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index 2076a0d..41d0632 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -748,10 +748,6 @@ def done(): } -# def test_sqla_overwrite_token_override_model(app, db, blueprint, request): -# class OAuth(OAuthConsumerMixin, db.Model): - - def test_sqla_cache(app, db, blueprint, request): cache = Cache(app) From ace69a4494ed828918bfaf7ebfd7fb7471828ab3 Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Tue, 12 Sep 2023 11:21:13 -0400 Subject: [PATCH 4/7] Fix assertions in other tests to reflect the additional query added --- tests/consumer/storage/test_sqla.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index 41d0632..73e2aef 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -128,7 +128,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 # check the database authorizations = OAuth.query.all() @@ -211,7 +211,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 3 + assert len(queries) == 4 # check the database alice = User.query.first() @@ -351,7 +351,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # lets do it again, with Bob as the logged in user -- he gets a different token if "_login_user" in flask.g: @@ -519,7 +519,7 @@ def logged_in(sender, token): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # check the database users = User.query.all() @@ -781,7 +781,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 expected_token = {"access_token": "foobar", "token_type": "bearer", "scope": [""]} From 546f964254bace9a4933b6772dcb371ae777bff9 Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Wed, 13 Sep 2023 17:57:01 -0400 Subject: [PATCH 5/7] Fix off by one in test --- tests/consumer/storage/test_sqla.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index 73e2aef..585a17d 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -351,7 +351,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 6 + assert len(queries) == 5 # lets do it again, with Bob as the logged in user -- he gets a different token if "_login_user" in flask.g: @@ -692,8 +692,8 @@ def done(): def test_sqla_overwrite_token(app, db, blueprint, request): class OAuth(OAuthConsumerMixin, db.Model): - provider_user_id=db.Column(db.String, nullable=False) - provider_user_login=db.Column(db.String, nullable=False) + provider_user_id = db.Column(db.String, nullable=False) + provider_user_login = db.Column(db.String, nullable=False) blueprint.storage = SQLAlchemyStorage(OAuth, db.session) @@ -710,7 +710,7 @@ def done(): provider="test-service", token={"access_token": "something", "token_type": "bearer", "scope": ["blah"]}, provider_user_id="some-hash", - provider_user_login = "user.name" + provider_user_login="user.name", ) db.session.add(existing) db.session.commit() From 377850f8868b4eac3efdfa4e2c83128faf01355b Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Thu, 14 Sep 2023 11:14:29 -0400 Subject: [PATCH 6/7] When the user require check is false we don't hit the additional query added for this update, but if they are then we need to account for the additional query --- tests/consumer/storage/test_sqla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index 585a17d..680a242 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -379,7 +379,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # check the database authorizations = OAuth.query.all() From de5a6a5531ccf248d80ea4921ee0521134154b57 Mon Sep 17 00:00:00 2001 From: Ryan Schaffer Date: Thu, 14 Sep 2023 16:38:04 -0400 Subject: [PATCH 7/7] Tests passing locally for flask 2.0.3 and python 37 --- tests/consumer/storage/test_sqla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index 680a242..168148a 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -351,7 +351,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # lets do it again, with Bob as the logged in user -- he gets a different token if "_login_user" in flask.g: