diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1fd8519..5715436 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,7 +3,7 @@ Changelog `unreleased`_ ------------- -nothing yet +* Added support for auto refreshing tokens if the OAuth model is overriden as in the multi-provider example `7.0.0`_ (2023-05-10) --------------------- diff --git a/flask_dance/consumer/storage/sqla.py b/flask_dance/consumer/storage/sqla.py index 10cf307..42e3e83 100644 --- a/flask_dance/consumer/storage/sqla.py +++ b/flask_dance/consumer/storage/sqla.py @@ -214,6 +214,9 @@ def set(self, blueprint, token, user=None, user_id=None): has_user = hasattr(self.model, "user") if has_user and u: existing_query = existing_query.filter_by(user=u) + + # grab the existing model before we delete so that we can copy overriden columns + existing = existing_query.first() # queue up delete query -- won't be run until commit() existing_query.delete() # create a new model for this token @@ -222,6 +225,21 @@ def set(self, blueprint, token, user=None, user_id=None): kwargs["user_id"] = uid if has_user and u: kwargs["user"] = u + + if existing: + EXCLUDE_COLS = ["created_at"] + EXCLUDE_COLS.extend(kwargs.keys()) + # if the oauth model is overridden, make sure to copy the columns + column_names = [ + col.name + for col in self.model.__table__.columns + if not col.nullable + and not col.primary_key + and col.name not in EXCLUDE_COLS + ] + for name in column_names: + kwargs[name] = getattr(existing, name) + self.session.add(self.model(**kwargs)) # commit to delete and add simultaneously self.session.commit() diff --git a/tests/consumer/storage/test_sqla.py b/tests/consumer/storage/test_sqla.py index c129a97..168148a 100644 --- a/tests/consumer/storage/test_sqla.py +++ b/tests/consumer/storage/test_sqla.py @@ -128,7 +128,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 # check the database authorizations = OAuth.query.all() @@ -211,7 +211,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 3 + assert len(queries) == 4 # check the database alice = User.query.first() @@ -351,7 +351,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # lets do it again, with Bob as the logged in user -- he gets a different token if "_login_user" in flask.g: @@ -379,7 +379,7 @@ def load_user(userid): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # check the database authorizations = OAuth.query.all() @@ -519,7 +519,7 @@ def logged_in(sender, token): "/oauth_done", ) - assert len(queries) == 5 + assert len(queries) == 6 # check the database users = User.query.all() @@ -692,7 +692,8 @@ def done(): def test_sqla_overwrite_token(app, db, blueprint, request): class OAuth(OAuthConsumerMixin, db.Model): - pass + provider_user_id = db.Column(db.String, nullable=False) + provider_user_login = db.Column(db.String, nullable=False) blueprint.storage = SQLAlchemyStorage(OAuth, db.session) @@ -708,6 +709,8 @@ def done(): existing = OAuth( provider="test-service", token={"access_token": "something", "token_type": "bearer", "scope": ["blah"]}, + provider_user_id="some-hash", + provider_user_login="user.name", ) db.session.add(existing) db.session.commit() @@ -730,7 +733,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 # check that the database record was overwritten authorizations = OAuth.query.all() @@ -778,7 +781,7 @@ def done(): "/oauth_done", ) - assert len(queries) == 2 + assert len(queries) == 3 expected_token = {"access_token": "foobar", "token_type": "bearer", "scope": [""]}