diff --git a/README.md b/README.md index 5990afdadb53afce9a5ba32646a3d4ffa86578cb..80d01bd0686fa8aaee9cda762c35bd449b280afb 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Please note that we refer to Debian packages here and **not** pip packages. - python3-flask-migrate - python3-qrcode - python3-fido2 (version 0.5.0, optional) -- python3-flask-oauthlib +- python3-oauthlib - python3-flask-babel Some of the dependencies (especially fido2 and flask-oauthlib) changed their API in recent versions, so make sure to install the versions from Debian Buster. diff --git a/setup.py b/setup.py index c81960073a4d95aebcaf09c0ec9c708653382b2e..31353ff9393d26873e329ada692dd24f6e4224b6 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ setup( 'Flask-SQLAlchemy==2.1', 'qrcode==6.1', 'fido2==0.5.0', - 'Flask-OAuthlib==0.9.5', + 'oauthlib==2.1.0', 'Flask-Migrate==2.1.1', 'Flask-Babel==0.11.2', 'alembic==1.0.0', diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 9cb2ac3e58bf6d4c7c321caf88621bc042ccd7df..a41e31828fe2b0f2125a5d10350ceb7482ac169e 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -76,10 +76,50 @@ class TestViews(UffdTestCase): self.assertTrue(r.json.get('groups')) def test_authorization(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assert_authorization(r) + + def test_authorization_without_redirect_uri(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', scope='profile'), follow_redirects=False) + self.assert_authorization(r) + + def test_authorization_without_scope(self): self.login_as('user') r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback'), follow_redirects=False) self.assert_authorization(r) + def test_authorization_invalid_scope(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='invalid'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_scope', r) + + def test_authorization_missing_client_id(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_missing_client_id', r) + + def test_authorization_invalid_client_id(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='invalid_client_id', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_client_id', r) + + def test_authorization_missing_response_type(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_missing_response_type', r) + + def test_authorization_invalid_response_type(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='token', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + self.assertEqual(r.status_code, 400) + dump('oauth2_authorization_invalid_response_type', r) + def test_authorization_devicelogin_start(self): ref = url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback') r = self.client.get(path=url_for('session.devicelogin_start', ref=ref), follow_redirects=True) @@ -104,3 +144,54 @@ class TestViews(UffdTestCase): ref = url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback') r = self.client.post(path=url_for('session.devicelogin_submit', ref=ref), data={'confirmation-code': code}, follow_redirects=False) self.assert_authorization(r) + + def get_auth_code(self): + self.login_as('user') + r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state='teststate', redirect_uri='http://localhost:5009/callback', scope='profile'), follow_redirects=False) + while True: + if r.status_code != 302 or r.location.startswith('http://localhost:5009/callback'): + break + r = self.client.get(r.location, follow_redirects=False) + self.assertEqual(r.status_code, 302) + self.assertTrue(r.location.startswith('http://localhost:5009/callback')) + args = parse_qs(urlparse(r.location).query) + self.assertEqual(args['state'], ['teststate']) + return args['code'][0] + + def test_token_urlsecret(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.content_type, 'application/json') + self.assertEqual(r.json['token_type'], 'Bearer') + self.assertEqual(r.json['scope'], 'profile') + + def test_token_invalid_code(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': 'abcdef', 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_invalid_client(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'invalid_client', 'client_secret': 'invalid_client_secret'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_unauthorized_client(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'authorization_code', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test'}, follow_redirects=True) + self.assertEqual(r.status_code, 401) + self.assertEqual(r.content_type, 'application/json') + + def test_token_unsupported_grant_type(self): + r = self.client.post(path=url_for('oauth2.token'), + data={'grant_type': 'password', 'code': self.get_auth_code(), 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True) + self.assertEqual(r.status_code, 400) + self.assertEqual(r.content_type, 'application/json') + self.assertEqual(r.json['error'], 'unsupported_grant_type') + + def test_userinfo_invalid_access_token(self): + token = 'invalidtoken' + r = self.client.get(path=url_for('oauth2.userinfo'), headers=[('Authorization', 'Bearer %s'%token)], follow_redirects=True) + self.assertEqual(r.status_code, 401) diff --git a/uffd/oauth2/templates/oauth2/error.html b/uffd/oauth2/templates/oauth2/error.html index 2e7dba29205623d85560aca11eb53b4eabaf063b..380d74b8773ed6538b6a7705d79e0d0d4de54605 100644 --- a/uffd/oauth2/templates/oauth2/error.html +++ b/uffd/oauth2/templates/oauth2/error.html @@ -3,14 +3,6 @@ {% block body %} <h1>OAuth2.0 Authorization Error</h1> <p><b>Error: {{ error }}</b> {{ '(' + error_description + ')' if error_description else '' }}</p> -{% if args %} -<p>Parameters:</p> -<ul> - {% for key, value in args.items() %} - <li>{{ key }}={{ value }}</li> - {% endfor %} -</ul> -{% endif %} <hr> diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py index c3c33af0f537c2f8ef116de82f024eeb2f6071ac..c52fb137b2fa71ee954aad6aad3315e7327c5eb0 100644 --- a/uffd/oauth2/views.py +++ b/uffd/oauth2/views.py @@ -1,9 +1,9 @@ import datetime import functools -import urllib.parse +import secrets -from flask import Blueprint, request, jsonify, render_template, session, redirect, url_for, flash -from flask_oauthlib.provider import OAuth2Provider +from flask import Blueprint, request, jsonify, render_template, session, redirect, url_for, flash, abort +import oauthlib.oauth2 from flask_babel import gettext as _ from sqlalchemy.exc import IntegrityError @@ -13,81 +13,132 @@ from uffd.secure_redirect import secure_local_redirect from uffd.session.models import DeviceLoginConfirmation from .models import OAuth2Client, OAuth2Grant, OAuth2Token, OAuth2DeviceLoginInitiation -oauth = OAuth2Provider() - -@oauth.clientgetter -def load_client(client_id): - return OAuth2Client.from_id(client_id) - -@oauth.grantgetter -def load_grant(client_id, code): - return OAuth2Grant.query.filter_by(client_id=client_id, code=code).first() - -@oauth.grantsetter -def save_grant(client_id, code, oauthreq, *args, **kwargs): # pylint: disable=unused-argument - expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) - grant = OAuth2Grant(user_dn=request.oauth2_user.dn, client_id=client_id, - code=code['code'], redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes)) - db.session.add(grant) - db.session.commit() - return grant - -@oauth.tokengetter -def load_token(access_token=None, refresh_token=None): - if access_token: - return OAuth2Token.query.filter_by(access_token=access_token).first() - if refresh_token: - return OAuth2Token.query.filter_by(refresh_token=refresh_token).first() - return None - -@oauth.tokensetter -def save_token(token_data, oauthreq, *args, **kwargs): # pylint: disable=unused-argument - OAuth2Token.query.filter_by(client_id=oauthreq.client.client_id, user_dn=oauthreq.user.dn).delete() - expires_in = token_data.get('expires_in') - expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in) - tok = OAuth2Token( - user_dn=oauthreq.user.dn, - client_id=oauthreq.client.client_id, - token_type=token_data['token_type'], - access_token=token_data['access_token'], - refresh_token=token_data['refresh_token'], - expires=expires, - _scopes=' '.join(oauthreq.scopes) - ) - db.session.add(tok) - db.session.commit() - return tok +class UffdRequestValidator(oauthlib.oauth2.RequestValidator): + # Argument "oauthreq" is named "request" in superclass but this clashes with flask's "request" object + # Arguments "token_value" and "token_data" are named "token" in superclass but this clashs with "token" endpoint + # pylint: disable=arguments-differ,arguments-renamed,unused-argument,too-many-public-methods,abstract-method + + # In all cases (aside from validate_bearer_token), either validate_client_id or authenticate_client is called + # before anything else. authenticate_client_id would be called instead of authenticate_client for non-confidential + # clients. However, we don't support those. + def validate_client_id(self, client_id, oauthreq, *args, **kwargs): + try: + oauthreq.client = OAuth2Client.from_id(client_id) + return True + except KeyError: + return False + + def authenticate_client(self, oauthreq, *args, **kwargs): + if oauthreq.client_secret is None: + return False + try: + oauthreq.client = OAuth2Client.from_id(oauthreq.client_id) + except KeyError: + return False + return secrets.compare_digest(oauthreq.client.client_secret, oauthreq.client_secret) + + def get_default_redirect_uri(self, client_id, oauthreq, *args, **kwargs): + return oauthreq.client.default_redirect_uri + def validate_redirect_uri(self, client_id, redirect_uri, oauthreq, *args, **kwargs): + return redirect_uri in oauthreq.client.redirect_uris + + def validate_response_type(self, client_id, response_type, client, oauthreq, *args, **kwargs): + return response_type == 'code' + + def get_default_scopes(self, client_id, oauthreq, *args, **kwargs): + return oauthreq.client.default_scopes + + def validate_scopes(self, client_id, scopes, client, oauthreq, *args, **kwargs): + return set(scopes).issubset({'profile'}) + + def save_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): + expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) + grant = OAuth2Grant(user_dn=oauthreq.user.dn, client_id=client_id, code=code['code'], + redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes)) + db.session.add(grant) + db.session.commit() + + def validate_code(self, client_id, code, client, oauthreq, *args, **kwargs): + oauthreq.grant = OAuth2Grant.query.filter_by(client_id=client_id, code=code).first() + if not oauthreq.grant: + return False + if datetime.datetime.utcnow() > oauthreq.grant.expires: + return False + oauthreq.user = oauthreq.grant.user + oauthreq.scopes = oauthreq.grant.scopes + return True + + def invalidate_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): + OAuth2Grant.query.filter_by(client_id=client_id, code=code).delete() + db.session.commit() + + def save_bearer_token(self, token_data, oauthreq, *args, **kwargs): + OAuth2Token.query.filter_by(client_id=oauthreq.client.client_id, user_dn=oauthreq.user.dn).delete() + expires_in = token_data.get('expires_in') + expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in) + tok = OAuth2Token( + user_dn=oauthreq.user.dn, + client_id=oauthreq.client.client_id, + token_type=token_data['token_type'], + access_token=token_data['access_token'], + refresh_token=token_data['refresh_token'], + expires=expires, + _scopes=' '.join(oauthreq.scopes) + ) + db.session.add(tok) + db.session.commit() + return oauthreq.client.default_redirect_uri + + def validate_grant_type(self, client_id, grant_type, client, oauthreq, *args, **kwargs): + return grant_type == 'authorization_code' + + def confirm_redirect_uri(self, client_id, code, redirect_uri, client, oauthreq, *args, **kwargs): + return redirect_uri == oauthreq.grant.redirect_uri + + def validate_bearer_token(self, token_value, scopes, oauthreq): + tok = OAuth2Token.query.filter_by(access_token=token_value).first() + if not tok: + return False + if datetime.datetime.utcnow() > tok.expires: + oauthreq.error_message = 'Token expired' + return False + if not set(scopes).issubset(tok.scopes): + oauthreq.error_message = 'Scopes invalid' + return False + oauthreq.access_token = tok + oauthreq.user = tok.user + oauthreq.scopes = scopes + oauthreq.client = tok.client + oauthreq.client_id = tok.client_id + return True + + # get_original_scopes/validate_refresh_token are only used for refreshing tokens. We don't implement the refresh endpoint. + # revoke_token is only used for revoking access tokens. We don't implement the revoke endpoint. + # get_id_token/validate_silent_authorization/validate_silent_login are OpenID Connect specfic. + # validate_user/validate_user_match are not required for Authorization Code Grant flow. + +validator = UffdRequestValidator() +server = oauthlib.oauth2.WebApplicationServer(validator) bp = Blueprint('oauth2', __name__, url_prefix='/oauth2/', template_folder='templates') -@bp.record -def init(state): - state.app.config.setdefault('OAUTH2_PROVIDER_ERROR_ENDPOINT', 'oauth2.error') - oauth.init_app(state.app) - -# flask-oauthlib has the bug to require the scope parameter for authorize -# requests, which is actually optional according to the OAuth2.0 spec. -# We don't really use scopes and this requirement just complicates the -# configuration of clients. -# See also: https://github.com/lepture/flask-oauthlib/pull/320 -def inject_scope(func): +def display_oauth_errors(func): @functools.wraps(func) def decorator(*args, **kwargs): - args = request.args.to_dict() - if not args.get('scope'): - args['scope'] = 'profile' - return redirect(request.base_url+'?'+urllib.parse.urlencode(args)) - return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except oauthlib.oauth2.rfc6749.errors.OAuth2Error as ex: + return render_template('oauth2/error.html', error=type(ex).__name__, error_description=ex.description), 400 return decorator @bp.route('/authorize', methods=['GET', 'POST']) -@inject_scope -@oauth.authorize_handler -def authorize(*args, **kwargs): # pylint: disable=unused-argument - client = kwargs['request'].client - request.oauth2_user = None +@display_oauth_errors +def authorize(): + scopes, credentials = server.validate_authorization_request(request.url, request.method, request.form, request.headers) + client = OAuth2Client.from_id(credentials['client_id']) + if request.user: - request.oauth2_user = request.user + credentials['user'] = request.user elif 'devicelogin_started' in session: del session['devicelogin_started'] host_delay = host_ratelimit.get_delay() @@ -115,26 +166,41 @@ def authorize(*args, **kwargs): # pylint: disable=unused-argument if not initiation or initiation.expired or not confirmation: flash('Device login failed') return redirect(url_for('session.login', ref=request.full_path, devicelogin=True)) - request.oauth2_user = confirmation.user + credentials['user'] = confirmation.user db.session.delete(initiation) db.session.commit() else: return redirect(url_for('session.login', ref=request.full_path, devicelogin=True)) + # Here we would normally ask the user, if he wants to give the requesting # service access to his data. Since we only have trusted services (the # clients defined in the server config), we don't ask for consent. session['oauth2-clients'] = session.get('oauth2-clients', []) if client.client_id not in session['oauth2-clients']: session['oauth2-clients'].append(client.client_id) - return client.access_allowed(request.oauth2_user) + + headers, body, status = server.create_authorization_response(request.url, request.method, request.form, request.headers, scopes, credentials) + return body or '', status, headers @bp.route('/token', methods=['GET', 'POST']) -@oauth.token_handler def token(): - return None + headers, body, status = server.create_token_response(request.url, request.method, request.form, request.headers) + return body, status, headers + +def oauth_required(*scopes): + def wrapper(func): + @functools.wraps(func) + def decorator(*args, **kwargs): + valid, oauthreq = server.verify_request(request.url, request.method, request.form, request.headers, scopes) + if not valid: + abort(401) + request.oauth = oauthreq + return func(*args, **kwargs) + return decorator + return wrapper @bp.route('/userinfo') -@oauth.require_oauth('profile') +@oauth_required('profile') def userinfo(): user = request.oauth.user # We once exposed the entryUUID here as "ldap_uuid" until realising that it @@ -149,13 +215,6 @@ def userinfo(): groups=[group.name for group in user.groups] ) -@bp.route('/error') -def error(): - args = dict(request.values) - err = args.pop('error', 'unknown') - error_description = args.pop('error_description', '') - return render_template('oauth2/error.html', error=err, error_description=error_description, args=args) - @bp.app_url_defaults def inject_logout_params(endpoint, values): if endpoint != 'oauth2.logout' or not session.get('oauth2-clients'):