Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • uffd/uffd
  • rixx/uffd
  • thies/uffd
  • leona/uffd
  • enbewe/uffd
  • strifel/uffd
  • thies/uffd-2
7 results
Show changes
Showing
with 2005 additions and 859 deletions
from uffd.database import db
from uffd.models.misc import lock_table, Lock
from tests.utils import MigrationTestCase
user_table = db.table('user',
db.column('id'),
db.column('unix_uid'),
db.column('loginname'),
db.column('displayname'),
db.column('primary_email_id'),
db.column('is_service_user'),
)
user_email_table = db.table('user_email',
db.column('id'),
db.column('address'),
db.column('address_normalized'),
db.column('verified'),
)
group_table = db.table('group',
db.column('id'),
db.column('unix_gid'),
db.column('name'),
db.column('description')
)
uid_allocation_table = db.table('uid_allocation', db.column('id'))
gid_allocation_table = db.table('gid_allocation', db.column('id'))
class TestMigration(MigrationTestCase):
REVISION = 'aeb07202a6c8'
def setUpApp(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 10005
self.app.config['USER_SERVICE_MIN_UID'] = 10006
self.app.config['USER_SERVICE_MAX_UID'] = 10010
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 20005
def create_user(self, uid):
db.session.execute(db.insert(user_email_table).values(
address=f'email{uid}@example.com',
address_normalized=f'email{uid}@example.com',
verified=True
))
email_id = db.session.execute(
db.select([user_email_table.c.id])
.where(user_email_table.c.address == f'email{uid}@example.com')
).scalar()
db.session.execute(db.insert(user_table).values(
unix_uid=uid,
loginname=f'user{uid}',
displayname='user',
primary_email_id=email_id,
is_service_user=False
))
def create_group(self, gid):
db.session.execute(db.insert(group_table).values(unix_gid=gid, name=f'group{gid}', description=''))
def fetch_uid_allocations(self):
return [row[0] for row in db.session.execute(
db.select([uid_allocation_table])
.order_by(uid_allocation_table.c.id)
).fetchall()]
def fetch_gid_allocations(self):
return [row[0] for row in db.session.execute(
db.select([gid_allocation_table])
.order_by(gid_allocation_table.c.id)
).fetchall()]
def test_empty(self):
# No users/groups
self.upgrade()
self.assertEqual(self.fetch_uid_allocations(), [])
self.assertEqual(self.fetch_gid_allocations(), [])
def test_gid_first_minus_one(self):
self.create_group(19999)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [19999])
def test_gid_first(self):
self.create_group(20000)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [20000])
def test_gid_first_plus_one(self):
self.create_group(20001)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001])
def test_gid_last_minus_one(self):
self.create_group(20004)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004])
def test_gid_last(self):
self.create_group(20005)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_gid_last_plus_one(self):
self.create_group(20006)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [20006])
def test_gid_complex(self):
self.create_group(10)
self.create_group(20001)
self.create_group(20003)
self.create_group(20010)
self.upgrade()
self.assertEqual(self.fetch_gid_allocations(), [10, 20000, 20001, 20002, 20003, 20010])
# The code for UIDs is mostly the same as for GIDs, so we don't test all
# the edge cases again.
def test_uid_different_ranges(self):
self.create_user(10)
self.create_user(10000)
self.create_user(10002)
self.create_user(10007)
self.create_user(10009)
self.create_user(90000)
self.upgrade()
self.assertEqual(self.fetch_uid_allocations(), [10, 10000, 10001, 10002, 10006, 10007, 10008, 10009, 90000])
def test_uid_same_ranges(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 10010
self.app.config['USER_SERVICE_MIN_UID'] = 10000
self.app.config['USER_SERVICE_MAX_UID'] = 10010
self.create_user(10)
self.create_user(10000)
self.create_user(10002)
self.create_user(10007)
self.create_user(10009)
self.create_user(90000)
self.upgrade()
self.assertEqual(self.fetch_uid_allocations(), [10, 10000, 10001, 10002, 10003, 10004, 10005, 10006, 10007, 10008, 10009, 90000])
import datetime
from flask import current_app
from uffd.database import db
from uffd.models import Invite, InviteGrant, InviteSignup, User, Role, RoleGroup
from tests.utils import UffdTestCase, db_flush
class TestInviteModel(UffdTestCase):
def test_expire(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), creator=self.get_admin())
self.assertFalse(invite.expired)
self.assertTrue(invite.active)
invite.valid_until = datetime.datetime.utcnow() - datetime.timedelta(seconds=60)
self.assertTrue(invite.expired)
self.assertFalse(invite.active)
def test_void(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), single_use=False, creator=self.get_admin())
self.assertFalse(invite.voided)
self.assertTrue(invite.active)
invite.used = True
self.assertFalse(invite.voided)
self.assertTrue(invite.active)
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), single_use=True, creator=self.get_admin())
self.assertFalse(invite.voided)
self.assertTrue(invite.active)
invite.used = True
self.assertTrue(invite.voided)
self.assertFalse(invite.active)
def test_permitted(self):
role = Role(name='testrole')
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=True, roles=[role])
self.assertFalse(invite.permitted)
self.assertFalse(invite.active)
invite.creator = self.get_admin()
self.assertTrue(invite.permitted)
self.assertTrue(invite.active)
invite.creator.is_deactivated = True
self.assertFalse(invite.permitted)
self.assertFalse(invite.active)
invite.creator = self.get_user()
self.assertFalse(invite.permitted)
self.assertFalse(invite.active)
role.moderator_group = self.get_access_group()
current_app.config['ACL_SIGNUP_GROUP'] = 'uffd_access'
self.assertTrue(invite.permitted)
self.assertTrue(invite.active)
role.moderator_group = None
self.assertFalse(invite.permitted)
self.assertFalse(invite.active)
role.moderator_group = self.get_access_group()
current_app.config['ACL_SIGNUP_GROUP'] = 'uffd_admin'
self.assertFalse(invite.permitted)
self.assertFalse(invite.active)
def test_disable(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), creator=self.get_admin())
self.assertTrue(invite.active)
invite.disable()
self.assertFalse(invite.active)
def test_reset_disabled(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), creator=self.get_admin())
invite.disable()
self.assertFalse(invite.active)
invite.reset()
self.assertTrue(invite.active)
def test_reset_expired(self):
invite = Invite(valid_until=datetime.datetime.utcnow() - datetime.timedelta(seconds=60), creator=self.get_admin())
self.assertFalse(invite.active)
invite.reset()
self.assertFalse(invite.active)
def test_reset_single_use(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), single_use=False, creator=self.get_admin())
invite.used = True
invite.disable()
self.assertFalse(invite.active)
invite.reset()
self.assertTrue(invite.active)
def test_short_token(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), creator=self.get_admin())
db.session.add(invite)
db.session.commit()
self.assertTrue(len(invite.short_token) <= len(invite.token)/3)
class TestInviteGrantModel(UffdTestCase):
def test_success(self):
user = self.get_user()
group0 = self.get_access_group()
role0 = Role(name='baserole', groups={group0: RoleGroup(group=group0)})
db.session.add(role0)
user.roles.append(role0)
user.update_groups()
group1 = self.get_admin_group()
role1 = Role(name='testrole1', groups={group1: RoleGroup(group=group1)})
db.session.add(role1)
role2 = Role(name='testrole2')
db.session.add(role2)
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), roles=[role1, role2], creator=self.get_admin())
self.assertIn(role0, user.roles)
self.assertNotIn(role1, user.roles)
self.assertNotIn(role2, user.roles)
self.assertIn(group0, user.groups)
self.assertNotIn(group1, user.groups)
self.assertFalse(invite.used)
grant = InviteGrant(invite=invite, user=user)
success, msg = grant.apply()
self.assertTrue(success)
self.assertIn(role0, user.roles)
self.assertIn(role1, user.roles)
self.assertIn(role2, user.roles)
self.assertIn(group0, user.groups)
self.assertIn(group1, user.groups)
self.assertTrue(invite.used)
db.session.commit()
db_flush()
user = self.get_user()
self.assertIn('baserole', [role.name for role in user.roles_effective])
self.assertIn('testrole1', [role.name for role in user.roles])
self.assertIn('testrole2', [role.name for role in user.roles])
self.assertIn(self.get_access_group(), user.groups)
self.assertIn(self.get_admin_group(), user.groups)
def test_inactive(self):
user = self.get_user()
group = self.get_admin_group()
role = Role(name='testrole1', groups={group: RoleGroup(group=group)})
db.session.add(role)
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), roles=[role], single_use=True, used=True, creator=self.get_admin())
self.assertFalse(invite.active)
grant = InviteGrant(invite=invite, user=user)
success, msg = grant.apply()
self.assertFalse(success)
self.assertIsInstance(msg, str)
self.assertNotIn(role, user.roles)
self.assertNotIn(group, user.groups)
def test_no_roles(self):
user = self.get_user()
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), creator=self.get_admin())
self.assertTrue(invite.active)
grant = InviteGrant(invite=invite, user=user)
success, msg = grant.apply()
self.assertFalse(success)
self.assertIsInstance(msg, str)
def test_no_new_roles(self):
user = self.get_user()
role = Role(name='testrole1')
db.session.add(role)
user.roles.append(role)
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), roles=[role], creator=self.get_admin())
self.assertTrue(invite.active)
grant = InviteGrant(invite=invite, user=user)
success, msg = grant.apply()
self.assertFalse(success)
self.assertIsInstance(msg, str)
class TestInviteSignupModel(UffdTestCase):
def create_base_roles(self):
baserole = Role(name='base', is_default=True)
baserole.groups[self.get_access_group()] = RoleGroup()
baserole.groups[self.get_users_group()] = RoleGroup()
db.session.add(baserole)
db.session.commit()
def test_success(self):
self.create_base_roles()
base_role = Role.query.filter_by(name='base').one()
base_group1 = self.get_access_group()
base_group2 = self.get_users_group()
group = self.get_admin_group()
role1 = Role(name='testrole1', groups={group: RoleGroup(group=group)})
db.session.add(role1)
role2 = Role(name='testrole2')
db.session.add(role2)
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), roles=[role1, role2], allow_signup=True, creator=self.get_admin())
signup = InviteSignup(invite=invite, loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
self.assertFalse(invite.used)
valid, msg = signup.validate()
self.assertTrue(valid)
self.assertFalse(invite.used)
user, msg = signup.finish('notsecret')
self.assertIsInstance(user, User)
self.assertTrue(invite.used)
self.assertEqual(user.loginname, 'newuser')
self.assertEqual(user.displayname, 'New User')
self.assertEqual(user.primary_email.address, 'test@example.com')
self.assertEqual(signup.user, user)
self.assertIn(base_role, user.roles_effective)
self.assertIn(role1, user.roles)
self.assertIn(role2, user.roles)
self.assertIn(base_group1, user.groups)
self.assertIn(base_group2, user.groups)
self.assertIn(group, user.groups)
db.session.commit()
db_flush()
self.assertEqual(len(User.query.filter_by(loginname='newuser').all()), 1)
def test_success_no_roles(self):
self.create_base_roles()
base_role = Role.query.filter_by(name='base').one()
base_group1 = self.get_access_group()
base_group2 = self.get_users_group()
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=True, creator=self.get_admin())
signup = InviteSignup(invite=invite, loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
self.assertFalse(invite.used)
valid, msg = signup.validate()
self.assertTrue(valid)
self.assertFalse(invite.used)
user, msg = signup.finish('notsecret')
self.assertIsInstance(user, User)
self.assertTrue(invite.used)
self.assertEqual(user.loginname, 'newuser')
self.assertEqual(user.displayname, 'New User')
self.assertEqual(user.primary_email.address, 'test@example.com')
self.assertEqual(signup.user, user)
self.assertIn(base_role, user.roles_effective)
self.assertEqual(len(user.roles_effective), 1)
self.assertIn(base_group1, user.groups)
self.assertIn(base_group2, user.groups)
self.assertEqual(len(user.groups), 2)
db.session.commit()
db_flush()
self.assertEqual(len(User.query.filter_by(loginname='newuser').all()), 1)
def test_inactive(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=True, single_use=True, used=True, creator=self.get_admin())
self.assertFalse(invite.active)
signup = InviteSignup(invite=invite, loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
valid, msg = signup.validate()
self.assertFalse(valid)
self.assertIsInstance(msg, str)
user, msg = signup.finish('notsecret')
self.assertIsNone(user)
self.assertIsInstance(msg, str)
def test_invalid(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=True, creator=self.get_admin())
self.assertTrue(invite.active)
signup = InviteSignup(invite=invite, loginname='', displayname='New User', mail='test@example.com', password='notsecret')
valid, msg = signup.validate()
self.assertFalse(valid)
self.assertIsInstance(msg, str)
def test_invalid2(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=True, creator=self.get_admin())
self.assertTrue(invite.active)
signup = InviteSignup(invite=invite, loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
user, msg = signup.finish('wrongpassword')
self.assertIsNone(user)
self.assertIsInstance(msg, str)
def test_no_signup(self):
invite = Invite(valid_until=datetime.datetime.utcnow() + datetime.timedelta(seconds=60), allow_signup=False, creator=self.get_admin())
self.assertTrue(invite.active)
signup = InviteSignup(invite=invite, loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
valid, msg = signup.validate()
self.assertFalse(valid)
self.assertIsInstance(msg, str)
user, msg = signup.finish('notsecret')
self.assertIsNone(user)
self.assertIsInstance(msg, str)
import unittest
import datetime
import time
from uffd.database import db
from uffd.models import RecoveryCodeMethod, TOTPMethod, WebauthnMethod
from uffd.models.mfa import _hotp
from tests.utils import UffdTestCase
class TestMfaPrimitives(unittest.TestCase):
def test_hotp(self):
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q'), '458290')
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q', digits=8), '20458290')
for digits in range(1, 10):
self.assertEqual(len(_hotp(1, b'abcd', digits=digits)), digits)
self.assertEqual(_hotp(1234, b''), '161024')
self.assertEqual(_hotp(0, b'\x04\x8fM\xcc\x7f\x82\x9c$a\x1b\xb3'), '279354')
self.assertEqual(_hotp(2**64-1, b'abcde'), '899292')
def get_fido2_test_cred(self):
try:
from uffd.fido2_compat import AttestedCredentialData
except ImportError:
self.skipTest('fido2 could not be imported')
# Example public key from webauthn spec 6.5.1.1
return AttestedCredentialData(bytes.fromhex('00000000000000000000000000000000'+'0040'+'053cbcc9d37a61d3bac87cdcc77ee326256def08ab15775d3a720332e4101d14fae95aeee3bc9698781812e143c0597dc6e180595683d501891e9dd030454c0a'+'A501020326200121582065eda5a12577c2bae829437fe338701a10aaa375e1bb5b5de108de439c08551d2258201e52ed75701163f7f9e40ddf9f341b3dc9ba860af7e0ca7ca7e9eecd0084d19c'))
class TestMfaMethodModels(UffdTestCase):
def test_common_attributes(self):
method = TOTPMethod(user=self.get_user(), name='testname')
self.assertTrue(method.created <= datetime.datetime.utcnow())
self.assertEqual(method.name, 'testname')
self.assertEqual(method.user.loginname, 'testuser')
method.user = self.get_admin()
self.assertEqual(method.user.loginname, 'testadmin')
def test_recovery_code_method(self):
method = RecoveryCodeMethod(user=self.get_user())
db.session.add(method)
db.session.commit()
method_id = method.id
method_code = method.code_value
db.session.expunge(method)
method = RecoveryCodeMethod.query.get(method_id)
self.assertFalse(hasattr(method, 'code_value'))
self.assertFalse(method.verify(''))
self.assertFalse(method.verify('A'*8))
self.assertTrue(method.verify(method_code))
def test_totp_method_attributes(self):
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
issuer = method.issuer
accountname = method.accountname
key_uri = method.key_uri
self.assertEqual(method.name, 'testname')
# Restore method with key parameter
_method = TOTPMethod(user=self.get_user(), key=method.key, name='testname')
self.assertEqual(_method.name, 'testname')
self.assertEqual(_method.raw_key, raw_key)
self.assertEqual(_method.issuer, issuer)
self.assertEqual(_method.accountname, accountname)
self.assertEqual(_method.key_uri, key_uri)
db.session.add(method)
db.session.commit()
_method_id = _method.id
db.session.expunge(_method)
# Restore method from db
_method = TOTPMethod.query.get(_method_id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(_method.raw_key, raw_key)
self.assertEqual(_method.issuer, issuer)
self.assertEqual(_method.accountname, accountname)
self.assertEqual(_method.key_uri, key_uri)
def test_totp_method_verify(self):
method = TOTPMethod(user=self.get_user())
counter = int(time.time()/30)
self.assertFalse(method.verify(''))
self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter+2, method.raw_key)))
def test_totp_method_verify_reuse(self):
method = TOTPMethod(user=self.get_user())
counter = int(time.time()/30)
self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter-1, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter-1, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter, method.raw_key)))
def test_webauthn_method(self):
data = get_fido2_test_cred(self)
method = WebauthnMethod(user=self.get_user(), cred=data, name='testname')
self.assertEqual(method.name, 'testname')
db.session.add(method)
db.session.commit()
method_id = method.id
method_cred = method.cred
db.session.expunge(method)
_method = WebauthnMethod.query.get(method_id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(bytes(method_cred), bytes(_method.cred))
self.assertEqual(data.credential_id, _method.cred.credential_id)
self.assertEqual(data.public_key, _method.cred.public_key)
# We only test (de-)serialization here, as everything else is currently implemented in the views
import time
import threading
from sqlalchemy.exc import IntegrityError
from uffd.database import db
from uffd.models import FeatureFlag, Lock
from uffd.models.misc import feature_flag_table
from tests.utils import ModelTestCase
class TestFeatureFlag(ModelTestCase):
def test_disabled(self):
flag = FeatureFlag('foo')
self.assertFalse(flag)
self.assertFalse(db.session.execute(db.select([flag.expr])).scalar())
def test_enabled(self):
db.session.execute(db.insert(feature_flag_table).values(name='foo'))
flag = FeatureFlag('foo')
self.assertTrue(flag)
self.assertTrue(db.session.execute(db.select([flag.expr])).scalar())
def test_toggle(self):
flag = FeatureFlag('foo')
hooks_called = []
@flag.enable_hook
def enable_hook1():
hooks_called.append('enable1')
@flag.enable_hook
def enable_hook2():
hooks_called.append('enable2')
@flag.disable_hook
def disable_hook1():
hooks_called.append('disable1')
@flag.disable_hook
def disable_hook2():
hooks_called.append('disable2')
hooks_called.clear()
flag.enable()
self.assertTrue(flag)
self.assertEqual(hooks_called, ['enable1', 'enable2'])
hooks_called.clear()
flag.disable()
self.assertFalse(flag)
self.assertEqual(hooks_called, ['disable1', 'disable2'])
flag.disable() # does nothing
self.assertFalse(flag)
flag.enable()
self.assertTrue(flag)
with self.assertRaises(IntegrityError):
flag.enable()
self.assertTrue(flag)
class TestLock(ModelTestCase):
DISABLE_SQLITE_MEMORY_DB = True
def setUpApp(self):
self.lock = Lock('testlock')
def run_lock_test(self):
result = []
def func():
with self.app.test_request_context():
self.lock.acquire()
result.append('bar')
t = threading.Thread(target=func)
t.start()
time.sleep(1)
result.append('foo')
time.sleep(1)
db.session.rollback()
t.join()
return result
def test_lock2(self):
self.assertEqual(self.run_lock_test(), ['bar', 'foo'])
self.lock.acquire()
self.assertEqual(self.run_lock_test(), ['foo', 'bar'])
import unittest
import datetime
import jwt
from uffd.database import db
from uffd.models import OAuth2Key
from tests.utils import UffdTestCase
TEST_JWK = dict(
id='HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU',
created=datetime.datetime(2023, 11, 9, 0, 21, 10),
active=True,
algorithm='RS256',
private_key_jwk='''{
"kty": "RSA",
"key_ops": ["sign"],
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB",
"d": "G7yoH5mLcZTA6ia-byCoN-zpofGvdga9AZnxPO0vsq6K_cY_O2gxuVZ3n6reAKKbuLNGCbb_D_Dffs4q8rprlfkgi3TCLzXX5Zv5HWTD7a4Y7xpxEzQ2sWo-iagVIqZVPh0pyjliqnTyUWnFmWiY0gBe9UHianHjFVZqe8E2HFOKgW3UUbQz0keg8JtJ3T9gzZrM38KWbqhOJO0VVSRAoANPTSnumfRsUCyWywrMtIfgAbQaKazqX3xkOsAF1L-iNfd6slzPvRyIQVflVDMdfKnsu-lHiKJ0DK_lg9f55T5FymgcXsq43EKBQ2H4v2dafIm-vtWx_TRZWj_msD32BEPBA-zTqh_oP1r6a3DZh4DBtWY3vzSiuhAC0erlRs-hRTX_e9ET5fUbJnmNxjnxQD9zZmwq4ujMK6KFnHct8t77Qxj3a-wDR_XyDJ4_EKYqHlcVHfxGNBSvIdjuZJkPJnVpVtfCtpyamQIR4u5oNV7fIwYe_tFnw0Y90rGoJMzB",
"p": "-A-FnH21HJ7GPWUm9k3mxsxSchy89QEUCZZiH6EcB4ZP8wJsxrQsUSIHCR74YmZEI3Ulsum1Ql4x50k7Q2sNh9SnwKvmctjksehGy4yCrdunAqjqyz3wFwGaKWnhn3frkiqH5ATjkOoc8qHz8saa7reeVClj47ZWyy-Nl559ycLMs0rI1N_THzO07C3jSbJhyPj0yeygAflsRqqnNvEQ6ps1VLiqf9G5jfSvUUn5DyKIpep9iGo29caGSIPIy_2h",
"q": "xNe1-QWskxOcY_GiHpFWdvzqr1o9fxg5whgpNcGi3caokw2iNHRYut4cbVvFFBlv_9B5QCl9WVfR2ADG0AtvkvUxEZqCdxEvcqjIANeRLKHDjW5kMuPS0_fcskFP-r7mCM9SBfPplfMVCF5nuNWf5LzNopWfsTChIDD1rSpPjItNYuwLXszm_3R81HHHeQLcyvoMxLCmeLy5TXX2hXOMHh2IMZCXAHopJmLJUVnQ48kr5jd2l0kLbmx3aBqdccJn",
"dp": "MLS7g1KbcRcrzXpDADGjkn0j4wwJfgHMMWW5toQnwMJ6iDh9qzZNTVDlGMFf-9IgpuWllU-WK4XbPpJ-dGpcqcLzfT1DbmFv5g65d9YLAqASVs9b6rQqpBnIb0E-79TYCEcZj4f2NsoBDRMHly-v1BdxmwzVdCylNhgMMS0Jfcgl8T5J2KJqDcJVT9piumGwGYnoZo1zjW-v9uAjHQKQU8BN5Git8ZL4YAsfMVLY-EPLmOhF5bcVO4TTcQGPN56B",
"dq": "HiiSl-G3rB0QE_v8g8Ruw_JCHrWrwGI8zzEWd0cApgv-3fDzzieZRKAtKNArpMW09DPDsAHrU5nx669KxqtJ3_EzIGhU3ttCMsYLRp3Af18VcADe1zEypwlNxf3dvCQtaGIjRgg13KSOr2aPa7FHOyt2MhfMjMBPn3gA3BQkdfsN0z8pCtBIABGf4ojAMBkxLOQcurH5_3uixGxzZcTrTd3mdPmbORZ-YYQ3JgCl0ZCL6kzLHaiyWKvDq66QOtK3",
"qi": "ySqD9cUxbq3wkCsPQId_YfQLIqb5RK_JJIMjtBOdTdo4aT5tmodYCSmjBmhrYXjDWtyJdelvPfdSfgncHJhf8VgkZ8TPvUeaQwsQFBwB5llwpdb72eEEJrmG1SVwNMoFCLXdNT3ACad16cUDMnWmklH0X07OzdxGOBnGhgLZUs4RbPjLH7OpYTyQqVy2L8vofqJR42cfePZw8WQM4k0PPbhralhybExIkSCmaQyYbACZ5k0OVQErEqnj4elglA0h"
}''',
public_key_jwk='''{
"kty": "RSA",
"key_ops": ["verify"],
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB"
}''',
)
class TestOAuth2Key(UffdTestCase):
def setUp(self):
super().setUp()
db.session.add(OAuth2Key(**TEST_JWK))
db.session.add(OAuth2Key(
id='1e9gdk7',
created=datetime.datetime(2014, 11, 8, 0, 0, 0),
active=True,
algorithm='RS256',
private_key_jwk='invalid',
public_key_jwk='''{
"kty":"RSA",
"n":"w7Zdfmece8iaB0kiTY8pCtiBtzbptJmP28nSWwtdjRu0f2GFpajvWE4VhfJAjEsOcwYzay7XGN0b-X84BfC8hmCTOj2b2eHT7NsZegFPKRUQzJ9wW8ipn_aDJWMGDuB1XyqT1E7DYqjUCEOD1b4FLpy_xPn6oV_TYOfQ9fZdbE5HGxJUzekuGcOKqOQ8M7wfYHhHHLxGpQVgL0apWuP2gDDOdTtpuld4D2LK1MZK99s9gaSjRHE8JDb1Z4IGhEcEyzkxswVdPndUWzfvWBBWXWxtSUvQGBRkuy1BHOa4sP6FKjWEeeF7gm7UMs2Nm2QUgNZw6xvEDGaLk4KASdIxRQ",
"e":"AQAB"
}'''
))
db.session.commit()
self.key = OAuth2Key.query.get('HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU')
self.key_oidc_spec = OAuth2Key.query.get('1e9gdk7')
def test_private_key(self):
self.key.private_key
def test_public_key(self):
self.key.private_key
def test_public_key_jwks_dict(self):
self.assertEqual(self.key.public_key_jwks_dict, {
"kid": "HvOn74G7njK1GoFNe8Dta087casdWMsm06pNhOXRgJU",
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": "vrznqUy8Xamph6s0Z02fFMIyjwLAMio35i9DXYjXP1ZQwSZ3SsIh3m2ablMnlu8PVlnYUzoj8rXyAWND0FSfWoQQxv1rq15pllKueddLoJsv321N_NRB8beGsLrsndw8QO0q3RWqV9O3kqhlTMjgj6bquX42wLaXrPLJyfbT3zObBsToG4UxpOyly84aklJXU5wIs0cbmjbfd8Xld38BG8Oh7Ozy5b93vPpJW6rudZRxU6QYC0r9bFFLIHJWrR4bzQMLGoJ63xjPOCl4WNpOYc9B7PNgnWTLXlFd51Hw9CaT2MRWsKNCSU77f6nZkfjWa1IsQdF0I48m46qgq7bEOOl9DbThbCnpblWrctdyg6du-OvCyVmkAo1KGtANl0027pgqUI_9HBMi33y3UPQm1ALHXIyIDBZtExH3lD6MMK3XGJfUxZuIOBndK-PXm5Fed52bgLOcf-24X6aHFn-8oyDVIj9OHkKWjy7jtKdmqZc4pBdVuCaMCYzj8iERWA3H",
"e": "AQAB"
})
def test_encode_jwt(self):
jwtdata = self.key.encode_jwt({'aud': 'test', 'foo': 'bar'})
self.assertIsInstance(jwtdata, str) # Regression check for #165
self.assertEqual(
jwt.get_unverified_header(jwtdata),
# typ is optional, x5u/x5c/jku/jwk are discoraged by OIDC Core 1.0 spec section 2
{'kid': self.key.id, 'alg': self.key.algorithm, 'typ': 'JWT'}
)
self.assertEqual(
OAuth2Key.decode_jwt(jwtdata, audience='test'),
{'aud': 'test', 'foo': 'bar'}
)
self.key.active = False
with self.assertRaises(jwt.exceptions.InvalidKeyError):
self.key.encode_jwt({'aud': 'test', 'foo': 'bar'})
def test_oidc_hash(self):
# Example from OIDC Core 1.0 spec A.3
self.assertEqual(
self.key.oidc_hash(b'jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y'),
'77QmUPtjPfzWtF2AnpK9RQ'
)
# Example from OIDC Core 1.0 spec A.4
self.assertEqual(
self.key.oidc_hash(b'Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk'),
'LDktKdoQak3Pk0cnXxCltA'
)
# Example from OIDC Core 1.0 spec A.6
self.assertEqual(
self.key.oidc_hash(b'jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y'),
'77QmUPtjPfzWtF2AnpK9RQ'
)
self.assertEqual(
self.key.oidc_hash(b'Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk'),
'LDktKdoQak3Pk0cnXxCltA'
)
def test_decode_jwt(self):
# Example from OIDC Core 1.0 spec A.2
jwt_data = (
'eyJraWQiOiIxZTlnZGs3IiwiYWxnIjoiUlMyNTYifQ.ewogImlz'
'cyI6ICJodHRwOi8vc2VydmVyLmV4YW1wbGUuY29tIiwKICJzdWIiOiAiMjQ4'
'Mjg5NzYxMDAxIiwKICJhdWQiOiAiczZCaGRSa3F0MyIsCiAibm9uY2UiOiAi'
'bi0wUzZfV3pBMk1qIiwKICJleHAiOiAxMzExMjgxOTcwLAogImlhdCI6IDEz'
'MTEyODA5NzAsCiAibmFtZSI6ICJKYW5lIERvZSIsCiAiZ2l2ZW5fbmFtZSI6'
'ICJKYW5lIiwKICJmYW1pbHlfbmFtZSI6ICJEb2UiLAogImdlbmRlciI6ICJm'
'ZW1hbGUiLAogImJpcnRoZGF0ZSI6ICIwMDAwLTEwLTMxIiwKICJlbWFpbCI6'
'ICJqYW5lZG9lQGV4YW1wbGUuY29tIiwKICJwaWN0dXJlIjogImh0dHA6Ly9l'
'eGFtcGxlLmNvbS9qYW5lZG9lL21lLmpwZyIKfQ.rHQjEmBqn9Jre0OLykYNn'
'spA10Qql2rvx4FsD00jwlB0Sym4NzpgvPKsDjn_wMkHxcp6CilPcoKrWHcip'
'R2iAjzLvDNAReF97zoJqq880ZD1bwY82JDauCXELVR9O6_B0w3K-E7yM2mac'
'AAgNCUwtik6SjoSUZRcf-O5lygIyLENx882p6MtmwaL1hd6qn5RZOQ0TLrOY'
'u0532g9Exxcm-ChymrB4xLykpDj3lUivJt63eEGGN6DH5K6o33TcxkIjNrCD'
'4XB1CKKumZvCedgHHF3IAK4dVEDSUoGlH9z4pP_eWYNXvqQOjGs-rDaQzUHl'
'6cQQWNiDpWOl_lxXjQEvQ'
)
self.assertEqual(
OAuth2Key.decode_jwt(jwt_data, options={'verify_exp': False, 'verify_aud': False}),
{
"iss": "http://server.example.com",
"sub": "248289761001",
"aud": "s6BhdRkqt3",
"nonce": "n-0S6_WzA2Mj",
"exp": 1311281970,
"iat": 1311280970,
"name": "Jane Doe",
"given_name": "Jane",
"family_name": "Doe",
"gender": "female",
"birthdate": "0000-10-31",
"email": "janedoe@example.com",
"picture": "http://example.com/janedoe/me.jpg"
}
)
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# {"alg":"RS256"} -> no key id
OAuth2Key.decode_jwt('eyJhbGciOiJSUzI1NiJ9.' + jwt_data.split('.', 1)[-1])
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# {"kid":"XXXXX","alg":"RS256"} -> unknown key id
OAuth2Key.decode_jwt('eyJraWQiOiJYWFhYWCIsImFsZyI6IlJTMjU2In0.' + jwt_data.split('.', 1)[-1])
OAuth2Key.query.get('1e9gdk7').active = False
with self.assertRaises(jwt.exceptions.InvalidKeyError):
# not active
OAuth2Key.decode_jwt(jwt_data)
def test_generate_rsa_key(self):
key = OAuth2Key.generate_rsa_key()
self.assertEqual(key.algorithm, 'RS256')
import unittest
from uffd.database import db
from uffd.models import User, Role, RoleGroup, TOTPMethod
from uffd.models.role import flatten_recursive
from tests.utils import UffdTestCase
class TestPrimitives(unittest.TestCase):
def test_flatten_recursive(self):
class Node:
def __init__(self, *neighbors):
self.neighbors = set(neighbors or set())
cycle = Node()
cycle.neighbors.add(cycle)
common = Node(cycle)
intermediate1 = Node(common)
intermediate2 = Node(common, intermediate1)
stub = Node()
backref = Node()
start1 = Node(intermediate1, intermediate2, stub, backref)
backref.neighbors.add(start1)
start2 = Node()
self.assertSetEqual(flatten_recursive({start1, start2}, 'neighbors'),
{start1, start2, backref, stub, intermediate1, intermediate2, common, cycle})
self.assertSetEqual(flatten_recursive(set(), 'neighbors'), set())
class TestUserRoleAttributes(UffdTestCase):
def test_roles_effective(self):
db.session.add(User(loginname='service', is_service_user=True, primary_email_address='service@example.com', displayname='Service'))
db.session.commit()
user = self.get_user()
service_user = User.query.filter_by(loginname='service').one_or_none()
included_by_default_role = Role(name='included_by_default')
default_role = Role(name='default', is_default=True, included_roles=[included_by_default_role])
included_role = Role(name='included')
cycle_role = Role(name='cycle')
direct_role1 = Role(name='role1', members=[user, service_user], included_roles=[included_role, cycle_role])
direct_role2 = Role(name='role2', members=[user, service_user], included_roles=[included_role])
cycle_role.included_roles.append(direct_role1)
db.session.add_all([included_by_default_role, default_role, included_role, cycle_role, direct_role1, direct_role2])
self.assertSetEqual(user.roles_effective, {direct_role1, direct_role2, cycle_role, included_role, default_role, included_by_default_role})
self.assertSetEqual(service_user.roles_effective, {direct_role1, direct_role2, cycle_role, included_role})
def test_compute_groups(self):
user = self.get_user()
group1 = self.get_users_group()
group2 = self.get_access_group()
role1 = Role(name='role1', groups={group1: RoleGroup(group=group1)})
role2 = Role(name='role2', groups={group1: RoleGroup(group=group1), group2: RoleGroup(group=group2)})
db.session.add_all([role1, role2])
self.assertSetEqual(user.compute_groups(), set())
role1.members.append(user)
role2.members.append(user)
self.assertSetEqual(user.compute_groups(), {group1, group2})
role2.groups[group2].requires_mfa = True
self.assertSetEqual(user.compute_groups(), {group1})
db.session.add(TOTPMethod(user=user))
db.session.commit()
self.assertSetEqual(user.compute_groups(), {group1, group2})
def test_update_groups(self):
user = self.get_user()
group1 = self.get_users_group()
group2 = self.get_access_group()
role1 = Role(name='role1', members=[user], groups={group1: RoleGroup(group=group1)})
role2 = Role(name='role2', groups={group2: RoleGroup(group=group2)})
db.session.add_all([role1, role2])
user.groups = [group2]
groups_added, groups_removed = user.update_groups()
self.assertSetEqual(groups_added, {group1})
self.assertSetEqual(groups_removed, {group2})
self.assertSetEqual(set(user.groups), {group1})
groups_added, groups_removed = user.update_groups()
self.assertSetEqual(groups_added, set())
self.assertSetEqual(groups_removed, set())
self.assertSetEqual(set(user.groups), {group1})
class TestRoleModel(UffdTestCase):
def test_members_effective(self):
db.session.add(User(loginname='service', is_service_user=True, primary_email_address='service@example.com', displayname='Service'))
db.session.commit()
user1 = self.get_user()
user2 = self.get_admin()
service = User.query.filter_by(loginname='service').one_or_none()
included_by_default_role = Role(name='included_by_default')
default_role = Role(name='default', is_default=True, included_roles=[included_by_default_role])
included_role = Role(name='included')
direct_role = Role(name='direct', members=[user1, user2, service], included_roles=[included_role])
empty_role = Role(name='empty', included_roles=[included_role])
self.assertSetEqual(included_by_default_role.members_effective, {user1, user2})
self.assertSetEqual(default_role.members_effective, {user1, user2})
self.assertSetEqual(included_role.members_effective, {user1, user2, service})
self.assertSetEqual(direct_role.members_effective, {user1, user2, service})
self.assertSetEqual(empty_role.members_effective, set())
def test_included_roles_recursive(self):
baserole = Role(name='base')
role1 = Role(name='role1', included_roles=[baserole])
role2 = Role(name='role2', included_roles=[baserole])
role3 = Role(name='role3', included_roles=[role1, role2])
self.assertSetEqual(role1.included_roles_recursive, {baserole})
self.assertSetEqual(role2.included_roles_recursive, {baserole})
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
baserole.included_roles.append(role1)
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
def test_groups_effective(self):
group1 = self.get_users_group()
group2 = self.get_access_group()
baserole = Role(name='base', groups={group1: RoleGroup(group=group1)})
role1 = Role(name='role1', groups={group2: RoleGroup(group=group2)}, included_roles=[baserole])
self.assertSetEqual(baserole.groups_effective, {group1})
self.assertSetEqual(role1.groups_effective, {group1, group2})
def test_update_member_groups(self):
user1 = self.get_user()
user1.update_groups()
user2 = self.get_admin()
user2.update_groups()
group1 = self.get_users_group()
group2 = self.get_access_group()
group3 = self.get_admin_group()
baserole = Role(name='base', members=[user1], groups={group1: RoleGroup(group=group1)})
role1 = Role(name='role1', members=[user2], groups={group2: RoleGroup(group=group2)}, included_roles=[baserole])
db.session.add_all([baserole, role1])
baserole.update_member_groups()
role1.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1})
self.assertSetEqual(set(user2.groups), {group1, group2})
baserole.groups[group3] = RoleGroup()
baserole.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1, group3})
self.assertSetEqual(set(user2.groups), {group1, group2, group3})
import itertools
from uffd.remailer import remailer
from uffd.tasks import cleanup_task
from uffd.database import db
from uffd.models import Service, ServiceUser, User, UserEmail, RemailerMode
from tests.utils import UffdTestCase
class TestServiceUser(UffdTestCase):
def setUp(self):
super().setUp()
db.session.add_all([Service(name='service1', limit_access=False), Service(name='service2', remailer_mode=RemailerMode.ENABLED_V1, limit_access=False)])
db.session.commit()
def test_auto_create(self):
service_count = Service.query.count()
user_count = User.query.count()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
db.session.add(User(loginname='newuser1', displayname='New User', primary_email_address='new1@example.com'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * (user_count + 1))
db.session.add(Service(name='service3'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), (service_count + 1) * (user_count + 1))
db.session.add(User(loginname='newuser2', displayname='New User', primary_email_address='new2@example.com'))
db.session.add(User(loginname='newuser3', displayname='New User', primary_email_address='new3@example.com'))
db.session.add(Service(name='service4'))
db.session.add(Service(name='service5'))
db.session.commit()
self.assertEqual(ServiceUser.query.count(), (service_count + 3) * (user_count + 3))
def test_create_missing(self):
service_count = Service.query.count()
user_count = User.query.count()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
db.session.delete(ServiceUser.query.first())
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * user_count - 1)
cleanup_task.run()
db.session.commit()
self.assertEqual(ServiceUser.query.count(), service_count * user_count)
def test_effective_remailer_mode(self):
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service.remailer_mode = RemailerMode.ENABLED_V2
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V2)
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testadmin']
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.DISABLED)
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testuser']
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V2)
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service_user.remailer_overwrite_mode = RemailerMode.ENABLED_V1
service.remailer_mode = RemailerMode.DISABLED
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.ENABLED_V1)
self.app.config['REMAILER_DOMAIN'] = ''
self.assertEqual(service_user.effective_remailer_mode, RemailerMode.DISABLED)
def test_service_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.service_email, None)
service_user.service_email = UserEmail(user=user, address='foo@bar', verified=True)
with self.assertRaises(Exception):
service_user.service_email = UserEmail(user=user, address='foo2@bar', verified=False)
with self.assertRaises(Exception):
service_user.service_email = UserEmail(user=self.get_admin(), address='foo3@bar', verified=True)
def test_real_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.assertEqual(service_user.real_email, user.primary_email.address)
service_user.service_email = UserEmail(user=user, address='foo@bar', verified=True)
self.assertEqual(service_user.real_email, user.primary_email.address)
service.enable_email_preferences = True
self.assertEqual(service_user.real_email, service_user.service_email.address)
service.limit_access = True
self.assertEqual(service_user.real_email, user.primary_email.address)
service.access_group = self.get_admin_group()
self.assertEqual(service_user.real_email, user.primary_email.address)
service.access_group = self.get_users_group()
self.assertEqual(service_user.real_email, service_user.service_email.address)
def test_get_by_remailer_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email = remailer.build_v1_address(service.id, user.id)
# 1. remailer not setup
self.app.config['REMAILER_DOMAIN'] = ''
self.assertIsNone(ServiceUser.get_by_remailer_email(user.primary_email.address))
self.assertIsNone(ServiceUser.get_by_remailer_email(remailer_email))
self.assertIsNone(ServiceUser.get_by_remailer_email('invalid'))
# 2. remailer setup
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertIsNone(ServiceUser.get_by_remailer_email(user.primary_email.address))
self.assertEqual(ServiceUser.get_by_remailer_email(remailer_email), service_user)
self.assertIsNone(ServiceUser.get_by_remailer_email('invalid'))
def test_email(self):
user = self.get_user()
service = Service.query.filter_by(name='service1').first()
service_user = ServiceUser.query.get((service.id, user.id))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email = remailer.build_v1_address(service.id, user.id)
# 1. remailer not setup
self.app.config['REMAILER_DOMAIN'] = ''
self.assertEqual(service_user.email, user.primary_email.address)
# 2. remailer setup + remailer disabled
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertEqual(service_user.email, user.primary_email.address)
# 3. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS unset
service.remailer_mode = RemailerMode.ENABLED_V1
db.session.commit()
self.assertEqual(service_user.email, remailer_email)
# 4. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS does not include user
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testadmin']
self.assertEqual(service_user.email, user.primary_email.address)
# 5. remailer setup + remailer enabled + REMAILER_LIMIT_TO_USERS includes user
self.app.config['REMAILER_LIMIT_TO_USERS'] = ['testuser']
self.assertEqual(service_user.email, remailer_email)
# 6. remailer setup + remailer disabled + user overwrite
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service.remailer_mode = RemailerMode.DISABLED
service_user.remailer_overwrite_mode = RemailerMode.ENABLED_V1
self.assertEqual(service_user.email, remailer_email)
# 7. remailer setup + remailer enabled + user overwrite
self.app.config['REMAILER_LIMIT_TO_USERS'] = None
service.remailer_mode = RemailerMode.ENABLED_V1
service_user.remailer_overwrite_mode = RemailerMode.DISABLED
self.assertEqual(service_user.email, user.primary_email.address)
def test_filter_query_by_email(self):
service = Service.query.filter_by(name='service1').first()
user = self.get_user()
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
remailer_email_v1 = remailer.build_v1_address(service.id, user.id)
remailer_email_v2 = remailer.build_v2_address(service.id, user.id)
email1 = user.primary_email
email2 = UserEmail(user=user, address='test2@example.com', verified=True)
db.session.add(email2)
service_user = ServiceUser.query.get((service.id, user.id))
all_service_users = ServiceUser.query.all()
cases = itertools.product(
# Input values
[
'test@example.com',
'test2@example.com',
'other@example.com',
remailer_email_v1,
remailer_email_v2,
],
# REMAILER_DOMAIN config
[None, 'remailer.example.com'],
# REMAILER_LIMIT config
[None, ['testuser', 'otheruser'], ['testadmin', 'otheruser']],
# service.remailer_mode
[RemailerMode.DISABLED, RemailerMode.ENABLED_V1, RemailerMode.ENABLED_V2],
# service.enable_email_preferences
[True, False],
# service.limit_access, service.access_group
[(False, None), (True, None), (True, self.get_admin_group()), (True, self.get_users_group())],
# service_user.service_email
[None, email1, email2],
# service_user.remailer_overwrite_mode
[None, RemailerMode.DISABLED, RemailerMode.ENABLED_V1, RemailerMode.ENABLED_V2],
)
for options in cases:
value = options[0]
self.app.config['REMAILER_DOMAIN'] = options[1]
self.app.config['REMAILER_LIMIT_TO_USERS'] = options[2]
service.remailer_mode = options[3]
service.enable_email_preferences = options[4]
service.limit_access, service.access_group = options[5]
service_user.service_email = options[6]
service_user.remailer_overwrite_mode = options[7]
a = {result for result in all_service_users if result.email == value}
b = set(ServiceUser.filter_query_by_email(ServiceUser.query, value).all())
if a != b:
self.fail(f'{a} != {b} with ' + repr(options))
import unittest
import datetime
from uffd.database import db
from uffd.models.session import Session, USER_AGENT_PARSER_SUPPORTED
from tests.utils import UffdTestCase
class TestSession(UffdTestCase):
def test_expire(self):
self.app.config['SESSION_LIFETIME_SECONDS'] = 100
self.app.config['PERMANENT_SESSION_LIFETIME'] = 10
user = self.get_user()
def make_session(created_age, last_used_age):
return Session(
user=user,
created=datetime.datetime.utcnow() - datetime.timedelta(seconds=created_age),
last_used=datetime.datetime.utcnow() - datetime.timedelta(seconds=last_used_age),
)
session1 = Session(user=user)
self.assertFalse(session1.expired)
session2 = make_session(0, 0)
self.assertFalse(session2.expired)
session3 = make_session(50, 5)
self.assertFalse(session3.expired)
session4 = make_session(50, 15)
self.assertTrue(session4.expired)
session5 = make_session(105, 5)
self.assertTrue(session5.expired)
session6 = make_session(105, 15)
self.assertTrue(session6.expired)
db.session.add_all([session1, session2, session3, session4, session5, session6])
db.session.commit()
self.assertEqual(set(Session.query.filter_by(expired=False).all()), {session1, session2, session3})
self.assertEqual(set(Session.query.filter_by(expired=True).all()), {session4, session5, session6})
def test_useragent_ua_parser(self):
if not USER_AGENT_PARSER_SUPPORTED:
self.skipTest('ua_parser not available')
session = Session(user_agent='Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0')
self.assertEqual(session.user_agent_browser, 'Firefox')
self.assertEqual(session.user_agent_platform, 'Windows')
def test_useragent_no_ua_parser(self):
session = Session(user_agent='Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0')
session.DISABLE_USER_AGENT_PARSER = True
self.assertEqual(session.user_agent_browser, 'Firefox')
self.assertEqual(session.user_agent_platform, 'Windows')
import datetime
from uffd.database import db
from uffd.models import Signup, User, FeatureFlag
from tests.utils import UffdTestCase, db_flush
def refetch_signup(signup):
db.session.add(signup)
db.session.commit()
id = signup.id
db.session.expunge(signup)
return Signup.query.get(id)
# We assume in all tests that Signup.validate and Signup.password.verify do
# not alter any state
class TestSignupModel(UffdTestCase):
def assert_validate_valid(self, signup):
valid, msg = signup.validate()
self.assertTrue(valid)
self.assertIsInstance(msg, str)
def assert_validate_invalid(self, signup):
valid, msg = signup.validate()
self.assertFalse(valid)
self.assertIsInstance(msg, str)
self.assertNotEqual(msg, '')
def assert_finish_success(self, signup, password):
self.assertIsNone(signup.user)
user, msg = signup.finish(password)
db.session.commit()
self.assertIsNotNone(user)
self.assertIsInstance(msg, str)
self.assertIsNotNone(signup.user)
def assert_finish_failure(self, signup, password):
prev_id = signup.user_id
user, msg = signup.finish(password)
self.assertIsNone(user)
self.assertIsInstance(msg, str)
self.assertNotEqual(msg, '')
self.assertEqual(signup.user_id, prev_id)
def test_password(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assertFalse(signup.password.verify('notsecret'))
self.assertFalse(signup.password.verify(''))
self.assertFalse(signup.password.verify('wrongpassword'))
self.assertTrue(signup.set_password('notsecret'))
self.assertTrue(signup.password.verify('notsecret'))
self.assertFalse(signup.password.verify('wrongpassword'))
def test_expired(self):
# TODO: Find a better way to test this!
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assertFalse(signup.expired)
signup.created = created=datetime.datetime.utcnow() - datetime.timedelta(hours=49)
self.assertTrue(signup.expired)
def test_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assertFalse(signup.completed)
signup.finish('notsecret')
db.session.commit()
self.assertTrue(signup.completed)
signup = refetch_signup(signup)
self.assertTrue(signup.completed)
def test_validate(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_valid(signup)
self.assert_validate_valid(refetch_signup(signup))
def test_validate_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_expired(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com',
password='notsecret', created=datetime.datetime.utcnow()-datetime.timedelta(hours=49))
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_loginname(self):
signup = Signup(loginname='', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_displayname(self):
signup = Signup(loginname='newuser', displayname='', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_mail(self):
signup = Signup(loginname='newuser', displayname='New User', mail='', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_password(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assertFalse(signup.set_password(''))
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_validate_exists(self):
signup = Signup(loginname='testuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_validate_invalid(signup)
self.assert_validate_invalid(refetch_signup(signup))
def test_finish(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
user = User.query.filter_by(loginname='newuser').one_or_none()
self.assertEqual(user.loginname, 'newuser')
self.assertEqual(user.displayname, 'New User')
self.assertEqual(user.primary_email.address, 'new@example.com')
def test_finish_completed(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_success(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_expired(self):
# TODO: Find a better way to test this!
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com',
password='notsecret', created=datetime.datetime.utcnow()-datetime.timedelta(hours=49))
self.assert_finish_failure(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_wrongpassword(self):
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com')
self.assert_finish_failure(signup, '')
self.assert_finish_failure(signup, 'wrongpassword')
signup = refetch_signup(signup)
self.assert_finish_failure(signup, '')
self.assert_finish_failure(signup, 'wrongpassword')
signup = Signup(loginname='newuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_failure(signup, 'wrongpassword')
self.assert_finish_failure(refetch_signup(signup), 'wrongpassword')
def test_finish_duplicate(self):
signup = Signup(loginname='testuser', displayname='New User', mail='new@example.com', password='notsecret')
self.assert_finish_failure(signup, 'notsecret')
self.assert_finish_failure(refetch_signup(signup), 'notsecret')
def test_finish_duplicate_email_strict_uniqueness(self):
FeatureFlag.unique_email_addresses.enable()
db.session.commit()
signup = Signup(loginname='newuser', displayname='New User', mail='test@example.com', password='notsecret')
self.assert_finish_failure(signup, 'notsecret')
def test_duplicate(self):
signup = Signup(loginname='newuser', displayname='New User', mail='test1@example.com', password='notsecret')
self.assert_validate_valid(signup)
db.session.add(signup)
db.session.commit()
signup1_id = signup.id
signup = Signup(loginname='newuser', displayname='New User', mail='test2@example.com', password='notsecret')
self.assert_validate_valid(signup)
db.session.add(signup)
db.session.commit()
signup2_id = signup.id
db_flush()
signup = Signup.query.get(signup2_id)
self.assert_finish_success(signup, 'notsecret')
db.session.commit()
db_flush()
signup = Signup.query.get(signup1_id)
self.assert_finish_failure(signup, 'notsecret')
user = User.query.filter_by(loginname='newuser').one_or_none()
self.assertEqual(user.primary_email.address, 'test2@example.com')
import datetime
import sqlalchemy
from uffd.database import db
from uffd.models import User, UserEmail, Group, FeatureFlag, IDAlreadyAllocatedError, IDRangeExhaustedError
from tests.utils import UffdTestCase, ModelTestCase
class TestUserModel(UffdTestCase):
def test_has_permission(self):
user_ = self.get_user() # has 'users' and 'uffd_access' group
admin = self.get_admin() # has 'users', 'uffd_access' and 'uffd_admin' group
self.assertTrue(user_.has_permission(None))
self.assertTrue(admin.has_permission(None))
self.assertTrue(user_.has_permission('users'))
self.assertTrue(admin.has_permission('users'))
self.assertFalse(user_.has_permission('notagroup'))
self.assertFalse(admin.has_permission('notagroup'))
self.assertFalse(user_.has_permission('uffd_admin'))
self.assertTrue(admin.has_permission('uffd_admin'))
self.assertFalse(user_.has_permission(['uffd_admin']))
self.assertTrue(admin.has_permission(['uffd_admin']))
self.assertFalse(user_.has_permission(['uffd_admin', 'notagroup']))
self.assertTrue(admin.has_permission(['uffd_admin', 'notagroup']))
self.assertFalse(user_.has_permission(['notagroup', 'uffd_admin']))
self.assertTrue(admin.has_permission(['notagroup', 'uffd_admin']))
self.assertTrue(user_.has_permission(['uffd_admin', 'users']))
self.assertTrue(admin.has_permission(['uffd_admin', 'users']))
self.assertTrue(user_.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
self.assertTrue(admin.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
self.assertFalse(user_.has_permission(['uffd_admin', ['users', 'notagroup']]))
self.assertTrue(admin.has_permission(['uffd_admin', ['users', 'notagroup']]))
def test_unix_uid_generation(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 18999
self.app.config['USER_SERVICE_MIN_UID'] = 19000
self.app.config['USER_SERVICE_MAX_UID'] = 19999
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
user2 = User(loginname='user2', displayname='user2', primary_email_address='user2@example.com')
db.session.add_all([user0, user1, user2])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(user1.unix_uid, 10001)
self.assertEqual(user2.unix_uid, 10002)
db.session.delete(user1)
db.session.commit()
user3 = User(loginname='user3', displayname='user3', primary_email_address='user3@example.com')
db.session.add(user3)
db.session.commit()
self.assertEqual(user3.unix_uid, 10003)
db.session.delete(user2)
db.session.commit()
user4 = User(loginname='user4', displayname='user4', primary_email_address='user4@example.com')
db.session.add(user4)
db.session.commit()
self.assertEqual(user4.unix_uid, 10004)
service0 = User(loginname='service0', displayname='service0', primary_email_address='service0@example.com', is_service_user=True)
service1 = User(loginname='service1', displayname='service1', primary_email_address='service1@example.com', is_service_user=True)
db.session.add_all([service0, service1])
db.session.commit()
self.assertEqual(service0.unix_uid, 19000)
self.assertEqual(service1.unix_uid, 19001)
def test_unix_uid_generation_overlapping(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 19999
self.app.config['USER_SERVICE_MIN_UID'] = 10000
self.app.config['USER_SERVICE_MAX_UID'] = 19999
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
service0 = User(loginname='service0', displayname='service0', primary_email_address='service0@example.com', is_service_user=True)
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
db.session.add_all([user0, service0, user1])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(service0.unix_uid, 10001)
self.assertEqual(user1.unix_uid, 10002)
def test_unix_uid_generation_overflow(self):
self.app.config['USER_MIN_UID'] = 10000
self.app.config['USER_MAX_UID'] = 10001
db.drop_all()
db.create_all()
user0 = User(loginname='user0', displayname='user0', primary_email_address='user0@example.com')
user1 = User(loginname='user1', displayname='user1', primary_email_address='user1@example.com')
db.session.add_all([user0, user1])
db.session.commit()
self.assertEqual(user0.unix_uid, 10000)
self.assertEqual(user1.unix_uid, 10001)
with self.assertRaises(sqlalchemy.exc.StatementError):
user2 = User(loginname='user2', displayname='user2', primary_email_address='user2@example.com')
db.session.add(user2)
db.session.commit()
def test_init_primary_email_address(self):
user = User(primary_email_address='foobar@example.com')
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(user.primary_email.verified, True)
self.assertEqual(user.primary_email.user, user)
user = User(primary_email_address='invalid')
self.assertEqual(user.primary_email.address, 'invalid')
self.assertEqual(user.primary_email.verified, True)
self.assertEqual(user.primary_email.user, user)
def test_set_primary_email_address(self):
user = User()
self.assertFalse(user.set_primary_email_address('invalid'))
self.assertIsNone(user.primary_email)
self.assertEqual(len(user.all_emails), 0)
self.assertTrue(user.set_primary_email_address('foobar@example.com'))
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(len(user.all_emails), 1)
self.assertFalse(user.set_primary_email_address('invalid'))
self.assertEqual(user.primary_email.address, 'foobar@example.com')
self.assertEqual(len(user.all_emails), 1)
self.assertTrue(user.set_primary_email_address('other@example.com'))
self.assertEqual(user.primary_email.address, 'other@example.com')
self.assertEqual(len(user.all_emails), 2)
self.assertEqual({user.all_emails[0].address, user.all_emails[1].address}, {'foobar@example.com', 'other@example.com'})
class TestUserEmailModel(UffdTestCase):
def test_normalize_address(self):
ref = UserEmail.normalize_address('foo@example.com')
self.assertEqual(ref, UserEmail.normalize_address('foo@example.com'))
self.assertEqual(ref, UserEmail.normalize_address('Foo@Example.Com'))
self.assertEqual(ref, UserEmail.normalize_address(' foo@example.com '))
self.assertNotEqual(ref, UserEmail.normalize_address('bar@example.com'))
self.assertNotEqual(ref, UserEmail.normalize_address('foo @example.com'))
# "No-Break Space" instead of SPACE (Unicode normalization + stripping)
self.assertEqual(ref, UserEmail.normalize_address('\u00A0foo@example.com '))
# Pre-composed "Angstrom Sign" vs. "A" + "Combining Ring Above" (Unicode normalization)
self.assertEqual(UserEmail.normalize_address('\u212B@example.com'), UserEmail.normalize_address('A\u030A@example.com'))
def test_address(self):
email = UserEmail()
self.assertIsNone(email.address)
self.assertIsNone(email.address_normalized)
email.address = 'Foo@example.com'
self.assertEqual(email.address, 'Foo@example.com')
self.assertEqual(email.address_normalized, UserEmail.normalize_address('Foo@example.com'))
with self.assertRaises(ValueError):
email.address = 'bar@example.com'
with self.assertRaises(ValueError):
email.address = None
def test_set_address(self):
email = UserEmail()
self.assertFalse(email.set_address('invalid'))
self.assertIsNone(email.address)
self.assertFalse(email.set_address(''))
self.assertFalse(email.set_address('@'))
self.app.config['REMAILER_DOMAIN'] = 'remailer.example.com'
self.assertFalse(email.set_address('foobar@remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser @ remailer.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@REMAILER.example.com'))
self.assertFalse(email.set_address('v1-1-testuser@foobar@remailer.example.com'))
self.assertTrue(email.set_address('foobar@example.com'))
self.assertEqual(email.address, 'foobar@example.com')
def test_verified(self):
email = UserEmail(user=self.get_user(), address='foo@example.com')
db.session.add(email)
self.assertEqual(email.verified, False)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=True).count(), 0)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=False).count(), 1)
email.verified = True
self.assertEqual(email.verified, True)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=True).count(), 1)
self.assertEqual(UserEmail.query.filter_by(address='foo@example.com', verified=False).count(), 0)
with self.assertRaises(ValueError):
email.verified = False
self.assertEqual(email.verified, True)
with self.assertRaises(ValueError):
email.verified = None
self.assertEqual(email.verified, True)
def test_verification(self):
email = UserEmail(address='foo@example.com')
self.assertFalse(email.finish_verification('test'))
secret = email.start_verification()
self.assertTrue(email.verification_secret)
self.assertTrue(email.verification_secret.verify(secret))
self.assertFalse(email.verification_expired)
self.assertFalse(email.finish_verification('test'))
orig_expires = email.verification_expires
email.verification_expires = datetime.datetime.utcnow() - datetime.timedelta(days=1)
self.assertFalse(email.finish_verification(secret))
email.verification_expires = orig_expires
self.assertTrue(email.finish_verification(secret))
self.assertFalse(email.verification_secret)
self.assertTrue(email.verification_expired)
def test_enable_strict_constraints(self):
email = UserEmail(address='foo@example.com', user=self.get_user())
db.session.add(email)
db.session.commit()
self.assertIsNone(email.enable_strict_constraints)
FeatureFlag.unique_email_addresses.enable()
self.assertTrue(email.enable_strict_constraints)
FeatureFlag.unique_email_addresses.disable()
self.assertIsNone(email.enable_strict_constraints)
def assert_can_add_address(self, **kwargs):
user_email = UserEmail(**kwargs)
db.session.add(user_email)
db.session.commit()
db.session.delete(user_email)
db.session.commit()
def assert_cannot_add_address(self, **kwargs):
with self.assertRaises(sqlalchemy.exc.IntegrityError):
db.session.add(UserEmail(**kwargs))
db.session.commit()
db.session.rollback()
def test_unique_constraints_old(self):
# The same user cannot add the same exact address multiple times, but
# different users can have the same address
user = self.get_user()
admin = self.get_admin()
db.session.add(UserEmail(user=user, address='foo@example.com'))
db.session.add(UserEmail(user=user, address='bar@example.com', verified=True))
db.session.commit()
self.assert_can_add_address(user=user, address='foobar@example.com')
self.assert_can_add_address(user=user, address='foobar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='foo@example.com')
self.assert_can_add_address(user=user, address='FOO@example.com')
self.assert_cannot_add_address(user=user, address='bar@example.com')
self.assert_can_add_address(user=user, address='BAR@example.com')
self.assert_cannot_add_address(user=user, address='foo@example.com', verified=True)
self.assert_can_add_address(user=user, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='bar@example.com', verified=True)
self.assert_can_add_address(user=user, address='BAR@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foobar@example.com')
self.assert_can_add_address(user=admin, address='foobar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foo@example.com')
self.assert_can_add_address(user=admin, address='FOO@example.com')
self.assert_can_add_address(user=admin, address='bar@example.com')
self.assert_can_add_address(user=admin, address='BAR@example.com')
self.assert_can_add_address(user=admin, address='foo@example.com', verified=True)
self.assert_can_add_address(user=admin, address='FOO@example.com', verified=True)
self.assert_can_add_address(user=admin, address='bar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='BAR@example.com', verified=True)
def test_unique_constraints_strict(self):
FeatureFlag.unique_email_addresses.enable()
# The same user cannot add the same (normalized) address multiple times,
# and different users cannot have the same verified (normalized) address
user = self.get_user()
admin = self.get_admin()
db.session.add(UserEmail(user=user, address='foo@example.com'))
db.session.add(UserEmail(user=user, address='bar@example.com', verified=True))
db.session.commit()
self.assert_can_add_address(user=user, address='foobar@example.com')
self.assert_can_add_address(user=user, address='foobar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='foo@example.com')
self.assert_cannot_add_address(user=user, address='FOO@example.com')
self.assert_cannot_add_address(user=user, address='bar@example.com')
self.assert_cannot_add_address(user=user, address='BAR@example.com')
self.assert_cannot_add_address(user=user, address='foo@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='bar@example.com', verified=True)
self.assert_cannot_add_address(user=user, address='BAR@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foobar@example.com')
self.assert_can_add_address(user=admin, address='foobar@example.com', verified=True)
self.assert_can_add_address(user=admin, address='foo@example.com')
self.assert_can_add_address(user=admin, address='FOO@example.com')
self.assert_can_add_address(user=admin, address='bar@example.com')
self.assert_can_add_address(user=admin, address='BAR@example.com')
self.assert_can_add_address(user=admin, address='foo@example.com', verified=True)
self.assert_can_add_address(user=admin, address='FOO@example.com', verified=True)
self.assert_cannot_add_address(user=admin, address='bar@example.com', verified=True)
self.assert_cannot_add_address(user=admin, address='BAR@example.com', verified=True)
class TestIDAllocator(ModelTestCase):
def allocate_gids(self, *gids):
for gid in gids:
Group.unix_gid_allocator.allocate(gid)
def fetch_gid_allocations(self):
return [row[0] for row in db.session.execute(
db.select([Group.unix_gid_allocator.allocation_table])
.order_by(Group.unix_gid_allocator.allocation_table.c.id)
).fetchall()]
def test_empty(self):
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000])
def test_first(self):
self.allocate_gids(20000)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20001)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001])
def test_out_of_range_before(self):
self.allocate_gids(19998)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [19998, 20000])
def test_out_of_range_right_before(self):
self.allocate_gids(19999)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [19999, 20000])
def test_out_of_range_after(self):
self.allocate_gids(20006)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20006])
def test_gap_at_beginning(self):
self.allocate_gids(20001)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20000)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001])
def test_multiple_gaps(self):
self.allocate_gids(20000, 20001, 20003, 20005)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20002)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20005])
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20004)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_last(self):
self.allocate_gids(20000, 20001, 20002, 20003, 20004)
self.assertEqual(Group.unix_gid_allocator.auto(20000, 20005), 20005)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_overflow(self):
self.allocate_gids(20000, 20001, 20002, 20003, 20004, 20005)
with self.assertRaises(IDRangeExhaustedError):
Group.unix_gid_allocator.auto(20000, 20005)
self.assertEqual(self.fetch_gid_allocations(), [20000, 20001, 20002, 20003, 20004, 20005])
def test_conflict(self):
self.allocate_gids(20000)
with self.assertRaises(IDAlreadyAllocatedError):
self.allocate_gids(20000)
self.assertEqual(self.fetch_gid_allocations(), [20000])
class TestGroup(ModelTestCase):
def test_unix_gid_generation(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 49999
group0 = Group(name='group0', description='group0')
group1 = Group(name='group1', description='group1')
group2 = Group(name='group2', description='group2')
group3 = Group(name='group3', description='group3', unix_gid=20004)
db.session.add_all([group0, group1, group2, group3])
db.session.commit()
self.assertEqual(group0.unix_gid, 20000)
self.assertEqual(group1.unix_gid, 20001)
self.assertEqual(group2.unix_gid, 20002)
self.assertEqual(group3.unix_gid, 20004)
db.session.delete(group2)
db.session.commit()
group4 = Group(name='group4', description='group4')
group5 = Group(name='group5', description='group5')
db.session.add_all([group4, group5])
db.session.commit()
self.assertEqual(group4.unix_gid, 20003)
self.assertEqual(group5.unix_gid, 20005)
def test_unix_gid_generation_conflict(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 49999
group0 = Group(name='group0', description='group0', unix_gid=20023)
db.session.add(group0)
db.session.commit()
with self.assertRaises(IDAlreadyAllocatedError):
Group(name='group1', description='group1', unix_gid=20023)
def test_unix_gid_generation_overflow(self):
self.app.config['GROUP_MIN_GID'] = 20000
self.app.config['GROUP_MAX_GID'] = 20001
group0 = Group(name='group0', description='group0')
group1 = Group(name='group1', description='group1')
db.session.add_all([group0, group1])
db.session.commit()
self.assertEqual(group0.unix_gid, 20000)
self.assertEqual(group1.unix_gid, 20001)
db.session.commit()
with self.assertRaises(sqlalchemy.exc.StatementError):
group2 = Group(name='group2', description='group2')
db.session.add(group2)
db.session.commit()
......@@ -2,7 +2,7 @@ import unittest
from flask import Flask, Blueprint, session, url_for
from uffd.csrf import csrf_bp, csrf_protect
from uffd.csrf import bp as csrf_bp, csrf_protect
uid_counter = 0
......
import unittest
import datetime
import time
from flask import url_for, session
# These imports are required, because otherwise we get circular imports?!
from uffd import ldap, user
from uffd.user.models import User
from uffd.session.views import get_current_user, is_valid_session
from uffd.mfa.models import MFAMethod, MFAType, RecoveryCodeMethod, TOTPMethod, WebauthnMethod, _hotp
from uffd import create_app, db
from utils import dump, UffdTestCase
class TestMfaPrimitives(unittest.TestCase):
def test_hotp(self):
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q'), '458290')
self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q', digits=8), '20458290')
for digits in range(1, 10):
self.assertEqual(len(_hotp(1, b'abcd', digits=digits)), digits)
self.assertEqual(_hotp(1234, b''), '161024')
self.assertEqual(_hotp(0, b'\x04\x8fM\xcc\x7f\x82\x9c$a\x1b\xb3'), '279354')
self.assertEqual(_hotp(2**64-1, b'abcde'), '899292')
def get_user():
return User.query.get('uid=testuser,ou=users,dc=example,dc=com')
def get_admin():
return User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
def get_fido2_test_cred():
try:
from fido2.ctap2 import AttestedCredentialData
except ImportError:
self.skipTest('fido2 could not be imported')
# Example public key from webauthn spec 6.5.1.1
return AttestedCredentialData(bytes.fromhex('00000000000000000000000000000000'+'0040'+'053cbcc9d37a61d3bac87cdcc77ee326256def08ab15775d3a720332e4101d14fae95aeee3bc9698781812e143c0597dc6e180595683d501891e9dd030454c0a'+'A501020326200121582065eda5a12577c2bae829437fe338701a10aaa375e1bb5b5de108de439c08551d2258201e52ed75701163f7f9e40ddf9f341b3dc9ba860af7e0ca7ca7e9eecd0084d19c'))
class TestMfaMethodModels(UffdTestCase):
def test_common_attributes(self):
method = MFAMethod(user=get_user(), name='testname')
self.assertTrue(method.created <= datetime.datetime.now())
self.assertEqual(method.name, 'testname')
self.assertEqual(method.user.loginname, 'testuser')
method.user = get_admin()
self.assertEqual(method.user.loginname, 'testadmin')
def test_recovery_code_method(self):
method = RecoveryCodeMethod(user=get_user())
db.session.add(method)
db.session.commit()
db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
_method = RecoveryCodeMethod.query.get(method.id)
self.assertFalse(hasattr(_method, 'code'))
self.assertFalse(_method.verify(''))
self.assertFalse(_method.verify('A'*8))
self.assertTrue(_method.verify(method.code))
def test_totp_method_attributes(self):
method = TOTPMethod(user=get_user(), name='testname')
self.assertEqual(method.name, 'testname')
# Restore method with key parameter
_method = TOTPMethod(user=get_user(), key=method.key, name='testname')
self.assertEqual(_method.name, 'testname')
self.assertEqual(method.raw_key, _method.raw_key)
self.assertEqual(method.issuer, _method.issuer)
self.assertEqual(method.accountname, _method.accountname)
self.assertEqual(method.key_uri, _method.key_uri)
db.session.add(method)
db.session.commit()
db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
# Restore method from db
_method = TOTPMethod.query.get(method.id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(method.raw_key, _method.raw_key)
self.assertEqual(method.issuer, _method.issuer)
self.assertEqual(method.accountname, _method.accountname)
self.assertEqual(method.key_uri, _method.key_uri)
def test_totp_method_verify(self):
method = TOTPMethod(user=get_user())
counter = int(time.time()/30)
self.assertFalse(method.verify(''))
self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
self.assertFalse(method.verify(_hotp(counter+2, method.raw_key)))
def test_webauthn_method(self):
data = get_fido2_test_cred()
method = WebauthnMethod(user=get_user(), cred=data, name='testname')
self.assertEqual(method.name, 'testname')
db.session.add(method)
db.session.commit()
db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
_method = WebauthnMethod.query.get(method.id)
self.assertEqual(_method.name, 'testname')
self.assertEqual(bytes(method.cred), bytes(_method.cred))
self.assertEqual(data.credential_id, _method.cred.credential_id)
self.assertEqual(data.public_key, _method.cred.public_key)
# We only test (de-)serialization here, as everything else is currently implemented in the views
class TestMfaViews(UffdTestCase):
def setUp(self):
super().setUp()
db.session.add(RecoveryCodeMethod(user=get_admin()))
db.session.add(TOTPMethod(user=get_admin(), name='Admin Phone'))
# We don't want to skip all tests only because fido2 is not installed!
#db.session.add(WebauthnMethod(user=get_admin(), cred=get_fido2_test_cred(), name='Admin FIDO2 dongle'))
db.session.commit()
def login(self):
self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
def add_recovery_codes(self, count=10):
user = get_user()
for _ in range(count):
db.session.add(RecoveryCodeMethod(user=user))
db.session.commit()
def add_totp(self):
db.session.add(TOTPMethod(user=get_user(), name='My phone'))
db.session.commit()
def add_webauthn(self):
db.session.add(WebauthnMethod(user=get_user(), cred=get_fido2_test_cred(), name='My FIDO2 dongle'))
db.session.commit()
def test_setup_disabled(self):
self.login()
r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
dump('mfa_setup_disabled', r)
self.assertEqual(r.status_code, 200)
def test_setup_recovery_codes(self):
self.login()
self.add_recovery_codes()
r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
dump('mfa_setup_only_recovery_codes', r)
self.assertEqual(r.status_code, 200)
def test_setup_enabled(self):
self.login()
self.add_recovery_codes()
self.add_totp()
self.add_webauthn()
r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
dump('mfa_setup_enabled', r)
self.assertEqual(r.status_code, 200)
def test_setup_few_recovery_codes(self):
self.login()
self.add_totp()
self.add_recovery_codes(1)
r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
dump('mfa_setup_few_recovery_codes', r)
self.assertEqual(r.status_code, 200)
def test_setup_no_recovery_codes(self):
self.login()
self.add_totp()
r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
dump('mfa_setup_no_recovery_codes', r)
self.assertEqual(r.status_code, 200)
def test_disable(self):
self.login()
self.add_recovery_codes()
self.add_totp()
admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all())
r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True)
dump('mfa_disable', r)
self.assertEqual(r.status_code, 200)
r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
dump('mfa_disable_submit', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods)
def test_disable_recovery_only(self):
self.login()
self.add_recovery_codes()
admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all())
self.assertNotEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True)
dump('mfa_disable_recovery_only', r)
self.assertEqual(r.status_code, 200)
r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
dump('mfa_disable_recovery_only_submit', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods)
def test_admin_disable(self):
for method in MFAMethod.query.filter_by(dn=get_admin().dn).all():
if not isinstance(method, RecoveryCodeMethod):
db.session.delete(method)
db.session.commit()
self.add_recovery_codes()
self.add_totp()
self.client.post(path=url_for('session.login'),
data={'loginname': 'testadmin', 'password': 'adminpassword'}, follow_redirects=True)
self.assertTrue(is_valid_session())
admin_methods = len(MFAMethod.query.filter_by(dn=get_admin().dn).all())
r = self.client.get(path=url_for('mfa.admin_disable', uid=get_user().uid), follow_redirects=True)
dump('mfa_admin_disable', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_user().dn).all()), 0)
self.assertEqual(len(MFAMethod.query.filter_by(dn=get_admin().dn).all()), admin_methods)
def test_setup_recovery(self):
self.login()
self.assertEqual(len(RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
dump('mfa_setup_recovery', r)
self.assertEqual(r.status_code, 200)
methods = RecoveryCodeMethod.query.filter_by(dn=get_current_user().dn).all()
self.assertNotEqual(len(methods), 0)
r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
dump('mfa_setup_recovery_reset', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=methods[0].id).all()), 0)
self.assertNotEqual(len(methods), 0)
def test_setup_totp(self):
self.login()
self.add_recovery_codes()
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
dump('mfa_setup_totp', r)
self.assertEqual(r.status_code, 200)
self.assertNotEqual(len(session.get('mfa_totp_key', '')), 0)
def test_setup_totp_without_recovery(self):
self.login()
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
dump('mfa_setup_totp_without_recovery', r)
self.assertEqual(r.status_code, 200)
def test_setup_totp_finish(self):
self.login()
self.add_recovery_codes()
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
code = _hotp(int(time.time()/30), method.raw_key)
r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
dump('mfa_setup_totp_finish', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1)
def test_setup_totp_finish_without_recovery(self):
self.login()
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
code = _hotp(int(time.time()/30), method.raw_key)
r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
dump('mfa_setup_totp_finish_without_recovery', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
def test_setup_totp_finish_wrong_code(self):
self.login()
self.add_recovery_codes()
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
method = TOTPMethod(get_current_user(), key=session.get('mfa_totp_key', ''))
code = _hotp(int(time.time()/30), method.raw_key)
code = str(int(code[0])+1)[-1] + code[1:]
r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
dump('mfa_setup_totp_finish_wrong_code', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
def test_setup_totp_finish_empty_code(self):
self.login()
self.add_recovery_codes()
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': ''}, follow_redirects=True)
dump('mfa_setup_totp_finish_empty_code', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 0)
def test_delete_totp(self):
self.login()
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(get_current_user(), name='test')
db.session.add(method)
db.session.commit()
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 2)
r = self.client.get(path=url_for('mfa.delete_totp', id=method.id), follow_redirects=True)
dump('mfa_delete_totp', r)
self.assertEqual(r.status_code, 200)
self.assertEqual(len(TOTPMethod.query.filter_by(id=method.id).all()), 0)
self.assertEqual(len(TOTPMethod.query.filter_by(dn=get_current_user().dn).all()), 1)
# TODO: webauthn setup tests
def test_auth_integration(self):
self.add_recovery_codes()
self.add_totp()
db.session.commit()
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
dump('mfa_auth_redirected', r)
self.assertEqual(r.status_code, 200)
self.assertIn(b'/mfa/auth', r.data)
self.assertFalse(is_valid_session())
r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
dump('mfa_auth', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
def test_auth_disabled(self):
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False)
r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertTrue(is_valid_session())
def test_auth_recovery_only(self):
self.add_recovery_codes()
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=False)
r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertTrue(is_valid_session())
def test_auth_recovery_code(self):
self.add_recovery_codes()
self.add_totp()
method = RecoveryCodeMethod(user=get_user())
db.session.add(method)
db.session.commit()
method_id = method.id
self.login()
r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
dump('mfa_auth_recovery_code', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('mfa.auth_finish', ref='/redirecttarget'), data={'code': method.code})
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertTrue(is_valid_session())
self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=method_id).all()), 0)
def test_auth_totp_code(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login()
r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
dump('mfa_auth_totp_code', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
code = _hotp(int(time.time()/30), raw_key)
r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_totp_code_submit', r)
self.assertEqual(r.status_code, 200)
self.assertTrue(is_valid_session())
def test_auth_empty_code(self):
self.add_recovery_codes()
self.add_totp()
self.login()
r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': ''}, follow_redirects=True)
dump('mfa_auth_empty_code', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
def test_auth_invalid_code(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login()
r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
code = _hotp(int(time.time()/30), raw_key)
code = str(int(code[0])+1)[-1] + code[1:]
r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_invalid_code', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
def test_auth_ratelimit(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login()
self.assertFalse(is_valid_session())
code = _hotp(int(time.time()/30), raw_key)
inv_code = str(int(code[0])+1)[-1] + code[1:]
for i in range(20):
r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': inv_code}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_ratelimit', r)
self.assertEqual(r.status_code, 200)
self.assertFalse(is_valid_session())
# TODO: webauthn auth tests
class TestMfaViewsOL(TestMfaViews):
use_openldap = True
import datetime
from urllib.parse import urlparse, parse_qs
from flask import url_for
# These imports are required, because otherwise we get circular imports?!
from uffd import ldap, user
from uffd.session.views import get_current_user
from uffd.user.models import User
from uffd.oauth2.models import OAuth2Client
from uffd import create_app, db, ldap
from utils import dump, UffdTestCase
def get_user():
return User.query.get('uid=testuser,ou=users,dc=example,dc=com')
def get_admin():
return User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
class TestOAuth2Client(UffdTestCase):
def setUpApp(self):
self.app.config['OAUTH2_CLIENTS'] = {
'test': {'client_secret': 'testsecret', 'redirect_uris': ['http://localhost:5009/callback', 'http://localhost:5009/callback2']},
'test1': {'client_secret': 'testsecret1', 'redirect_uris': ['http://localhost:5008/callback'], 'required_group': 'users'},
}
def test_from_id(self):
client = OAuth2Client.from_id('test')
self.assertEqual(client.client_id, 'test')
self.assertEqual(client.client_secret, 'testsecret')
self.assertEqual(client.redirect_uris, ['http://localhost:5009/callback', 'http://localhost:5009/callback2'])
self.assertEqual(client.default_redirect_uri, 'http://localhost:5009/callback')
self.assertEqual(client.default_scopes, ['profile'])
self.assertEqual(client.client_type, 'confidential')
client = OAuth2Client.from_id('test1')
self.assertEqual(client.client_id, 'test1')
self.assertEqual(client.required_group, 'users')
def test_access_allowed(self):
user = get_user() # has 'users' and 'uffd_access' group
admin = get_admin() # has 'users', 'uffd_access' and 'uffd_admin' group
client = OAuth2Client('test', '', [''], ['uffd_admin', ['users', 'notagroup']])
self.assertFalse(client.access_allowed(user))
self.assertTrue(client.access_allowed(admin))
# More required_group values are tested by TestUserModel.test_has_permission
class TestViews(UffdTestCase):
def setUpApp(self):
self.app.config['OAUTH2_CLIENTS'] = {
'test': {'client_secret': 'testsecret', 'redirect_uris': ['http://localhost:5009/callback', 'http://localhost:5009/callback2']},
'test1': {'client_secret': 'testsecret1', 'redirect_uris': ['http://localhost:5008/callback'], 'required_group': 'uffd_admin'},
}
def test_authorization(self):
self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'userpassword'}, follow_redirects=True)
state = 'teststate'
r = self.client.get(path=url_for('oauth2.authorize', response_type='code', client_id='test', state=state, redirect_uri='http://localhost:5009/callback'), follow_redirects=False)
while True:
if r.status_code != 302 or r.location.startswith('http://localhost:5009/callback'):
break
r = self.client.get(r.location, follow_redirects=False)
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.startswith('http://localhost:5009/callback'))
args = parse_qs(urlparse(r.location).query)
self.assertEqual(args['state'], [state])
code = args['code'][0]
r = self.client.post(path=url_for('oauth2.token'),
data={'grant_type': 'authorization_code', 'code': code, 'redirect_uri': 'http://localhost:5009/callback', 'client_id': 'test', 'client_secret': 'testsecret'}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.content_type, 'application/json')
self.assertEqual(r.json['token_type'], 'Bearer')
self.assertEqual(r.json['scope'], 'profile')
token = r.json['access_token']
r = self.client.get(path=url_for('oauth2.userinfo'), headers=[('Authorization', 'Bearer %s'%token)], follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertEqual(r.content_type, 'application/json')
user = get_user()
self.assertEqual(r.json['id'], user.uid)
self.assertEqual(r.json['name'], user.displayname)
self.assertEqual(r.json['nickname'], user.loginname)
self.assertEqual(r.json['email'], user.mail)
self.assertTrue(r.json.get('groups'))
class TestViewsOL(TestViews):
use_openldap = True
This diff is collapsed.
import time
from uffd.models.ratelimit import get_addrkey, format_delay, Ratelimit
from flask import Flask, Blueprint, session, url_for
from uffd.ratelimit import get_addrkey, format_delay, Ratelimit, RatelimitEvent
from utils import UffdTestCase
from tests.utils import UffdTestCase
class TestRatelimit(UffdTestCase):
def test_limiting(self):
......@@ -48,19 +44,3 @@ class TestRatelimit(UffdTestCase):
self.assertIsInstance(format_delay(120), str)
self.assertIsInstance(format_delay(3600), str)
self.assertIsInstance(format_delay(4000), str)
def test_cleanup(self):
ratelimit = Ratelimit('test', 1, 1)
ratelimit.log('')
ratelimit.log('1')
ratelimit.log('2')
ratelimit.log('3')
ratelimit.log('4')
time.sleep(1)
ratelimit.log('5')
self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 6)
ratelimit.cleanup()
self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 1)
time.sleep(1)
ratelimit.cleanup()
self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 0)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.