Skip to content
Snippets Groups Projects
Commit 778edcc0 authored by Julian's avatar Julian
Browse files

Implemented role inherintance (more precisely: "inclusion")

parent e4b1b075
No related branches found
No related tags found
No related merge requests found
"""Role inclusion
Revision ID: 5a07d4a63b64
Revises: a29870f95175
Create Date: 2021-04-05 15:00:26.205433
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '5a07d4a63b64'
down_revision = 'a29870f95175'
branch_labels = None
depends_on = None
def upgrade():
op.create_table('role-inclusion',
sa.Column('role_id', sa.Integer(), nullable=False),
sa.Column('included_role_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['included_role_id'], ['role.id'], ),
sa.ForeignKeyConstraint(['role_id'], ['role.id'], ),
sa.PrimaryKeyConstraint('role_id', 'included_role_id')
)
def downgrade():
op.drop_table('role-inclusion')
...@@ -6,12 +6,89 @@ from flask import url_for, session ...@@ -6,12 +6,89 @@ from flask import url_for, session
# These imports are required, because otherwise we get circular imports?! # These imports are required, because otherwise we get circular imports?!
from uffd import ldap, user from uffd import ldap, user
from uffd.user.models import Group from uffd.user.models import User, Group
from uffd.role.models import Role from uffd.role.models import Role
from uffd import create_app, db from uffd import create_app, db
from utils import dump, UffdTestCase from utils import dump, UffdTestCase
class TestUserRoleAttributes(UffdTestCase):
def test_roles_recursive(self):
user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com')
user1.update_groups()
baserole = Role(name='base')
role1 = Role(name='role1', members=[user1], included_roles=[baserole])
role2 = Role(name='role2', included_roles=[baserole])
db.session.add_all([baserole, role1, role2])
self.assertSetEqual(user1.roles_recursive, {baserole, role1})
baserole.included_roles.append(role2)
self.assertSetEqual(user1.roles_recursive, {baserole, role1, role2})
def test_update_groups(self):
user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com')
user1.update_groups()
self.assertSetEqual(set(user1.groups), set())
group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com')
group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com')
baserole = Role(name='base', groups=[group1])
role1 = Role(name='role1', groups=[group2], members=[user1])
db.session.add_all([baserole, role1])
user1.update_groups()
self.assertSetEqual(set(user1.groups), {group2})
role1.included_roles.append(baserole)
user1.update_groups()
self.assertSetEqual(set(user1.groups), {group1, group2})
class TestRoleModel(UffdTestCase):
def test_indirect_members(self):
user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com')
user1.update_groups()
user2 = User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
user2.update_groups()
baserole = Role(name='base', members=[user1])
role1 = Role(name='role1', included_roles=[baserole], members=[user2])
self.assertSetEqual(baserole.indirect_members, {user2})
self.assertSetEqual(role1.indirect_members, set())
def test_included_roles_recursive(self):
baserole = Role(name='base')
role1 = Role(name='role1', included_roles=[baserole])
role2 = Role(name='role2', included_roles=[baserole])
role3 = Role(name='role3', included_roles=[role1, role2])
self.assertSetEqual(role1.included_roles_recursive, {baserole})
self.assertSetEqual(role2.included_roles_recursive, {baserole})
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
baserole.included_roles.append(role1)
self.assertSetEqual(role3.included_roles_recursive, {baserole, role1, role2})
def test_included_groups(self):
group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com')
group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com')
baserole = Role(name='base', groups=[group1])
role1 = Role(name='role1', groups=[group2], included_roles=[baserole])
self.assertSetEqual(baserole.included_groups, set())
self.assertSetEqual(role1.included_groups, {group1})
def test_update_member_groups(self):
user1 = User.query.get('uid=testuser,ou=users,dc=example,dc=com')
user1.update_groups()
user2 = User.query.get('uid=testadmin,ou=users,dc=example,dc=com')
user2.update_groups()
group1 = Group.query.get('cn=users,ou=groups,dc=example,dc=com')
group2 = Group.query.get('cn=uffd_access,ou=groups,dc=example,dc=com')
group3 = Group.query.get('cn=uffd_admin,ou=groups,dc=example,dc=com')
baserole = Role(name='base', members=[user1], groups=[group1])
role1 = Role(name='role1', members=[user2], groups=[group2], included_roles=[baserole])
db.session.add_all([baserole, role1])
baserole.update_member_groups()
role1.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1})
self.assertSetEqual(set(user2.groups), {group1, group2})
baserole.groups.add(group3)
baserole.update_member_groups()
self.assertSetEqual(set(user1.groups), {group1, group3})
self.assertSetEqual(set(user2.groups), {group1, group2, group3})
class TestRoleViews(UffdTestCase): class TestRoleViews(UffdTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -23,10 +23,32 @@ class RoleGroup(LdapMapping, db.Model): ...@@ -23,10 +23,32 @@ class RoleGroup(LdapMapping, db.Model):
class RoleUser(LdapMapping, db.Model): class RoleUser(LdapMapping, db.Model):
__tablename__ = 'role-user' __tablename__ = 'role-user'
# pylint: disable=E1101
role_inclusion = db.Table('role-inclusion',
Column('role_id', Integer, ForeignKey('role.id'), primary_key=True),
Column('included_role_id', Integer, ForeignKey('role.id'), primary_key=True)
)
def flatten_recursive(objs, attr):
'''Returns a set of objects and all objects included in object.`attr` recursivly while avoiding loops'''
objs = set(objs)
new_objs = set(objs)
while new_objs:
for obj in getattr(new_objs.pop(), attr):
if obj not in objs:
objs.add(obj)
new_objs.add(obj)
return objs
def get_roles_recursive(user):
return flatten_recursive(user.roles, 'included_roles')
User.roles_recursive = property(get_roles_recursive)
def update_user_groups(user): def update_user_groups(user):
current_groups = set(user.groups) current_groups = set(user.groups)
groups = set() groups = set()
for role in user.roles: for role in user.roles_recursive:
groups.update(role.groups) groups.update(role.groups)
if groups == current_groups: if groups == current_groups:
return set(), set() return set(), set()
...@@ -44,6 +66,11 @@ class Role(db.Model): ...@@ -44,6 +66,11 @@ class Role(db.Model):
id = Column(Integer(), primary_key=True, autoincrement=True) id = Column(Integer(), primary_key=True, autoincrement=True)
name = Column(String(32), unique=True) name = Column(String(32), unique=True)
description = Column(Text(), default='') description = Column(Text(), default='')
included_roles = relationship('Role', secondary=role_inclusion,
primaryjoin=id == role_inclusion.c.role_id,
secondaryjoin=id == role_inclusion.c.included_role_id,
backref='including_roles')
including_roles = [] # overwritten by backref
db_members = relationship("RoleUser", backref="role", cascade="all, delete-orphan") db_members = relationship("RoleUser", backref="role", cascade="all, delete-orphan")
members = DBRelationship('db_members', User, RoleUser, backattr='role', backref='roles') members = DBRelationship('db_members', User, RoleUser, backattr='role', backref='roles')
...@@ -51,6 +78,24 @@ class Role(db.Model): ...@@ -51,6 +78,24 @@ class Role(db.Model):
db_groups = relationship("RoleGroup", backref="role", cascade="all, delete-orphan") db_groups = relationship("RoleGroup", backref="role", cascade="all, delete-orphan")
groups = DBRelationship('db_groups', Group, RoleGroup, backattr='role', backref='roles') groups = DBRelationship('db_groups', Group, RoleGroup, backattr='role', backref='roles')
@property
def indirect_members(self):
users = set()
for role in flatten_recursive(self.including_roles, 'including_roles'):
users.update(role.members)
return users
@property
def included_roles_recursive(self):
return flatten_recursive(self.included_roles, 'included_roles')
@property
def included_groups(self):
groups = set()
for role in self.included_roles_recursive:
groups.update(role.groups)
return groups
def update_member_groups(self): def update_member_groups(self):
for user in self.members: for user in set(self.members).union(self.indirect_members):
user.update_groups() user.update_groups()
...@@ -15,6 +15,47 @@ ...@@ -15,6 +15,47 @@
<small class="form-text text-muted"> <small class="form-text text-muted">
</small> </small>
</div> </div>
<div class="form-group col">
<span>Roles to include groups from recursively</span>
<table class="table table-striped table-sm">
<thead>
<tr>
<th scope="col"></th>
<th scope="col">name</th>
<th scope="col">description</th>
<th scope="col">currently includes groups</th>
</tr>
</thead>
<tbody>
{% for r in roles|sort(attribute="name")|sort(attribute='name') %}
<tr id="include-role-{{ role.id }}">
<td>
<div class="form-check">
<input class="form-check-input" type="checkbox" id="include-role-{{ r.id }}-checkbox" name="include-role-{{ r.id }}" value="1" aria-label="enabled"
{% if r == role %}disabled{% endif %}
{% if r in role.included_roles %}checked{% endif %}>
</div>
</td>
<td>
<a href="{{ url_for("role.show", roleid=r.id) }}">
{{ r.name }}
</a>
</td>
<td>
{{ r.description }}
</td>
<td>
{% for group in r.included_groups.union(r.groups)|sort(attribute='name') %}
<a href="{{ url_for("group.show", gid=group.gid) }}">{{ group.name }}</a>{{ ', ' if not loop.last }}
{% endfor %}
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<div class="form-group col"> <div class="form-group col">
<span>Included groups</span> <span>Included groups</span>
<table class="table table-striped table-sm"> <table class="table table-striped table-sm">
...@@ -46,6 +87,7 @@ ...@@ -46,6 +87,7 @@
</tbody> </tbody>
</table> </table>
</div> </div>
<div class="form-group col"> <div class="form-group col">
<p> <p>
Members Members
......
...@@ -58,7 +58,7 @@ def show(roleid=False): ...@@ -58,7 +58,7 @@ def show(roleid=False):
role = Role() role = Role()
else: else:
role = Role.query.filter_by(id=roleid).one() role = Role.query.filter_by(id=roleid).one()
return render_template('role.html', role=role, groups=Group.query.all()) return render_template('role.html', role=role, groups=Group.query.all(), roles=Role.query.all())
@bp.route("/<int:roleid>/update", methods=['POST']) @bp.route("/<int:roleid>/update", methods=['POST'])
@bp.route("/new", methods=['POST']) @bp.route("/new", methods=['POST'])
...@@ -72,6 +72,11 @@ def update(roleid=False): ...@@ -72,6 +72,11 @@ def update(roleid=False):
role = Role.query.filter_by(id=roleid).one() role = Role.query.filter_by(id=roleid).one()
role.name = request.values['name'] role.name = request.values['name']
role.description = request.values['description'] role.description = request.values['description']
for included_role in Role.query.all():
if included_role != role and request.values.get('include-role-{}'.format(included_role.id)):
role.included_roles.append(included_role)
elif included_role in role.included_roles:
role.included_roles.remove(included_role)
for group in Group.query.all(): for group in Group.query.all():
if request.values.get('group-{}'.format(group.gid), False): if request.values.get('group-{}'.format(group.gid), False):
role.groups.add(group) role.groups.add(group)
...@@ -86,7 +91,7 @@ def update(roleid=False): ...@@ -86,7 +91,7 @@ def update(roleid=False):
@csrf_protect(blueprint=bp) @csrf_protect(blueprint=bp)
def delete(roleid): def delete(roleid):
role = Role.query.filter_by(id=roleid).one() role = Role.query.filter_by(id=roleid).one()
oldmembers = list(role.members) oldmembers = set(role.members).union(role.indirect_members)
role.members.clear() role.members.clear()
db.session.delete(role) db.session.delete(role)
for user in oldmembers: for user in oldmembers:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment