diff --git a/app.py b/app.py index de0a89c6ed9f2408fabd4cee65d2d48264c061e9..e74f17a14cc40eaed4b0eeb2a0b4c5d94a0529fb 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,7 @@ import os import secrets -from flask import Flask, session, request, redirect, abort, Response +from flask import Flask, session, request, redirect, abort, Response, url_for from requests_oauthlib import OAuth2Session @@ -19,15 +19,23 @@ def create_app(test_config=None): @app.route('/auth') def auth(): - if not session.get('user_id'): + try: + user_id = session['user_id'] + user_name = session['user_name'] + user_nickname = session['user_nickname'] + user_email = session['user_email'] + user_ldap_dn = session['user_ldap_dn'] + user_groups = session['user_groups'] + except KeyError: + session['cookies_enabled'] = True abort(401) resp = Response('Ok', 200) - resp.headers['OAUTH-USER-ID'] = session['user_id'] - resp.headers['OAUTH-USER-NAME'] = session['user_name'] - resp.headers['OAUTH-USER-NICKNAME'] = session['user_nickname'] - resp.headers['OAUTH-USER-EMAIL'] = session['user_email'] - resp.headers['OAUTH-USER-LDAP-DN'] = session['user_ldap_dn'] - resp.headers['OAUTH-USER-GROUPS'] = ','.join(session['user_groups']) + resp.headers['OAUTH-USER-ID'] = user_id + resp.headers['OAUTH-USER-NAME'] = user_name + resp.headers['OAUTH-USER-NICKNAME'] = user_nickname + resp.headers['OAUTH-USER-EMAIL'] = user_email + resp.headers['OAUTH-USER-LDAP-DN'] = user_ldap_dn + resp.headers['OAUTH-USER-GROUPS'] = ','.join(user_groups) return resp def get_oauth(**kwargs): @@ -36,8 +44,14 @@ def create_app(test_config=None): @app.route('/login') def login(): + # The cookies_enabled check prevents redirect loops: + # login (sets state) -> idp_authorize -> callback (no state set) -> login + if not session.get('cookies_enabled'): + session['cookies_enabled'] = True + abort(400, description='Enable cookies and reload two times to continue') client = get_oauth() url, state = client.authorization_url(app.config['OAUTH2_AUTH_URL']) + session.clear() session['state'] = state parts = request.full_path.split('?rawurl=', 1) if len(parts) == 2: @@ -48,18 +62,27 @@ def create_app(test_config=None): @app.route('/callback') def callback(): - client = get_oauth(state=session.pop('state')) + redirect_url = session.get('url', '/') + if 'state' not in session: + session.clear() + session['cookies_enabled'] = True + return redirect(url_for('login', url=redirect_url)) + state = session['state'] + + client = get_oauth(state=state) client.fetch_token(app.config['OAUTH2_TOKEN_URL'], client_secret=request.headers['X-CLIENT-SECRET'], authorization_response=request.url, verify=(not app.debug)) userinfo = client.get(app.config['OAUTH2_USERINFO_URL']).json() + + session.clear() session['user_id'] = userinfo['id'] session['user_name'] = userinfo['name'] session['user_nickname'] = userinfo['nickname'] session['user_email'] = userinfo['email'] session['user_ldap_dn'] = userinfo['ldap_dn'] session['user_groups'] = userinfo['groups'] - return redirect(session.pop('url')) + return redirect(redirect_url) @app.route('/logout') def logout(): diff --git a/test_app.py b/test_app.py index f72e356f1885ee1fc884fb12ed823b048b41934b..cfe4f6310c350b9cc54c01748c59afe3b612cd56 100644 --- a/test_app.py +++ b/test_app.py @@ -6,7 +6,6 @@ except ImportError: import json import urllib.parse -from flask import session from requests import Session, Response from app import create_app @@ -78,8 +77,12 @@ class TestCases(unittest.TestCase): def test_auth_no_session(self): r = self.client.get(path='/auth', headers=headers) self.assertEqual(r.status_code, 401) + with self.client.session_transaction() as session: + self.assertEqual(session['cookies_enabled'], True) def test_login(self): + with self.client.session_transaction() as session: + session['cookies_enabled'] = True r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False) self.assertEqual(r.status_code, 302) url = urllib.parse.urlparse(r.location) @@ -91,8 +94,13 @@ class TestCases(unittest.TestCase): self.assertEqual(qs['client_id'], ['test_client_id']) self.assertEqual(qs['redirect_uri'], ['https://127.0.0.123:7654/callback']) self.assertGreater(len(qs['state'][0]), 8) - self.assertEqual(session['state'], qs['state'][0]) - self.assertEqual(session['url'], 'https://127.0.0.123:7654/app') + with self.client.session_transaction() as session: + self.assertEqual(session['state'], qs['state'][0]) + self.assertEqual(session['url'], 'https://127.0.0.123:7654/app') + + def test_login_no_cookies(self): + r = self.client.get(path='/login', query_string={'url': 'https://127.0.0.123:7654/app'}, headers=headers, follow_redirects=False) + self.assertEqual(r.status_code, 400) def test_callback(self): code = 'testcode' @@ -113,6 +121,16 @@ class TestCases(unittest.TestCase): self.assertNotIn('state', session) self.assertNotIn('url', session) + def test_callback_no_session(self): + code = 'testcode' + state = 'teststate' + r = self.client.get(path='/callback', headers=headers, query_string={'code': code, 'state': state}, follow_redirects=False) + self.assertEqual(r.status_code, 302) + url = urllib.parse.urlparse(r.location) + self.assertEqual(url.path, '/login') + with self.client.session_transaction() as session: + self.assertEqual(session['cookies_enabled'], True) + def test_auth_session(self): with self.client.session_transaction() as session: session['user_id'] = 1234