From dd20dcc089e50d12565b0028d22872ee2037b9b7 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@cccv.de>
Date: Sun, 19 Sep 2021 22:08:58 +0200
Subject: [PATCH] Gracefully handle missing session keys

Fixes #2
---
 app.py      | 43 +++++++++++++++++++++++++++++++++----------
 test_app.py | 24 +++++++++++++++++++++---
 2 files changed, 54 insertions(+), 13 deletions(-)

diff --git a/app.py b/app.py
index de0a89c..e74f17a 100644
--- a/app.py
+++ b/app.py
@@ -1,7 +1,7 @@
 import os
 import secrets
 
-from flask import Flask, session, request, redirect, abort, Response
+from flask import Flask, session, request, redirect, abort, Response, url_for
 
 from requests_oauthlib import OAuth2Session
 
@@ -19,15 +19,23 @@ def create_app(test_config=None):
 
 	@app.route('/auth')
 	def auth():
-		if not session.get('user_id'):
+		try:
+			user_id = session['user_id']
+			user_name = session['user_name']
+			user_nickname = session['user_nickname']
+			user_email = session['user_email']
+			user_ldap_dn = session['user_ldap_dn']
+			user_groups = session['user_groups']
+		except KeyError:
+			session['cookies_enabled'] = True
 			abort(401)
 		resp = Response('Ok', 200)
-		resp.headers['OAUTH-USER-ID'] = session['user_id']
-		resp.headers['OAUTH-USER-NAME'] = session['user_name']
-		resp.headers['OAUTH-USER-NICKNAME'] = session['user_nickname']
-		resp.headers['OAUTH-USER-EMAIL'] = session['user_email']
-		resp.headers['OAUTH-USER-LDAP-DN'] = session['user_ldap_dn']
-		resp.headers['OAUTH-USER-GROUPS'] = ','.join(session['user_groups'])
+		resp.headers['OAUTH-USER-ID'] = user_id
+		resp.headers['OAUTH-USER-NAME'] = user_name
+		resp.headers['OAUTH-USER-NICKNAME'] = user_nickname
+		resp.headers['OAUTH-USER-EMAIL'] = user_email
+		resp.headers['OAUTH-USER-LDAP-DN'] = user_ldap_dn
+		resp.headers['OAUTH-USER-GROUPS'] = ','.join(user_groups)
 		return resp
 
 	def get_oauth(**kwargs):
@@ -36,8 +44,14 @@ def create_app(test_config=None):
 
 	@app.route('/login')
 	def login():
+		# The cookies_enabled check prevents redirect loops:
+		# login (sets state) -> idp_authorize -> callback (no state set) -> login
+		if not session.get('cookies_enabled'):
+			session['cookies_enabled'] = True
+			abort(400, description='Enable cookies and reload two times to continue')
 		client = get_oauth()
 		url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL'])
+		session.clear()
 		session['state'] = state
 		parts = request.full_path.split('?rawurl=', 1)
 		if len(parts) == 2:
@@ -48,18 +62,27 @@ def create_app(test_config=None):
 
 	@app.route('/callback')
 	def callback():
-		client = get_oauth(state=session.pop('state'))
+		redirect_url = session.get('url', '/')
+		if 'state' not in session:
+			session.clear()
+			session['cookies_enabled'] = True
+			return redirect(url_for('login', url=redirect_url))
+		state = session['state']
+
+		client = get_oauth(state=state)
 		client.fetch_token(app.config['OAUTH2_TOKEN_URL'],
 			client_secret=request.headers['X-CLIENT-SECRET'],
 			authorization_response=request.url, verify=(not app.debug))
 		userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json()
+
+		session.clear()
 		session['user_id'] = userinfo['id']
 		session['user_name'] = userinfo['name']
 		session['user_nickname'] = userinfo['nickname']
 		session['user_email'] = userinfo['email']
 		session['user_ldap_dn'] = userinfo['ldap_dn']
 		session['user_groups'] = userinfo['groups']
-		return redirect(session.pop('url'))
+		return redirect(redirect_url)
 
 	@app.route('/logout')
 	def logout():
diff --git a/test_app.py b/test_app.py
index f72e356..cfe4f63 100644
--- a/test_app.py
+++ b/test_app.py
@@ -6,7 +6,6 @@ except ImportError:
 import json
 import urllib.parse
 
-from flask import session
 from requests import Session, Response
 
 from app import create_app
@@ -78,8 +77,12 @@ class TestCases(unittest.TestCase):
 	def test_auth_no_session(self):
 		r = self.client.get(path='/auth', headers=headers)
 		self.assertEqual(r.status_code, 401)
+		with self.client.session_transaction() as session:
+			self.assertEqual(session['cookies_enabled'], True)
 
 	def test_login(self):
+		with self.client.session_transaction() as session:
+			session['cookies_enabled'] = True
 		r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
 		self.assertEqual(r.status_code, 302)
 		url = urllib.parse.urlparse(r.location)
@@ -91,8 +94,13 @@ class TestCases(unittest.TestCase):
 		self.assertEqual(qs['client_id'], ['test_client_id'])
 		self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback'])
 		self.assertGreater(len(qs['state'][0]), 8)
-		self.assertEqual(session['state'], qs['state'][0])
-		self.assertEqual(session['url'], 'https://127.0.0.123:7654/app')
+		with self.client.session_transaction() as session:
+			self.assertEqual(session['state'], qs['state'][0])
+			self.assertEqual(session['url'], 'https://127.0.0.123:7654/app')
+
+	def test_login_no_cookies(self):
+		r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
+		self.assertEqual(r.status_code, 400)
 
 	def test_callback(self):
 		code = 'testcode'
@@ -113,6 +121,16 @@ class TestCases(unittest.TestCase):
 			self.assertNotIn('state', session)
 			self.assertNotIn('url', session)
 
+	def test_callback_no_session(self):
+		code = 'testcode'
+		state = 'teststate'
+		r = self.client.get(path='/callback', headers=headers, query_string={'code': code, 'state': state}, follow_redirects=False)
+		self.assertEqual(r.status_code, 302)
+		url = urllib.parse.urlparse(r.location)
+		self.assertEqual(url.path, '/login')
+		with self.client.session_transaction() as session:
+			self.assertEqual(session['cookies_enabled'], True)
+
 	def test_auth_session(self):
 		with self.client.session_transaction() as session:
 			session['user_id'] = 1234
-- 
GitLab