Skip to content
Snippets Groups Projects
Commit dd20dcc0 authored by Julian's avatar Julian
Browse files

Gracefully handle missing session keys

Fixes #2
parent edb31385
Branches
No related tags found
No related merge requests found
import os
import secrets
from flask import Flask, session, request, redirect, abort, Response
from flask import Flask, session, request, redirect, abort, Response, url_for
from requests_oauthlib import OAuth2Session
......@@ -19,15 +19,23 @@ def create_app(test_config=None):
@app.route('/auth')
def auth():
if not session.get('user_id'):
try:
user_id = session['user_id']
user_name = session['user_name']
user_nickname = session['user_nickname']
user_email = session['user_email']
user_ldap_dn = session['user_ldap_dn']
user_groups = session['user_groups']
except KeyError:
session['cookies_enabled'] = True
abort(401)
resp = Response('Ok', 200)
resp.headers['OAUTH-USER-ID'] = session['user_id']
resp.headers['OAUTH-USER-NAME'] = session['user_name']
resp.headers['OAUTH-USER-NICKNAME'] = session['user_nickname']
resp.headers['OAUTH-USER-EMAIL'] = session['user_email']
resp.headers['OAUTH-USER-LDAP-DN'] = session['user_ldap_dn']
resp.headers['OAUTH-USER-GROUPS'] = ','.join(session['user_groups'])
resp.headers['OAUTH-USER-ID'] = user_id
resp.headers['OAUTH-USER-NAME'] = user_name
resp.headers['OAUTH-USER-NICKNAME'] = user_nickname
resp.headers['OAUTH-USER-EMAIL'] = user_email
resp.headers['OAUTH-USER-LDAP-DN'] = user_ldap_dn
resp.headers['OAUTH-USER-GROUPS'] = ','.join(user_groups)
return resp
def get_oauth(**kwargs):
......@@ -36,8 +44,14 @@ def create_app(test_config=None):
@app.route('/login')
def login():
# The cookies_enabled check prevents redirect loops:
# login (sets state) -> idp_authorize -> callback (no state set) -> login
if not session.get('cookies_enabled'):
session['cookies_enabled'] = True
abort(400, description='Enable cookies and reload two times to continue')
client = get_oauth()
url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL'])
session.clear()
session['state'] = state
parts = request.full_path.split('?rawurl=', 1)
if len(parts) == 2:
......@@ -48,18 +62,27 @@ def create_app(test_config=None):
@app.route('/callback')
def callback():
client = get_oauth(state=session.pop('state'))
redirect_url = session.get('url', '/')
if 'state' not in session:
session.clear()
session['cookies_enabled'] = True
return redirect(url_for('login', url=redirect_url))
state = session['state']
client = get_oauth(state=state)
client.fetch_token(app.config['OAUTH2_TOKEN_URL'],
client_secret=request.headers['X-CLIENT-SECRET'],
authorization_response=request.url, verify=(not app.debug))
userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json()
session.clear()
session['user_id'] = userinfo['id']
session['user_name'] = userinfo['name']
session['user_nickname'] = userinfo['nickname']
session['user_email'] = userinfo['email']
session['user_ldap_dn'] = userinfo['ldap_dn']
session['user_groups'] = userinfo['groups']
return redirect(session.pop('url'))
return redirect(redirect_url)
@app.route('/logout')
def logout():
......
......@@ -6,7 +6,6 @@ except ImportError:
import json
import urllib.parse
from flask import session
from requests import Session, Response
from app import create_app
......@@ -78,8 +77,12 @@ class TestCases(unittest.TestCase):
def test_auth_no_session(self):
r = self.client.get(path='/auth', headers=headers)
self.assertEqual(r.status_code, 401)
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_login(self):
with self.client.session_transaction() as session:
session['cookies_enabled'] = True
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location)
......@@ -91,9 +94,14 @@ class TestCases(unittest.TestCase):
self.assertEqual(qs['client_id'], ['test_client_id'])
self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback'])
self.assertGreater(len(qs['state'][0]), 8)
with self.client.session_transaction() as session:
self.assertEqual(session['state'], qs['state'][0])
self.assertEqual(session['url'], 'https://127.0.0.123:7654/app')
def test_login_no_cookies(self):
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 400)
def test_callback(self):
code = 'testcode'
state = 'teststate'
......@@ -113,6 +121,16 @@ class TestCases(unittest.TestCase):
self.assertNotIn('state', session)
self.assertNotIn('url', session)
def test_callback_no_session(self):
code = 'testcode'
state = 'teststate'
r = self.client.get(path='/callback', headers=headers, query_string={'code': code, 'state': state}, follow_redirects=False)
self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location)
self.assertEqual(url.path, '/login')
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_auth_session(self):
with self.client.session_transaction() as session:
session['user_id'] = 1234
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment