Skip to content
Snippets Groups Projects
Commit 4457282d authored by Julian's avatar Julian
Browse files

Fix OAuth2 authorization code invalidation

9bfd6f81 changed the format of authorization codes, but did not adapt the
invalidation code accordingly. Because of this, authorization codes were
not invalidated and could have been used multiple times to request access
tokens until expiring.
parent 7a94d7de
No related branches found
No related tags found
No related merge requests found
...@@ -192,6 +192,19 @@ class TestViews(UffdTestCase): ...@@ -192,6 +192,19 @@ class TestViews(UffdTestCase):
data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True)
self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0 self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0
self.assertEqual(r.content_type, 'application/json') self.assertEqual(r.content_type, 'application/json')
self.assertEqual(r.json['error'], 'invalid_grant')
def test_token_code_invalidation(self):
code = self.get_auth_code()
r = self.client.post(path=url_for('oauth2.token'),
data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'},
headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
r = self.client.post(path=url_for('oauth2.token'),
data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'},
headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True)
self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0
self.assertEqual(r.json['error'], 'invalid_grant')
def test_token_invalid_client(self): def test_token_invalid_client(self):
r = self.client.post(path=url_for('oauth2.token'), r = self.client.post(path=url_for('oauth2.token'),
......
...@@ -92,7 +92,15 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator): ...@@ -92,7 +92,15 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator):
return True return True
def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs):
OAuth2Grant.query.filter_by(client=oauthreq.client, code=code).delete() if '-' not in code:
return
grant_id, grant_code = code.split('-', 2)
grant = OAuth2Grant.query.get(grant_id)
if not grant or grant.client != oauthreq.client:
return
if not secrets.compare_digest(grant.code, grant_code):
return
db.session.delete(grant)
db.session.commit() db.session.commit()
def save_bearer_token(self, token_data, oauthreq, *args, **kwargs): def save_bearer_token(self, token_data, oauthreq, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment