From 0d48e6cbc52acf70ff811cdab0edce060d0ef74d Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Sun, 13 Jun 2021 22:51:13 +0200 Subject: [PATCH] Replace get_current_user/is_valid_session with request.user --- tests/test_invite.py | 2 +- tests/test_mfa.py | 79 ++++++++++++------------- tests/test_oauth2.py | 1 - tests/test_rolemod.py | 1 - tests/test_selfservice.py | 39 ++++++------ tests/test_session.py | 18 +++--- tests/test_signup.py | 16 ++--- tests/test_user.py | 1 - uffd/invite/templates/invite/list.html | 6 +- uffd/invite/templates/invite/use.html | 4 +- uffd/invite/views.py | 30 ++++------ uffd/mail/views.py | 4 +- uffd/mfa/views.py | 82 ++++++++++++-------------- uffd/oauth2/views.py | 6 +- uffd/role/views.py | 4 +- uffd/rolemod/views.py | 12 ++-- uffd/selfservice/views.py | 8 +-- uffd/services/views.py | 17 ++---- uffd/session/__init__.py | 2 +- uffd/session/views.py | 48 +++++++-------- uffd/templates/base.html | 2 +- uffd/user/views_group.py | 6 +- uffd/user/views_user.py | 4 +- 23 files changed, 180 insertions(+), 212 deletions(-) diff --git a/tests/test_invite.py b/tests/test_invite.py index d3691d05..c21a15c1 100644 --- a/tests/test_invite.py +++ b/tests/test_invite.py @@ -12,7 +12,7 @@ from uffd import create_app, db from uffd.invite.models import Invite, InviteGrant, InviteSignup from uffd.user.models import User, Group from uffd.role.models import Role -from uffd.session.views import get_current_user, is_valid_session, login_get_user +from uffd.session.views import login_get_user from utils import dump, UffdTestCase, db_flush diff --git a/tests/test_mfa.py b/tests/test_mfa.py index 4ea98102..dc0c657c 100644 --- a/tests/test_mfa.py +++ b/tests/test_mfa.py @@ -2,13 +2,12 @@ import unittest import datetime import time -from flask import url_for, session +from flask import url_for, session, request # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user from uffd.user.models import User -from uffd.session.views import get_current_user, is_valid_session from uffd.mfa.models import MFAMethod, MFAType, RecoveryCodeMethod, TOTPMethod, WebauthnMethod, _hotp from uffd import create_app, db @@ -176,21 +175,21 @@ class TestMfaViews(UffdTestCase): r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True) dump('mfa_disable_submit', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0) self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods) def test_disable_recovery_only(self): self.login() self.add_recovery_codes() admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all()) - self.assertNotEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertNotEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True) dump('mfa_disable_recovery_only', r) self.assertEqual(r.status_code, 200) r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True) dump('mfa_disable_recovery_only_submit', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0) self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods) def test_admin_disable(self): @@ -202,7 +201,7 @@ class TestMfaViews(UffdTestCase): self.add_totp() self.client.post(path=url_for('session.login'), data={'loginname': 'testadmin', 'password': 'adminpassword'}, follow_redirects=True) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all()) r = self.client.get(path=url_for('mfa.admin_disable', uid=get_user().uid), follow_redirects=True) dump('mfa_admin_disable', r) @@ -212,11 +211,11 @@ class TestMfaViews(UffdTestCase): def test_setup_recovery(self): self.login() - self.assertEqual(len(RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True) dump('mfa_setup_recovery', r) self.assertEqual(r.status_code, 200) - methods = RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all() + methods = RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all() self.assertNotEqual(len(methods), 0) r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True) dump('mfa_setup_recovery_reset', r) @@ -241,62 +240,62 @@ class TestMfaViews(UffdTestCase): def test_setup_totp_finish(self): self.login() self.add_recovery_codes() - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True) - method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', '')) + method = TOTPMethod(request.user, key=session.get('mfa_totp_key', '')) code = _hotp(int(time.time()/30), method.raw_key) r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True) dump('mfa_setup_totp_finish', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 1) def test_setup_totp_finish_without_recovery(self): self.login() - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True) - method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', '')) + method = TOTPMethod(request.user, key=session.get('mfa_totp_key', '')) code = _hotp(int(time.time()/30), method.raw_key) r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True) dump('mfa_setup_totp_finish_without_recovery', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) def test_setup_totp_finish_wrong_code(self): self.login() self.add_recovery_codes() - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True) - method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', '')) + method = TOTPMethod(request.user, key=session.get('mfa_totp_key', '')) code = _hotp(int(time.time()/30), method.raw_key) code = str(int(code[0])+1)[-1] + code[1:] r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True) dump('mfa_setup_totp_finish_wrong_code', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) def test_setup_totp_finish_empty_code(self): self.login() self.add_recovery_codes() - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True) r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': ''}, follow_redirects=True) dump('mfa_setup_totp_finish_empty_code', r) self.assertEqual(r.status_code, 200) - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0) def test_delete_totp(self): self.login() self.add_recovery_codes() self.add_totp() - method = TOTPMethod(get_current_user(), name='test') + method = TOTPMethod(request.user, name='test') db.session.add(method) db.session.commit() - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 2) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 2) r = self.client.get(path=url_for('mfa.delete_totp', id=method.id), follow_redirects=True) dump('mfa_delete_totp', r) self.assertEqual(r.status_code, 200) self.assertEqual(len(TOTPMethod.query.filter_by(id=method.id).all()), 0) - self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1) + self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 1) # TODO: webauthn setup tests @@ -304,36 +303,36 @@ class TestMfaViews(UffdTestCase): self.add_recovery_codes() self.add_totp() db.session.commit() - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('session.login'), data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True) dump('mfa_auth_redirected', r) self.assertEqual(r.status_code, 200) self.assertIn(b'/mfa/auth', r.data) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False) dump('mfa_auth', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) def test_auth_disabled(self): - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('session.login'), data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False) r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False) self.assertEqual(r.status_code, 302) self.assertTrue(r.location.endswith('/redirecttarget')) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) def test_auth_recovery_only(self): self.add_recovery_codes() - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('session.login'), data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False) r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False) self.assertEqual(r.status_code, 302) self.assertTrue(r.location.endswith('/redirecttarget')) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) def test_auth_recovery_code(self): self.add_recovery_codes() @@ -346,11 +345,11 @@ class TestMfaViews(UffdTestCase): r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False) dump('mfa_auth_recovery_code', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('mfa.auth_finish', ref='/redirecttarget'), data={'code': method.code}) self.assertEqual(r.status_code, 302) self.assertTrue(r.location.endswith('/redirecttarget')) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=method_id).all()), 0) def test_auth_totp_code(self): @@ -364,12 +363,12 @@ class TestMfaViews(UffdTestCase): r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False) dump('mfa_auth_totp_code', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) code = _hotp(int(time.time()/30), raw_key) r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True) dump('mfa_auth_totp_code_submit', r) self.assertEqual(r.status_code, 200) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) def test_auth_empty_code(self): self.add_recovery_codes() @@ -377,11 +376,11 @@ class TestMfaViews(UffdTestCase): self.login() r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': ''}, follow_redirects=True) dump('mfa_auth_empty_code', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) def test_auth_invalid_code(self): self.add_recovery_codes() @@ -393,13 +392,13 @@ class TestMfaViews(UffdTestCase): self.login() r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) code = _hotp(int(time.time()/30), raw_key) code = str(int(code[0])+1)[-1] + code[1:] r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True) dump('mfa_auth_invalid_code', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) def test_auth_ratelimit(self): self.add_recovery_codes() @@ -409,17 +408,17 @@ class TestMfaViews(UffdTestCase): db.session.add(method) db.session.commit() self.login() - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) code = _hotp(int(time.time()/30), raw_key) inv_code = str(int(code[0])+1)[-1] + code[1:] for i in range(20): r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': inv_code}, follow_redirects=True) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True) dump('mfa_auth_ratelimit', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) # TODO: webauthn auth tests diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 1dc7e1c5..b77a0992 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -6,7 +6,6 @@ from flask import url_for # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user -from uffd.session.views import get_current_user from uffd.user.models import User from uffd.oauth2.models import OAuth2Client from uffd import create_app, db, ldap diff --git a/tests/test_rolemod.py b/tests/test_rolemod.py index f8598ab6..f1baba1a 100644 --- a/tests/test_rolemod.py +++ b/tests/test_rolemod.py @@ -1,7 +1,6 @@ from flask import url_for from uffd.user.models import User, Group -from uffd.session import get_current_user from uffd.role.models import Role from uffd.database import db from uffd.ldap import ldap diff --git a/tests/test_selfservice.py b/tests/test_selfservice.py index 1d6ab94b..6d5d9698 100644 --- a/tests/test_selfservice.py +++ b/tests/test_selfservice.py @@ -1,12 +1,11 @@ import datetime import unittest -from flask import url_for +from flask import url_for, request # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user -from uffd.session.views import get_current_user from uffd.selfservice.models import MailToken, PasswordToken from uffd.user.models import User from uffd import create_app, db @@ -28,88 +27,88 @@ class TestSelfservice(UffdTestCase): r = self.client.get(path=url_for('selfservice.index')) dump('selfservice_index', r) self.assertEqual(r.status_code, 200) - user = get_current_user() + user = request.user self.assertIn(user.displayname.encode(), r.data) self.assertIn(user.loginname.encode(), r.data) self.assertIn(user.mail.encode(), r.data) def test_update_displayname(self): self.login() - user = get_current_user() + user = request.user r = self.client.post(path=url_for('selfservice.update'), data={'displayname': 'New Display Name', 'mail': user.mail, 'password': '', 'password1': ''}, follow_redirects=True) dump('update_displayname', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertEqual(_user.displayname, 'New Display Name') def test_update_displayname_invalid(self): self.login() - user = get_current_user() + user = request.user r = self.client.post(path=url_for('selfservice.update'), data={'displayname': '', 'mail': user.mail, 'password': '', 'password1': ''}, follow_redirects=True) dump('update_displayname_invalid', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertNotEqual(_user.displayname, '') def test_update_mail(self): self.login() - user = get_current_user() + user = request.user r = self.client.post(path=url_for('selfservice.update'), data={'displayname': user.displayname, 'mail': 'newemail@example.com', 'password': '', 'password1': ''}, follow_redirects=True) dump('update_mail', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertNotEqual(_user.mail, 'newemail@example.com') token = MailToken.query.filter(MailToken.loginname == user.loginname).first() self.assertEqual(token.newmail, 'newemail@example.com') self.assertIn(token.token, str(self.app.last_mail.get_content())) r = self.client.get(path=url_for('selfservice.token_mail', token=token.token), follow_redirects=True) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertEqual(_user.mail, 'newemail@example.com') def test_update_mail_sendfailure(self): self.app.config['MAIL_SKIP_SEND'] = 'fail' self.login() - user = get_current_user() + user = request.user r = self.client.post(path=url_for('selfservice.update'), data={'displayname': user.displayname, 'mail': 'newemail@example.com', 'password': '', 'password1': ''}, follow_redirects=True) dump('update_mail_sendfailure', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertNotEqual(_user.mail, 'newemail@example.com') # Maybe also check that there is no new token in the db def test_token_mail_emptydb(self): self.login() - user = get_current_user() + user = request.user r = self.client.get(path=url_for('selfservice.token_mail', token='A'*128), follow_redirects=True) dump('token_mail_emptydb', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertEqual(_user.mail, user.mail) def test_token_mail_invalid(self): self.login() - user = get_current_user() + user = request.user db.session.add(MailToken(loginname=user.loginname, newmail='newusermail@example.com')) db.session.commit() r = self.client.get(path=url_for('selfservice.token_mail', token='A'*128), follow_redirects=True) dump('token_mail_invalid', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertEqual(_user.mail, user.mail) @unittest.skip('See #26') def test_token_mail_wrong_user(self): self.login() - user = get_current_user() + user = request.user admin_user = User.query.get('uid=testadmin,ou=users,dc=example,dc=com') db.session.add(MailToken(loginname=user.loginname, newmail='newusermail@example.com')) admin_token = MailToken(loginname='testadmin', newmail='newadminmail@example.com') @@ -118,14 +117,14 @@ class TestSelfservice(UffdTestCase): r = self.client.get(path=url_for('selfservice.token_mail', token=admin_token.token), follow_redirects=True) dump('token_mail_wrong_user', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user _admin_user = User.query.get('uid=testadmin,ou=users,dc=example,dc=com') self.assertEqual(_user.mail, user.mail) self.assertEqual(_admin_user.mail, admin_user.mail) def test_token_mail_expired(self): self.login() - user = get_current_user() + user = request.user token = MailToken(loginname=user.loginname, newmail='newusermail@example.com', created=(datetime.datetime.now() - datetime.timedelta(days=10))) db.session.add(token) @@ -133,7 +132,7 @@ class TestSelfservice(UffdTestCase): r = self.client.get(path=url_for('selfservice.token_mail', token=token.token), follow_redirects=True) dump('token_mail_expired', r) self.assertEqual(r.status_code, 200) - _user = get_current_user() + _user = request.user self.assertEqual(_user.mail, user.mail) tokens = MailToken.query.filter(MailToken.loginname == user.loginname).all() self.assertEqual(len(tokens), 0) diff --git a/tests/test_session.py b/tests/test_session.py index dae41ab5..0882b08e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,12 +1,12 @@ import time import unittest -from flask import url_for +from flask import url_for, request # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user -from uffd.session.views import get_current_user, login_required, is_valid_session +from uffd.session.views import login_required from uffd import create_app, db from utils import dump, UffdTestCase @@ -32,24 +32,24 @@ class TestSession(UffdTestCase): def setUp(self): super().setUp() - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) def login(self): self.client.post(path=url_for('session.login'), data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True) - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) def assertLogin(self): - self.assertTrue(is_valid_session()) + self.assertIsNotNone(request.user) self.assertEqual(self.client.get(path=url_for('test_login_required'), follow_redirects=True).data, b'SUCCESS') - self.assertEqual(get_current_user().loginname, 'testuser') + self.assertEqual(request.user.loginname, 'testuser') def assertLogout(self): - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) self.assertNotEqual(self.client.get(path=url_for('test_login_required'), follow_redirects=True).data, b'SUCCESS') - self.assertEqual(get_current_user(), None) + self.assertEqual(request.user, None) def test_login(self): self.assertLogout() @@ -131,7 +131,7 @@ class TestSession(UffdTestCase): data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True) dump('login_ratelimit', r) self.assertEqual(r.status_code, 200) - self.assertFalse(is_valid_session()) + self.assertIsNone(request.user) class TestSessionOL(TestSession): use_openldap = True diff --git a/tests/test_signup.py b/tests/test_signup.py index 501e8896..65007789 100644 --- a/tests/test_signup.py +++ b/tests/test_signup.py @@ -2,7 +2,7 @@ import unittest import datetime import time -from flask import url_for, session +from flask import url_for, session, request # These imports are required, because otherwise we get circular imports?! from uffd import user @@ -11,7 +11,7 @@ from uffd.ldap import ldap from uffd import create_app, db from uffd.signup.models import Signup from uffd.user.models import User -from uffd.session.views import get_current_user, is_valid_session, login_get_user +from uffd.session.views import login_get_user from utils import dump, UffdTestCase, db_flush @@ -345,8 +345,8 @@ class TestSignupViews(UffdTestCase): self.assertEqual(signup.user.mail, 'test@example.com') if self.use_openldap: self.assertIsNotNone(login_get_user('newuser', 'notsecret')) - self.assertTrue(is_valid_session()) - self.assertEqual(get_current_user().loginname, 'newuser') + self.assertIsNotNone(request.user) + self.assertEqual(request.user.loginname, 'newuser') def test_confirm_loggedin(self): signup = Signup(loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret') @@ -354,16 +354,16 @@ class TestSignupViews(UffdTestCase): self.client.post(path=url_for('session.login'), data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True) self.assertFalse(signup.completed) - self.assertTrue(is_valid_session()) - self.assertEqual(get_current_user().loginname, 'testuser') + self.assertIsNotNone(request.user) + self.assertEqual(request.user.loginname, 'testuser') r = self.client.get(path=url_for('signup.signup_confirm', token=signup.token), follow_redirects=True) self.assertEqual(r.status_code, 200) r = self.client.post(path=url_for('signup.signup_confirm_submit', token=signup.token), follow_redirects=True, data={'password': 'notsecret'}) self.assertEqual(r.status_code, 200) signup = refetch_signup(signup) self.assertTrue(signup.completed) - self.assertTrue(is_valid_session()) - self.assertEqual(get_current_user().loginname, 'newuser') + self.assertIsNotNone(request.user) + self.assertEqual(request.user.loginname, 'newuser') def test_confirm_notfound(self): r = self.client.get(path=url_for('signup.signup_confirm', token='notasignuptoken'), follow_redirects=True) diff --git a/tests/test_user.py b/tests/test_user.py index baedf5bc..f9fb574b 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -6,7 +6,6 @@ from flask import url_for, session # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user -from uffd.session.views import get_current_user from uffd.user.models import User from uffd.role.models import Role from uffd import create_app, db diff --git a/uffd/invite/templates/invite/list.html b/uffd/invite/templates/invite/list.html index a371f9e4..00fd0ee4 100644 --- a/uffd/invite/templates/invite/list.html +++ b/uffd/invite/templates/invite/list.html @@ -21,7 +21,7 @@ {% for invite in invites|sort(attribute='created', reverse=True)|sort(attribute='active', reverse=True) %} <tr> <td> - {% if invite.creator == get_current_user() and invite.active %} + {% if invite.creator == request.user and invite.active %} <a href="{{ url_for('invite.use', token=invite.token) }}"><code>{{ invite.short_token }}</code></a> <button type="button" class="btn btn-link btn-sm p-0 copy-clipboard" data-copy="{{ url_for('invite.use', token=invite.token, _external=True) }}" title="Copy link to clipboard"><i class="fas fa-clipboard"></i></button> <button type="button" class="btn btn-link btn-sm p-0" data-toggle="modal" data-target="#modal-{{ invite.id }}-qrcode" title="Show link as QR code"><i class="fas fa-qrcode"></i></button> @@ -121,7 +121,7 @@ <form action="{{ url_for('invite.disable', invite_id=invite.id) }}" method="POST"> <button type="submit" class="btn btn-primary">Disable Link</button> </form> - {% elif invite.creator == get_current_user() and not invite.expired and invite.permitted %} + {% elif invite.creator == request.user and not invite.expired and invite.permitted %} <form action="{{ url_for('invite.reset', invite_id=invite.id) }}" method="POST"> <button type="submit" class="btn btn-primary">Reenable Link</button> </form> @@ -132,7 +132,7 @@ </div> {% endfor %} -{% for invite in invites if invite.creator == get_current_user() %} +{% for invite in invites if invite.creator == request.user %} <div class="modal" tabindex="-1" id="modal-{{ invite.id }}-qrcode"> <div class="modal-dialog"> <div class="modal-content"> diff --git a/uffd/invite/templates/invite/use.html b/uffd/invite/templates/invite/use.html index a51db885..4d5b53ad 100644 --- a/uffd/invite/templates/invite/use.html +++ b/uffd/invite/templates/invite/use.html @@ -10,7 +10,7 @@ <div class="col-12 mb-3"> <h2 class="text-center">Invite Link</h2> </div> - {% if not is_valid_session() %} + {% if not request.user %} <p>Welcome to the CCCV Single-Sign-On!</p> {% endif %} @@ -28,7 +28,7 @@ {% endfor %} </ul> {% endif %} - {% if is_valid_session() %} + {% if request.user %} {% if invite.roles %} <form method="POST" action="{{ url_for("invite.grant", token=invite.token) }}" class="mb-2"> <button type="submit" class="btn btn-primary btn-block">Add the roles to your account now</button> diff --git a/uffd/invite/views.py b/uffd/invite/views.py index 518a64cb..ece81c64 100644 --- a/uffd/invite/views.py +++ b/uffd/invite/views.py @@ -7,7 +7,7 @@ import sqlalchemy from uffd.csrf import csrf_protect from uffd.database import db from uffd.ldap import ldap -from uffd.session import get_current_user, login_required, is_valid_session +from uffd.session import login_required from uffd.role.models import Role from uffd.invite.models import Invite, InviteSignup, InviteGrant from uffd.user.models import User @@ -16,18 +16,16 @@ from uffd.navbar import register_navbar from uffd.ratelimit import host_ratelimit, format_delay from uffd.signup.views import signup_ratelimit - bp = Blueprint('invite', __name__, template_folder='templates', url_prefix='/invite/') def invite_acl(): - if not is_valid_session(): + if not request.user: return False - user = get_current_user() - if user.is_in_group(current_app.config['ACL_ADMIN_GROUP']): + if request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']): return True - if user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']): + if request.user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']): return True - if Role.query.filter(Role.moderator_group_dn.in_(user.group_dns)).count(): + if Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).count(): return True return False @@ -57,27 +55,25 @@ def reset_acl_filter(user): @register_navbar('Invites', icon='link', blueprint=bp, visible=invite_acl) @invite_acl_required def index(): - invites = Invite.query.filter(view_acl_filter(get_current_user())).all() + invites = Invite.query.filter(view_acl_filter(request.user)).all() return render_template('invite/list.html', invites=invites) @bp.route('/new') @invite_acl_required def new(): - user = get_current_user() - if user.is_in_group(current_app.config['ACL_ADMIN_GROUP']): + if request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']): allow_signup = True roles = Role.query.all() else: - allow_signup = user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']) - roles = Role.query.filter(Role.moderator_group_dn.in_(user.group_dns)).all() + allow_signup = request.user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']) + roles = Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).all() return render_template('invite/new.html', roles=roles, allow_signup=allow_signup) @bp.route('/new', methods=['POST']) @invite_acl_required @csrf_protect(blueprint=bp) def new_submit(): - user = get_current_user() - invite = Invite(creator=user, + invite = Invite(creator=request.user, single_use=(request.values['single-use'] == '1'), valid_until=datetime.datetime.fromisoformat(request.values['valid-until']), allow_signup=(request.values.get('allow-signup', '0') == '1')) @@ -101,7 +97,7 @@ def new_submit(): @invite_acl_required @csrf_protect(blueprint=bp) def disable(invite_id): - invite = Invite.query.filter(view_acl_filter(get_current_user())).filter_by(id=invite_id).first_or_404() + invite = Invite.query.filter(view_acl_filter(request.user)).filter_by(id=invite_id).first_or_404() invite.disable() db.session.commit() return redirect(url_for('.index')) @@ -110,7 +106,7 @@ def disable(invite_id): @invite_acl_required @csrf_protect(blueprint=bp) def reset(invite_id): - invite = Invite.query.filter(reset_acl_filter(get_current_user())).filter_by(id=invite_id).first_or_404() + invite = Invite.query.filter(reset_acl_filter(request.user)).filter_by(id=invite_id).first_or_404() invite.reset() db.session.commit() return redirect(url_for('.index')) @@ -128,7 +124,7 @@ def use(token): @csrf_protect(blueprint=bp) def grant(token): invite = Invite.query.filter_by(token=token).first_or_404() - invite_grant = InviteGrant(invite=invite, user=get_current_user()) + invite_grant = InviteGrant(invite=invite, user=request.user) db.session.add(invite_grant) success, msg = invite_grant.apply() if not success: diff --git a/uffd/mail/views.py b/uffd/mail/views.py index 2a7a1432..6ba83d10 100644 --- a/uffd/mail/views.py +++ b/uffd/mail/views.py @@ -3,7 +3,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash, from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.ldap import ldap -from uffd.session import login_required, is_valid_session, get_current_user +from uffd.session import login_required from uffd.mail.models import Mail @@ -16,7 +16,7 @@ def mail_acl(): #pylint: disable=inconsistent-return-statements return redirect(url_for('index')) def mail_acl_check(): - return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']) + return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']) @bp.route("/") @register_navbar('Mail', icon='envelope', blueprint=bp, visible=mail_acl_check) diff --git a/uffd/mfa/views.py b/uffd/mfa/views.py index d384b95c..7b95479f 100644 --- a/uffd/mfa/views.py +++ b/uffd/mfa/views.py @@ -5,7 +5,7 @@ from flask import Blueprint, render_template, session, request, redirect, url_fo from uffd.database import db from uffd.mfa.models import MFAMethod, TOTPMethod, WebauthnMethod, RecoveryCodeMethod -from uffd.session.views import get_current_user, login_required, pre_mfa_login_required +from uffd.session.views import login_required, login_required_pre_mfa, set_request_user from uffd.user.models import User from uffd.csrf import csrf_protect from uffd.ratelimit import Ratelimit, format_delay @@ -17,10 +17,9 @@ mfa_ratelimit = Ratelimit('mfa', 1*60, 3) @bp.route('/', methods=['GET']) @login_required() def setup(): - user = get_current_user() - recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all() - totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all() - webauthn_methods = WebauthnMethod.query.filter_by(dn=user.dn).all() + recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all() + totp_methods = TOTPMethod.query.filter_by(dn=request.user.dn).all() + webauthn_methods = WebauthnMethod.query.filter_by(dn=request.user.dn).all() return render_template('mfa/setup.html', totp_methods=totp_methods, webauthn_methods=webauthn_methods, recovery_methods=recovery_methods) @bp.route('/setup/disable', methods=['GET']) @@ -32,8 +31,7 @@ def disable(): @login_required() @csrf_protect(blueprint=bp) def disable_confirm(): - user = get_current_user() - MFAMethod.query.filter_by(dn=user.dn).delete() + MFAMethod.query.filter_by(dn=request.user.dn).delete() db.session.commit() return redirect(url_for('mfa.setup')) @@ -43,7 +41,7 @@ def disable_confirm(): def admin_disable(uid): # Group cannot be checked with login_required kwarg, because the config # variable is not available when the decorator is processed - if not get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']): + if not request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']): flash('Access denied') return redirect(url_for('index')) user = User.query.filter_by(uid=uid).one() @@ -56,12 +54,11 @@ def admin_disable(uid): @login_required() @csrf_protect(blueprint=bp) def setup_recovery(): - user = get_current_user() - for method in RecoveryCodeMethod.query.filter_by(dn=user.dn).all(): + for method in RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all(): db.session.delete(method) methods = [] for _ in range(10): - method = RecoveryCodeMethod(user) + method = RecoveryCodeMethod(request.user) methods.append(method) db.session.add(method) db.session.commit() @@ -70,8 +67,7 @@ def setup_recovery(): @bp.route('/setup/totp', methods=['GET']) @login_required() def setup_totp(): - user = get_current_user() - method = TOTPMethod(user) + method = TOTPMethod(request.user) session['mfa_totp_key'] = method.key return render_template('mfa/setup_totp.html', method=method, name=request.values['name']) @@ -79,11 +75,10 @@ def setup_totp(): @login_required() @csrf_protect(blueprint=bp) def setup_totp_finish(): - user = get_current_user() - if not RecoveryCodeMethod.query.filter_by(dn=user.dn).all(): + if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all(): flash('Generate recovery codes first!') return redirect(url_for('mfa.setup')) - method = TOTPMethod(user, name=request.values['name'], key=session.pop('mfa_totp_key')) + method = TOTPMethod(request.user, name=request.values['name'], key=session.pop('mfa_totp_key')) if method.verify(request.form['code']): db.session.add(method) db.session.commit() @@ -95,8 +90,7 @@ def setup_totp_finish(): @login_required() @csrf_protect(blueprint=bp) def delete_totp(id): #pylint: disable=redefined-builtin - user = get_current_user() - method = TOTPMethod.query.filter_by(dn=user.dn, id=id).first_or_404() + method = TOTPMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404() db.session.delete(method) db.session.commit() return redirect(url_for('mfa.setup')) @@ -124,17 +118,16 @@ if WEBAUTHN_SUPPORTED: @login_required() @csrf_protect(blueprint=bp) def setup_webauthn_begin(): - user = get_current_user() - if not RecoveryCodeMethod.query.filter_by(dn=user.dn).all(): + if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all(): abort(403) - methods = WebauthnMethod.query.filter_by(dn=user.dn).all() + methods = WebauthnMethod.query.filter_by(dn=request.user.dn).all() creds = [method.cred for method in methods] server = get_webauthn_server() registration_data, state = server.register_begin( { - "id": user.dn.encode(), - "name": user.loginname, - "displayName": user.displayname, + "id": request.user.dn.encode(), + "name": request.user.loginname, + "displayName": request.user.displayname, }, creds, user_verification='discouraged', @@ -146,23 +139,21 @@ if WEBAUTHN_SUPPORTED: @login_required() @csrf_protect(blueprint=bp) def setup_webauthn_complete(): - user = get_current_user() server = get_webauthn_server() data = cbor.loads(request.get_data())[0] client_data = ClientData(data["clientDataJSON"]) att_obj = AttestationObject(data["attestationObject"]) auth_data = server.register_complete(session["webauthn-state"], client_data, att_obj) - method = WebauthnMethod(user, auth_data.credential_data, name=data['name']) + method = WebauthnMethod(request.user, auth_data.credential_data, name=data['name']) db.session.add(method) db.session.commit() return cbor.dumps({"status": "OK"}) @bp.route("/auth/webauthn/begin", methods=["POST"]) - @pre_mfa_login_required(no_redirect=True) + @login_required_pre_mfa(no_redirect=True) def auth_webauthn_begin(): - user = get_current_user() server = get_webauthn_server() - methods = WebauthnMethod.query.filter_by(dn=user.dn).all() + methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() creds = [method.cred for method in methods] if not creds: abort(404) @@ -171,11 +162,10 @@ if WEBAUTHN_SUPPORTED: return cbor.dumps(auth_data) @bp.route("/auth/webauthn/complete", methods=["POST"]) - @pre_mfa_login_required(no_redirect=True) + @login_required_pre_mfa(no_redirect=True) def auth_webauthn_complete(): - user = get_current_user() server = get_webauthn_server() - methods = WebauthnMethod.query.filter_by(dn=user.dn).all() + methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() creds = [method.cred for method in methods] if not creds: abort(404) @@ -195,51 +185,53 @@ if WEBAUTHN_SUPPORTED: signature, ) session['user_mfa'] = True + set_request_user() return cbor.dumps({"status": "OK"}) @bp.route('/setup/webauthn/<int:id>/delete') @login_required() @csrf_protect(blueprint=bp) def delete_webauthn(id): #pylint: disable=redefined-builtin - user = get_current_user() - method = WebauthnMethod.query.filter_by(dn=user.dn, id=id).first_or_404() + method = WebauthnMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404() db.session.delete(method) db.session.commit() return redirect(url_for('mfa.setup')) @bp.route('/auth', methods=['GET']) -@pre_mfa_login_required() +@login_required_pre_mfa() def auth(): - user = get_current_user() - recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all() - totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all() - webauthn_methods = WebauthnMethod.query.filter_by(dn=user.dn).all() + recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() + totp_methods = TOTPMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() + webauthn_methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() if not totp_methods and not webauthn_methods: session['user_mfa'] = True + set_request_user() + if session.get('user_mfa'): return redirect(request.values.get('ref', url_for('index'))) return render_template('mfa/auth.html', ref=request.values.get('ref'), totp_methods=totp_methods, webauthn_methods=webauthn_methods, recovery_methods=recovery_methods) @bp.route('/auth', methods=['POST']) -@pre_mfa_login_required() +@login_required_pre_mfa() def auth_finish(): - user = get_current_user() - delay = mfa_ratelimit.get_delay(user.dn) + delay = mfa_ratelimit.get_delay(request.user_pre_mfa.dn) if delay: flash('We received too many invalid attempts! Please wait at least %s.'%format_delay(delay)) return redirect(url_for('mfa.auth', ref=request.values.get('ref'))) - recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all() - totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all() + recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() + totp_methods = TOTPMethod.query.filter_by(dn=request.user_pre_mfa.dn).all() for method in totp_methods: if method.verify(request.form['code']): session['user_mfa'] = True + set_request_user() return redirect(request.values.get('ref', url_for('index'))) for method in recovery_methods: if method.verify(request.form['code']): db.session.delete(method) db.session.commit() session['user_mfa'] = True + set_request_user() if len(recovery_methods) <= 1: flash('You have exhausted your recovery codes. Please generate new ones now!') return redirect(url_for('mfa.setup')) @@ -247,6 +239,6 @@ def auth_finish(): flash('You only have a few recovery codes remaining. Make sure to generate new ones before they run out.') return redirect(url_for('mfa.setup')) return redirect(request.values.get('ref', url_for('index'))) - mfa_ratelimit.log(user.dn) + mfa_ratelimit.log(request.user_pre_mfa.dn) flash('Two-factor authentication failed') return redirect(url_for('mfa.auth', ref=request.values.get('ref'))) diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py index b712a82c..12e62feb 100644 --- a/uffd/oauth2/views.py +++ b/uffd/oauth2/views.py @@ -7,7 +7,7 @@ from flask import Blueprint, request, jsonify, render_template, session, redirec from flask_oauthlib.provider import OAuth2Provider from uffd.database import db -from uffd.session.views import get_current_user, login_required +from uffd.session.views import login_required from .models import OAuth2Client, OAuth2Grant, OAuth2Token oauth = OAuth2Provider() @@ -23,7 +23,7 @@ def load_grant(client_id, code): @oauth.grantsetter def save_grant(client_id, code, oauthreq, *args, **kwargs): # pylint: disable=unused-argument expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) - grant = OAuth2Grant(user_dn=get_current_user().dn, client_id=client_id, + grant = OAuth2Grant(user_dn=request.user.dn, client_id=client_id, code=code['code'], redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes)) db.session.add(grant) db.session.commit() @@ -89,7 +89,7 @@ def authorize(*args, **kwargs): # pylint: disable=unused-argument session['oauth2-clients'] = session.get('oauth2-clients', []) if client.client_id not in session['oauth2-clients']: session['oauth2-clients'].append(client.client_id) - return client.access_allowed(get_current_user()) + return client.access_allowed(request.user) @bp.route('/token', methods=['GET', 'POST']) @oauth.token_handler diff --git a/uffd/role/views.py b/uffd/role/views.py index d2c3d846..9cf9eb9b 100644 --- a/uffd/role/views.py +++ b/uffd/role/views.py @@ -7,7 +7,7 @@ from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.role.models import Role from uffd.user.models import User, Group -from uffd.session import get_current_user, login_required, is_valid_session +from uffd.session import login_required from uffd.database import db from uffd.ldap import ldap @@ -44,7 +44,7 @@ def role_acl(): #pylint: disable=inconsistent-return-statements return redirect(url_for('index')) def role_acl_check(): - return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']) + return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']) @bp.route("/") @register_navbar('Roles', icon='key', blueprint=bp, visible=role_acl_check) diff --git a/uffd/rolemod/views.py b/uffd/rolemod/views.py index 174d923b..d9c5f834 100644 --- a/uffd/rolemod/views.py +++ b/uffd/rolemod/views.py @@ -4,14 +4,14 @@ from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.role.models import Role from uffd.user.models import User -from uffd.session import get_current_user, login_required, is_valid_session +from uffd.session import login_required from uffd.database import db from uffd.ldap import ldap bp = Blueprint('rolemod', __name__, template_folder='templates', url_prefix='/rolemod/') def user_is_rolemod(): - return is_valid_session() and Role.query.filter(Role.moderator_group_dn.in_(get_current_user().group_dns)).count() + return request.user and Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).count() @bp.before_request @login_required() @@ -23,7 +23,7 @@ def acl_check(): #pylint: disable=inconsistent-return-statements @bp.route("/") @register_navbar('Moderation', icon='user-lock', blueprint=bp, visible=user_is_rolemod) def index(): - roles = Role.query.filter(Role.moderator_group_dn.in_(get_current_user().group_dns)).all() + roles = Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).all() return render_template('rolemod/list.html', roles=roles) @bp.route("/<int:role_id>") @@ -31,7 +31,7 @@ def show(role_id): # prefetch all users so the ldap orm can cache them and doesn't run one ldap query per user User.query.all() role = Role.query.get_or_404(role_id) - if role.moderator_group not in get_current_user().groups: + if role.moderator_group not in request.user.groups: flash('Access denied') return redirect(url_for('index')) return render_template('rolemod/show.html', role=role) @@ -40,7 +40,7 @@ def show(role_id): @csrf_protect(blueprint=bp) def update(role_id): role = Role.query.get_or_404(role_id) - if role.moderator_group not in get_current_user().groups: + if role.moderator_group not in request.user.groups: flash('Access denied') return redirect(url_for('index')) if request.form['description'] != role.description: @@ -55,7 +55,7 @@ def update(role_id): @csrf_protect(blueprint=bp) def delete_member(role_id, member_dn): role = Role.query.get_or_404(role_id) - if role.moderator_group not in get_current_user().groups: + if role.moderator_group not in request.user.groups: flash('Access denied') return redirect(url_for('index')) member = User.query.get_or_404(member_dn) diff --git a/uffd/selfservice/views.py b/uffd/selfservice/views.py index 6bc07b25..d469ef69 100644 --- a/uffd/selfservice/views.py +++ b/uffd/selfservice/views.py @@ -9,7 +9,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash, from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.user.models import User -from uffd.session import get_current_user, login_required, is_valid_session +from uffd.session import login_required from uffd.selfservice.models import PasswordToken, MailToken from uffd.database import db from uffd.ldap import ldap @@ -20,17 +20,17 @@ bp = Blueprint("selfservice", __name__, template_folder='templates', url_prefix= reset_ratelimit = Ratelimit('passwordreset', 1*60*60, 3) @bp.route("/") -@register_navbar('Selfservice', icon='portrait', blueprint=bp, visible=is_valid_session) +@register_navbar('Selfservice', icon='portrait', blueprint=bp, visible=lambda: bool(request.user)) @login_required() def index(): - return render_template('selfservice/self.html', user=get_current_user()) + return render_template('selfservice/self.html', user=request.user) @bp.route("/update", methods=(['POST'])) @csrf_protect(blueprint=bp) @login_required() def update(): password_changed = False - user = get_current_user() + user = request.user if request.values['displayname'] != user.displayname: if user.set_displayname(request.values['displayname']): flash('Display name changed.') diff --git a/uffd/services/views.py b/uffd/services/views.py index f55be57f..585de42c 100644 --- a/uffd/services/views.py +++ b/uffd/services/views.py @@ -1,7 +1,6 @@ -from flask import Blueprint, render_template, current_app, abort +from flask import Blueprint, render_template, current_app, abort, request from uffd.navbar import register_navbar -from uffd.session import is_valid_session, get_current_user bp = Blueprint("services", __name__, template_folder='templates', url_prefix='/services') @@ -69,25 +68,19 @@ def get_services(user=None): return services def services_visible(): - user = None - if is_valid_session(): - user = get_current_user() - return len(get_services(user)) > 0 + return len(get_services(request.user)) > 0 @bp.route("/") @register_navbar('Services', icon='sitemap', blueprint=bp, visible=services_visible) def index(): - user = None - if is_valid_session(): - user = get_current_user() - services = get_services(user) + services = get_services(request.user) if not current_app.config['SERVICES']: abort(404) banner = current_app.config.get('SERVICES_BANNER') # Set the banner to None if it is not public and no user is logged in - if not (current_app.config["SERVICES_BANNER_PUBLIC"] or user): + if not (current_app.config["SERVICES_BANNER_PUBLIC"] or request.user): banner = None - return render_template('services/overview.html', user=user, services=services, banner=banner) + return render_template('services/overview.html', user=request.user, services=services, banner=banner) diff --git a/uffd/session/__init__.py b/uffd/session/__init__.py index 5cddcdc6..0e571f3a 100644 --- a/uffd/session/__init__.py +++ b/uffd/session/__init__.py @@ -1,3 +1,3 @@ -from .views import bp as bp_ui, get_current_user, login_required, is_valid_session, set_session +from .views import bp as bp_ui, login_required, set_session bp = [bp_ui] diff --git a/uffd/session/views.py b/uffd/session/views.py index 1b76519d..2c57bc6d 100644 --- a/uffd/session/views.py +++ b/uffd/session/views.py @@ -12,6 +12,21 @@ bp = Blueprint("session", __name__, template_folder='templates', url_prefix='/') login_ratelimit = Ratelimit('login', 1*60, 3) +@bp.before_app_request +def set_request_user(): + request.user = None + request.user_pre_mfa = None + if 'user_dn' not in session: + return + if 'logintime' not in session: + return + if datetime.datetime.now().timestamp() > session['logintime'] + current_app.config['SESSION_LIFETIME_SECONDS']: + return + user = User.query.get(session['user_dn']) + request.user_pre_mfa = user + if session.get('user_mfa'): + request.user = user + def login_get_user(loginname, password): dn = User(loginname=loginname).dn @@ -84,34 +99,11 @@ def login(): set_session(user, password=password) return redirect(url_for('mfa.auth', ref=request.values.get('ref', url_for('index')))) -def get_current_user(): - if 'user_dn' not in session: - return None - return User.query.get(session['user_dn']) -bp.add_app_template_global(get_current_user) - -def login_valid(): - user = get_current_user() - if user is None: - return False - if datetime.datetime.now().timestamp() > session['logintime'] + current_app.config['SESSION_LIFETIME_SECONDS']: - return False - return True - -def is_valid_session(): - if not login_valid(): - return False - if not session.get('user_mfa'): - return False - return True -bp.add_app_template_global(is_valid_session) - -def pre_mfa_login_required(no_redirect=False): +def login_required_pre_mfa(no_redirect=False): def wrapper(func): @functools.wraps(func) def decorator(*args, **kwargs): - if not login_valid() or datetime.datetime.now().timestamp() > session['logintime'] + 10*60: - session.clear() + if not request.user_pre_mfa: if no_redirect: abort(403) flash('You need to login first') @@ -124,12 +116,12 @@ def login_required(group=None): def wrapper(func): @functools.wraps(func) def decorator(*args, **kwargs): - if not login_valid(): + if not request.user_pre_mfa: flash('You need to login first') return redirect(url_for('session.login', ref=request.url)) - if not session.get('user_mfa'): + if not request.user: return redirect(url_for('mfa.auth', ref=request.url)) - if not get_current_user().is_in_group(group): + if not request.user.is_in_group(group): flash('Access denied') return redirect(url_for('index')) return func(*args, **kwargs) diff --git a/uffd/templates/base.html b/uffd/templates/base.html index 5ab3cba8..ec80472d 100644 --- a/uffd/templates/base.html +++ b/uffd/templates/base.html @@ -67,7 +67,7 @@ </li> {% endfor %} </ul> - {% if is_valid_session() %} + {% if request.user %} <ul class="navbar-nav ml-auto"> <li class="nav-item"> <a class="nav-link" href="{{ url_for("session.logout") }}"> diff --git a/uffd/user/views_group.py b/uffd/user/views_group.py index dca984e7..d4318b3c 100644 --- a/uffd/user/views_group.py +++ b/uffd/user/views_group.py @@ -1,7 +1,7 @@ -from flask import Blueprint, render_template, url_for, redirect, flash, current_app +from flask import Blueprint, render_template, url_for, redirect, flash, current_app, request from uffd.navbar import register_navbar -from uffd.session import login_required, is_valid_session, get_current_user +from uffd.session import login_required from .models import Group @@ -14,7 +14,7 @@ def group_acl(): #pylint: disable=inconsistent-return-statements return redirect(url_for('index')) def group_acl_check(): - return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']) + return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']) @bp.route("/") @register_navbar('Groups', icon='layer-group', blueprint=bp, visible=group_acl_check) diff --git a/uffd/user/views_user.py b/uffd/user/views_user.py index aadaf5af..a76beceb 100644 --- a/uffd/user/views_user.py +++ b/uffd/user/views_user.py @@ -6,7 +6,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash, from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.selfservice import send_passwordreset -from uffd.session import login_required, is_valid_session, get_current_user +from uffd.session import login_required from uffd.role.models import Role from uffd.database import db from uffd.ldap import ldap, LDAPCommitError @@ -22,7 +22,7 @@ def user_acl(): #pylint: disable=inconsistent-return-statements return redirect(url_for('index')) def user_acl_check(): - return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']) + return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']) @bp.route("/") @register_navbar('Users', icon='users', blueprint=bp, visible=user_acl_check) -- GitLab