From 0d48e6cbc52acf70ff811cdab0edce060d0ef74d Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Sun, 13 Jun 2021 22:51:13 +0200
Subject: [PATCH] Replace get_current_user/is_valid_session with request.user

---
 tests/test_invite.py                   |  2 +-
 tests/test_mfa.py                      | 79 ++++++++++++-------------
 tests/test_oauth2.py                   |  1 -
 tests/test_rolemod.py                  |  1 -
 tests/test_selfservice.py              | 39 ++++++------
 tests/test_session.py                  | 18 +++---
 tests/test_signup.py                   | 16 ++---
 tests/test_user.py                     |  1 -
 uffd/invite/templates/invite/list.html |  6 +-
 uffd/invite/templates/invite/use.html  |  4 +-
 uffd/invite/views.py                   | 30 ++++------
 uffd/mail/views.py                     |  4 +-
 uffd/mfa/views.py                      | 82 ++++++++++++--------------
 uffd/oauth2/views.py                   |  6 +-
 uffd/role/views.py                     |  4 +-
 uffd/rolemod/views.py                  | 12 ++--
 uffd/selfservice/views.py              |  8 +--
 uffd/services/views.py                 | 17 ++----
 uffd/session/__init__.py               |  2 +-
 uffd/session/views.py                  | 48 +++++++--------
 uffd/templates/base.html               |  2 +-
 uffd/user/views_group.py               |  6 +-
 uffd/user/views_user.py                |  4 +-
 23 files changed, 180 insertions(+), 212 deletions(-)

diff --git a/tests/test_invite.py b/tests/test_invite.py
index d3691d05..c21a15c1 100644
--- a/tests/test_invite.py
+++ b/tests/test_invite.py
@@ -12,7 +12,7 @@ from uffd import create_app, db
 from uffd.invite.models import Invite, InviteGrant, InviteSignup
 from uffd.user.models import User, Group
 from uffd.role.models import Role
-from uffd.session.views import get_current_user, is_valid_session, login_get_user
+from uffd.session.views import login_get_user
 
 from utils import dump, UffdTestCase, db_flush
 
diff --git a/tests/test_mfa.py b/tests/test_mfa.py
index 4ea98102..dc0c657c 100644
--- a/tests/test_mfa.py
+++ b/tests/test_mfa.py
@@ -2,13 +2,12 @@ import unittest
 import datetime
 import time
 
-from flask import url_for, session
+from flask import url_for, session, request
 
 # These imports are required, because otherwise we get circular imports?!
 from uffd import ldap, user
 
 from uffd.user.models import User
-from uffd.session.views import get_current_user, is_valid_session
 from uffd.mfa.models import MFAMethod, MFAType, RecoveryCodeMethod, TOTPMethod, WebauthnMethod, _hotp
 from uffd import create_app, db
 
@@ -176,21 +175,21 @@ class TestMfaViews(UffdTestCase):
 		r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
 		dump('mfa_disable_submit', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods)
 
 	def test_disable_recovery_only(self):
 		self.login()
 		self.add_recovery_codes()
 		admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all())
-		self.assertNotEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertNotEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True)
 		dump('mfa_disable_recovery_only', r)
 		self.assertEqual(r.status_code, 200)
 		r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
 		dump('mfa_disable_recovery_only_submit', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(MFAMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods)
 
 	def test_admin_disable(self):
@@ -202,7 +201,7 @@ class TestMfaViews(UffdTestCase):
 		self.add_totp()
 		self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testadmin', 'password': 'adminpassword'}, follow_redirects=True)
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 		admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all())
 		r = self.client.get(path=url_for('mfa.admin_disable', uid=get_user().uid), follow_redirects=True)
 		dump('mfa_admin_disable', r)
@@ -212,11 +211,11 @@ class TestMfaViews(UffdTestCase):
 
 	def test_setup_recovery(self):
 		self.login()
-		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
 		dump('mfa_setup_recovery', r)
 		self.assertEqual(r.status_code, 200)
-		methods = RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all()
+		methods = RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all()
 		self.assertNotEqual(len(methods), 0)
 		r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
 		dump('mfa_setup_recovery_reset', r)
@@ -241,62 +240,62 @@ class TestMfaViews(UffdTestCase):
 	def test_setup_totp_finish(self):
 		self.login()
 		self.add_recovery_codes()
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
-		method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
+		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
 		code = _hotp(int(time.time()/30), method.raw_key)
 		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
 		dump('mfa_setup_totp_finish', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 1)
 
 	def test_setup_totp_finish_without_recovery(self):
 		self.login()
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
-		method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
+		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
 		code = _hotp(int(time.time()/30), method.raw_key)
 		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
 		dump('mfa_setup_totp_finish_without_recovery', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 
 	def test_setup_totp_finish_wrong_code(self):
 		self.login()
 		self.add_recovery_codes()
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
-		method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
+		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
 		code = _hotp(int(time.time()/30), method.raw_key)
 		code = str(int(code[0])+1)[-1] + code[1:]
 		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
 		dump('mfa_setup_totp_finish_wrong_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 
 	def test_setup_totp_finish_empty_code(self):
 		self.login()
 		self.add_recovery_codes()
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
 		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': ''}, follow_redirects=True)
 		dump('mfa_setup_totp_finish_empty_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 0)
 
 	def test_delete_totp(self):
 		self.login()
 		self.add_recovery_codes()
 		self.add_totp()
-		method = TOTPMethod(get_current_user(), name='test')
+		method = TOTPMethod(request.user, name='test')
 		db.session.add(method)
 		db.session.commit()
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 2)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 2)
 		r = self.client.get(path=url_for('mfa.delete_totp', id=method.id), follow_redirects=True)
 		dump('mfa_delete_totp', r)
 		self.assertEqual(r.status_code, 200)
 		self.assertEqual(len(TOTPMethod.query.filter_by(id=method.id).all()), 0)
-		self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1)
+		self.assertEqual(len(TOTPMethod.query.filter_by(dn=request.user.dn).all()), 1)
 
 	# TODO: webauthn setup tests
 
@@ -304,36 +303,36 @@ class TestMfaViews(UffdTestCase):
 		self.add_recovery_codes()
 		self.add_totp()
 		db.session.commit()
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
 		dump('mfa_auth_redirected', r)
 		self.assertEqual(r.status_code, 200)
 		self.assertIn(b'/mfa/auth', r.data)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
 		dump('mfa_auth', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 	def test_auth_disabled(self):
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False)
 		r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
 		self.assertEqual(r.status_code, 302)
 		self.assertTrue(r.location.endswith('/redirecttarget'))
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 
 	def test_auth_recovery_only(self):
 		self.add_recovery_codes()
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False)
 		r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
 		self.assertEqual(r.status_code, 302)
 		self.assertTrue(r.location.endswith('/redirecttarget'))
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 
 	def test_auth_recovery_code(self):
 		self.add_recovery_codes()
@@ -346,11 +345,11 @@ class TestMfaViews(UffdTestCase):
 		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
 		dump('mfa_auth_recovery_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('mfa.auth_finish', ref='/redirecttarget'), data={'code': method.code})
 		self.assertEqual(r.status_code, 302)
 		self.assertTrue(r.location.endswith('/redirecttarget'))
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=method_id).all()), 0)
 
 	def test_auth_totp_code(self):
@@ -364,12 +363,12 @@ class TestMfaViews(UffdTestCase):
 		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
 		dump('mfa_auth_totp_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		code = _hotp(int(time.time()/30), raw_key)
 		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
 		dump('mfa_auth_totp_code_submit', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 
 	def test_auth_empty_code(self):
 		self.add_recovery_codes()
@@ -377,11 +376,11 @@ class TestMfaViews(UffdTestCase):
 		self.login()
 		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': ''}, follow_redirects=True)
 		dump('mfa_auth_empty_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 	def test_auth_invalid_code(self):
 		self.add_recovery_codes()
@@ -393,13 +392,13 @@ class TestMfaViews(UffdTestCase):
 		self.login()
 		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		code = _hotp(int(time.time()/30), raw_key)
 		code = str(int(code[0])+1)[-1] + code[1:]
 		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
 		dump('mfa_auth_invalid_code', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 	def test_auth_ratelimit(self):
 		self.add_recovery_codes()
@@ -409,17 +408,17 @@ class TestMfaViews(UffdTestCase):
 		db.session.add(method)
 		db.session.commit()
 		self.login()
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		code = _hotp(int(time.time()/30), raw_key)
 		inv_code = str(int(code[0])+1)[-1] + code[1:]
 		for i in range(20):
 			r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': inv_code}, follow_redirects=True)
 			self.assertEqual(r.status_code, 200)
-			self.assertFalse(is_valid_session())
+			self.assertIsNone(request.user)
 		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
 		dump('mfa_auth_ratelimit', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 	# TODO: webauthn auth tests
 
diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py
index 1dc7e1c5..b77a0992 100644
--- a/tests/test_oauth2.py
+++ b/tests/test_oauth2.py
@@ -6,7 +6,6 @@ from flask import url_for
 # These imports are required, because otherwise we get circular imports?!
 from uffd import ldap, user
 
-from uffd.session.views import get_current_user
 from uffd.user.models import User
 from uffd.oauth2.models import OAuth2Client
 from uffd import create_app, db, ldap
diff --git a/tests/test_rolemod.py b/tests/test_rolemod.py
index f8598ab6..f1baba1a 100644
--- a/tests/test_rolemod.py
+++ b/tests/test_rolemod.py
@@ -1,7 +1,6 @@
 from flask import url_for
 
 from uffd.user.models import User, Group
-from uffd.session import get_current_user
 from uffd.role.models import Role
 from uffd.database import db
 from uffd.ldap import ldap
diff --git a/tests/test_selfservice.py b/tests/test_selfservice.py
index 1d6ab94b..6d5d9698 100644
--- a/tests/test_selfservice.py
+++ b/tests/test_selfservice.py
@@ -1,12 +1,11 @@
 import datetime
 import unittest
 
-from flask import url_for
+from flask import url_for, request
 
 # These imports are required, because otherwise we get circular imports?!
 from uffd import ldap, user
 
-from uffd.session.views import get_current_user
 from uffd.selfservice.models import MailToken, PasswordToken
 from uffd.user.models import User
 from uffd import create_app, db
@@ -28,88 +27,88 @@ class TestSelfservice(UffdTestCase):
 		r = self.client.get(path=url_for('selfservice.index'))
 		dump('selfservice_index', r)
 		self.assertEqual(r.status_code, 200)
-		user = get_current_user()
+		user = request.user
 		self.assertIn(user.displayname.encode(), r.data)
 		self.assertIn(user.loginname.encode(), r.data)
 		self.assertIn(user.mail.encode(), r.data)
 
 	def test_update_displayname(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		r = self.client.post(path=url_for('selfservice.update'),
 			data={'displayname': 'New Display Name', 'mail': user.mail, 'password': '', 'password1': ''},
 			follow_redirects=True)
 		dump('update_displayname', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertEqual(_user.displayname, 'New Display Name')
 
 	def test_update_displayname_invalid(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		r = self.client.post(path=url_for('selfservice.update'),
 			data={'displayname': '', 'mail': user.mail, 'password': '', 'password1': ''},
 			follow_redirects=True)
 		dump('update_displayname_invalid', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertNotEqual(_user.displayname, '')
 
 	def test_update_mail(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		r = self.client.post(path=url_for('selfservice.update'),
 			data={'displayname': user.displayname, 'mail': 'newemail@example.com', 'password': '', 'password1': ''},
 			follow_redirects=True)
 		dump('update_mail', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertNotEqual(_user.mail, 'newemail@example.com')
 		token = MailToken.query.filter(MailToken.loginname == user.loginname).first()
 		self.assertEqual(token.newmail, 'newemail@example.com')
 		self.assertIn(token.token, str(self.app.last_mail.get_content()))
 		r = self.client.get(path=url_for('selfservice.token_mail', token=token.token), follow_redirects=True)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertEqual(_user.mail, 'newemail@example.com')
 
 	def test_update_mail_sendfailure(self):
 		self.app.config['MAIL_SKIP_SEND'] = 'fail'
 		self.login()
-		user = get_current_user()
+		user = request.user
 		r = self.client.post(path=url_for('selfservice.update'),
 			data={'displayname': user.displayname, 'mail': 'newemail@example.com', 'password': '', 'password1': ''},
 			follow_redirects=True)
 		dump('update_mail_sendfailure', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertNotEqual(_user.mail, 'newemail@example.com')
 		# Maybe also check that there is no new token in the db
 
 	def test_token_mail_emptydb(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		r = self.client.get(path=url_for('selfservice.token_mail', token='A'*128), follow_redirects=True)
 		dump('token_mail_emptydb', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertEqual(_user.mail, user.mail)
 
 	def test_token_mail_invalid(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		db.session.add(MailToken(loginname=user.loginname, newmail='newusermail@example.com'))
 		db.session.commit()
 		r = self.client.get(path=url_for('selfservice.token_mail', token='A'*128), follow_redirects=True)
 		dump('token_mail_invalid', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertEqual(_user.mail, user.mail)
 
 	@unittest.skip('See #26')
 	def test_token_mail_wrong_user(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		admin_user = User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
 		db.session.add(MailToken(loginname=user.loginname, newmail='newusermail@example.com'))
 		admin_token = MailToken(loginname='testadmin', newmail='newadminmail@example.com')
@@ -118,14 +117,14 @@ class TestSelfservice(UffdTestCase):
 		r = self.client.get(path=url_for('selfservice.token_mail', token=admin_token.token), follow_redirects=True)
 		dump('token_mail_wrong_user', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		_admin_user = User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
 		self.assertEqual(_user.mail, user.mail)
 		self.assertEqual(_admin_user.mail, admin_user.mail)
 
 	def test_token_mail_expired(self):
 		self.login()
-		user = get_current_user()
+		user = request.user
 		token = MailToken(loginname=user.loginname, newmail='newusermail@example.com',
 			created=(datetime.datetime.now() - datetime.timedelta(days=10)))
 		db.session.add(token)
@@ -133,7 +132,7 @@ class TestSelfservice(UffdTestCase):
 		r = self.client.get(path=url_for('selfservice.token_mail', token=token.token), follow_redirects=True)
 		dump('token_mail_expired', r)
 		self.assertEqual(r.status_code, 200)
-		_user = get_current_user()
+		_user = request.user
 		self.assertEqual(_user.mail, user.mail)
 		tokens = MailToken.query.filter(MailToken.loginname == user.loginname).all()
 		self.assertEqual(len(tokens), 0)
diff --git a/tests/test_session.py b/tests/test_session.py
index dae41ab5..0882b08e 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,12 +1,12 @@
 import time
 import unittest
 
-from flask import url_for
+from flask import url_for, request
 
 # These imports are required, because otherwise we get circular imports?!
 from uffd import ldap, user
 
-from uffd.session.views import get_current_user, login_required, is_valid_session
+from uffd.session.views import login_required
 from uffd import create_app, db
 
 from utils import dump, UffdTestCase
@@ -32,24 +32,24 @@ class TestSession(UffdTestCase):
 
 	def setUp(self):
 		super().setUp()
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 	def login(self):
 		self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 
 	def assertLogin(self):
-		self.assertTrue(is_valid_session())
+		self.assertIsNotNone(request.user)
 		self.assertEqual(self.client.get(path=url_for('test_login_required'),
 			follow_redirects=True).data, b'SUCCESS')
-		self.assertEqual(get_current_user().loginname, 'testuser')
+		self.assertEqual(request.user.loginname, 'testuser')
 
 	def assertLogout(self):
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 		self.assertNotEqual(self.client.get(path=url_for('test_login_required'),
 			follow_redirects=True).data, b'SUCCESS')
-		self.assertEqual(get_current_user(), None)
+		self.assertEqual(request.user, None)
 
 	def test_login(self):
 		self.assertLogout()
@@ -131,7 +131,7 @@ class TestSession(UffdTestCase):
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
 		dump('login_ratelimit', r)
 		self.assertEqual(r.status_code, 200)
-		self.assertFalse(is_valid_session())
+		self.assertIsNone(request.user)
 
 class TestSessionOL(TestSession):
 	use_openldap = True
diff --git a/tests/test_signup.py b/tests/test_signup.py
index 501e8896..65007789 100644
--- a/tests/test_signup.py
+++ b/tests/test_signup.py
@@ -2,7 +2,7 @@ import unittest
 import datetime
 import time
 
-from flask import url_for, session
+from flask import url_for, session, request
 
 # These imports are required, because otherwise we get circular imports?!
 from uffd import user
@@ -11,7 +11,7 @@ from uffd.ldap import ldap
 from uffd import create_app, db
 from uffd.signup.models import Signup
 from uffd.user.models import User
-from uffd.session.views import get_current_user, is_valid_session, login_get_user
+from uffd.session.views import login_get_user
 
 from utils import dump, UffdTestCase, db_flush
 
@@ -345,8 +345,8 @@ class TestSignupViews(UffdTestCase):
 		self.assertEqual(signup.user.mail, 'test@example.com')
 		if self.use_openldap:
 			self.assertIsNotNone(login_get_user('newuser', 'notsecret'))
-		self.assertTrue(is_valid_session())
-		self.assertEqual(get_current_user().loginname, 'newuser')
+		self.assertIsNotNone(request.user)
+		self.assertEqual(request.user.loginname, 'newuser')
 
 	def test_confirm_loggedin(self):
 		signup = Signup(loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
@@ -354,16 +354,16 @@ class TestSignupViews(UffdTestCase):
 		self.client.post(path=url_for('session.login'),
 			data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
 		self.assertFalse(signup.completed)
-		self.assertTrue(is_valid_session())
-		self.assertEqual(get_current_user().loginname, 'testuser')
+		self.assertIsNotNone(request.user)
+		self.assertEqual(request.user.loginname, 'testuser')
 		r = self.client.get(path=url_for('signup.signup_confirm', token=signup.token), follow_redirects=True)
 		self.assertEqual(r.status_code, 200)
 		r = self.client.post(path=url_for('signup.signup_confirm_submit', token=signup.token), follow_redirects=True, data={'password': 'notsecret'})
 		self.assertEqual(r.status_code, 200)
 		signup = refetch_signup(signup)
 		self.assertTrue(signup.completed)
-		self.assertTrue(is_valid_session())
-		self.assertEqual(get_current_user().loginname, 'newuser')
+		self.assertIsNotNone(request.user)
+		self.assertEqual(request.user.loginname, 'newuser')
 
 	def test_confirm_notfound(self):
 		r = self.client.get(path=url_for('signup.signup_confirm', token='notasignuptoken'), follow_redirects=True)
diff --git a/tests/test_user.py b/tests/test_user.py
index baedf5bc..f9fb574b 100644
--- a/tests/test_user.py
+++ b/tests/test_user.py
@@ -6,7 +6,6 @@ from flask import url_for, session
 # These imports are required, because otherwise we get circular imports?!
 from uffd import ldap, user
 
-from uffd.session.views import get_current_user
 from uffd.user.models import User
 from uffd.role.models import Role
 from uffd import create_app, db
diff --git a/uffd/invite/templates/invite/list.html b/uffd/invite/templates/invite/list.html
index a371f9e4..00fd0ee4 100644
--- a/uffd/invite/templates/invite/list.html
+++ b/uffd/invite/templates/invite/list.html
@@ -21,7 +21,7 @@
 			{% for invite in invites|sort(attribute='created', reverse=True)|sort(attribute='active', reverse=True) %}
 			<tr>
 				<td>
-					{% if invite.creator == get_current_user() and invite.active %}
+					{% if invite.creator == request.user and invite.active %}
 					<a href="{{ url_for('invite.use', token=invite.token) }}"><code>{{ invite.short_token }}</code></a>
 					<button type="button" class="btn btn-link btn-sm p-0 copy-clipboard" data-copy="{{ url_for('invite.use', token=invite.token, _external=True) }}" title="Copy link to clipboard"><i class="fas fa-clipboard"></i></button>
 					<button type="button" class="btn btn-link btn-sm p-0" data-toggle="modal" data-target="#modal-{{ invite.id }}-qrcode" title="Show link as QR code"><i class="fas fa-qrcode"></i></button>
@@ -121,7 +121,7 @@
 				<form action="{{ url_for('invite.disable', invite_id=invite.id) }}" method="POST">
 				<button type="submit" class="btn btn-primary">Disable Link</button>
 				</form>
-				{% elif invite.creator == get_current_user() and not invite.expired and invite.permitted %}
+				{% elif invite.creator == request.user and not invite.expired and invite.permitted %}
 				<form action="{{ url_for('invite.reset', invite_id=invite.id) }}" method="POST">
 				<button type="submit" class="btn btn-primary">Reenable Link</button>
 				</form>
@@ -132,7 +132,7 @@
 </div>
 {% endfor %}
 
-{% for invite in invites if invite.creator == get_current_user() %}
+{% for invite in invites if invite.creator == request.user %}
 <div class="modal" tabindex="-1" id="modal-{{ invite.id }}-qrcode">
 	<div class="modal-dialog">
 		<div class="modal-content">
diff --git a/uffd/invite/templates/invite/use.html b/uffd/invite/templates/invite/use.html
index a51db885..4d5b53ad 100644
--- a/uffd/invite/templates/invite/use.html
+++ b/uffd/invite/templates/invite/use.html
@@ -10,7 +10,7 @@
 		<div class="col-12 mb-3">
 			<h2 class="text-center">Invite Link</h2>
 		</div>
-		{% if not is_valid_session() %}
+		{% if not request.user %}
 		<p>Welcome to the CCCV Single-Sign-On!</p>
 		{% endif %}
 
@@ -28,7 +28,7 @@
 			{% endfor %}
 		</ul>
 		{% endif %}
-		{% if is_valid_session() %}
+		{% if request.user %}
 			{% if invite.roles %}
 				<form method="POST" action="{{ url_for("invite.grant", token=invite.token) }}" class="mb-2">
 					<button type="submit" class="btn btn-primary btn-block">Add the roles to your account now</button>
diff --git a/uffd/invite/views.py b/uffd/invite/views.py
index 518a64cb..ece81c64 100644
--- a/uffd/invite/views.py
+++ b/uffd/invite/views.py
@@ -7,7 +7,7 @@ import sqlalchemy
 from uffd.csrf import csrf_protect
 from uffd.database import db
 from uffd.ldap import ldap
-from uffd.session import get_current_user, login_required, is_valid_session
+from uffd.session import login_required
 from uffd.role.models import Role
 from uffd.invite.models import Invite, InviteSignup, InviteGrant
 from uffd.user.models import User
@@ -16,18 +16,16 @@ from uffd.navbar import register_navbar
 from uffd.ratelimit import host_ratelimit, format_delay
 from uffd.signup.views import signup_ratelimit
 
-
 bp = Blueprint('invite', __name__, template_folder='templates', url_prefix='/invite/')
 
 def invite_acl():
-	if not is_valid_session():
+	if not request.user:
 		return False
-	user = get_current_user()
-	if user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
+	if request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
 		return True
-	if user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']):
+	if request.user.is_in_group(current_app.config['ACL_SIGNUP_GROUP']):
 		return True
-	if Role.query.filter(Role.moderator_group_dn.in_(user.group_dns)).count():
+	if Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).count():
 		return True
 	return False
 
@@ -57,27 +55,25 @@ def reset_acl_filter(user):
 @register_navbar('Invites', icon='link', blueprint=bp, visible=invite_acl)
 @invite_acl_required
 def index():
-	invites = Invite.query.filter(view_acl_filter(get_current_user())).all()
+	invites = Invite.query.filter(view_acl_filter(request.user)).all()
 	return render_template('invite/list.html', invites=invites)
 
 @bp.route('/new')
 @invite_acl_required
 def new():
-	user = get_current_user()
-	if user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
+	if request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
 		allow_signup = True
 		roles = Role.query.all()
 	else:
-		allow_signup = user.is_in_group(current_app.config['ACL_SIGNUP_GROUP'])
-		roles = Role.query.filter(Role.moderator_group_dn.in_(user.group_dns)).all()
+		allow_signup = request.user.is_in_group(current_app.config['ACL_SIGNUP_GROUP'])
+		roles = Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).all()
 	return render_template('invite/new.html', roles=roles, allow_signup=allow_signup)
 
 @bp.route('/new', methods=['POST'])
 @invite_acl_required
 @csrf_protect(blueprint=bp)
 def new_submit():
-	user = get_current_user()
-	invite = Invite(creator=user,
+	invite = Invite(creator=request.user,
 	                single_use=(request.values['single-use'] == '1'),
 	                valid_until=datetime.datetime.fromisoformat(request.values['valid-until']),
 	                allow_signup=(request.values.get('allow-signup', '0') == '1'))
@@ -101,7 +97,7 @@ def new_submit():
 @invite_acl_required
 @csrf_protect(blueprint=bp)
 def disable(invite_id):
-	invite = Invite.query.filter(view_acl_filter(get_current_user())).filter_by(id=invite_id).first_or_404()
+	invite = Invite.query.filter(view_acl_filter(request.user)).filter_by(id=invite_id).first_or_404()
 	invite.disable()
 	db.session.commit()
 	return redirect(url_for('.index'))
@@ -110,7 +106,7 @@ def disable(invite_id):
 @invite_acl_required
 @csrf_protect(blueprint=bp)
 def reset(invite_id):
-	invite = Invite.query.filter(reset_acl_filter(get_current_user())).filter_by(id=invite_id).first_or_404()
+	invite = Invite.query.filter(reset_acl_filter(request.user)).filter_by(id=invite_id).first_or_404()
 	invite.reset()
 	db.session.commit()
 	return redirect(url_for('.index'))
@@ -128,7 +124,7 @@ def use(token):
 @csrf_protect(blueprint=bp)
 def grant(token):
 	invite = Invite.query.filter_by(token=token).first_or_404()
-	invite_grant = InviteGrant(invite=invite, user=get_current_user())
+	invite_grant = InviteGrant(invite=invite, user=request.user)
 	db.session.add(invite_grant)
 	success, msg = invite_grant.apply()
 	if not success:
diff --git a/uffd/mail/views.py b/uffd/mail/views.py
index 2a7a1432..6ba83d10 100644
--- a/uffd/mail/views.py
+++ b/uffd/mail/views.py
@@ -3,7 +3,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash,
 from uffd.navbar import register_navbar
 from uffd.csrf import csrf_protect
 from uffd.ldap import ldap
-from uffd.session import login_required, is_valid_session, get_current_user
+from uffd.session import login_required
 
 from uffd.mail.models import Mail
 
@@ -16,7 +16,7 @@ def mail_acl(): #pylint: disable=inconsistent-return-statements
 		return redirect(url_for('index'))
 
 def mail_acl_check():
-	return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP'])
+	return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP'])
 
 @bp.route("/")
 @register_navbar('Mail', icon='envelope', blueprint=bp, visible=mail_acl_check)
diff --git a/uffd/mfa/views.py b/uffd/mfa/views.py
index d384b95c..7b95479f 100644
--- a/uffd/mfa/views.py
+++ b/uffd/mfa/views.py
@@ -5,7 +5,7 @@ from flask import Blueprint, render_template, session, request, redirect, url_fo
 
 from uffd.database import db
 from uffd.mfa.models import MFAMethod, TOTPMethod, WebauthnMethod, RecoveryCodeMethod
-from uffd.session.views import get_current_user, login_required, pre_mfa_login_required
+from uffd.session.views import login_required, login_required_pre_mfa, set_request_user
 from uffd.user.models import User
 from uffd.csrf import csrf_protect
 from uffd.ratelimit import Ratelimit, format_delay
@@ -17,10 +17,9 @@ mfa_ratelimit = Ratelimit('mfa', 1*60, 3)
 @bp.route('/', methods=['GET'])
 @login_required()
 def setup():
-	user = get_current_user()
-	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all()
-	totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all()
-	webauthn_methods = WebauthnMethod.query.filter_by(dn=user.dn).all()
+	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all()
+	totp_methods = TOTPMethod.query.filter_by(dn=request.user.dn).all()
+	webauthn_methods = WebauthnMethod.query.filter_by(dn=request.user.dn).all()
 	return render_template('mfa/setup.html', totp_methods=totp_methods, webauthn_methods=webauthn_methods, recovery_methods=recovery_methods)
 
 @bp.route('/setup/disable', methods=['GET'])
@@ -32,8 +31,7 @@ def disable():
 @login_required()
 @csrf_protect(blueprint=bp)
 def disable_confirm():
-	user = get_current_user()
-	MFAMethod.query.filter_by(dn=user.dn).delete()
+	MFAMethod.query.filter_by(dn=request.user.dn).delete()
 	db.session.commit()
 	return redirect(url_for('mfa.setup'))
 
@@ -43,7 +41,7 @@ def disable_confirm():
 def admin_disable(uid):
 	# Group cannot be checked with login_required kwarg, because the config
 	# variable is not available when the decorator is processed
-	if not get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP']):
+	if not request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
 		flash('Access denied')
 		return redirect(url_for('index'))
 	user = User.query.filter_by(uid=uid).one()
@@ -56,12 +54,11 @@ def admin_disable(uid):
 @login_required()
 @csrf_protect(blueprint=bp)
 def setup_recovery():
-	user = get_current_user()
-	for method in RecoveryCodeMethod.query.filter_by(dn=user.dn).all():
+	for method in RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
 		db.session.delete(method)
 	methods = []
 	for _ in range(10):
-		method = RecoveryCodeMethod(user)
+		method = RecoveryCodeMethod(request.user)
 		methods.append(method)
 		db.session.add(method)
 	db.session.commit()
@@ -70,8 +67,7 @@ def setup_recovery():
 @bp.route('/setup/totp', methods=['GET'])
 @login_required()
 def setup_totp():
-	user = get_current_user()
-	method = TOTPMethod(user)
+	method = TOTPMethod(request.user)
 	session['mfa_totp_key'] = method.key
 	return render_template('mfa/setup_totp.html', method=method, name=request.values['name'])
 
@@ -79,11 +75,10 @@ def setup_totp():
 @login_required()
 @csrf_protect(blueprint=bp)
 def setup_totp_finish():
-	user = get_current_user()
-	if not RecoveryCodeMethod.query.filter_by(dn=user.dn).all():
+	if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
 		flash('Generate recovery codes first!')
 		return redirect(url_for('mfa.setup'))
-	method = TOTPMethod(user, name=request.values['name'], key=session.pop('mfa_totp_key'))
+	method = TOTPMethod(request.user, name=request.values['name'], key=session.pop('mfa_totp_key'))
 	if method.verify(request.form['code']):
 		db.session.add(method)
 		db.session.commit()
@@ -95,8 +90,7 @@ def setup_totp_finish():
 @login_required()
 @csrf_protect(blueprint=bp)
 def delete_totp(id): #pylint: disable=redefined-builtin
-	user = get_current_user()
-	method = TOTPMethod.query.filter_by(dn=user.dn, id=id).first_or_404()
+	method = TOTPMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404()
 	db.session.delete(method)
 	db.session.commit()
 	return redirect(url_for('mfa.setup'))
@@ -124,17 +118,16 @@ if WEBAUTHN_SUPPORTED:
 	@login_required()
 	@csrf_protect(blueprint=bp)
 	def setup_webauthn_begin():
-		user = get_current_user()
-		if not RecoveryCodeMethod.query.filter_by(dn=user.dn).all():
+		if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
 			abort(403)
-		methods = WebauthnMethod.query.filter_by(dn=user.dn).all()
+		methods = WebauthnMethod.query.filter_by(dn=request.user.dn).all()
 		creds = [method.cred for method in methods]
 		server = get_webauthn_server()
 		registration_data, state = server.register_begin(
 			{
-				"id": user.dn.encode(),
-				"name": user.loginname,
-				"displayName": user.displayname,
+				"id": request.user.dn.encode(),
+				"name": request.user.loginname,
+				"displayName": request.user.displayname,
 			},
 			creds,
 			user_verification='discouraged',
@@ -146,23 +139,21 @@ if WEBAUTHN_SUPPORTED:
 	@login_required()
 	@csrf_protect(blueprint=bp)
 	def setup_webauthn_complete():
-		user = get_current_user()
 		server = get_webauthn_server()
 		data = cbor.loads(request.get_data())[0]
 		client_data = ClientData(data["clientDataJSON"])
 		att_obj = AttestationObject(data["attestationObject"])
 		auth_data = server.register_complete(session["webauthn-state"], client_data, att_obj)
-		method = WebauthnMethod(user, auth_data.credential_data, name=data['name'])
+		method = WebauthnMethod(request.user, auth_data.credential_data, name=data['name'])
 		db.session.add(method)
 		db.session.commit()
 		return cbor.dumps({"status": "OK"})
 
 	@bp.route("/auth/webauthn/begin", methods=["POST"])
-	@pre_mfa_login_required(no_redirect=True)
+	@login_required_pre_mfa(no_redirect=True)
 	def auth_webauthn_begin():
-		user = get_current_user()
 		server = get_webauthn_server()
-		methods = WebauthnMethod.query.filter_by(dn=user.dn).all()
+		methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
 		creds = [method.cred for method in methods]
 		if not creds:
 			abort(404)
@@ -171,11 +162,10 @@ if WEBAUTHN_SUPPORTED:
 		return cbor.dumps(auth_data)
 
 	@bp.route("/auth/webauthn/complete", methods=["POST"])
-	@pre_mfa_login_required(no_redirect=True)
+	@login_required_pre_mfa(no_redirect=True)
 	def auth_webauthn_complete():
-		user = get_current_user()
 		server = get_webauthn_server()
-		methods = WebauthnMethod.query.filter_by(dn=user.dn).all()
+		methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
 		creds = [method.cred for method in methods]
 		if not creds:
 			abort(404)
@@ -195,51 +185,53 @@ if WEBAUTHN_SUPPORTED:
 			signature,
 		)
 		session['user_mfa'] = True
+		set_request_user()
 		return cbor.dumps({"status": "OK"})
 
 @bp.route('/setup/webauthn/<int:id>/delete')
 @login_required()
 @csrf_protect(blueprint=bp)
 def delete_webauthn(id): #pylint: disable=redefined-builtin
-	user = get_current_user()
-	method = WebauthnMethod.query.filter_by(dn=user.dn, id=id).first_or_404()
+	method = WebauthnMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404()
 	db.session.delete(method)
 	db.session.commit()
 	return redirect(url_for('mfa.setup'))
 
 @bp.route('/auth', methods=['GET'])
-@pre_mfa_login_required()
+@login_required_pre_mfa()
 def auth():
-	user = get_current_user()
-	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all()
-	totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all()
-	webauthn_methods = WebauthnMethod.query.filter_by(dn=user.dn).all()
+	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
+	totp_methods = TOTPMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
+	webauthn_methods = WebauthnMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
 	if not totp_methods and not webauthn_methods:
 		session['user_mfa'] = True
+		set_request_user()
+
 	if session.get('user_mfa'):
 		return redirect(request.values.get('ref', url_for('index')))
 	return render_template('mfa/auth.html', ref=request.values.get('ref'), totp_methods=totp_methods,
 			webauthn_methods=webauthn_methods, recovery_methods=recovery_methods)
 
 @bp.route('/auth', methods=['POST'])
-@pre_mfa_login_required()
+@login_required_pre_mfa()
 def auth_finish():
-	user = get_current_user()
-	delay = mfa_ratelimit.get_delay(user.dn)
+	delay = mfa_ratelimit.get_delay(request.user_pre_mfa.dn)
 	if delay:
 		flash('We received too many invalid attempts! Please wait at least %s.'%format_delay(delay))
 		return redirect(url_for('mfa.auth', ref=request.values.get('ref')))
-	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=user.dn).all()
-	totp_methods = TOTPMethod.query.filter_by(dn=user.dn).all()
+	recovery_methods = RecoveryCodeMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
+	totp_methods = TOTPMethod.query.filter_by(dn=request.user_pre_mfa.dn).all()
 	for method in totp_methods:
 		if method.verify(request.form['code']):
 			session['user_mfa'] = True
+			set_request_user()
 			return redirect(request.values.get('ref', url_for('index')))
 	for method in recovery_methods:
 		if method.verify(request.form['code']):
 			db.session.delete(method)
 			db.session.commit()
 			session['user_mfa'] = True
+			set_request_user()
 			if len(recovery_methods) <= 1:
 				flash('You have exhausted your recovery codes. Please generate new ones now!')
 				return redirect(url_for('mfa.setup'))
@@ -247,6 +239,6 @@ def auth_finish():
 				flash('You only have a few recovery codes remaining. Make sure to generate new ones before they run out.')
 				return redirect(url_for('mfa.setup'))
 			return redirect(request.values.get('ref', url_for('index')))
-	mfa_ratelimit.log(user.dn)
+	mfa_ratelimit.log(request.user_pre_mfa.dn)
 	flash('Two-factor authentication failed')
 	return redirect(url_for('mfa.auth', ref=request.values.get('ref')))
diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py
index b712a82c..12e62feb 100644
--- a/uffd/oauth2/views.py
+++ b/uffd/oauth2/views.py
@@ -7,7 +7,7 @@ from flask import Blueprint, request, jsonify, render_template, session, redirec
 from flask_oauthlib.provider import OAuth2Provider
 
 from uffd.database import db
-from uffd.session.views import get_current_user, login_required
+from uffd.session.views import login_required
 from .models import OAuth2Client, OAuth2Grant, OAuth2Token
 
 oauth = OAuth2Provider()
@@ -23,7 +23,7 @@ def load_grant(client_id, code):
 @oauth.grantsetter
 def save_grant(client_id, code, oauthreq, *args, **kwargs): # pylint: disable=unused-argument
 	expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100)
-	grant = OAuth2Grant(user_dn=get_current_user().dn, client_id=client_id,
+	grant = OAuth2Grant(user_dn=request.user.dn, client_id=client_id,
 		code=code['code'], redirect_uri=oauthreq.redirect_uri, expires=expires, _scopes=' '.join(oauthreq.scopes))
 	db.session.add(grant)
 	db.session.commit()
@@ -89,7 +89,7 @@ def authorize(*args, **kwargs): # pylint: disable=unused-argument
 	session['oauth2-clients'] = session.get('oauth2-clients', [])
 	if client.client_id not in session['oauth2-clients']:
 		session['oauth2-clients'].append(client.client_id)
-	return client.access_allowed(get_current_user())
+	return client.access_allowed(request.user)
 
 @bp.route('/token', methods=['GET', 'POST'])
 @oauth.token_handler
diff --git a/uffd/role/views.py b/uffd/role/views.py
index d2c3d846..9cf9eb9b 100644
--- a/uffd/role/views.py
+++ b/uffd/role/views.py
@@ -7,7 +7,7 @@ from uffd.navbar import register_navbar
 from uffd.csrf import csrf_protect
 from uffd.role.models import Role
 from uffd.user.models import User, Group
-from uffd.session import get_current_user, login_required, is_valid_session
+from uffd.session import login_required
 from uffd.database import db
 from uffd.ldap import ldap
 
@@ -44,7 +44,7 @@ def role_acl(): #pylint: disable=inconsistent-return-statements
 		return redirect(url_for('index'))
 
 def role_acl_check():
-	return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP'])
+	return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP'])
 
 @bp.route("/")
 @register_navbar('Roles', icon='key', blueprint=bp, visible=role_acl_check)
diff --git a/uffd/rolemod/views.py b/uffd/rolemod/views.py
index 174d923b..d9c5f834 100644
--- a/uffd/rolemod/views.py
+++ b/uffd/rolemod/views.py
@@ -4,14 +4,14 @@ from uffd.navbar import register_navbar
 from uffd.csrf import csrf_protect
 from uffd.role.models import Role
 from uffd.user.models import User
-from uffd.session import get_current_user, login_required, is_valid_session
+from uffd.session import login_required
 from uffd.database import db
 from uffd.ldap import ldap
 
 bp = Blueprint('rolemod', __name__, template_folder='templates', url_prefix='/rolemod/')
 
 def user_is_rolemod():
-	return is_valid_session() and Role.query.filter(Role.moderator_group_dn.in_(get_current_user().group_dns)).count()
+	return request.user and Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).count()
 
 @bp.before_request
 @login_required()
@@ -23,7 +23,7 @@ def acl_check(): #pylint: disable=inconsistent-return-statements
 @bp.route("/")
 @register_navbar('Moderation', icon='user-lock', blueprint=bp, visible=user_is_rolemod)
 def index():
-	roles = Role.query.filter(Role.moderator_group_dn.in_(get_current_user().group_dns)).all()
+	roles = Role.query.filter(Role.moderator_group_dn.in_(request.user.group_dns)).all()
 	return render_template('rolemod/list.html', roles=roles)
 
 @bp.route("/<int:role_id>")
@@ -31,7 +31,7 @@ def show(role_id):
 	# prefetch all users so the ldap orm can cache them and doesn't run one ldap query per user
 	User.query.all()
 	role = Role.query.get_or_404(role_id)
-	if role.moderator_group not in get_current_user().groups:
+	if role.moderator_group not in request.user.groups:
 		flash('Access denied')
 		return redirect(url_for('index'))
 	return render_template('rolemod/show.html', role=role)
@@ -40,7 +40,7 @@ def show(role_id):
 @csrf_protect(blueprint=bp)
 def update(role_id):
 	role = Role.query.get_or_404(role_id)
-	if role.moderator_group not in get_current_user().groups:
+	if role.moderator_group not in request.user.groups:
 		flash('Access denied')
 		return redirect(url_for('index'))
 	if request.form['description'] != role.description:
@@ -55,7 +55,7 @@ def update(role_id):
 @csrf_protect(blueprint=bp)
 def delete_member(role_id, member_dn):
 	role = Role.query.get_or_404(role_id)
-	if role.moderator_group not in get_current_user().groups:
+	if role.moderator_group not in request.user.groups:
 		flash('Access denied')
 		return redirect(url_for('index'))
 	member = User.query.get_or_404(member_dn)
diff --git a/uffd/selfservice/views.py b/uffd/selfservice/views.py
index 6bc07b25..d469ef69 100644
--- a/uffd/selfservice/views.py
+++ b/uffd/selfservice/views.py
@@ -9,7 +9,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash,
 from uffd.navbar import register_navbar
 from uffd.csrf import csrf_protect
 from uffd.user.models import User
-from uffd.session import get_current_user, login_required, is_valid_session
+from uffd.session import login_required
 from uffd.selfservice.models import PasswordToken, MailToken
 from uffd.database import db
 from uffd.ldap import ldap
@@ -20,17 +20,17 @@ bp = Blueprint("selfservice", __name__, template_folder='templates', url_prefix=
 reset_ratelimit = Ratelimit('passwordreset', 1*60*60, 3)
 
 @bp.route("/")
-@register_navbar('Selfservice', icon='portrait', blueprint=bp, visible=is_valid_session)
+@register_navbar('Selfservice', icon='portrait', blueprint=bp, visible=lambda: bool(request.user))
 @login_required()
 def index():
-	return render_template('selfservice/self.html', user=get_current_user())
+	return render_template('selfservice/self.html', user=request.user)
 
 @bp.route("/update", methods=(['POST']))
 @csrf_protect(blueprint=bp)
 @login_required()
 def update():
 	password_changed = False
-	user = get_current_user()
+	user = request.user
 	if request.values['displayname'] != user.displayname:
 		if user.set_displayname(request.values['displayname']):
 			flash('Display name changed.')
diff --git a/uffd/services/views.py b/uffd/services/views.py
index f55be57f..585de42c 100644
--- a/uffd/services/views.py
+++ b/uffd/services/views.py
@@ -1,7 +1,6 @@
-from flask import Blueprint, render_template, current_app, abort
+from flask import Blueprint, render_template, current_app, abort, request
 
 from uffd.navbar import register_navbar
-from uffd.session import is_valid_session, get_current_user
 
 bp = Blueprint("services", __name__, template_folder='templates', url_prefix='/services')
 
@@ -69,25 +68,19 @@ def get_services(user=None):
 	return services
 
 def services_visible():
-	user = None
-	if is_valid_session():
-		user = get_current_user()
-	return len(get_services(user)) > 0
+	return len(get_services(request.user)) > 0
 
 @bp.route("/")
 @register_navbar('Services', icon='sitemap', blueprint=bp, visible=services_visible)
 def index():
-	user = None
-	if is_valid_session():
-		user = get_current_user()
-	services = get_services(user)
+	services = get_services(request.user)
 	if not current_app.config['SERVICES']:
 		abort(404)
 
 	banner = current_app.config.get('SERVICES_BANNER')
 
 	# Set the banner to None if it is not public and no user is logged in
-	if not (current_app.config["SERVICES_BANNER_PUBLIC"] or user):
+	if not (current_app.config["SERVICES_BANNER_PUBLIC"] or request.user):
 		banner = None
 
-	return render_template('services/overview.html', user=user, services=services, banner=banner)
+	return render_template('services/overview.html', user=request.user, services=services, banner=banner)
diff --git a/uffd/session/__init__.py b/uffd/session/__init__.py
index 5cddcdc6..0e571f3a 100644
--- a/uffd/session/__init__.py
+++ b/uffd/session/__init__.py
@@ -1,3 +1,3 @@
-from .views import bp as bp_ui, get_current_user, login_required, is_valid_session, set_session
+from .views import bp as bp_ui, login_required, set_session
 
 bp = [bp_ui]
diff --git a/uffd/session/views.py b/uffd/session/views.py
index 1b76519d..2c57bc6d 100644
--- a/uffd/session/views.py
+++ b/uffd/session/views.py
@@ -12,6 +12,21 @@ bp = Blueprint("session", __name__, template_folder='templates', url_prefix='/')
 
 login_ratelimit = Ratelimit('login', 1*60, 3)
 
+@bp.before_app_request
+def set_request_user():
+	request.user = None
+	request.user_pre_mfa = None
+	if 'user_dn' not in session:
+		return
+	if 'logintime' not in session:
+		return
+	if datetime.datetime.now().timestamp() > session['logintime'] + current_app.config['SESSION_LIFETIME_SECONDS']:
+		return
+	user = User.query.get(session['user_dn'])
+	request.user_pre_mfa = user
+	if session.get('user_mfa'):
+		request.user = user
+
 def login_get_user(loginname, password):
 	dn = User(loginname=loginname).dn
 
@@ -84,34 +99,11 @@ def login():
 	set_session(user, password=password)
 	return redirect(url_for('mfa.auth', ref=request.values.get('ref', url_for('index'))))
 
-def get_current_user():
-	if 'user_dn' not in session:
-		return None
-	return User.query.get(session['user_dn'])
-bp.add_app_template_global(get_current_user)
-
-def login_valid():
-	user = get_current_user()
-	if user is None:
-		return False
-	if datetime.datetime.now().timestamp() > session['logintime'] + current_app.config['SESSION_LIFETIME_SECONDS']:
-		return False
-	return True
-
-def is_valid_session():
-	if not login_valid():
-		return False
-	if not session.get('user_mfa'):
-		return False
-	return True
-bp.add_app_template_global(is_valid_session)
-
-def pre_mfa_login_required(no_redirect=False):
+def login_required_pre_mfa(no_redirect=False):
 	def wrapper(func):
 		@functools.wraps(func)
 		def decorator(*args, **kwargs):
-			if not login_valid() or datetime.datetime.now().timestamp() > session['logintime'] + 10*60:
-				session.clear()
+			if not request.user_pre_mfa:
 				if no_redirect:
 					abort(403)
 				flash('You need to login first')
@@ -124,12 +116,12 @@ def login_required(group=None):
 	def wrapper(func):
 		@functools.wraps(func)
 		def decorator(*args, **kwargs):
-			if not login_valid():
+			if not request.user_pre_mfa:
 				flash('You need to login first')
 				return redirect(url_for('session.login', ref=request.url))
-			if not session.get('user_mfa'):
+			if not request.user:
 				return redirect(url_for('mfa.auth', ref=request.url))
-			if not get_current_user().is_in_group(group):
+			if not request.user.is_in_group(group):
 				flash('Access denied')
 				return redirect(url_for('index'))
 			return func(*args, **kwargs)
diff --git a/uffd/templates/base.html b/uffd/templates/base.html
index 5ab3cba8..ec80472d 100644
--- a/uffd/templates/base.html
+++ b/uffd/templates/base.html
@@ -67,7 +67,7 @@
 					</li>
 					{% endfor %}
 				</ul>
-				{% if is_valid_session() %}
+				{% if request.user %}
 				<ul class="navbar-nav ml-auto">
 					<li class="nav-item">
 						<a class="nav-link" href="{{ url_for("session.logout") }}">
diff --git a/uffd/user/views_group.py b/uffd/user/views_group.py
index dca984e7..d4318b3c 100644
--- a/uffd/user/views_group.py
+++ b/uffd/user/views_group.py
@@ -1,7 +1,7 @@
-from flask import Blueprint, render_template, url_for, redirect, flash, current_app
+from flask import Blueprint, render_template, url_for, redirect, flash, current_app, request
 
 from uffd.navbar import register_navbar
-from uffd.session import login_required, is_valid_session, get_current_user
+from uffd.session import login_required
 
 from .models import Group
 
@@ -14,7 +14,7 @@ def group_acl(): #pylint: disable=inconsistent-return-statements
 		return redirect(url_for('index'))
 
 def group_acl_check():
-	return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP'])
+	return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP'])
 
 @bp.route("/")
 @register_navbar('Groups', icon='layer-group', blueprint=bp, visible=group_acl_check)
diff --git a/uffd/user/views_user.py b/uffd/user/views_user.py
index aadaf5af..a76beceb 100644
--- a/uffd/user/views_user.py
+++ b/uffd/user/views_user.py
@@ -6,7 +6,7 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash,
 from uffd.navbar import register_navbar
 from uffd.csrf import csrf_protect
 from uffd.selfservice import send_passwordreset
-from uffd.session import login_required, is_valid_session, get_current_user
+from uffd.session import login_required
 from uffd.role.models import Role
 from uffd.database import db
 from uffd.ldap import ldap, LDAPCommitError
@@ -22,7 +22,7 @@ def user_acl(): #pylint: disable=inconsistent-return-statements
 		return redirect(url_for('index'))
 
 def user_acl_check():
-	return is_valid_session() and get_current_user().is_in_group(current_app.config['ACL_ADMIN_GROUP'])
+	return request.user and request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP'])
 
 @bp.route("/")
 @register_navbar('Users', icon='users', blueprint=bp, visible=user_acl_check)
-- 
GitLab