From 45d4598ef47e48bdc3f6f8528ba4a73386f60d40 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@cccv.de> Date: Thu, 2 Sep 2021 12:37:17 +0200 Subject: [PATCH] Replace flask_oauthlib with plain oauthlib flask_oauthlib is no longer available in Debian Bullseye. It is only a wrapper around oauthlib, which is still available. While this change does increase the OAuth2 code size, it achieves compatability with both Debian Buster and Bullseye. Aside from error handling, this change has no noticable effects on OAuth2.0 clients. In terms of error handling, a few cases that were not properly handled before now return appropriate error pages. Fixes #101 --- README.md | 2 +- setup.py | 2 +- tests/test_oauth2.py | 91 ++++++++++ uffd/oauth2/templates/oauth2/error.html | 8 - uffd/oauth2/views.py | 221 +++++++++++++++--------- 5 files changed, 233 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index 5990afda..80d01bd0 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Please note that we refer to Debian packages here and **not** pip packages. - python3-flask-migrate - python3-qrcode - python3-fido2 (version 0.5.0, optional) -- python3-flask-oauthlib +- python3-oauthlib - python3-flask-babel Some of the dependencies (especially fido2 and flask-oauthlib) changed their API in recent versions, so make sure to install the versions from Debian Buster. diff --git a/setup.py b/setup.py index c8196007..31353ff9 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( 'Flask-SQLAlchemy==2.1', 'qrcode==6.1', 'fido2==0.5.0', - 'Flask-OAuthlib==0.9.5', + 'oauthlib==2.1.0', 'Flask-Migrate==2.1.1', 'Flask-Babel==0.11.2', 'alembic==1.0.0', diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 9cb2ac3e..a41e3182 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -76,10 +76,50 @@ class TestViews(UffdTestCase): self.assertTrue(r.json.get('groups')) def test_authorization(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assert_authorization(r) + + def test_authorization_without_redirect_uri(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', scope='profile'), follow_redirects=False) + self.assert_authorization(r) + + def test_authorization_without_scope(self): self.login_as('user') r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback'), follow_redirects=False) self.assert_authorization(r) + def test_authorization_invalid_scope(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='invalid'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_scope', r) + + def test_authorization_missing_client_id(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_missing_client_id', r) + + def test_authorization_invalid_client_id(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='invalid_client_id', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_client_id', r) + + def test_authorization_missing_response_type(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_missing_response_type', r) + + def test_authorization_invalid_response_type(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='token', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_response_type', r) + def test_authorization_devicelogin_start(self): ref = url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback') r = self.client.get(path=url_for('session.devicelogin_start', ref=ref), follow_redirects=True) @@ -104,3 +144,54 @@ class TestViews(UffdTestCase): ref = url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback') r = self.client.post(path=url_for('session.devicelogin_submit', ref=ref), data={'confirmation-code': code}, follow_redirects=False) self.assert_authorization(r) + + def get_auth_code(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + while True: + if r.status_code != 302 or r.location.startswith('http://localhost:5009/callback'): + break + r = self.client.get(r.location, follow_redirects=False) + self.assertEqual(r.status_code, 302) + self.assertTrue(r.location.startswith('http://localhost:5009/callback')) + args = parse_qs(urlparse(r.location).query) + self.assertEqual(args['state'], ['teststate']) + return args['code'][0] + + def test_token_urlsecret(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.content_type, 'application/json') + self.assertEqual(r.json['token_type'], 'Bearer') + self.assertEqual(r.json['scope'], 'profile') + + def test_token_invalid_code(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_invalid_client(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'invalid_client', 'client_secret': 'invalid_client_secret'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_unauthorized_client(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_unsupported_grant_type(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'password', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 400) + self.assertEqual(r.content_type, 'application/json') + self.assertEqual(r.json['error'], 'unsupported_grant_type') + + def test_userinfo_invalid_access_token(self): + token = 'invalidtoken' + r = self.client.get(path=url_for('oauth2.userinfo'), headers=[('Authorization', 'Bearer %s'%token)], follow_redirects=True) + self.assertEqual(r.status_code, 401) diff --git a/uffd/oauth2/templates/oauth2/error.html b/uffd/oauth2/templates/oauth2/error.html index 2e7dba29..380d74b8 100644 --- a/uffd/oauth2/templates/oauth2/error.html +++ b/uffd/oauth2/templates/oauth2/error.html @@ -3,14 +3,6 @@ {% block body %} <h1>OAuth2.0 Authorization Error</h1> <p><b>Error: {{ error }}</b> {{ '(' + error_description + ')' if error_description else '' }}</p> -{% if args %} -<p>Parameters:</p> -<ul> - {% for key, value in args.items() %} - <li>{{ key }}={{ value }}</li> - {% endfor %} -</ul> -{% endif %} <hr> diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py index c3c33af0..c52fb137 100644 --- a/uffd/oauth2/views.py +++ b/uffd/oauth2/views.py @@ -1,9 +1,9 @@ import datetime import functools -import urllib.parse +import secrets -from flask import Blueprint, request, jsonify, render_template, session, redirect, url_for, flash -from flask_oauthlib.provider import OAuth2Provider +from flask import Blueprint, request, jsonify, render_template, session, redirect, url_for, flash, abort +import oauthlib.oauth2 from flask_babel import gettext as _ from sqlalchemy.exc import IntegrityError @@ -13,81 +13,132 @@ from uffd.secure_redirect import secure_local_redirect from uffd.session.models import DeviceLoginConfirmation from .models import OAuth2Client, OAuth2Grant, OAuth2Token, OAuth2DeviceLoginInitiation -oauth = OAuth2Provider() - -@oauth.clientgetter -def load_client(client_id): - return OAuth2Client.from_id(client_id) - -@oauth.grantgetter -def load_grant(client_id, code): - return OAuth2Grant.query.filter_by(client_id=client_id, code=code).first() - -@oauth.grantsetter -def save_grant(client_id, code, oauthreq, *args, **kwargs): # pylint: disable=unused-argument - expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) - grant = OAuth2Grant(user_dn=request.oauth2_user.dn, client_id=client_id, - code=code['code'], redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes)) - db.session.add(grant) - db.session.commit() - return grant - -@oauth.tokengetter -def load_token(access_token=None, refresh_token=None): - if access_token: - return OAuth2Token.query.filter_by(access_token=access_token).first() - if refresh_token: - return OAuth2Token.query.filter_by(refresh_token=refresh_token).first() - return None - -@oauth.tokensetter -def save_token(token_data, oauthreq, *args, **kwargs): # pylint: disable=unused-argument - OAuth2Token.query.filter_by(client_id=oauthreq.client.client_id, user_dn=oauthreq.user.dn).delete() - expires_in = token_data.get('expires_in') - expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in) - tok = OAuth2Token( - user_dn=oauthreq.user.dn, - client_id=oauthreq.client.client_id, - token_type=token_data['token_type'], - access_token=token_data['access_token'], - refresh_token=token_data['refresh_token'], - expires=expires, - _scopes=' '.join(oauthreq.scopes) - ) - db.session.add(tok) - db.session.commit() - return tok +class UffdRequestValidator(oauthlib.oauth2.RequestValidator): + # Argument "oauthreq" is named "request" in superclass but this clashes with flask's "request" object + # Arguments "token_value" and "token_data" are named "token" in superclass but this clashs with "token" endpoint + # pylint: disable=arguments-differ,arguments-renamed,unused-argument,too-many-public-methods,abstract-method + + # In all cases (aside from validate_bearer_token), either validate_client_id or authenticate_client is called + # before anything else. authenticate_client_id would be called instead of authenticate_client for non-confidential + # clients. However, we don't support those. + def validate_client_id(self, client_id, oauthreq, *args, **kwargs): + try: + oauthreq.client = OAuth2Client.from_id(client_id) + return True + except KeyError: + return False + + def authenticate_client(self, oauthreq, *args, **kwargs): + if oauthreq.client_secret is None: + return False + try: + oauthreq.client = OAuth2Client.from_id(oauthreq.client_id) + except KeyError: + return False + return secrets.compare_digest(oauthreq.client.client_secret, oauthreq.client_secret) + + def get_default_redirect_uri(self, client_id, oauthreq, *args, **kwargs): + return oauthreq.client.default_redirect_uri + def validate_redirect_uri(self, client_id, redirect_uri, oauthreq, *args, **kwargs): + return redirect_uri in oauthreq.client.redirect_uris + + def validate_response_type(self, client_id, response_type, client, oauthreq, *args, **kwargs): + return response_type == 'code' + + def get_default_scopes(self, client_id, oauthreq, *args, **kwargs): + return oauthreq.client.default_scopes + + def validate_scopes(self, client_id, scopes, client, oauthreq, *args, **kwargs): + return set(scopes).issubset({'profile'}) + + def save_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): + expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) + grant = OAuth2Grant(user_dn=oauthreq.user.dn, client_id=client_id, code=code['code'], + redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes)) + db.session.add(grant) + db.session.commit() + + def validate_code(self, client_id, code, client, oauthreq, *args, **kwargs): + oauthreq.grant = OAuth2Grant.query.filter_by(client_id=client_id, code=code).first() + if not oauthreq.grant: + return False + if datetime.datetime.utcnow() > oauthreq.grant.expires: + return False + oauthreq.user = oauthreq.grant.user + oauthreq.scopes = oauthreq.grant.scopes + return True + + def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): + OAuth2Grant.query.filter_by(client_id=client_id, code=code).delete() + db.session.commit() + + def save_bearer_token(self, token_data, oauthreq, *args, **kwargs): + OAuth2Token.query.filter_by(client_id=oauthreq.client.client_id, user_dn=oauthreq.user.dn).delete() + expires_in = token_data.get('expires_in') + expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in) + tok = OAuth2Token( + user_dn=oauthreq.user.dn, + client_id=oauthreq.client.client_id, + token_type=token_data['token_type'], + access_token=token_data['access_token'], + refresh_token=token_data['refresh_token'], + expires=expires, + _scopes=' '.join(oauthreq.scopes) + ) + db.session.add(tok) + db.session.commit() + return oauthreq.client.default_redirect_uri + + def validate_grant_type(self, client_id, grant_type, client, oauthreq, *args, **kwargs): + return grant_type == 'authorization_code' + + def confirm_redirect_uri(self, client_id, code, redirect_uri, client, oauthreq, *args, **kwargs): + return redirect_uri == oauthreq.grant.redirect_uri + + def validate_bearer_token(self, token_value, scopes, oauthreq): + tok = OAuth2Token.query.filter_by(access_token=token_value).first() + if not tok: + return False + if datetime.datetime.utcnow() > tok.expires: + oauthreq.error_message = 'Token expired' + return False + if not set(scopes).issubset(tok.scopes): + oauthreq.error_message = 'Scopes invalid' + return False + oauthreq.access_token = tok + oauthreq.user = tok.user + oauthreq.scopes = scopes + oauthreq.client = tok.client + oauthreq.client_id = tok.client_id + return True + + # get_original_scopes/validate_refresh_token are only used for refreshing tokens. We don't implement the refresh endpoint. + # revoke_token is only used for revoking access tokens. We don't implement the revoke endpoint. + # get_id_token/validate_silent_authorization/validate_silent_login are OpenID Connect specfic. + # validate_user/validate_user_match are not required for Authorization Code Grant flow. + +validator = UffdRequestValidator() +server = oauthlib.oauth2.WebApplicationServer(validator) bp = Blueprint('oauth2', __name__, url_prefix='/oauth2/', template_folder='templates') -@bp.record -def init(state): - state.app.config.setdefault('OAUTH2_PROVIDER_ERROR_ENDPOINT', 'oauth2.error') - oauth.init_app(state.app) - -# flask-oauthlib has the bug to require the scope parameter for authorize -# requests, which is actually optional according to the OAuth2.0 spec. -# We don't really use scopes and this requirement just complicates the -# configuration of clients. -# See also: https://github.com/lepture/flask-oauthlib/pull/320 -def inject_scope(func): +def display_oauth_errors(func): @functools.wraps(func) def decorator(*args, **kwargs): - args = request.args.to_dict() - if not args.get('scope'): - args['scope'] = 'profile' - return redirect(request.base_url+'?'+urllib.parse.urlencode(args)) - return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except oauthlib.oauth2.rfc6749.errors.OAuth2Error as ex: + return render_template('oauth2/error.html', error=type(ex).__name__, error_description=ex.description), 400 return decorator @bp.route('/authorize', methods=['GET', 'POST']) -@inject_scope -@oauth.authorize_handler -def authorize(*args, **kwargs): # pylint: disable=unused-argument - client = kwargs['request'].client - request.oauth2_user = None +@display_oauth_errors +def authorize(): + scopes, credentials = server.validate_authorization_request(request.url, request.method, request.form, request.headers) + client = OAuth2Client.from_id(credentials['client_id']) + if request.user: - request.oauth2_user = request.user + credentials['user'] = request.user elif 'devicelogin_started' in session: del session['devicelogin_started'] host_delay = host_ratelimit.get_delay() @@ -115,26 +166,41 @@ def authorize(*args, **kwargs): # pylint: disable=unused-argument if not initiation or initiation.expired or not confirmation: flash('Device login failed') return redirect(url_for('session.login', ref=request.full_path, devicelogin=True)) - request.oauth2_user = confirmation.user + credentials['user'] = confirmation.user db.session.delete(initiation) db.session.commit() else: return redirect(url_for('session.login', ref=request.full_path, devicelogin=True)) + # Here we would normally ask the user, if he wants to give the requesting # service access to his data. Since we only have trusted services (the # clients defined in the server config), we don't ask for consent. session['oauth2-clients'] = session.get('oauth2-clients', []) if client.client_id not in session['oauth2-clients']: session['oauth2-clients'].append(client.client_id) - return client.access_allowed(request.oauth2_user) + + headers, body, status = server.create_authorization_response(request.url, request.method, request.form, request.headers, scopes, credentials) + return body or '', status, headers @bp.route('/token', methods=['GET', 'POST']) -@oauth.token_handler def token(): - return None + headers, body, status = server.create_token_response(request.url, request.method, request.form, request.headers) + return body, status, headers + +def oauth_required(*scopes): + def wrapper(func): + @functools.wraps(func) + def decorator(*args, **kwargs): + valid, oauthreq = server.verify_request(request.url, request.method, request.form, request.headers, scopes) + if not valid: + abort(401) + request.oauth = oauthreq + return func(*args, **kwargs) + return decorator + return wrapper @bp.route('/userinfo') -@oauth.require_oauth('profile') +@oauth_required('profile') def userinfo(): user = request.oauth.user # We once exposed the entryUUID here as "ldap_uuid" until realising that it @@ -149,13 +215,6 @@ def userinfo(): groups=[group.name for group in user.groups] ) -@bp.route('/error') -def error(): - args = dict(request.values) - err = args.pop('error', 'unknown') - error_description = args.pop('error_description', '') - return render_template('oauth2/error.html', error=err, error_description=error_description, args=args) - @bp.app_url_defaults def inject_logout_params(endpoint, values): if endpoint != 'oauth2.logout' or not session.get('oauth2-clients'): -- GitLab