diff --git a/tests/test_role.py b/tests/test_role.py index 6fef9547df589eca9aa3fc382bac471d9bdf7b11..cdcfcf679367f252583e976f804a35cd6b809b09 100644 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -9,32 +9,12 @@ from uffd.ldap import ldap from uffd import user from uffd.user.models import User, Group -from uffd.role.models import flatten_recursive, Role, RoleGroup +from uffd.role.models import Role, RoleGroup from uffd.mfa.models import TOTPMethod from uffd import create_app, db from utils import dump, UffdTestCase -class TestPrimitives(unittest.TestCase): - def test_flatten_recursive(self): - class Node: - def __init__(self, *neighbors): - self.neighbors = set(neighbors or set()) - - cycle = Node() - cycle.neighbors.add(cycle) - common = Node(cycle) - intermediate1 = Node(common) - intermediate2 = Node(common, intermediate1) - stub = Node() - backref = Node() - start1 = Node(intermediate1, intermediate2, stub, backref) - backref.neighbors.add(start1) - start2 = Node() - self.assertSetEqual(flatten_recursive({start1, start2}, 'neighbors'), - {start1, start2, backref, stub, intermediate1, intermediate2, common, cycle}) - self.assertSetEqual(flatten_recursive(set(), 'neighbors'), set()) - class TestUserRoleAttributes(UffdTestCase): def test_roles_effective(self): for user in User.query.filter_by(loginname='service').all(): @@ -103,6 +83,8 @@ class TestRoleModel(UffdTestCase): included_role = Role(name='included') direct_role = Role(name='direct', members=[user1, user2, service], included_roles=[included_role]) empty_role = Role(name='empty', included_roles=[included_role]) + db.session.add_all([included_by_default_role, default_role, included_role, direct_role, empty_role]) + db.session.commit() self.assertSetEqual(included_by_default_role.members_effective, {user1, user2}) self.assertSetEqual(default_role.members_effective, {user1, user2}) self.assertSetEqual(included_role.members_effective, {user1, user2, service}) @@ -116,10 +98,13 @@ class TestRoleModel(UffdTestCase): role1 = Role(name='role1', included_roles=[baserole]) role2 = Role(name='role2', included_roles=[baserole]) role3 = Role(name='role3', included_roles=[role1, role2]) + db.session.add_all([baserole, role1, role2, role3]) + db.session.commit() self.assertSetEqual(role1.included_roles_recursive, {baserole}) self.assertSetEqual(role2.included_roles_recursive, {baserole}) self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2}) baserole.included_roles.append(role1) + db.session.commit() self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2}) def test_groups_effective(self): @@ -127,6 +112,8 @@ class TestRoleModel(UffdTestCase): group2 = self.get_access_group() baserole = Role(name='base', groups={group1: RoleGroup(group=group1)}) role1 = Role(name='role1', groups={group2: RoleGroup(group=group2)}, included_roles=[baserole]) + db.session.add_all([baserole, role1]) + db.session.commit() self.assertSetEqual(baserole.groups_effective, {group1}) self.assertSetEqual(role1.groups_effective, {group1, group2}) @@ -141,6 +128,7 @@ class TestRoleModel(UffdTestCase): baserole = Role(name='base', members=[user1], groups={group1: RoleGroup(group=group1)}) role1 = Role(name='role1', members=[user2], groups={group2: RoleGroup(group=group2)}, included_roles=[baserole]) db.session.add_all([baserole, role1]) + db.session.commit() baserole.update_member_groups() role1.update_member_groups() self.assertSetEqual(set(user1.groups), {group1}) diff --git a/uffd/role/models.py b/uffd/role/models.py index e5b40190cf2b3ded9d3a732a4317fd756ab4e28c..dd39caa1d9b1dae5c5f353d589cb8ea0c1af189d 100644 --- a/uffd/role/models.py +++ b/uffd/role/models.py @@ -35,22 +35,17 @@ role_inclusion = db.Table('role-inclusion', Column('included_role_id', Integer, ForeignKey('role.id'), primary_key=True) ) -def flatten_recursive(objs, attr): - '''Returns a set of objects and all objects included in object.`attr` recursivly while avoiding loops''' - objs = set(objs) - new_objs = set(objs) - while new_objs: - for obj in getattr(new_objs.pop(), attr): - if obj not in objs: - objs.add(obj) - new_objs.add(obj) - return objs - def get_user_roles_effective(user): - base = set(user.roles) - if not user.is_service_user: - base.update(Role.query.filter_by(is_default=True)) - return flatten_recursive(base, 'included_roles') + direct_roles = db.session.query(Role).join(RoleUser)\ + .filter(RoleUser.dn == user.dn) + # pylint: disable=singleton-comparison + base_roles = db.session.query(Role).filter(db.and_(Role.is_default == True, + user.is_service_user is False)) + cte = direct_roles.union(base_roles).cte('cte', recursive=True) + rquery = cte.union(db.session.query(Role)\ + .join(role_inclusion, Role.id == role_inclusion.c.included_role_id)\ + .join(cte, role_inclusion.c.role_id == cte.c.role_id)) + return set(Role.query.join(rquery, rquery.c.role_id == Role.id).all()) User.roles_effective = property(get_user_roles_effective) @@ -115,8 +110,14 @@ class Role(db.Model): @property def members_effective(self): + cte = db.session.query(Role).filter(Role.id == self.id)\ + .cte('cte', recursive=True) + rquery = cte.union(db.session.query(Role)\ + .join(role_inclusion, Role.id == role_inclusion.c.role_id)\ + .join(cte, role_inclusion.c.included_role_id == cte.c.id)) + including_roles_recursive = Role.query.join(rquery, rquery.c.id == Role.id).all() members = set() - for role in flatten_recursive([self], 'including_roles'): + for role in including_roles_recursive: members.update(role.members) if role.is_default: members.update([user for user in User.query.all() if not user.is_service_user]) @@ -124,7 +125,14 @@ class Role(db.Model): @property def included_roles_recursive(self): - return flatten_recursive(self.included_roles, 'included_roles') + cte = db.session.query(Role)\ + .join(role_inclusion, Role.id == role_inclusion.c.included_role_id)\ + .filter(role_inclusion.c.role_id == self.id)\ + .cte('cte', recursive=True) + rquery = cte.union(db.session.query(Role)\ + .join(role_inclusion, Role.id == role_inclusion.c.included_role_id)\ + .join(cte, role_inclusion.c.role_id == cte.c.id)) + return set(Role.query.join(rquery, rquery.c.id == Role.id).all()) @property def groups_effective(self):