Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • uffd/uffd
  • rixx/uffd
  • thies/uffd
  • leona/uffd
  • enbewe/uffd
  • strifel/uffd
  • thies/uffd-2
7 results
Show changes
Showing
with 3882 additions and 0 deletions
import unittest
import datetime
import time
from uffd.database import db
from uffd.models import RecoveryCodeMethod, TOTPMethod, WebauthnMethod
from uffd.models.mfa import _hotp
from tests.utils import UffdTestCase
class TestMfaPrimitives(unittest.TestCase):
def test_hotp(self):
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q'), '458290')
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q', digits=8), '20458290')
for digits in range(1, 10):
self.assertEqual(len(_hotp(1, b'abcd', digits=digits)), digits)
self.assertEqual(_hotp(1234, b''), '161024')
self.assertEqual(_hotp(0, b'\x04\x8fM\xcc\x7f\x82\x9c$a\x1b\xb3'), '279354')
self.assertEqual(_hotp(2**64-1, b'abcde'), '899292')
def get_fido2_test_cred(self):
try:
from uffd.fido2_compat import AttestedCredentialData
except ImportError:
self.skipTest('fido2 could not be imported')
# Example public key from webauthn spec 6.5.1.1
return AttestedCredentialData(bytes.fromhex('00000000000000000000000000000000'+'0040'+'053cbcc9d37a61d3bac87cdcc77ee326256def08ab15775d3a720332e4101d14fae95aeee3bc9698781812e143c0597dc6e180595683d501891e9dd030454c0a'+'A501020326200121582065eda5a12577c2bae829437fe338701a10aaa375e1bb5b5de108de439c08551d2258201e52ed75701163f7f9e40ddf9f341b3dc9ba860af7e0ca7ca7e9eecd0084d19c'))
class TestMfaMethodModels(UffdTestCase):
def test_common_attributes(self):
method = TOTPMethod(user=self.get_user(), name='testname')
self.assertTrue(method.created <= datetime.datetime.utcnow())
self.assertEqual(method.name, 'testname')
self.assertEqual(method.user.loginname, 'testuser')
method.user = self.get_admin()
self.assertEqual(method.user.loginname, 'testadmin')
def test_recovery_code_method(self):
method = RecoveryCodeMethod(user=self.get_user())
db.session.add(method)
db.session.commit()
method_id = method.id
method_code = method.code_value
db.session.expunge(method)
method = RecoveryCodeMethod.query.get(method_id)
self.assertFalse(hasattr(method, 'code_value'))
self.assertFalse(method.verify(''))
self.assertFalse(method.verify('A'*8))
self.assertTrue(method.verify(method_code))
def test_totp_method_attributes(self):
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
issuer = method.issuer
accountname = method.accountname
key_uri = method.key_uri
self.assertEqual(method.name, 'testname')
# Restore method with key parameter
_method = TOTPMethod(user=self.get_user(), key=method.key, name='testname')
self.assertEqual(_method.name, 'testname')
self.assertEqual(_method.raw_key, raw_key)
self.assertEqual(_method.issuer, issuer)
self.assertEqual(_method.accountname, accountname)
self.assertEqual(_method.key_uri, key_uri)
db.session.add(method)
db.session.commit()
_method_id = _method.id
db.session.expunge(_method)
# Restore method from db
_method = TOTPMethod.query.get(_method_id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(_method.raw_key, raw_key)
self.assertEqual(_method.issuer, issuer)
self.assertEqual(_method.accountname, accountname)
self.assertEqual(_method.key_uri, key_uri)
def test_totp_method_verify(self):
method = TOTPMethod(user=self.get_user())
counter = int(time.time()/30)
self.assertFalse(method.verify(''))
self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter+2, method.raw_key)))
def test_totp_method_verify_reuse(self):
method = TOTPMethod(user=self.get_user())
counter = int(time.time()/30)
self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter-1, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter-1, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter, method.raw_key)))
def test_webauthn_method(self):
data = get_fido2_test_cred(self)
method = WebauthnMethod(user=self.get_user(), cred=data, name='testname')
self.assertEqual(method.name, 'testname')
db.session.add(method)
db.session.commit()
method_id = method.id
method_cred = method.cred
db.session.expunge(method)
_method = WebauthnMethod.query.get(method_id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(bytes(method_cred), bytes(_method.cred))
self.assertEqual(data.credential_id, _method.cred.credential_id)
self.assertEqual(data.public_key, _method.cred.public_key)
# We only test (de-)serialization here, as everything else is currently implemented in the views
import time
import threading
from sqlalchemy.exc import IntegrityError
from uffd.database import db
from uffd.models import FeatureFlag, Lock
from uffd.models.misc import feature_flag_table
from tests.utils import ModelTestCase
class TestFeatureFlag(ModelTestCase):
def test_disabled(self):
flag = FeatureFlag('foo')
self.assertFalse(flag)
self.assertFalse(db.session.execute(db.select([flag.expr])).scalar())
def test_enabled(self):
db.session.execute(db.insert(feature_flag_table).values(name='foo'))
flag = FeatureFlag('foo')
self.assertTrue(flag)
self.assertTrue(db.session.execute(db.select([flag.expr])).scalar())
def test_toggle(self):
flag = FeatureFlag('foo')
hooks_called = []
@flag.enable_hook
def enable_hook1():
hooks_called.append('enable1')
@flag.enable_hook
def enable_hook2():
hooks_called.append('enable2')
@flag.disable_hook
def disable_hook1():
hooks_called.append('disable1')
@flag.disable_hook
def disable_hook2():
hooks_called.append('disable2')
hooks_called.clear()
flag.enable()
self.assertTrue(flag)
self.assertEqual(hooks_called, ['enable1', 'enable2'])
hooks_called.clear()
flag.disable()
self.assertFalse(flag)
self.assertEqual(hooks_called, ['disable1', 'disable2'])
flag.disable() # does nothing
self.assertFalse(flag)
flag.enable()
self.assertTrue(flag)
with self.assertRaises(IntegrityError):
flag.enable()
self.assertTrue(flag)
class TestLock(ModelTestCase):
DISABLE_SQLITE_MEMORY_DB = True
def setUpApp(self):
self.lock = Lock('testlock')
def run_lock_test(self):
result = []
def func():
with self.app.test_request_context():
self.lock.acquire()
result.append('bar')
t = threading.Thread(target=func)
t.start()
time.sleep(1)
result.append('foo')
time.sleep(1)
db.session.rollback()
t.join()
return result
def test_lock2(self):
self.assertEqual(self.run_lock_test(), ['bar', 'foo'])
self.lock.acquire()
self.assertEqual(self.run_lock_test(), ['foo', 'bar'])
import unittest
import datetime
import jwt
from uffd.database import db
from uffd.models import OAuth2Key
from tests.utils import UffdTestCase
TEST_JWK = dict(
id='HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU',
created=datetime.datetime(2023, 11, 9, 0, 21, 10),
active=True,
algorithm='RS256',
private_key_jwk='''{
"kty": "RSA",
"key_ops": ["sign"],
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB",
"d": "G7yoH5mLcZTA6ia-byCoN-zpofGvdga9AZnxPO0vsq6K_cY_O2gxuVZ3n6reAKKbuLNGCbb_D_Dffs4q8rprlfkgi3TCLzXX5Zv5HWTD7a4Y7xpxEzQ2sWo-iagVIqZVPh0pyjliqnTyUWnFmWiY0gBe9UHianHjFVZqe8E2HFOKgW3UUbQz0keg8JtJ3T9gzZrM38KWbqhOJO0VVSRAoANPTSnumfRsUCyWywrMtIfgAbQaKazqX3xkOsAF1L-iNfd6slzPvRyIQVflVDMdfKnsu-lHiKJ0DK_lg9f55T5FymgcXsq43EKBQ2H4v2dafIm-vtWx_TRZWj_msD32BEPBA-zTqh_oP1r6a3DZh4DBtWY3vzSiuhAC0erlRs-hRTX_e9ET5fUbJnmNxjnxQD9zZmwq4ujMK6KFnHct8t77Qxj3a-wDR_XyDJ4_EKYqHlcVHfxGNBSvIdjuZJkPJnVpVtfCtpyamQIR4u5oNV7fIwYe_tFnw0Y90rGoJMzB",
"p": "-A-FnH21HJ7GPWUm9k3mxsxSchy89QEUCZZiH6EcB4ZP8wJsxrQsUSIHCR74YmZEI3Ulsum1Ql4x50k7Q2sNh9SnwKvmctjksehGy4yCrdunAqjqyz3wFwGaKWnhn3frkiqH5ATjkOoc8qHz8saa7reeVClj47ZWyy-Nl559ycLMs0rI1N_THzO07C3jSbJhyPj0yeygAflsRqqnNvEQ6ps1VLiqf9G5jfSvUUn5DyKIpep9iGo29caGSIPIy_2h",
"q": "xNe1-QWskxOcY_GiHpFWdvzqr1o9fxg5whgpNcGi3caokw2iNHRYut4cbVvFFBlv_9B5QCl9WVfR2ADG0AtvkvUxEZqCdxEvcqjIANeRLKHDjW5kMuPS0_fcskFP-r7mCM9SBfPplfMVCF5nuNWf5LzNopWfsTChIDD1rSpPjItNYuwLXszm_3R81HHHeQLcyvoMxLCmeLy5TXX2hXOMHh2IMZCXAHopJmLJUVnQ48kr5jd2l0kLbmx3aBqdccJn",
"dp": "MLS7g1KbcRcrzXpDADGjkn0j4wwJfgHMMWW5toQnwMJ6iDh9qzZNTVDlGMFf-9IgpuWllU-WK4XbPpJ-dGpcqcLzfT1DbmFv5g65d9YLAqASVs9b6rQqpBnIb0E-79TYCEcZj4f2NsoBDRMHly-v1BdxmwzVdCylNhgMMS0Jfcgl8T5J2KJqDcJVT9piumGwGYnoZo1zjW-v9uAjHQKQU8BN5Git8ZL4YAsfMVLY-EPLmOhF5bcVO4TTcQGPN56B",
"dq": "HiiSl-G3rB0QE_v8g8Ruw_JCHrWrwGI8zzEWd0cApgv-3fDzzieZRKAtKNArpMW09DPDsAHrU5nx669KxqtJ3_EzIGhU3ttCMsYLRp3Af18VcADe1zEypwlNxf3dvCQtaGIjRgg13KSOr2aPa7FHOyt2MhfMjMBPn3gA3BQkdfsN0z8pCtBIABGf4ojAMBkxLOQcurH5_3uixGxzZcTrTd3mdPmbORZ-YYQ3JgCl0ZCL6kzLHaiyWKvDq66QOtK3",
"qi": "ySqD9cUxbq3wkCsPQId_YfQLIqb5RK_JJIMjtBOdTdo4aT5tmodYCSmjBmhrYXjDWtyJdelvPfdSfgncHJhf8VgkZ8TPvUeaQwsQFBwB5llwpdb72eEEJrmG1SVwNMoFCLXdNT3ACad16cUDMnWmklH0X07OzdxGOBnGhgLZUs4RbPjLH7OpYTyQqVy2L8vofqJR42cfePZw8WQM4k0PPbhralhybExIkSCmaQyYbACZ5k0OVQErEqnj4elglA0h"
}''',
public_key_jwk='''{
"kty": "RSA",
"key_ops": ["verify"],
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB"
}''',
)
class TestOAuth2Key(UffdTestCase):
def setUp(self):
super().setUp()
db.session.add(OAuth2Key(**TEST_JWK))
db.session.add(OAuth2Key(
id='1e9gdk7',
created=datetime.datetime(2014, 11, 8, 0, 0, 0),
active=True,
algorithm='RS256',
private_key_jwk='invalid',
public_key_jwk='''{
"kty":"RSA",
"n":"w7Zdfmece8iaB0kiTY8pCtiBtzbptJmP28nSWwtdjRu0f2GFpajvWE4VhfJAjEsOcwYzay7XGN0b-X84BfC8hmCTOj2b2eHT7NsZegFPKRUQzJ9wW8ipn_aDJWMGDuB1XyqT1E7DYqjUCEOD1b4FLpy_xPn6oV_TYOfQ9fZdbE5HGxJUzekuGcOKqOQ8M7wfYHhHHLxGpQVgL0apWuP2gDDOdTtpuld4D2LK1MZK99s9gaSjRHE8JDb1Z4IGhEcEyzkxswVdPndUWzfvWBBWXWxtSUvQGBRkuy1BHOa4sP6FKjWEeeF7gm7UMs2Nm2QUgNZw6xvEDGaLk4KASdIxRQ",
"e":"AQAB"
}'''
))
db.session.commit()
self.key = OAuth2Key.query.get('HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU')
self.key_oidc_spec = OAuth2Key.query.get('1e9gdk7')
def test_private_key(self):
self.key.private_key
def test_public_key(self):
self.key.private_key
def test_public_key_jwks_dict(self):
self.assertEqual(self.key.public_key_jwks_dict, {
"kid": "HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU",
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB"
})
def test_encode_jwt(self):
jwtdata = self.key.encode_jwt({'aud': 'test', 'foo': 'bar'})
self.assertIsInstance(jwtdata, str) # Regression check for #165
self.assertEqual(
jwt.get_unverified_header(jwtdata),
# typ is optional, x5u/x5c/jku/jwk are discoraged by OIDC Core 1.0 spec section 2
{'kid': self.key.id, 'alg': self.key.algorithm, 'typ': 'JWT'}
)
self.assertEqual(
OAuth2Key.decode_jwt(jwtdata, audience='test'),
{'aud': 'test', 'foo': 'bar'}
)
self.key.active = False
with self.assertRaises(jwt.exceptions.InvalidKeyError):
self.key.encode_jwt({'aud': 'test', 'foo': 'bar'})
def test_oidc_hash(self):
# Example from OIDC Core 1.0 spec A.3
self.assertEqual(
self.key.oidc_hash(b'jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y'),
'77QmUPtjPfzWtF2AnpK9RQ'
)
# Example from OIDC Core 1.0 spec A.4
self.assertEqual(
self.key.oidc_hash(b'Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk'),
'LDktKdoQak3Pk0cnXxCltA'
)
# Example from OIDC Core 1.0 spec A.6
self.assertEqual(
self.key.oidc_hash(b'jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y'),
'77QmUPtjPfzWtF2AnpK9RQ'
)
self.assertEqual(
self.key.oidc_hash(b'Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk'),
'LDktKdoQak3Pk0cnXxCltA'
)
def test_decode_jwt(self):
# Example from OIDC Core 1.0 spec A.2
jwt_data = (
'eyJraWQiOiIxZTlnZGs3IiwiYWxnIjoiUlMyNTYifQ.ewogImlz'
'cyI6ICJodHRwOi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4'
'Mjg5NzYxMDAxIiwKICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAi'
'bi0wUzZfV3pBMk1qIiwKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEz'
'MTEyODA5NzAsCiAibmFtZSI6ICJKYW5lIERvZSIsCiAiZ2l2ZW5fbmFtZSI6'
'ICJKYW5lIiwKICJmYW1pbHlfbmFtZSI6ICJEb2UiLAogImdlbmRlciI6ICJm'
'ZW1hbGUiLAogImJpcnRoZGF0ZSI6ICIwMDAwLTEwLTMxIiwKICJlbWFpbCI6'
'ICJqYW5lZG9lQGV4YW1wbGUuY29tIiwKICJwaWN0dXJlIjogImh0dHA6Ly9l'
'eGFtcGxlLmNvbS9qYW5lZG9lL21lLmpwZyIKfQ.rHQjEmBqn9Jre0OLykYNn'
'spA10Qql2rvx4FsD00jwlB0Sym4NzpgvPKsDjn_wMkHxcp6CilPcoKrWHcip'
'R2iAjzLvDNAReF97zoJqq880ZD1bwY82JDauCXELVR9O6_B0w3K-E7yM2mac'
'AAgNCUwtik6SjoSUZRcf-O5lygIyLENx882p6MtmwaL1hd6qn5RZOQ0TLrOY'
'u0532g9Exxcm-ChymrB4xLykpDj3lUivJt63eEGGN6DH5K6o33TcxkIjNrCD'
'4XB1CKKumZvCedgHHF3IAK4dVEDSUoGlH9z4pP_eWYNXvqQOjGs-rDaQzUHl'
'6cQQWNiDpWOl_lxXjQEvQ'
)
self.assertEqual(
OAuth2Key.decode_jwt(jwt_data, options={'verify_exp': False, 'verify_aud': False}),
{
"iss": "http://server.example.com",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"nonce": "n-0S6_WzA2Mj",
"exp": 1311281970,
"iat": 1311280970,
"name": "Jane Doe",
"given_name": "Jane",
"family_name": "Doe",
"gender": "female",
"birthdate": "0000-10-31",
"email": "janedoe@example.com",
"picture": "http://example.com/janedoe/me.jpg"
}
)
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# {"alg":"RS256"} -> no key id
OAuth2Key.decode_jwt('eyJhbGciOiJSUzI1NiJ9.' + jwt_data.split('.', 1)[-1])
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# {"kid":"XXXXX","alg":"RS256"} -> unknown key id
OAuth2Key.decode_jwt('eyJraWQiOiJYWFhYWCIsImFsZyI6IlJTMjU2In0.' + jwt_data.split('.', 1)[-1])
OAuth2Key.query.get('1e9gdk7').active = False
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# not active
OAuth2Key.decode_jwt(jwt_data)
def test_generate_rsa_key(self):
key = OAuth2Key.generate_rsa_key()
self.assertEqual(key.algorithm, 'RS256')
import unittest
from uffd.database import db
from uffd.models import User, Role, RoleGroup, TOTPMethod
from uffd.models.role import flatten_recursive
from tests.utils import UffdTestCase
class TestPrimitives(unittest.TestCase):
def test_flatten_recursive(self):
class Node:
def __init__(self, *neighbors):
self.neighbors = set(neighbors or set())
cycle = Node()
cycle.neighbors.add(cycle)
common = Node(cycle)
intermediate1 = Node(common)
intermediate2 = Node(common, intermediate1)
stub = Node()
backref = Node()
start1 = Node(intermediate1, intermediate2, stub, backref)
backref.neighbors.add(start1)
start2 = Node()
self.assertSetEqual(flatten_recursive({start1, start2}, 'neighbors'),
{start1, start2, backref, stub, intermediate1, intermediate2, common, cycle})
self.assertSetEqual(flatten_recursive(set(), 'neighbors'), set())
class TestUserRoleAttributes(UffdTestCase):
def test_roles_effective(self):
db.session.add(User(loginname='service', is_service_user=True, primary_email_address='service@example.com', displayname='Service'))
db.session.commit()
user = self.get_user()
service_user = User.query.filter_by(loginname='service').one_or_none()
included_by_default_role = Role(name='included_by_default')
default_role = Role(name='default', is_default=True, included_roles=[included_by_default_role])
included_role = Role(name='included')
cycle_role = Role(name='cycle')
direct_role1 = Role(name='role1', members=[user, service_user], included_roles=[included_role, cycle_role])
direct_role2 = Role(name='role2', members=[user, service_user], included_roles=[included_role])
cycle_role.included_roles.append(direct_role1)
db.session.add_all([included_by_default_role, default_role, included_role, cycle_role, direct_role1, direct_role2])
self.assertSetEqual(user.roles_effective, {direct_role1, direct_role2, cycle_role, included_role, default_role, included_by_default_role})
self.assertSetEqual(service_user.roles_effective, {direct_role1, direct_role2, cycle_role, included_role})
def test_compute_groups(self):
user = self.get_user()
group1 = self.get_users_group()
group2 = self.get_access_group()
role1 = Role(name='role1', groups={group1: RoleGroup(group=group1)})
role2 = Role(name='role2', groups={group1: RoleGroup(group=group1), group2: RoleGroup(group=group2)})
db.session.add_all([role1, role2])
self.assertSetEqual(user.compute_groups(), set())
role1.members.append(user)
role2.members.append(user)
self.assertSetEqual(user.compute_groups(), {group1, group2})
role2.groups[group2].requires_mfa = True
self.assertSetEqual(user.compute_groups(), {group1})
db.session.add(TOTPMethod(user=user))
db.session.commit()
self.assertSetEqual(user.compute_groups(), {group1, group2})
def test_update_groups(self):
user = self.get_user()
group1 = self.get_users_group()
group2 = self.get_access_group()
role1 = Role(name='role1', members=[user], groups={group1: RoleGroup(group=group1)})
role2 = Role(name='role2', groups={group2: RoleGroup(group=group2)})
db.session.add_all([role1, role2])
user.groups = [group2]
groups_added, groups_removed = user.update_groups()
self.assertSetEqual(groups_added, {group1})
self.assertSetEqual(groups_removed, {group2})
self.assertSetEqual(set(user.groups), {group1})
groups_added, groups_removed = user.update_groups()
self.assertSetEqual(groups_added, set())
self.assertSetEqual(groups_removed, set())
self.assertSetEqual(set(user.groups), {group1})
class TestRoleModel(UffdTestCase):
def test_members_effective(self):
db.session.add(User(loginname='service', is_service_user=True, primary_email_address='service@example.com', displayname='Service'))
db.session.commit()
user1 = self.get_user()
user2 = self.get_admin()
service = User.query.filter_by(loginname='service').one_or_none()
included_by_default_role = Role(name='included_by_default')
default_role = Role(name='default', is_default=True, included_roles=[included_by_default_role])
included_role = Role(name='included')
direct_role = Role(name='direct', members=[user1, user2, service], included_roles=[included_role])
empty_role = Role(name='empty', included_roles=[included_role])
self.assertSetEqual(included_by_default_role.members_effective, {user1, user2})
self.assertSetEqual(default_role.members_effective, {user1, user2})
self.assertSetEqual(included_role.members_effective, {user1, user2, service})
self.assertSetEqual(direct_role.members_effective, {user1, user2, service})
self.assertSetEqual(empty_role.members_effective, set())
def test_included_roles_recursive(self):
baserole = Role(name='base')
role1 = Role(name='role1', included_roles=[baserole])
role2 = Role(name='role2', included_roles=[baserole])
role3 = Role(name='role3', included_roles=[role1, role2])
self.assertSetEqual(role1.included_roles_recursive, {baserole})
self.assertSetEqual(role2.included_roles_recursive, {baserole})
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
baserole.included_roles.append(role1)
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
def test_groups_effective(self):
group1 = self.get_users_group()
group2 = self.get_access_group()
baserole = Role(name='base', groups={group1: RoleGroup(group=group1)})
role1 = Role(name='role1', groups={group2: RoleGroup(group=group2)}, included_roles=[baserole])
self.assertSetEqual(baserole.groups_effective, {group1})
self.assertSetEqual(role1.groups_effective, {group1, group2})
def test_update_member_groups(self):
user1 = self.get_user()
user1.update_groups()
user2 = self.get_admin()
user2.update_groups()
group1 = self.get_users_group()
group2 = self.get_access_group()
group3 = self.get_admin_group()
baserole = Role(name='base', members=[user1], groups={group1: RoleGroup(group=group1)})
role1 = Role(name='role1', members=[user2], groups={group2: RoleGroup(group=group2)}, included_roles=[baserole])
db.session.add_all([baserole, role1])
baserole.update_member_groups()
role1.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1})
self.assertSetEqual(set(user2.groups), {group1, group2})
baserole.groups[group3] = RoleGroup()
baserole.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1, group3})
self.assertSetEqual(set(user2.groups), {group1, group2, group3})
import itertools
from uffd.remailer import remailer
from uffd.tasks import cleanup_task
from uffd.database import db
from uffd.models import Service, ServiceUser, User, UserEmail, RemailerMode
from tests.utils import UffdTestCase
class TestServiceUser(UffdTestCase):
def setUp(self):
super().setUp()
db.session.add_all([Service(name='service1', limit_access=False), Service(name='service2', remailer_mode=RemailerMode.ENABLED_V1, limit_access=False)])
db.session.commit()
def test_auto_create(self):
service_count = Service.query.count()
user_count = User.query.count()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
db.session.add(User(loginname='newuser1', displayname='New User', primary_email_address='new1@example.com'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * (user_count + 1))
db.session.add(Service(name='service3'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), (service_count + 1) * (user_count + 1))
db.session.add(User(loginname='newuser2', displayname='New User', primary_email_address='new2@example.com'))
db.session.add(User(loginname='newuser3', displayname='New User', primary_email_address='new3@example.com'))
db.session.add(Service(name='service4'))
db.session.add(Service(name='service5'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), (service_count + 3) * (user_count + 3))
def test_create_missing(self):
service_count = Service.query.count()
user_count = User.query.count()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
db.session.delete(ServiceUser.query.first())
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * user_count - 1)
cleanup_task.run()
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
def test_effective_remailer_mode(self):
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service.remailer_mode = RemailerMode.ENABLED_V2
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V2)
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testadmin']
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.DISABLED)
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testuser']
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V2)
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service_user.remailer_overwrite_mode = RemailerMode.ENABLED_V1
service.remailer_mode = RemailerMode.DISABLED
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V1)
self.app.config['REMAILER_DOMAIN'] = ''
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.DISABLED)
def test_service_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.service_email, None)
service_user.service_email = UserEmail(user=user, address='foo@bar', verified=True)
with self.assertRaises(Exception):
service_user.service_email = UserEmail(user=user, address='foo2@bar', verified=False)
with self.assertRaises(Exception):
service_user.service_email = UserEmail(user=self.get_admin(), address='foo3@bar', verified=True)
def test_real_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.real_email, user.primary_email.address)
service_user.service_email = UserEmail(user=user, address='foo@bar', verified=True)
self.assertEqual(service_user.real_email, user.primary_email.address)
service.enable_email_preferences = True
self.assertEqual(service_user.real_email, service_user.service_email.address)
service.limit_access = True
self.assertEqual(service_user.real_email, user.primary_email.address)
service.access_group = self.get_admin_group()
self.assertEqual(service_user.real_email, user.primary_email.address)
service.access_group = self.get_users_group()
self.assertEqual(service_user.real_email, service_user.service_email.address)
def test_get_by_remailer_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email = remailer.build_v1_address(service.id, user.id)
# 1. remailer not setup
self.app.config['REMAILER_DOMAIN'] = ''
self.assertIsNone(ServiceUser.get_by_remailer_email(user.primary_email.address))
self.assertIsNone(ServiceUser.get_by_remailer_email(remailer_email))
self.assertIsNone(ServiceUser.get_by_remailer_email('invalid'))
# 2. remailer setup
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertIsNone(ServiceUser.get_by_remailer_email(user.primary_email.address))
self.assertEqual(ServiceUser.get_by_remailer_email(remailer_email), service_user)
self.assertIsNone(ServiceUser.get_by_remailer_email('invalid'))
def test_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email = remailer.build_v1_address(service.id, user.id)
# 1. remailer not setup
self.app.config['REMAILER_DOMAIN'] = ''
self.assertEqual(service_user.email, user.primary_email.address)
# 2. remailer setup + remailer disabled
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertEqual(service_user.email, user.primary_email.address)
# 3. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS unset
service.remailer_mode = RemailerMode.ENABLED_V1
db.session.commit()
self.assertEqual(service_user.email, remailer_email)
# 4. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS does not include user
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testadmin']
self.assertEqual(service_user.email, user.primary_email.address)
# 5. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS includes user
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testuser']
self.assertEqual(service_user.email, remailer_email)
# 6. remailer setup + remailer disabled + user overwrite
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service.remailer_mode = RemailerMode.DISABLED
service_user.remailer_overwrite_mode = RemailerMode.ENABLED_V1
self.assertEqual(service_user.email, remailer_email)
# 7. remailer setup + remailer enabled + user overwrite
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service.remailer_mode = RemailerMode.ENABLED_V1
service_user.remailer_overwrite_mode = RemailerMode.DISABLED
self.assertEqual(service_user.email, user.primary_email.address)
def test_filter_query_by_email(self):
service = Service.query.filter_by(name='service1').first()
user = self.get_user()
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email_v1 = remailer.build_v1_address(service.id, user.id)
remailer_email_v2 = remailer.build_v2_address(service.id, user.id)
email1 = user.primary_email
email2 = UserEmail(user=user, address='test2@example.com', verified=True)
db.session.add(email2)
service_user = ServiceUser.query.get((service.id, user.id))
all_service_users = ServiceUser.query.all()
cases = itertools.product(
# Input values
[
'test@example.com',
'test2@example.com',
'other@example.com',
remailer_email_v1,
remailer_email_v2,
],
# REMAILER_DOMAIN config
[None, 'remailer.example.com'],
# REMAILER_LIMIT config
[None, ['testuser', 'otheruser'], ['testadmin', 'otheruser']],
# service.remailer_mode
[RemailerMode.DISABLED, RemailerMode.ENABLED_V1, RemailerMode.ENABLED_V2],
# service.enable_email_preferences
[True, False],
# service.limit_access, service.access_group
[(False, None), (True, None), (True, self.get_admin_group()), (True, self.get_users_group())],
# service_user.service_email
[None, email1, email2],
# service_user.remailer_overwrite_mode
[None, RemailerMode.DISABLED, RemailerMode.ENABLED_V1, RemailerMode.ENABLED_V2],
)
for options in cases:
value = options[0]
self.app.config['REMAILER_DOMAIN'] = options[1]
self.app.config['REMAILER_LIMIT_TO_USERS'] = options[2]
service.remailer_mode = options[3]
service.enable_email_preferences = options[4]
service.limit_access, service.access_group = options[5]
service_user.service_email = options[6]
service_user.remailer_overwrite_mode = options[7]
a = {result for result in all_service_users if result.email == value}
b = set(ServiceUser.filter_query_by_email(ServiceUser.query, value).all())
if a != b:
self.fail(f'{a} != {b} with ' + repr(options))
import unittest
import datetime
from uffd.database import db
from uffd.models.session import Session, USER_AGENT_PARSER_SUPPORTED
from tests.utils import UffdTestCase
class TestSession(UffdTestCase):
def test_expire(self):
self.app.config['SESSION_LIFETIME_SECONDS'] = 100
self.app.config['PERMANENT_SESSION_LIFETIME'] = 10
user = self.get_user()
def make_session(created_age, last_used_age):
return Session(
user=user,
created=datetime.datetime.utcnow() - datetime.timedelta(seconds=created_age),
last_used=datetime.datetime.utcnow() - datetime.timedelta(seconds=last_used_age),
)
session1 = Session(user=user)
self.assertFalse(session1.expired)
session2 = make_session(0, 0)
self.assertFalse(session2.expired)
session3 = make_session(50, 5)
self.assertFalse(session3.expired)
session4 = make_session(50, 15)
self.assertTrue(session4.expired)
session5 = make_session(105, 5)
self.assertTrue(session5.expired)
session6 = make_session(105, 15)
self.assertTrue(session6.expired)
db.session.add_all([session1, session2, session3, session4, session5, session6])
db.session.commit()
self.assertEqual(set(Session.query.filter_by(expired=False).all()), {session1, session2, session3})
self.assertEqual(set(Session.query.filter_by(expired=True).all()), {session4, session5, session6})
def test_useragent_ua_parser(self):
if not USER_AGENT_PARSER_SUPPORTED:
self.skipTest('ua_parser not available')
session = Session(user_agent='Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0')
self.assertEqual(session.user_agent_browser, 'Firefox')
self.assertEqual(session.user_agent_platform, 'Windows')
def test_useragent_no_ua_parser(self):
session = Session(user_agent='Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0')
session.DISABLE_USER_AGENT_PARSER = True
self.assertEqual(session.user_agent_browser, 'Firefox')
self.assertEqual(session.user_agent_platform, 'Windows')
import datetime
from uffd.database import db
from uffd.models import Signup, User, FeatureFlag
from tests.utils import UffdTestCase, db_flush
def refetch_signup(signup):
db.session.add(signup)
db.session.commit()
id = signup.id
db.session.expunge(signup)
return Signup.query.get(id)
# We assume in all tests that Signup.validate and Signup.password.verify do
# not alter any state
class TestSignupModel(UffdTestCase):
def assert_validate_valid(self, signup):
valid, msg = signup.validate()
self.assertTrue(valid)
self.assertIsInstance(msg, str)
def assert_validate_invalid(self, signup):
valid, msg = signup.validate()
self.assertFalse(valid)
self.assertIsInstance(msg, str)
self.assertNotEqual(msg, '')
def assert_finish_success(self, signup, password):
self.assertIsNone(signup.user)
user, msg = signup.finish(password)
db.session.commit()
self.assertIsNotNone(user)
self.assertIsInstance(msg, str)
self.assertIsNotNone(signup.user)
def assert_finish_failure(self, signup, password):
prev_id = signup.user_id
user, msg = signup.finish(password)
self.assertIsNone(user)
self.assertIsInstance(msg, str)
self.assertNotEqual(msg, '')
self.assertEqual(signup.user_id, prev_id)
def test_password(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assertFalse(signup.password.verify('notsecret'))
self.assertFalse(signup.password.verify(''))
self.assertFalse(signup.password.verify('wrongpassword'))
self.assertTrue(signup.set_password('notsecret'))
self.assertTrue(signup.password.verify('notsecret'))
self.assertFalse(signup.password.verify('wrongpassword'))
def test_expired(self):
# TODO: Find a better way to test this!
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assertFalse(signup.expired)
signup.created = created=datetime.datetime.utcnow() - datetime.timedelta(hours=49)
self.assertTrue(signup.expired)
def test_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assertFalse(signup.completed)
signup.finish('notsecret')
db.session.commit()
self.assertTrue(signup.completed)
signup = refetch_signup(signup)
self.assertTrue(signup.completed)
def test_validate(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_valid(signup)
self.assert_validate_valid(refetch_signup(signup))
def test_validate_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_expired(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com',
password='notsecret', created=datetime.datetime.utcnow()-datetime.timedelta(hours=49))
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_loginname(self):
signup = Signup(loginname='', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_displayname(self):
signup = Signup(loginname='newuser', displayname='', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_mail(self):
signup = Signup(loginname='newuser', displayname='New User', mail='', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_password(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assertFalse(signup.set_password(''))
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_exists(self):
signup = Signup(loginname='testuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_finish(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
user = User.query.filter_by(loginname='newuser').one_or_none()
self.assertEqual(user.loginname, 'newuser')
self.assertEqual(user.displayname, 'New User')
self.assertEqual(user.primary_email.address, 'new@example.com')
def test_finish_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_expired(self):
# TODO: Find a better way to test this!
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com',
password='notsecret', created=datetime.datetime.utcnow()-datetime.timedelta(hours=49))
self.assert_finish_failure(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_wrongpassword(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assert_finish_failure(signup, '')
self.assert_finish_failure(signup, 'wrongpassword')
signup = refetch_signup(signup)
self.assert_finish_failure(signup, '')
self.assert_finish_failure(signup, 'wrongpassword')
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_failure(signup, 'wrongpassword')
self.assert_finish_failure(refetch_signup(signup), 'wrongpassword')
def test_finish_duplicate(self):
signup = Signup(loginname='testuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_failure(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_duplicate_email_strict_uniqueness(self):
FeatureFlag.unique_email_addresses.enable()
db.session.commit()
signup = Signup(loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
self.assert_finish_failure(signup, 'notsecret')
def test_duplicate(self):
signup = Signup(loginname='newuser', displayname='New User', mail='test1@example.com', password='notsecret')
self.assert_validate_valid(signup)
db.session.add(signup)
db.session.commit()
signup1_id = signup.id
signup = Signup(loginname='newuser', displayname='New User', mail='test2@example.com', password='notsecret')
self.assert_validate_valid(signup)
db.session.add(signup)
db.session.commit()
signup2_id = signup.id
db_flush()
signup = Signup.query.get(signup2_id)
self.assert_finish_success(signup, 'notsecret')
db.session.commit()
db_flush()
signup = Signup.query.get(signup1_id)
self.assert_finish_failure(signup, 'notsecret')
user = User.query.filter_by(loginname='newuser').one_or_none()
self.assertEqual(user.primary_email.address, 'test2@example.com')
import datetime
import sqlalchemy
from uffd.database import db
from uffd.models import User, UserEmail, Group, FeatureFlag, IDAlreadyAllocatedError, IDRangeExhaustedError
from tests.utils import UffdTestCase, ModelTestCase
class TestUserModel(UffdTestCase):
def test_has_permission(self):
user_ = self.get_user() # has 'users' and 'uffd_access' group
admin = self.get_admin() # has 'users', 'uffd_access' and 'uffd_admin' group
self.assertTrue(user_.has_permission(None))
self.assertTrue(admin.has_permission(None))
self.assertTrue(user_.has_permission('users'))
self.assertTrue(admin.has_permission('users'))
self.assertFalse(user_.has_permission('notagroup'))
self.assertFalse(admin.has_permission('notagroup'))
self.assertFalse(user_.has_permission('uffd_admin'))
self.assertTrue(admin.has_permission('uffd_admin'))
self.assertFalse(user_.has_permission(['uffd_admin']))
self.assertTrue(admin.has_permission(['uffd_admin']))
self.assertFalse(user_.has_permission(['uffd_admin', 'notagroup']))
self.assertTrue(admin.has_permission(['uffd_admin', 'notagroup']))
self.assertFalse(user_.has_permission(['notagroup', 'uffd_admin']))
self.assertTrue(admin.has_permission(['notagroup', 'uffd_admin']))
self.assertTrue(user_.has_permission(['uffd_admin', 'users']))
self.assertTrue(admin.has_permission(['uffd_admin', 'users']))
self.assertTrue(user_.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
self.assertTrue(admin.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
self.assertFalse(user_.has_permission(['uffd_admin', ['users', 'notagroup']]))
self.assertTrue(admin.has_permission(['uffd_admin', ['users', 'notagroup']]))
def test_unix_uid_generation(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 18999
self.app.config['USER_SERVICE_MIN_UID'] = 19000
self.app.config['USER_SERVICE_MAX_UID'] = 19999
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
user2 = User(loginname='user2', displayname='user2', primary_email_address='user2@example.com')
db.session.add_all([user0, user1, user2])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(user1.unix_uid, 10001)
self.assertEqual(user2.unix_uid, 10002)
db.session.delete(user1)
db.session.commit()
user3 = User(loginname='user3', displayname='user3', primary_email_address='user3@example.com')
db.session.add(user3)
db.session.commit()
self.assertEqual(user3.unix_uid, 10003)
db.session.delete(user2)
db.session.commit()
user4 = User(loginname='user4', displayname='user4', primary_email_address='user4@example.com')
db.session.add(user4)
db.session.commit()
self.assertEqual(user4.unix_uid, 10004)
service0 = User(loginname='service0', displayname='service0', primary_email_address='service0@example.com', is_service_user=True)
service1 = User(loginname='service1', displayname='service1', primary_email_address='service1@example.com', is_service_user=True)
db.session.add_all([service0, service1])
db.session.commit()
self.assertEqual(service0.unix_uid, 19000)
self.assertEqual(service1.unix_uid, 19001)
def test_unix_uid_generation_overlapping(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 19999
self.app.config['USER_SERVICE_MIN_UID'] = 10000
self.app.config['USER_SERVICE_MAX_UID'] = 19999
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
service0 = User(loginname='service0', displayname='service0', primary_email_address='service0@example.com', is_service_user=True)
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
db.session.add_all([user0, service0, user1])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(service0.unix_uid, 10001)
self.assertEqual(user1.unix_uid, 10002)
def test_unix_uid_generation_overflow(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 10001
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
db.session.add_all([user0, user1])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(user1.unix_uid, 10001)
with self.assertRaises(sqlalchemy.exc.StatementError):
user2 = User(loginname='user2', displayname='user2', primary_email_address='user2@example.com')
db.session.add(user2)
db.session.commit()
def test_init_primary_email_address(self):
user = User(primary_email_address='foobar@example.com')
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(user.primary_email.verified, True)
self.assertEqual(user.primary_email.user, user)
user = User(primary_email_address='invalid')
self.assertEqual(user.primary_email.address, 'invalid')
self.assertEqual(user.primary_email.verified, True)
self.assertEqual(user.primary_email.user, user)
def test_set_primary_email_address(self):
user = User()
self.assertFalse(user.set_primary_email_address('invalid'))
self.assertIsNone(user.primary_email)
self.assertEqual(len(user.all_emails), 0)
self.assertTrue(user.set_primary_email_address('foobar@example.com'))
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(len(user.all_emails), 1)
self.assertFalse(user.set_primary_email_address('invalid'))
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(len(user.all_emails), 1)
self.assertTrue(user.set_primary_email_address('other@example.com'))
self.assertEqual(user.primary_email.address, 'other@example.com')
self.assertEqual(len(user.all_emails), 2)
self.assertEqual({user.all_emails[0].address, user.all_emails[1].address}, {'foobar@example.com', 'other@example.com'})
class TestUserEmailModel(UffdTestCase):
def test_normalize_address(self):
ref = UserEmail.normalize_address('foo@example.com')
self.assertEqual(ref, UserEmail.normalize_address('foo@example.com'))
self.assertEqual(ref, UserEmail.normalize_address('Foo@Example.Com'))
self.assertEqual(ref, UserEmail.normalize_address(' foo@example.com '))
self.assertNotEqual(ref, UserEmail.normalize_address('bar@example.com'))
self.assertNotEqual(ref, UserEmail.normalize_address('foo @example.com'))
# "No-Break Space" instead of SPACE (Unicode normalization + stripping)
self.assertEqual(ref, UserEmail.normalize_address('\u00A0foo@example.com '))
# Pre-composed "Angstrom Sign" vs. "A" + "Combining Ring Above" (Unicode normalization)
self.assertEqual(UserEmail.normalize_address('\u212B@example.com'), UserEmail.normalize_address('A\u030A@example.com'))
def test_address(self):
email = UserEmail()
self.assertIsNone(email.address)
self.assertIsNone(email.address_normalized)
email.address = 'Foo@example.com'
self.assertEqual(email.address, 'Foo@example.com')
self.assertEqual(email.address_normalized, UserEmail.normalize_address('Foo@example.com'))
with self.assertRaises(ValueError):
email.address = 'bar@example.com'
with self.assertRaises(ValueError):
email.address = None
def test_set_address(self):
email = UserEmail()
self.assertFalse(email.set_address('invalid'))
self.assertIsNone(email.address)
self.assertFalse(email.set_address(''))
self.assertFalse(email.set_address('@'))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertFalse(email.set_address('foobar@remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser @ remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@REMAILER.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@foobar@remailer.example.com'))
self.assertTrue(email.set_address('foobar@example.com'))
self.assertEqual(email.address, 'foobar@example.com')
def test_verified(self):
email = UserEmail(user=self.get_user(), address='foo@example.com')
db.session.add(email)
self.assertEqual(email.verified, False)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=True).count(), 0)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=False).count(), 1)
email.verified = True
self.assertEqual(email.verified, True)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=True).count(), 1)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=False).count(), 0)
with self.assertRaises(ValueError):
email.verified = False
self.assertEqual(email.verified, True)
with self.assertRaises(ValueError):
email.verified = None
self.assertEqual(email.verified, True)
def test_verification(self):
email = UserEmail(address='foo@example.com')
self.assertFalse(email.finish_verification('test'))
secret = email.start_verification()
self.assertTrue(email.verification_secret)
self.assertTrue(email.verification_secret.verify(secret))
self.assertFalse(email.verification_expired)
self.assertFalse(email.finish_verification('test'))
orig_expires = email.verification_expires
email.verification_expires = datetime.datetime.utcnow() - datetime.timedelta(days=1)
self.assertFalse(email.finish_verification(secret))
email.verification_expires = orig_expires
self.assertTrue(email.finish_verification(secret))
self.assertFalse(email.verification_secret)
self.assertTrue(email.verification_expired)
def test_enable_strict_constraints(self):
email = UserEmail(address='foo@example.com', user=self.get_user())
db.session.add(email)
db.session.commit()
self.assertIsNone(email.enable_strict_constraints)
FeatureFlag.unique_email_addresses.enable()
self.assertTrue(email.enable_strict_constraints)
FeatureFlag.unique_email_addresses.disable()
self.assertIsNone(email.enable_strict_constraints)
def assert_can_add_address(self, **kwargs):
user_email = UserEmail(**kwargs)
db.session.add(user_email)
db.session.commit()
db.session.delete(user_email)
db.session.commit()
def assert_cannot_add_address(self, **kwargs):
with self.assertRaises(sqlalchemy.exc.IntegrityError):
db.session.add(UserEmail(**kwargs))
db.session.commit()
db.session.rollback()
def test_unique_constraints_old(self):
# The same user cannot add the same exact address multiple times, but
# different users can have the same address
user = self.get_user()
admin = self.get_admin()
db.session.add(UserEmail(user=user, address='foo@example.com'))
db.session.add(UserEmail(user=user, address='bar@example.com', verified=True))
db.session.commit()
self.assert_can_add_address(user=user, address='foobar@example.com')
self.assert_can_add_address(user=user, address='foobar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='foo@example.com')
self.assert_can_add_address(user=user, address='FOO@example.com')
self.assert_cannot_add_address(user=user, address='bar@example.com')
self.assert_can_add_address(user=user, address='BAR@example.com')
self.assert_cannot_add_address(user=user, address='foo@example.com', verified=True)
self.assert_can_add_address(user=user, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='bar@example.com', verified=True)
self.assert_can_add_address(user=user, address='BAR@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foobar@example.com')
self.assert_can_add_address(user=admin, address='foobar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foo@example.com')
self.assert_can_add_address(user=admin, address='FOO@example.com')
self.assert_can_add_address(user=admin, address='bar@example.com')
self.assert_can_add_address(user=admin, address='BAR@example.com')
self.assert_can_add_address(user=admin, address='foo@example.com', verified=True)
self.assert_can_add_address(user=admin, address='FOO@example.com', verified=True)
self.assert_can_add_address(user=admin, address='bar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='BAR@example.com', verified=True)
def test_unique_constraints_strict(self):
FeatureFlag.unique_email_addresses.enable()
# The same user cannot add the same (normalized) address multiple times,
# and different users cannot have the same verified (normalized) address
user = self.get_user()
admin = self.get_admin()
db.session.add(UserEmail(user=user, address='foo@example.com'))
db.session.add(UserEmail(user=user, address='bar@example.com', verified=True))
db.session.commit()
self.assert_can_add_address(user=user, address='foobar@example.com')
self.assert_can_add_address(user=user, address='foobar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='foo@example.com')
self.assert_cannot_add_address(user=user, address='FOO@example.com')
self.assert_cannot_add_address(user=user, address='bar@example.com')
self.assert_cannot_add_address(user=user, address='BAR@example.com')
self.assert_cannot_add_address(user=user, address='foo@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='bar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='BAR@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foobar@example.com')
self.assert_can_add_address(user=admin, address='foobar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foo@example.com')
self.assert_can_add_address(user=admin, address='FOO@example.com')
self.assert_can_add_address(user=admin, address='bar@example.com')
self.assert_can_add_address(user=admin, address='BAR@example.com')
self.assert_can_add_address(user=admin, address='foo@example.com', verified=True)
self.assert_can_add_address(user=admin, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=admin, address='bar@example.com', verified=True)
self.assert_cannot_add_address(user=admin, address='BAR@example.com', verified=True)
class TestIDAllocator(ModelTestCase):
def allocate_gids(self, *gids):
for gid in gids:
Group.unix_gid_allocator.allocate(gid)
def fetch_gid_allocations(self):
return [row[0] for row in db.session.execute(
db.select([Group.unix_gid_allocator.allocation_table])
.order_by(Group.unix_gid_allocator.allocation_table.c.id)
).fetchall()]
def test_empty(self):
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000])
def test_first(self):
self.allocate_gids(20000)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20001)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001])
def test_out_of_range_before(self):
self.allocate_gids(19998)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [19998, 20000])
def test_out_of_range_right_before(self):
self.allocate_gids(19999)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [19999, 20000])
def test_out_of_range_after(self):
self.allocate_gids(20006)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20006])
def test_gap_at_beginning(self):
self.allocate_gids(20001)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001])
def test_multiple_gaps(self):
self.allocate_gids(20000, 20001, 20003, 20005)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20002)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20005])
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20004)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_last(self):
self.allocate_gids(20000, 20001, 20002, 20003, 20004)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20005)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_overflow(self):
self.allocate_gids(20000, 20001, 20002, 20003, 20004, 20005)
with self.assertRaises(IDRangeExhaustedError):
Group.unix_gid_allocator.auto(20000, 20005)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_conflict(self):
self.allocate_gids(20000)
with self.assertRaises(IDAlreadyAllocatedError):
self.allocate_gids(20000)
self.assertEqual(self.fetch_gid_allocations(), [20000])
class TestGroup(ModelTestCase):
def test_unix_gid_generation(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 49999
group0 = Group(name='group0', description='group0')
group1 = Group(name='group1', description='group1')
group2 = Group(name='group2', description='group2')
group3 = Group(name='group3', description='group3', unix_gid=20004)
db.session.add_all([group0, group1, group2, group3])
db.session.commit()
self.assertEqual(group0.unix_gid, 20000)
self.assertEqual(group1.unix_gid, 20001)
self.assertEqual(group2.unix_gid, 20002)
self.assertEqual(group3.unix_gid, 20004)
db.session.delete(group2)
db.session.commit()
group4 = Group(name='group4', description='group4')
group5 = Group(name='group5', description='group5')
db.session.add_all([group4, group5])
db.session.commit()
self.assertEqual(group4.unix_gid, 20003)
self.assertEqual(group5.unix_gid, 20005)
def test_unix_gid_generation_conflict(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 49999
group0 = Group(name='group0', description='group0', unix_gid=20023)
db.session.add(group0)
db.session.commit()
with self.assertRaises(IDAlreadyAllocatedError):
Group(name='group1', description='group1', unix_gid=20023)
def test_unix_gid_generation_overflow(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 20001
group0 = Group(name='group0', description='group0')
group1 = Group(name='group1', description='group1')
db.session.add_all([group0, group1])
db.session.commit()
self.assertEqual(group0.unix_gid, 20000)
self.assertEqual(group1.unix_gid, 20001)
db.session.commit()
with self.assertRaises(sqlalchemy.exc.StatementError):
group2 = Group(name='group2', description='group2')
db.session.add(group2)
db.session.commit()
import unittest
from flask import Flask, Blueprint, session, url_for
from uffd.csrf import bp as csrf_bp, csrf_protect
uid_counter = 0
class TestCSRF(unittest.TestCase):
unprotected_ep = 'foo'
protected_ep = 'bar'
def setUp(self):
self.app = Flask(__name__)
self.app.testing = True
self.app.config['SECRET_KEY'] = 'DEBUGKEY'
self.app.register_blueprint(csrf_bp)
@self.app.route('/', methods=['GET', 'POST'])
def index():
return 'SUCCESS', 200
@self.app.route('/login', methods=['GET', 'POST'])
def login():
global uid_counter
session['_csrf_token'] = 'secret_csrf_token%d'%uid_counter
uid_counter += 1
return 'Ok', 200
@self.app.route('/logout', methods=['GET', 'POST'])
def logout():
session.clear()
return 'Ok', 200
@self.app.route('/foo', methods=['GET', 'POST'])
def foo():
return 'SUCCESS', 200
@self.app.route('/bar', methods=['GET', 'POST'])
@csrf_protect()
def bar():
return 'SUCCESS', 200
self.bp = Blueprint('bp', __name__)
@self.bp.route('/foo', methods=['GET', 'POST'])
@csrf_protect(blueprint=self.bp) # This time on .foo and not on .bar!
def foo():
return 'SUCCESS', 200
@self.bp.route('/bar', methods=['GET', 'POST'])
def bar():
return 'SUCCESS', 200
self.app.register_blueprint(self.bp, url_prefix='/bp/')
self.client = self.app.test_client()
self.client.__enter__()
# Just do some request so that we can use url_for
self.client.get(path='/')
def tearDown(self):
self.client.__exit__(None, None, None)
def set_token(self):
self.client.get(path='/login')
def clear_token(self):
self.client.get(path='/logout')
def test_notoken_unprotected(self):
url = url_for(self.unprotected_ep)
self.assertTrue('csrf' not in url)
self.assertEqual(self.client.get(path=url).data, b'SUCCESS')
def test_token_unprotected(self):
self.set_token()
self.test_notoken_unprotected()
def test_notoken_protected(self):
url = url_for(self.protected_ep)
self.assertNotEqual(self.client.get(path=url).data, b'SUCCESS')
def test_token_protected(self):
self.set_token()
url = url_for(self.protected_ep)
self.assertEqual(self.client.get(path=url).data, b'SUCCESS')
def test_wrong_token_protected(self):
self.set_token()
url = url_for(self.protected_ep)
self.set_token()
self.assertNotEqual(self.client.get(path=url).data, b'SUCCESS')
def test_deleted_token_protected(self):
self.set_token()
url = url_for(self.protected_ep)
self.clear_token()
self.assertNotEqual(self.client.get(path=url).data, b'SUCCESS')
class TestBlueprintCSRF(TestCSRF):
unprotected_ep = 'bp.bar'
protected_ep = 'bp.foo'
import unittest
from uffd.password_hash import *
class TestPasswordHashRegistry(unittest.TestCase):
def test(self):
registry = PasswordHashRegistry()
@registry.register
class TestPasswordHash:
METHOD_NAME = 'test'
def __init__(self, value, **kwargs):
self.value = value
self.kwargs = kwargs
@registry.register
class Test2PasswordHash:
METHOD_NAME = 'test2'
result = registry.parse('{test}data', key='value')
self.assertIsInstance(result, TestPasswordHash)
self.assertEqual(result.value, '{test}data')
self.assertEqual(result.kwargs, {'key': 'value'})
with self.assertRaises(ValueError):
registry.parse('{invalid}data')
with self.assertRaises(ValueError):
registry.parse('invalid')
with self.assertRaises(ValueError):
registry.parse('{invalid')
class TestPasswordHash(unittest.TestCase):
def setUp(self):
class TestPasswordHash(PasswordHash):
@classmethod
def from_password(cls, password):
cls(build_value(cls.METHOD_NAME, password))
def verify(self, password):
return self.data == password
class TestPasswordHash1(TestPasswordHash):
METHOD_NAME = 'test1'
class TestPasswordHash2(TestPasswordHash):
METHOD_NAME = 'test2'
self.TestPasswordHash1 = TestPasswordHash1
self.TestPasswordHash2 = TestPasswordHash2
def test(self):
obj = self.TestPasswordHash1('{test1}data')
self.assertEqual(obj.value, '{test1}data')
self.assertEqual(obj.data, 'data')
self.assertIs(obj.target_cls, self.TestPasswordHash1)
self.assertFalse(obj.needs_rehash)
def test_invalid(self):
with self.assertRaises(ValueError):
self.TestPasswordHash1('invalid')
with self.assertRaises(ValueError):
self.TestPasswordHash1('{invalid}data')
with self.assertRaises(ValueError):
self.TestPasswordHash1('{test2}data')
def test_target_cls(self):
obj = self.TestPasswordHash1('{test1}data', target_cls=self.TestPasswordHash1)
self.assertEqual(obj.value, '{test1}data')
self.assertEqual(obj.data, 'data')
self.assertIs(obj.target_cls, self.TestPasswordHash1)
self.assertFalse(obj.needs_rehash)
obj = self.TestPasswordHash1('{test1}data', target_cls=self.TestPasswordHash2)
self.assertEqual(obj.value, '{test1}data')
self.assertEqual(obj.data, 'data')
self.assertIs(obj.target_cls, self.TestPasswordHash2)
self.assertTrue(obj.needs_rehash)
obj = self.TestPasswordHash1('{test1}data', target_cls=PasswordHash)
self.assertEqual(obj.value, '{test1}data')
self.assertEqual(obj.data, 'data')
self.assertIs(obj.target_cls, PasswordHash)
self.assertFalse(obj.needs_rehash)
class TestPlaintextPasswordHash(unittest.TestCase):
def test_verify(self):
obj = PlaintextPasswordHash('{plain}password')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
def test_from_password(self):
obj = PlaintextPasswordHash.from_password('password')
self.assertEqual(obj.value, '{plain}password')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
class TestHashlibPasswordHash(unittest.TestCase):
def test_verify(self):
obj = SHA512PasswordHash('{sha512}sQnzu7wkTrgkQZF+0G1hi5AI3Qmzvv0bXgc5THBqi7mAsdd4Xll27ASbRt9fEyavWi6m0QP9B8lThf+rDKy8hg==')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
def test_from_password(self):
obj = SHA512PasswordHash.from_password('password')
self.assertIsNotNone(obj.value)
self.assertTrue(obj.value.startswith('{sha512}'))
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
class TestSaltedHashlibPasswordHash(unittest.TestCase):
def test_verify(self):
obj = SaltedSHA512PasswordHash('{ssha512}dOeDLmVpHJThhHeag10Hm2g4T7s3SBE6rGHcXUolXJHVufY4qT782rwZ/0XE6cuLcBZ0KpnwmUzRpAEtZBdv+JYEEtZQs/uC')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
def test_from_password(self):
obj = SaltedSHA512PasswordHash.from_password('password')
self.assertIsNotNone(obj.value)
self.assertTrue(obj.value.startswith('{ssha512}'))
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
class TestCryptPasswordHash(unittest.TestCase):
def test_verify(self):
obj = CryptPasswordHash('{crypt}$5$UbTTMBH9NRurlQcX$bUiUTyedvmArlVt.62ZLRV80e2v3DjcBp/tSDkP2imD')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
def test_from_password(self):
obj = CryptPasswordHash.from_password('password')
self.assertIsNotNone(obj.value)
self.assertTrue(obj.value.startswith('{crypt}'))
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
class TestArgon2PasswordHash(unittest.TestCase):
def test_verify(self):
obj = Argon2PasswordHash('{argon2}$argon2id$v=19$m=102400,t=2,p=8$Jc8LpCgPLjwlN/7efHLvwQ$ZqSg3CFb2/hBb3X8hOq4aw')
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
obj = Argon2PasswordHash('{argon2}$invalid$')
self.assertFalse(obj.verify('password'))
def test_from_password(self):
obj = Argon2PasswordHash.from_password('password')
self.assertIsNotNone(obj.value)
self.assertTrue(obj.value.startswith('{argon2}'))
self.assertTrue(obj.verify('password'))
self.assertFalse(obj.verify('notpassword'))
def test_needs_rehash(self):
obj = Argon2PasswordHash('{argon2}$argon2id$v=19$m=102400,t=2,p=8$Jc8LpCgPLjwlN/7efHLvwQ$ZqSg3CFb2/hBb3X8hOq4aw')
self.assertFalse(obj.needs_rehash)
obj = Argon2PasswordHash('{argon2}$argon2id$v=19$m=102400,t=2,p=8$Jc8LpCgPLjwlN/7efHLvwQ$ZqSg3CFb2/hBb3X8hOq4aw', target_cls=PlaintextPasswordHash)
self.assertTrue(obj.needs_rehash)
obj = Argon2PasswordHash('{argon2}$argon2d$v=19$m=102400,t=2,p=8$kshPgLU1+h72l/Z8QWh8Ig$tYerKCe/5I2BCPKu8hCl2w')
self.assertTrue(obj.needs_rehash)
obj = Argon2PasswordHash('{argon2}$argon2id$v=19$m=102400,t=1,p=8$aa6i4vg/szKX5xHVGFaAeQ$v6j0ltuVqQaZlmuepaVJ1A')
self.assertTrue(obj.needs_rehash)
class TestInvalidPasswordHash(unittest.TestCase):
def test(self):
obj = InvalidPasswordHash('test')
self.assertEqual(obj.value, 'test')
self.assertFalse(obj.verify('test'))
self.assertTrue(obj.needs_rehash)
self.assertFalse(obj)
obj = InvalidPasswordHash(None)
self.assertIsNone(obj.value)
self.assertFalse(obj.verify('test'))
self.assertTrue(obj.needs_rehash)
self.assertFalse(obj)
class TestPasswordWrapper(unittest.TestCase):
def setUp(self):
class Test:
password_hash = None
password = PasswordHashAttribute('password_hash', PlaintextPasswordHash)
self.test = Test()
def test_get_none(self):
self.test.password_hash = None
obj = self.test.password
self.assertIsInstance(obj, InvalidPasswordHash)
self.assertEqual(obj.value, None)
self.assertTrue(obj.needs_rehash)
def test_get_valid(self):
self.test.password_hash = '{plain}password'
obj = self.test.password
self.assertIsInstance(obj, PlaintextPasswordHash)
self.assertEqual(obj.value, '{plain}password')
self.assertFalse(obj.needs_rehash)
def test_get_needs_rehash(self):
self.test.password_hash = '{ssha512}dOeDLmVpHJThhHeag10Hm2g4T7s3SBE6rGHcXUolXJHVufY4qT782rwZ/0XE6cuLcBZ0KpnwmUzRpAEtZBdv+JYEEtZQs/uC'
obj = self.test.password
self.assertIsInstance(obj, SaltedSHA512PasswordHash)
self.assertEqual(obj.value, '{ssha512}dOeDLmVpHJThhHeag10Hm2g4T7s3SBE6rGHcXUolXJHVufY4qT782rwZ/0XE6cuLcBZ0KpnwmUzRpAEtZBdv+JYEEtZQs/uC')
self.assertTrue(obj.needs_rehash)
def test_set(self):
self.test.password = 'password'
self.assertEqual(self.test.password_hash, '{plain}password')
def test_set_none(self):
self.test.password = None
self.assertIsNone(self.test.password_hash)
from uffd.models.ratelimit import get_addrkey, format_delay, Ratelimit
from tests.utils import UffdTestCase
class TestRatelimit(UffdTestCase):
def test_limiting(self):
cases = [
(1*60, 3),
(1*60*60, 3),
(1*60*60, 25),
]
for index, case in enumerate(cases):
interval, limit = case
key = str(index)
ratelimit = Ratelimit('test', interval, limit)
for i in range(limit):
ratelimit.log(key)
self.assertLessEqual(ratelimit.get_delay(key), interval)
ratelimit.log(key)
self.assertGreater(ratelimit.get_delay(key), interval)
def test_addrkey(self):
self.assertEqual(get_addrkey('192.168.0.1'), get_addrkey('192.168.0.99'))
self.assertNotEqual(get_addrkey('192.168.0.1'), get_addrkey('192.168.1.1'))
self.assertEqual(get_addrkey('fdee:707a:f38a:c369::'), get_addrkey('fdee:707a:f38a:ffff::'))
self.assertNotEqual(get_addrkey('fdee:707a:f38a:c369::'), get_addrkey('fdee:707a:f38b:c369::'))
cases = [
'',
'192.168.0.',
':',
'::',
'192.168.0.1/24',
'192.168.0.1/24',
'host.example.com',
]
for case in cases:
self.assertIsInstance(get_addrkey(case), str)
def test_format_delay(self):
self.assertIsInstance(format_delay(0), str)
self.assertIsInstance(format_delay(1), str)
self.assertIsInstance(format_delay(30), str)
self.assertIsInstance(format_delay(60), str)
self.assertIsInstance(format_delay(120), str)
self.assertIsInstance(format_delay(3600), str)
self.assertIsInstance(format_delay(4000), str)
from uffd.remailer import remailer
from tests.utils import UffdTestCase
USER_ID = 1234
SERVICE1_ID = 4223
SERVICE2_ID = 3242
ADDR_V1_S1 = 'v1-WzQyMjMsMTIzNF0.MeO6bHGTgIyPvvq2r3xriokLMCU@remailer.example.com'
ADDR_V1_S2 = 'v1-WzMyNDIsMTIzNF0.p2a_RkJc0oHBc9u4_S8G9METflA@remailer.example.com'
ADDR_V2_S1 = 'v2-lm2demrtfqytemzulu-ghr3u3drsoaizd567k3k67dlrkeqwmbf@remailer.example.com'
ADDR_V2_S2 = 'v2-lmztenbsfqytemzulu-u5tl6rscltjidqlt3o4p2lyg6targ7sq@remailer.example.com'
class TestRemailer(UffdTestCase):
def test_is_remailer_domain(self):
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertTrue(remailer.is_remailer_domain('remailer.example.com'))
self.assertTrue(remailer.is_remailer_domain('REMAILER.EXAMPLE.COM'))
self.assertTrue(remailer.is_remailer_domain(' remailer.example.com '))
self.assertFalse(remailer.is_remailer_domain('other.remailer.example.com'))
self.assertFalse(remailer.is_remailer_domain('example.com'))
self.app.config['REMAILER_OLD_DOMAINS'] = [' OTHER.remailer.example.com ']
self.assertTrue(remailer.is_remailer_domain(' OTHER.remailer.example.com '))
self.assertTrue(remailer.is_remailer_domain('remailer.example.com'))
self.assertTrue(remailer.is_remailer_domain('other.remailer.example.com'))
self.assertFalse(remailer.is_remailer_domain('example.com'))
def test_build_v1_address(self):
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertEqual(remailer.build_v1_address(SERVICE1_ID, USER_ID), ADDR_V1_S1)
self.assertEqual(remailer.build_v1_address(SERVICE2_ID, USER_ID), ADDR_V1_S2)
long_addr = remailer.build_v1_address(1000, 1000000)
self.assertLessEqual(len(long_addr.split('@')[0]), 64)
self.assertLessEqual(len(long_addr), 256)
self.app.config['REMAILER_OLD_DOMAINS'] = ['old.remailer.example.com']
self.assertEqual(remailer.build_v1_address(SERVICE1_ID, USER_ID), ADDR_V1_S1)
self.app.config['REMAILER_SECRET_KEY'] = self.app.config['SECRET_KEY']
self.assertEqual(remailer.build_v1_address(SERVICE1_ID, USER_ID), ADDR_V1_S1)
self.app.config['REMAILER_SECRET_KEY'] = 'REMAILER-DEBUGKEY'
self.assertNotEqual(remailer.build_v1_address(SERVICE1_ID, USER_ID), ADDR_V1_S1)
def test_build_v2_address(self):
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertEqual(remailer.build_v2_address(SERVICE1_ID, USER_ID), ADDR_V2_S1)
self.assertEqual(remailer.build_v2_address(SERVICE2_ID, USER_ID), ADDR_V2_S2)
long_addr = remailer.build_v2_address(1000, 1000000)
self.assertLessEqual(len(long_addr.split('@')[0]), 64)
self.assertLessEqual(len(long_addr), 256)
self.app.config['REMAILER_OLD_DOMAINS'] = ['old.remailer.example.com']
self.assertEqual(remailer.build_v2_address(SERVICE1_ID, USER_ID), ADDR_V2_S1)
self.app.config['REMAILER_SECRET_KEY'] = self.app.config['SECRET_KEY']
self.assertEqual(remailer.build_v2_address(SERVICE1_ID, USER_ID), ADDR_V2_S1)
self.app.config['REMAILER_SECRET_KEY'] = 'REMAILER-DEBUGKEY'
self.assertNotEqual(remailer.build_v2_address(SERVICE1_ID, USER_ID), ADDR_V2_S1)
def test_parse_address(self):
# REMAILER_DOMAIN behaviour
self.app.config['REMAILER_DOMAIN'] = None
self.assertIsNone(remailer.parse_address(ADDR_V1_S2))
self.assertIsNone(remailer.parse_address(ADDR_V2_S2))
self.assertIsNone(remailer.parse_address('foo@example.com'))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertEqual(remailer.parse_address(ADDR_V1_S2), (SERVICE2_ID, USER_ID))
self.assertEqual(remailer.parse_address(ADDR_V2_S2), (SERVICE2_ID, USER_ID))
self.assertIsNone(remailer.parse_address('foo@example.com'))
self.assertIsNone(remailer.parse_address('foo@remailer.example.com'))
self.assertIsNone(remailer.parse_address('v1-foo@remailer.example.com'))
self.assertIsNone(remailer.parse_address('v2-foo@remailer.example.com'))
self.assertIsNone(remailer.parse_address('v2-foo-bar@remailer.example.com'))
self.app.config['REMAILER_DOMAIN'] = 'new-remailer.example.com'
self.assertIsNone(remailer.parse_address(ADDR_V1_S2))
self.assertIsNone(remailer.parse_address(ADDR_V2_S2))
self.app.config['REMAILER_OLD_DOMAINS'] = ['remailer.example.com']
self.assertEqual(remailer.parse_address(ADDR_V1_S2), (SERVICE2_ID, USER_ID))
self.assertEqual(remailer.parse_address(ADDR_V2_S2), (SERVICE2_ID, USER_ID))
# REMAILER_SECRET_KEY behaviour
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.app.config['REMAILER_OLD_DOMAINS'] = []
self.assertEqual(remailer.parse_address(ADDR_V1_S2), (SERVICE2_ID, USER_ID))
self.assertEqual(remailer.parse_address(ADDR_V2_S2), (SERVICE2_ID, USER_ID))
self.app.config['REMAILER_SECRET_KEY'] = self.app.config['SECRET_KEY']
self.assertEqual(remailer.parse_address(ADDR_V1_S2), (SERVICE2_ID, USER_ID))
self.assertEqual(remailer.parse_address(ADDR_V2_S2), (SERVICE2_ID, USER_ID))
self.app.config['REMAILER_SECRET_KEY'] = 'REMAILER-DEBUGKEY'
self.assertIsNone(remailer.parse_address(ADDR_V1_S2))
self.assertIsNone(remailer.parse_address(ADDR_V2_S2))
import unittest
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from uffd.tasks import CleanupTask
class TestCleanupTask(unittest.TestCase):
def test(self):
app = Flask(__name__)
app.testing = True
app.debug = True
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'
db = SQLAlchemy(app)
cleanup_task = CleanupTask()
@cleanup_task.delete_by_attribute('delete_me')
class TestModel(db.Model):
id = db.Column(db.Integer(), primary_key=True, autoincrement=True)
delete_me = db.Column(db.Boolean(), default=False, nullable=False)
with app.test_request_context():
db.create_all()
db.session.add(TestModel(delete_me=True))
db.session.add(TestModel(delete_me=True))
db.session.add(TestModel(delete_me=True))
db.session.add(TestModel(delete_me=False))
db.session.add(TestModel(delete_me=False))
db.session.commit()
db.session.expire_all()
self.assertEqual(TestModel.query.count(), 5)
with app.test_request_context():
cleanup_task.run()
db.session.commit()
db.session.expire_all()
with app.test_request_context():
self.assertEqual(TestModel.query.count(), 2)
from uffd.utils import nopad_b32decode, nopad_b32encode, nopad_urlsafe_b64decode, nopad_urlsafe_b64encode
from tests.utils import UffdTestCase
class TestUtils(UffdTestCase):
def test_nopad_b32(self):
for n in range(0, 32):
self.assertEqual(b'X'*n, nopad_b32decode(nopad_b32encode(b'X'*n)))
def test_nopad_b64(self):
for n in range(0, 32):
self.assertEqual(b'X'*n, nopad_urlsafe_b64decode(nopad_urlsafe_b64encode(b'X'*n)))
import os
import unittest
from flask import url_for
import flask_migrate
from uffd import create_app, db
from uffd.models import User, Group, Mail
def dump(basename, resp):
basename = basename.replace('.', '_').replace('/', '_')
suffix = '.html'
root = os.environ.get('DUMP_PAGES')
if not root:
return
os.makedirs(root, exist_ok=True)
path = os.path.join(root, basename+suffix)
with open(path, 'wb') as f:
f.write(resp.data)
def db_flush():
db.session.rollback()
db.session.expire_all()
class AppTestCase(unittest.TestCase):
DISABLE_SQLITE_MEMORY_DB = False
def setUp(self):
config = {
'TESTING': True,
'DEBUG': True,
'SQLALCHEMY_DATABASE_URI': 'sqlite:///:memory:',
'SECRET_KEY': 'DEBUGKEY',
'MAIL_SKIP_SEND': True,
'SELF_SIGNUP': True,
}
if self.DISABLE_SQLITE_MEMORY_DB:
try:
os.remove('/tmp/uffd-migration-test-db.sqlite3')
except FileNotFoundError:
pass
config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:////tmp/uffd-migration-test-db.sqlite3'
if os.environ.get('TEST_WITH_MYSQL'):
import MySQLdb
conn = MySQLdb.connect(user='root', unix_socket='/var/run/mysqld/mysqld.sock')
cur = conn.cursor()
try:
cur.execute('DROP DATABASE uffd_tests')
except:
pass
cur.execute('CREATE DATABASE uffd_tests CHARACTER SET utf8mb4 COLLATE utf8mb4_nopad_bin')
conn.close()
config['SQLALCHEMY_DATABASE_URI'] = 'mysql+mysqldb:///uffd_tests?unix_socket=/var/run/mysqld/mysqld.sock&charset=utf8mb4'
self.app = create_app(config)
self.setUpApp()
def setUpApp(self):
pass
def tearDown(self):
if self.DISABLE_SQLITE_MEMORY_DB:
try:
os.remove('/tmp/uffd-migration-test-db.sqlite3')
except FileNotFoundError:
pass
class MigrationTestCase(AppTestCase):
DISABLE_SQLITE_MEMORY_DB = True
REVISION = None
def setUp(self):
super().setUp()
self.request_context = self.app.test_request_context()
self.request_context.__enter__()
if self.REVISION:
flask_migrate.upgrade(revision=self.REVISION + '-1')
def upgrade(self, revision='+1'):
db.session.commit()
flask_migrate.upgrade(revision=revision)
def downgrade(self, revision='-1'):
db.session.commit()
flask_migrate.downgrade(revision=revision)
def tearDown(self):
db.session.rollback()
self.request_context.__exit__(None, None, None)
super().tearDown()
class ModelTestCase(AppTestCase):
def setUp(self):
super().setUp()
self.request_context = self.app.test_request_context()
self.request_context.__enter__()
db.create_all()
db.session.commit()
def tearDown(self):
db.session.rollback()
self.request_context.__exit__(None, None, None)
super().tearDown()
class UffdTestCase(AppTestCase):
def setUp(self):
super().setUp()
self.client = self.app.test_client()
self.client.__enter__()
# Just do some request so that we can use url_for
self.client.get(path='/')
db.create_all()
# This reflects the old LDAP example data
users_group = Group(name='users', unix_gid=20001, description='Base group for all users')
db.session.add(users_group)
access_group = Group(name='uffd_access', unix_gid=20002, description='Access to Single-Sign-On and Selfservice')
db.session.add(access_group)
admin_group = Group(name='uffd_admin', unix_gid=20003, description='Admin access to uffd')
db.session.add(admin_group)
testuser = User(loginname='testuser', unix_uid=10000, password='userpassword', primary_email_address='test@example.com', displayname='Test User', groups=[users_group, access_group])
db.session.add(testuser)
testadmin = User(loginname='testadmin', unix_uid=10001, password='adminpassword', primary_email_address='admin@example.com', displayname='Test Admin', groups=[users_group, access_group, admin_group])
db.session.add(testadmin)
testmail = Mail(uid='test', receivers=['test1@example.com', 'test2@example.com'], destinations=['testuser@mail.example.com'])
db.session.add(testmail)
self.setUpDB()
db.session.commit()
def setUpDB(self):
pass
def tearDown(self):
self.client.__exit__(None, None, None)
super().tearDown()
def get_user(self):
return User.query.filter_by(loginname='testuser').one_or_none()
def get_admin(self):
return User.query.filter_by(loginname='testadmin').one_or_none()
def get_admin_group(self):
return Group.query.filter_by(name='uffd_admin').one_or_none()
def get_access_group(self):
return Group.query.filter_by(name='uffd_access').one_or_none()
def get_users_group(self):
return Group.query.filter_by(name='users').one_or_none()
def get_mail(self):
return Mail.query.filter_by(uid='test').one_or_none()
def login_as(self, user, ref=None):
# It is currently not possible to login while already logged in as another
# user, so make sure that we are not logged in first
self.client.get(path=url_for('session.logout'), follow_redirects=True)
loginname = None
password = None
if user == 'user':
loginname = 'testuser'
password = 'userpassword'
elif user == 'admin':
loginname = 'testadmin'
password = 'adminpassword'
return self.client.post(path=url_for('session.login', ref=ref),
data={'loginname': loginname, 'password': password}, follow_redirects=True)
This diff is collapsed.
This diff is collapsed.
import unittest
from flask import url_for
from uffd.database import db
from uffd.models import Mail
from tests.utils import dump, UffdTestCase
class TestMailViews(UffdTestCase):
def setUp(self):
super().setUp()
self.login_as('admin')
def test_index(self):
r = self.client.get(path=url_for('mail.index'), follow_redirects=True)
dump('mail_index', r)
self.assertEqual(r.status_code, 200)
def test_index_empty(self):
db.session.delete(self.get_mail())
db.session.commit()
self.assertIsNone(self.get_mail())
r = self.client.get(path=url_for('mail.index'), follow_redirects=True)
dump('mail_index_empty', r)
self.assertEqual(r.status_code, 200)
def test_show(self):
r = self.client.get(path=url_for('mail.show', mai_id=self.get_mail().id), follow_redirects=True)
dump('mail_show', r)
self.assertEqual(r.status_code, 200)
def test_new(self):
r = self.client.get(path=url_for('mail.show'), follow_redirects=True)
dump('mail_new', r)
self.assertEqual(r.status_code, 200)
def test_update(self):
m = self.get_mail()
self.assertIsNotNone(m)
self.assertEqual(m.uid, 'test')
self.assertEqual(sorted(m.receivers), ['test1@example.com', 'test2@example.com'])
self.assertEqual(sorted(m.destinations), ['testuser@mail.example.com'])
r = self.client.post(path=url_for('mail.update', mail_id=m.id),
data={'mail-uid': 'test1', 'mail-receivers': 'foo@bar.com\ntest@bar.com',
'mail-destinations': 'testuser@mail.example.com\ntestadmin@mail.example.com'}, follow_redirects=True)
dump('mail_update', r)
self.assertEqual(r.status_code, 200)
m = self.get_mail()
self.assertIsNotNone(m)
self.assertEqual(m.uid, 'test')
self.assertEqual(sorted(m.receivers), ['foo@bar.com', 'test@bar.com'])
self.assertEqual(sorted(m.destinations), ['testadmin@mail.example.com', 'testuser@mail.example.com'])
def test_create(self):
r = self.client.post(path=url_for('mail.update'),
data={'mail-uid': 'test1', 'mail-receivers': 'foo@bar.com\ntest@bar.com',
'mail-destinations': 'testuser@mail.example.com\ntestadmin@mail.example.com'}, follow_redirects=True)
dump('mail_create', r)
self.assertEqual(r.status_code, 200)
m = Mail.query.filter_by(uid='test1').one()
self.assertEqual(m.uid, 'test1')
self.assertEqual(sorted(m.receivers), ['foo@bar.com', 'test@bar.com'])
self.assertEqual(sorted(m.destinations), ['testadmin@mail.example.com', 'testuser@mail.example.com'])
@unittest.skip('We do not catch DB errors at the moment!') # TODO
def test_create_error(self):
r = self.client.post(path=url_for('mail.update'),
data={'mail-uid': 'test', 'mail-receivers': 'foo@bar.com\ntest@bar.com',
'mail-destinations': 'testuser@mail.example.com\ntestadmin@mail.example.com'}, follow_redirects=True)
dump('mail_create_error', r)
self.assertEqual(r.status_code, 200)
m = self.get_mail()
self.assertIsNotNone(m)
self.assertEqual(m.uid, 'test')
self.assertEqual(sorted(m.receivers), ['test1@example.com', 'test2@example.com'])
self.assertEqual(sorted(m.destinations), ['testuser@mail.example.com'])
def test_delete(self):
self.assertIsNotNone(self.get_mail())
r = self.client.get(path=url_for('mail.delete', mail_id=self.get_mail().id), follow_redirects=True)
dump('mail_delete', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(self.get_mail())
This diff is collapsed.