From 4457282dd57487599ac2902c3ddff0d4a31c5a6a Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@cccv.de>
Date: Wed, 8 Nov 2023 15:28:29 +0100
Subject: [PATCH] Fix OAuth2 authorization code invalidation

9bfd6f8 changed the format of authorization codes, but did not adapt the
invalidation code accordingly. Because of this, authorization codes were
not invalidated and could have been used multiple times to request access
tokens until expiring.
---
 tests/views/test_oauth2.py | 13 +++++++++++++
 uffd/views/oauth2.py       | 10 +++++++++-
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/tests/views/test_oauth2.py b/tests/views/test_oauth2.py
index 6d26baf..42067d5 100644
--- a/tests/views/test_oauth2.py
+++ b/tests/views/test_oauth2.py
@@ -192,6 +192,19 @@ class TestViews(UffdTestCase):
 			data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True)
 		self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0
 		self.assertEqual(r.content_type, 'application/json')
+		self.assertEqual(r.json['error'], 'invalid_grant')
+
+	def test_token_code_invalidation(self):
+		code = self.get_auth_code()
+		r = self.client.post(path=url_for('oauth2.token'),
+			data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'},
+			headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True)
+		self.assertEqual(r.status_code, 200)
+		r = self.client.post(path=url_for('oauth2.token'),
+			data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback'},
+			headers={'Authorization': f'Basic dGVzdDp0ZXN0c2VjcmV0'}, follow_redirects=True)
+		self.assertIn(r.status_code, [400, 401]) # oauthlib behaviour changed between v2.1.0 and v3.1.0
+		self.assertEqual(r.json['error'], 'invalid_grant')
 
 	def test_token_invalid_client(self):
 		r = self.client.post(path=url_for('oauth2.token'),
diff --git a/uffd/views/oauth2.py b/uffd/views/oauth2.py
index 43d65fa..becb485 100644
--- a/uffd/views/oauth2.py
+++ b/uffd/views/oauth2.py
@@ -92,7 +92,15 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator):
 		return True
 
 	def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs):
-		OAuth2Grant.query.filter_by(client=oauthreq.client, code=code).delete()
+		if '-' not in code:
+			return
+		grant_id, grant_code = code.split('-', 2)
+		grant = OAuth2Grant.query.get(grant_id)
+		if not grant or grant.client != oauthreq.client:
+			return
+		if not secrets.compare_digest(grant.code, grant_code):
+			return
+		db.session.delete(grant)
 		db.session.commit()
 
 	def save_bearer_token(self, token_data, oauthreq, *args, **kwargs):
-- 
GitLab