From 9bfd6f81ee6633b73243de64ab10c2ae20562407 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@cccv.de>
Date: Mon, 6 Sep 2021 20:13:19 +0200
Subject: [PATCH] Verify OAuth2 codes/tokens in constant-time

This change effectivly invalidates all existing grants/tokens.
---
 uffd/oauth2/views.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py
index e3300929..bfbe54dc 100644
--- a/uffd/oauth2/views.py
+++ b/uffd/oauth2/views.py
@@ -58,10 +58,16 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator):
 		                    redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes))
 		db.session.add(grant)
 		db.session.commit()
+		code['code'] = f"{grant.id}-{code['code']}"
 
 	def validate_code(self, client_id, code, client, oauthreq, *args, **kwargs):
-		oauthreq.grant = OAuth2Grant.query.filter_by(client_id=client_id, code=code).first()
-		if not oauthreq.grant:
+		if '-' not in code:
+			return False
+		grant_id, grant_code = code.split('-', 2)
+		oauthreq.grant = OAuth2Grant.query.get(grant_id)
+		if not oauthreq.grant or oauthreq.grant.client_id != client_id:
+			return False
+		if not secrets.compare_digest(oauthreq.grant.code, grant_code):
 			return False
 		if datetime.datetime.utcnow() > oauthreq.grant.expires:
 			return False
@@ -88,6 +94,8 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator):
 		)
 		db.session.add(tok)
 		db.session.commit()
+		token_data['access_token'] = f"{tok.id}-{token_data['access_token']}"
+		token_data['refresh_token'] = f"{tok.id}-{token_data['refresh_token']}"
 		return oauthreq.client.default_redirect_uri
 
 	def validate_grant_type(self, client_id, grant_type, client, oauthreq, *args, **kwargs):
@@ -97,8 +105,11 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator):
 		return redirect_uri == oauthreq.grant.redirect_uri
 
 	def validate_bearer_token(self, token_value, scopes, oauthreq):
-		tok = OAuth2Token.query.filter_by(access_token=token_value).first()
-		if not tok:
+		if '-' not in token_value:
+			return False
+		tok_id, tok_secret = token_value.split('-', 2)
+		tok = OAuth2Token.query.get(tok_id)
+		if not tok or not secrets.compare_digest(tok.access_token, tok_secret):
 			return False
 		if datetime.datetime.utcnow() > tok.expires:
 			oauthreq.error_message = 'Token expired'
-- 
GitLab