From 8a6ca93c67c7d41d88c24f19c53957b9ab0f2f52 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@cccv.de>
Date: Mon, 28 Feb 2022 02:42:23 +0100
Subject: [PATCH] Refactor mapper object flushing/expiring in tests

---
 tests/test_mfa.py    | 27 ++++++++++++++++-----------
 tests/test_signup.py |  2 +-
 tests/utils.py       |  2 +-
 3 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/tests/test_mfa.py b/tests/test_mfa.py
index a5d92cde..c1131eaf 100644
--- a/tests/test_mfa.py
+++ b/tests/test_mfa.py
@@ -45,12 +45,14 @@ class TestMfaMethodModels(UffdTestCase):
 		method = RecoveryCodeMethod(user=self.get_user())
 		db.session.add(method)
 		db.session.commit()
-		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
-		_method = RecoveryCodeMethod.query.get(method.id)
-		self.assertFalse(hasattr(_method, 'code'))
-		self.assertFalse(_method.verify(''))
-		self.assertFalse(_method.verify('A'*8))
-		self.assertTrue(_method.verify(method.code))
+		method_id = method.id
+		method_code = method.code
+		db.session.expunge(method)
+		method = RecoveryCodeMethod.query.get(method_id)
+		self.assertFalse(hasattr(method, 'code'))
+		self.assertFalse(method.verify(''))
+		self.assertFalse(method.verify('A'*8))
+		self.assertTrue(method.verify(method_code))
 
 	def test_totp_method_attributes(self):
 		method = TOTPMethod(user=self.get_user(), name='testname')
@@ -68,9 +70,10 @@ class TestMfaMethodModels(UffdTestCase):
 		self.assertEqual(_method.key_uri, key_uri)
 		db.session.add(method)
 		db.session.commit()
-		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
+		_method_id = _method.id
+		db.session.expunge(_method)
 		# Restore method from db
-		_method = TOTPMethod.query.get(method.id)
+		_method = TOTPMethod.query.get(_method_id)
 		self.assertEqual(_method.name, 'testname')
 		self.assertEqual(_method.raw_key, raw_key)
 		self.assertEqual(_method.issuer, issuer)
@@ -91,10 +94,12 @@ class TestMfaMethodModels(UffdTestCase):
 		self.assertEqual(method.name, 'testname')
 		db.session.add(method)
 		db.session.commit()
-		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
-		_method = WebauthnMethod.query.get(method.id)
+		method_id = method.id
+		method_cred = method.cred
+		db.session.expunge(method)
+		_method = WebauthnMethod.query.get(method_id)
 		self.assertEqual(_method.name, 'testname')
-		self.assertEqual(bytes(method.cred), bytes(_method.cred))
+		self.assertEqual(bytes(method_cred), bytes(_method.cred))
 		self.assertEqual(data.credential_id, _method.cred.credential_id)
 		self.assertEqual(data.public_key, _method.cred.public_key)
 		# We only test (de-)serialization here, as everything else is currently implemented in the views
diff --git a/tests/test_signup.py b/tests/test_signup.py
index e64d0123..c202a501 100644
--- a/tests/test_signup.py
+++ b/tests/test_signup.py
@@ -19,7 +19,7 @@ def refetch_signup(signup):
 	db.session.add(signup)
 	db.session.commit()
 	id = signup.id
-	db_flush()
+	db.session.expunge(signup)
 	return Signup.query.get(id)
 
 # We assume in all tests that Signup.validate and Signup.password.verify do
diff --git a/tests/utils.py b/tests/utils.py
index c9ec3e92..6a2e2cb8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -22,7 +22,7 @@ def dump(basename, resp):
 
 def db_flush():
 	db.session.rollback()
-	db.session = db.create_scoped_session()
+	db.session.expire_all()
 
 class UffdTestCase(unittest.TestCase):
 	def get_user(self):
-- 
GitLab