diff --git a/tests/views/test_oauth2.py b/tests/views/test_oauth2.py index 6d26baf6b5f1f910df17cbe9a04a3011eae1a169..42067d583ce6fd870e85a729ea877166c4745f58 100644 --- a/tests/views/test_oauth2.py +++ b/tests/views/test_oauth2.py @@ -192,6 +192,19 @@ class TestViews(UffdTestCase): data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0 self.assertEqual(r.content_type, 'application/json') + self.assertEqual(r.json['error'], 'invalid_grant') + + def test_token_code_invalidation(self): + code = self.get_auth_code() + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'}, + headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True) + self.assertEqual(r.status_code, 200) + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'}, + headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True) + self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0 + self.assertEqual(r.json['error'], 'invalid_grant') def test_token_invalid_client(self): r = self.client.post(path=url_for('oauth2.token'), diff --git a/uffd/views/oauth2.py b/uffd/views/oauth2.py index 43d65fa61e5b6b1dd8d253b3a6e53c1bda065c11..becb485a6409de77dfd23ba606fc378f43eb138c 100644 --- a/uffd/views/oauth2.py +++ b/uffd/views/oauth2.py @@ -92,7 +92,15 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator): return True def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): - OAuth2Grant.query.filter_by(client=oauthreq.client, code=code).delete() + if '-' not in code: + return + grant_id, grant_code = code.split('-', 2) + grant = OAuth2Grant.query.get(grant_id) + if not grant or grant.client != oauthreq.client: + return + if not secrets.compare_digest(grant.code, grant_code): + return + db.session.delete(grant) db.session.commit() def save_bearer_token(self, token_data, oauthreq, *args, **kwargs):