diff --git a/migrations/versions/5a07d4a63b64_role_inclusion.py b/migrations/versions/5a07d4a63b64_role_inclusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f10e4bba07f8c84ca45628bb3f7d3039bdd0842e --- /dev/null +++ b/migrations/versions/5a07d4a63b64_role_inclusion.py @@ -0,0 +1,30 @@ +"""Role inclusion + +Revision ID: 5a07d4a63b64 +Revises: a29870f95175 +Create Date: 2021-04-05 15:00:26.205433 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5a07d4a63b64' +down_revision = 'a29870f95175' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table('role-inclusion', + sa.Column('role_id', sa.Integer(), nullable=False), + sa.Column('included_role_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['included_role_id'], ['role.id'], ), + sa.ForeignKeyConstraint(['role_id'], ['role.id'], ), + sa.PrimaryKeyConstraint('role_id', 'included_role_id') + ) + + +def downgrade(): + op.drop_table('role-inclusion') diff --git a/tests/test_role.py b/tests/test_role.py index c6a42bdc09a0f7e0e87643c84dd75cb7b7699fee..8dca42503767c1e3303bc3732d02a4486ef32c8e 100644 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -6,12 +6,89 @@ from flask import url_for, session # These imports are required, because otherwise we get circular imports?! from uffd import ldap, user -from uffd.user.models import Group +from uffd.user.models import User, Group from uffd.role.models import Role from uffd import create_app, db from utils import dump, UffdTestCase +class TestUserRoleAttributes(UffdTestCase): + def test_roles_recursive(self): + user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com') + user1.update_groups() + baserole = Role(name='base') + role1 = Role(name='role1', members=[user1], included_roles=[baserole]) + role2 = Role(name='role2', included_roles=[baserole]) + db.session.add_all([baserole, role1, role2]) + self.assertSetEqual(user1.roles_recursive, {baserole, role1}) + baserole.included_roles.append(role2) + self.assertSetEqual(user1.roles_recursive, {baserole, role1, role2}) + + def test_update_groups(self): + user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com') + user1.update_groups() + self.assertSetEqual(set(user1.groups), set()) + group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com') + group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com') + baserole = Role(name='base', groups=[group1]) + role1 = Role(name='role1', groups=[group2], members=[user1]) + db.session.add_all([baserole, role1]) + user1.update_groups() + self.assertSetEqual(set(user1.groups), {group2}) + role1.included_roles.append(baserole) + user1.update_groups() + self.assertSetEqual(set(user1.groups), {group1, group2}) + +class TestRoleModel(UffdTestCase): + def test_indirect_members(self): + user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com') + user1.update_groups() + user2 = User.query.get('uid=testadmin,ou=users,dc=example,dc=com') + user2.update_groups() + baserole = Role(name='base', members=[user1]) + role1 = Role(name='role1', included_roles=[baserole], members=[user2]) + self.assertSetEqual(baserole.indirect_members, {user2}) + self.assertSetEqual(role1.indirect_members, set()) + + def test_included_roles_recursive(self): + baserole = Role(name='base') + role1 = Role(name='role1', included_roles=[baserole]) + role2 = Role(name='role2', included_roles=[baserole]) + role3 = Role(name='role3', included_roles=[role1, role2]) + self.assertSetEqual(role1.included_roles_recursive, {baserole}) + self.assertSetEqual(role2.included_roles_recursive, {baserole}) + self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2}) + baserole.included_roles.append(role1) + self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2}) + + def test_included_groups(self): + group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com') + group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com') + baserole = Role(name='base', groups=[group1]) + role1 = Role(name='role1', groups=[group2], included_roles=[baserole]) + self.assertSetEqual(baserole.included_groups, set()) + self.assertSetEqual(role1.included_groups, {group1}) + + def test_update_member_groups(self): + user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com') + user1.update_groups() + user2 = User.query.get('uid=testadmin,ou=users,dc=example,dc=com') + user2.update_groups() + group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com') + group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com') + group3 = Group.query.get('cn=uffd_admin,ou=groups,dc=example,dc=com') + baserole = Role(name='base', members=[user1], groups=[group1]) + role1 = Role(name='role1', members=[user2], groups=[group2], included_roles=[baserole]) + db.session.add_all([baserole, role1]) + baserole.update_member_groups() + role1.update_member_groups() + self.assertSetEqual(set(user1.groups), {group1}) + self.assertSetEqual(set(user2.groups), {group1, group2}) + baserole.groups.add(group3) + baserole.update_member_groups() + self.assertSetEqual(set(user1.groups), {group1, group3}) + self.assertSetEqual(set(user2.groups), {group1, group2, group3}) + class TestRoleViews(UffdTestCase): def setUp(self): super().setUp() diff --git a/uffd/role/models.py b/uffd/role/models.py index 4738ce91c5f66c3f701186f9f04ea656b38598e5..ade3870261f18634c45484041c87d4a0fcfe8bda 100644 --- a/uffd/role/models.py +++ b/uffd/role/models.py @@ -23,10 +23,32 @@ class RoleGroup(LdapMapping, db.Model): class RoleUser(LdapMapping, db.Model): __tablename__ = 'role-user' +# pylint: disable=E1101 +role_inclusion = db.Table('role-inclusion', + Column('role_id', Integer, ForeignKey('role.id'), primary_key=True), + Column('included_role_id', Integer, ForeignKey('role.id'), primary_key=True) +) + +def flatten_recursive(objs, attr): + '''Returns a set of objects and all objects included in object.`attr` recursivly while avoiding loops''' + objs = set(objs) + new_objs = set(objs) + while new_objs: + for obj in getattr(new_objs.pop(), attr): + if obj not in objs: + objs.add(obj) + new_objs.add(obj) + return objs + +def get_roles_recursive(user): + return flatten_recursive(user.roles, 'included_roles') + +User.roles_recursive = property(get_roles_recursive) + def update_user_groups(user): current_groups = set(user.groups) groups = set() - for role in user.roles: + for role in user.roles_recursive: groups.update(role.groups) if groups == current_groups: return set(), set() @@ -44,6 +66,11 @@ class Role(db.Model): id = Column(Integer(), primary_key=True, autoincrement=True) name = Column(String(32), unique=True) description = Column(Text(), default='') + included_roles = relationship('Role', secondary=role_inclusion, + primaryjoin=id == role_inclusion.c.role_id, + secondaryjoin=id == role_inclusion.c.included_role_id, + backref='including_roles') + including_roles = [] # overwritten by backref db_members = relationship("RoleUser", backref="role", cascade="all, delete-orphan") members = DBRelationship('db_members', User, RoleUser, backattr='role', backref='roles') @@ -51,6 +78,24 @@ class Role(db.Model): db_groups = relationship("RoleGroup", backref="role", cascade="all, delete-orphan") groups = DBRelationship('db_groups', Group, RoleGroup, backattr='role', backref='roles') + @property + def indirect_members(self): + users = set() + for role in flatten_recursive(self.including_roles, 'including_roles'): + users.update(role.members) + return users + + @property + def included_roles_recursive(self): + return flatten_recursive(self.included_roles, 'included_roles') + + @property + def included_groups(self): + groups = set() + for role in self.included_roles_recursive: + groups.update(role.groups) + return groups + def update_member_groups(self): - for user in self.members: + for user in set(self.members).union(self.indirect_members): user.update_groups() diff --git a/uffd/role/templates/role.html b/uffd/role/templates/role.html index 3b2fa8e1ac9a3ae481070aa926c30edf40c13be0..e7c494b74c4db76f74f402c13b7f58a9fdfc2088 100644 --- a/uffd/role/templates/role.html +++ b/uffd/role/templates/role.html @@ -15,6 +15,47 @@ <small class="form-text text-muted"> </small> </div> + + <div class="form-group col"> + <span>Roles to include groups from recursively</span> + <table class="table table-striped table-sm"> + <thead> + <tr> + <th scope="col"></th> + <th scope="col">name</th> + <th scope="col">description</th> + <th scope="col">currently includes groups</th> + </tr> + </thead> + <tbody> + {% for r in roles|sort(attribute="name")|sort(attribute='name') %} + <tr id="include-role-{{ role.id }}"> + <td> + <div class="form-check"> + <input class="form-check-input" type="checkbox" id="include-role-{{ r.id }}-checkbox" name="include-role-{{ r.id }}" value="1" aria-label="enabled" + {% if r == role %}disabled{% endif %} + {% if r in role.included_roles %}checked{% endif %}> + </div> + </td> + <td> + <a href="{{ url_for("role.show", roleid=r.id) }}"> + {{ r.name }} + </a> + </td> + <td> + {{ r.description }} + </td> + <td> + {% for group in r.included_groups.union(r.groups)|sort(attribute='name') %} + <a href="{{ url_for("group.show", gid=group.gid) }}">{{ group.name }}</a>{{ ', ' if not loop.last }} + {% endfor %} + </td> + </tr> + {% endfor %} + </tbody> + </table> + </div> + <div class="form-group col"> <span>Included groups</span> <table class="table table-striped table-sm"> @@ -46,6 +87,7 @@ </tbody> </table> </div> + <div class="form-group col"> <p> Members diff --git a/uffd/role/views.py b/uffd/role/views.py index 86c94145743821b547b89d87dfbd4f36299a3e0b..a5e11be0242f260cfaf3f0e78f7fe813daf56924 100644 --- a/uffd/role/views.py +++ b/uffd/role/views.py @@ -58,7 +58,7 @@ def show(roleid=False): role = Role() else: role = Role.query.filter_by(id=roleid).one() - return render_template('role.html', role=role, groups=Group.query.all()) + return render_template('role.html', role=role, groups=Group.query.all(), roles=Role.query.all()) @bp.route("/<int:roleid>/update", methods=['POST']) @bp.route("/new", methods=['POST']) @@ -72,6 +72,11 @@ def update(roleid=False): role = Role.query.filter_by(id=roleid).one() role.name = request.values['name'] role.description = request.values['description'] + for included_role in Role.query.all(): + if included_role != role and request.values.get('include-role-{}'.format(included_role.id)): + role.included_roles.append(included_role) + elif included_role in role.included_roles: + role.included_roles.remove(included_role) for group in Group.query.all(): if request.values.get('group-{}'.format(group.gid), False): role.groups.add(group) @@ -86,7 +91,7 @@ def update(roleid=False): @csrf_protect(blueprint=bp) def delete(roleid): role = Role.query.filter_by(id=roleid).one() - oldmembers = list(role.members) + oldmembers = set(role.members).union(role.indirect_members) role.members.clear() db.session.delete(role) for user in oldmembers: