Skip to content
Snippets Groups Projects
Commit dd20dcc0 authored by Julian's avatar Julian
Browse files

Gracefully handle missing session keys

Fixes #2
parent edb31385
No related branches found
No related tags found
No related merge requests found
Pipeline #7773 failed
import os import os
import secrets import secrets
from flask import Flask, session, request, redirect, abort, Response from flask import Flask, session, request, redirect, abort, Response, url_for
from requests_oauthlib import OAuth2Session from requests_oauthlib import OAuth2Session
...@@ -19,15 +19,23 @@ def create_app(test_config=None): ...@@ -19,15 +19,23 @@ def create_app(test_config=None):
@app.route('/auth') @app.route('/auth')
def auth(): def auth():
if not session.get('user_id'): try:
user_id = session['user_id']
user_name = session['user_name']
user_nickname = session['user_nickname']
user_email = session['user_email']
user_ldap_dn = session['user_ldap_dn']
user_groups = session['user_groups']
except KeyError:
session['cookies_enabled'] = True
abort(401) abort(401)
resp = Response('Ok', 200) resp = Response('Ok', 200)
resp.headers['OAUTH-USER-ID'] = session['user_id'] resp.headers['OAUTH-USER-ID'] = user_id
resp.headers['OAUTH-USER-NAME'] = session['user_name'] resp.headers['OAUTH-USER-NAME'] = user_name
resp.headers['OAUTH-USER-NICKNAME'] = session['user_nickname'] resp.headers['OAUTH-USER-NICKNAME'] = user_nickname
resp.headers['OAUTH-USER-EMAIL'] = session['user_email'] resp.headers['OAUTH-USER-EMAIL'] = user_email
resp.headers['OAUTH-USER-LDAP-DN'] = session['user_ldap_dn'] resp.headers['OAUTH-USER-LDAP-DN'] = user_ldap_dn
resp.headers['OAUTH-USER-GROUPS'] = ','.join(session['user_groups']) resp.headers['OAUTH-USER-GROUPS'] = ','.join(user_groups)
return resp return resp
def get_oauth(**kwargs): def get_oauth(**kwargs):
...@@ -36,8 +44,14 @@ def create_app(test_config=None): ...@@ -36,8 +44,14 @@ def create_app(test_config=None):
@app.route('/login') @app.route('/login')
def login(): def login():
# The cookies_enabled check prevents redirect loops:
# login (sets state) -> idp_authorize -> callback (no state set) -> login
if not session.get('cookies_enabled'):
session['cookies_enabled'] = True
abort(400, description='Enable cookies and reload two times to continue')
client = get_oauth() client = get_oauth()
url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL']) url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL'])
session.clear()
session['state'] = state session['state'] = state
parts = request.full_path.split('?rawurl=', 1) parts = request.full_path.split('?rawurl=', 1)
if len(parts) == 2: if len(parts) == 2:
...@@ -48,18 +62,27 @@ def create_app(test_config=None): ...@@ -48,18 +62,27 @@ def create_app(test_config=None):
@app.route('/callback') @app.route('/callback')
def callback(): def callback():
client = get_oauth(state=session.pop('state')) redirect_url = session.get('url', '/')
if 'state' not in session:
session.clear()
session['cookies_enabled'] = True
return redirect(url_for('login', url=redirect_url))
state = session['state']
client = get_oauth(state=state)
client.fetch_token(app.config['OAUTH2_TOKEN_URL'], client.fetch_token(app.config['OAUTH2_TOKEN_URL'],
client_secret=request.headers['X-CLIENT-SECRET'], client_secret=request.headers['X-CLIENT-SECRET'],
authorization_response=request.url, verify=(not app.debug)) authorization_response=request.url, verify=(not app.debug))
userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json() userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json()
session.clear()
session['user_id'] = userinfo['id'] session['user_id'] = userinfo['id']
session['user_name'] = userinfo['name'] session['user_name'] = userinfo['name']
session['user_nickname'] = userinfo['nickname'] session['user_nickname'] = userinfo['nickname']
session['user_email'] = userinfo['email'] session['user_email'] = userinfo['email']
session['user_ldap_dn'] = userinfo['ldap_dn'] session['user_ldap_dn'] = userinfo['ldap_dn']
session['user_groups'] = userinfo['groups'] session['user_groups'] = userinfo['groups']
return redirect(session.pop('url')) return redirect(redirect_url)
@app.route('/logout') @app.route('/logout')
def logout(): def logout():
......
...@@ -6,7 +6,6 @@ except ImportError: ...@@ -6,7 +6,6 @@ except ImportError:
import json import json
import urllib.parse import urllib.parse
from flask import session
from requests import Session, Response from requests import Session, Response
from app import create_app from app import create_app
...@@ -78,8 +77,12 @@ class TestCases(unittest.TestCase): ...@@ -78,8 +77,12 @@ class TestCases(unittest.TestCase):
def test_auth_no_session(self): def test_auth_no_session(self):
r = self.client.get(path='/auth', headers=headers) r = self.client.get(path='/auth', headers=headers)
self.assertEqual(r.status_code, 401) self.assertEqual(r.status_code, 401)
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_login(self): def test_login(self):
with self.client.session_transaction() as session:
session['cookies_enabled'] = True
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False) r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 302) self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location) url = urllib.parse.urlparse(r.location)
...@@ -91,9 +94,14 @@ class TestCases(unittest.TestCase): ...@@ -91,9 +94,14 @@ class TestCases(unittest.TestCase):
self.assertEqual(qs['client_id'], ['test_client_id']) self.assertEqual(qs['client_id'], ['test_client_id'])
self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback']) self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback'])
self.assertGreater(len(qs['state'][0]), 8) self.assertGreater(len(qs['state'][0]), 8)
with self.client.session_transaction() as session:
self.assertEqual(session['state'], qs['state'][0]) self.assertEqual(session['state'], qs['state'][0])
self.assertEqual(session['url'], 'https://127.0.0.123:7654/app') self.assertEqual(session['url'], 'https://127.0.0.123:7654/app')
def test_login_no_cookies(self):
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 400)
def test_callback(self): def test_callback(self):
code = 'testcode' code = 'testcode'
state = 'teststate' state = 'teststate'
...@@ -113,6 +121,16 @@ class TestCases(unittest.TestCase): ...@@ -113,6 +121,16 @@ class TestCases(unittest.TestCase):
self.assertNotIn('state', session) self.assertNotIn('state', session)
self.assertNotIn('url', session) self.assertNotIn('url', session)
def test_callback_no_session(self):
code = 'testcode'
state = 'teststate'
r = self.client.get(path='/callback', headers=headers, query_string={'code': code, 'state': state}, follow_redirects=False)
self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location)
self.assertEqual(url.path, '/login')
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_auth_session(self): def test_auth_session(self):
with self.client.session_transaction() as session: with self.client.session_transaction() as session:
session['user_id'] = 1234 session['user_id'] = 1234
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment