Skip to content
Snippets Groups Projects
Commit dd20dcc0 authored by Julian's avatar Julian
Browse files

Gracefully handle missing session keys

Fixes #2
parent edb31385
No related branches found
No related tags found
No related merge requests found
Pipeline #7773 failed
import os
import secrets
from flask import Flask, session, request, redirect, abort, Response
from flask import Flask, session, request, redirect, abort, Response, url_for
from requests_oauthlib import OAuth2Session
......@@ -19,15 +19,23 @@ def create_app(test_config=None):
@app.route('/auth')
def auth():
if not session.get('user_id'):
try:
user_id = session['user_id']
user_name = session['user_name']
user_nickname = session['user_nickname']
user_email = session['user_email']
user_ldap_dn = session['user_ldap_dn']
user_groups = session['user_groups']
except KeyError:
session['cookies_enabled'] = True
abort(401)
resp = Response('Ok', 200)
resp.headers['OAUTH-USER-ID'] = session['user_id']
resp.headers['OAUTH-USER-NAME'] = session['user_name']
resp.headers['OAUTH-USER-NICKNAME'] = session['user_nickname']
resp.headers['OAUTH-USER-EMAIL'] = session['user_email']
resp.headers['OAUTH-USER-LDAP-DN'] = session['user_ldap_dn']
resp.headers['OAUTH-USER-GROUPS'] = ','.join(session['user_groups'])
resp.headers['OAUTH-USER-ID'] = user_id
resp.headers['OAUTH-USER-NAME'] = user_name
resp.headers['OAUTH-USER-NICKNAME'] = user_nickname
resp.headers['OAUTH-USER-EMAIL'] = user_email
resp.headers['OAUTH-USER-LDAP-DN'] = user_ldap_dn
resp.headers['OAUTH-USER-GROUPS'] = ','.join(user_groups)
return resp
def get_oauth(**kwargs):
......@@ -36,8 +44,14 @@ def create_app(test_config=None):
@app.route('/login')
def login():
# The cookies_enabled check prevents redirect loops:
# login (sets state) -> idp_authorize -> callback (no state set) -> login
if not session.get('cookies_enabled'):
session['cookies_enabled'] = True
abort(400, description='Enable cookies and reload two times to continue')
client = get_oauth()
url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL'])
session.clear()
session['state'] = state
parts = request.full_path.split('?rawurl=', 1)
if len(parts) == 2:
......@@ -48,18 +62,27 @@ def create_app(test_config=None):
@app.route('/callback')
def callback():
client = get_oauth(state=session.pop('state'))
redirect_url = session.get('url', '/')
if 'state' not in session:
session.clear()
session['cookies_enabled'] = True
return redirect(url_for('login', url=redirect_url))
state = session['state']
client = get_oauth(state=state)
client.fetch_token(app.config['OAUTH2_TOKEN_URL'],
client_secret=request.headers['X-CLIENT-SECRET'],
authorization_response=request.url, verify=(not app.debug))
userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json()
session.clear()
session['user_id'] = userinfo['id']
session['user_name'] = userinfo['name']
session['user_nickname'] = userinfo['nickname']
session['user_email'] = userinfo['email']
session['user_ldap_dn'] = userinfo['ldap_dn']
session['user_groups'] = userinfo['groups']
return redirect(session.pop('url'))
return redirect(redirect_url)
@app.route('/logout')
def logout():
......
......@@ -6,7 +6,6 @@ except ImportError:
import json
import urllib.parse
from flask import session
from requests import Session, Response
from app import create_app
......@@ -78,8 +77,12 @@ class TestCases(unittest.TestCase):
def test_auth_no_session(self):
r = self.client.get(path='/auth', headers=headers)
self.assertEqual(r.status_code, 401)
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_login(self):
with self.client.session_transaction() as session:
session['cookies_enabled'] = True
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location)
......@@ -91,9 +94,14 @@ class TestCases(unittest.TestCase):
self.assertEqual(qs['client_id'], ['test_client_id'])
self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback'])
self.assertGreater(len(qs['state'][0]), 8)
with self.client.session_transaction() as session:
self.assertEqual(session['state'], qs['state'][0])
self.assertEqual(session['url'], 'https://127.0.0.123:7654/app')
def test_login_no_cookies(self):
r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False)
self.assertEqual(r.status_code, 400)
def test_callback(self):
code = 'testcode'
state = 'teststate'
......@@ -113,6 +121,16 @@ class TestCases(unittest.TestCase):
self.assertNotIn('state', session)
self.assertNotIn('url', session)
def test_callback_no_session(self):
code = 'testcode'
state = 'teststate'
r = self.client.get(path='/callback', headers=headers, query_string={'code': code, 'state': state}, follow_redirects=False)
self.assertEqual(r.status_code, 302)
url = urllib.parse.urlparse(r.location)
self.assertEqual(url.path, '/login')
with self.client.session_transaction() as session:
self.assertEqual(session['cookies_enabled'], True)
def test_auth_session(self):
with self.client.session_transaction() as session:
session['user_id'] = 1234
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment