diff --git a/tests/test_role.py b/tests/test_role.py index a88baf447bde9aa68c11d979a2b999aba63ba194..2f4e0601b942fd165624c1004137733dd5d6b609 100644 --- a/tests/test_role.py +++ b/tests/test_role.py @@ -19,8 +19,8 @@ class TestRoleViews(UffdTestCase): data={'loginname': 'testadmin', 'password': 'adminpassword'}, follow_redirects=True) def test_index(self): - db.session.add(Role('base', 'Base role description')) - db.session.add(Role('test1', 'Test1 role description')) + db.session.add(Role(name='base', description='Base role description')) + db.session.add(Role(name='test1', description='Test1 role description')) db.session.commit() r = self.client.get(path=url_for('role.index'), follow_redirects=True) dump('role_index', r) @@ -32,7 +32,7 @@ class TestRoleViews(UffdTestCase): self.assertEqual(r.status_code, 200) def test_show(self): - role = Role('base', 'Base role description') + role = Role(name='base', description='Base role description') db.session.add(role) db.session.commit() r = self.client.get(path=url_for('role.show', roleid=role.id), follow_redirects=True) @@ -45,14 +45,14 @@ class TestRoleViews(UffdTestCase): self.assertEqual(r.status_code, 200) def test_update(self): - role = Role('base', 'Base role description') + role = Role(name='base', description='Base role description') db.session.add(role) db.session.commit() - role.add_group(Group.ldap_get('cn=uffd_admin,ou=groups,dc=example,dc=com')) + role.groups.add(Group.ldap_get('cn=uffd_admin,ou=groups,dc=example,dc=com')) db.session.commit() self.assertEqual(role.name, 'base') self.assertEqual(role.description, 'Base role description') - self.assertEqual(role.group_dns(), ['cn=uffd_admin,ou=groups,dc=example,dc=com']) + self.assertEqual([group.dn for group in role.groups], ['cn=uffd_admin,ou=groups,dc=example,dc=com']) r = self.client.post(path=url_for('role.update', roleid=role.id), data={'name': 'base1', 'description': 'Base role description1', 'group-20001': '1', 'group-20002': '1'}, follow_redirects=True) @@ -61,7 +61,7 @@ class TestRoleViews(UffdTestCase): role = Role.query.get(role.id) self.assertEqual(role.name, 'base1') self.assertEqual(role.description, 'Base role description1') - self.assertEqual(sorted(role.group_dns()), ['cn=uffd_access,ou=groups,dc=example,dc=com', + self.assertEqual(sorted([group.dn for group in role.groups]), ['cn=uffd_access,ou=groups,dc=example,dc=com', 'cn=users,ou=groups,dc=example,dc=com']) # TODO: verify that group memberships are updated (currently not possible with ldap mock!) @@ -76,12 +76,12 @@ class TestRoleViews(UffdTestCase): self.assertIsNotNone(role) self.assertEqual(role.name, 'base') self.assertEqual(role.description, 'Base role description') - self.assertEqual(sorted(role.group_dns()), ['cn=uffd_access,ou=groups,dc=example,dc=com', + self.assertEqual(sorted([group.dn for group in role.groups]), ['cn=uffd_access,ou=groups,dc=example,dc=com', 'cn=users,ou=groups,dc=example,dc=com']) # TODO: verify that group memberships are updated (currently not possible with ldap mock!) def test_delete(self): - role = Role('base', 'Base role description') + role = Role(name='base', description='Base role description') db.session.add(role) db.session.commit() role_id = role.id diff --git a/tests/test_user.py b/tests/test_user.py index 9130de4008c7a7a8c622c650015fc9c6cfe64c8c..5bc1e313a2f8d351015664f8652534e371a3d8d4 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -64,10 +64,10 @@ class TestUserViews(UffdTestCase): self.assertEqual(r.status_code, 200) def test_new(self): - db.session.add(Role('base')) - role1 = Role('role1') + db.session.add(Role(name='base')) + role1 = Role(name='role1') db.session.add(role1) - role2 = Role('role2') + role2 = Role(name='role2') db.session.add(role2) db.session.commit() role1_id = role1.id @@ -81,7 +81,7 @@ class TestUserViews(UffdTestCase): dump('user_new_submit', r) self.assertEqual(r.status_code, 200) user = User.ldap_get('uid=newuser,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser') self.assertEqual(user.displayname, 'New User') @@ -124,12 +124,12 @@ class TestUserViews(UffdTestCase): def test_update(self): user = get_user() - db.session.add(Role('base')) - role1 = Role('role1') + db.session.add(Role(name='base')) + role1 = Role(name='role1') db.session.add(role1) - role2 = Role('role2') + role2 = Role(name='role2') db.session.add(role2) - role2.add_member(user) + role2.members.add(user) db.session.commit() role1_id = role1.id oldpw = get_user_password() @@ -142,7 +142,7 @@ class TestUserViews(UffdTestCase): dump('user_update_submit', r) self.assertEqual(r.status_code, 200) _user = get_user() - roles = sorted([r.name for r in Role.get_for_user(_user)]) + roles = sorted([r.name for r in _user.roles]) self.assertEqual(_user.displayname, 'New User') self.assertEqual(_user.mail, 'newuser@example.com') self.assertEqual(_user.uid, user.uid) @@ -229,10 +229,10 @@ class TestUserViews(UffdTestCase): self.assertIsNone(get_user()) def test_csvimport(self): - db.session.add(Role('base')) - role1 = Role('role1') + db.session.add(Role(name='base')) + role1 = Role(name='role1') db.session.add(role1) - role2 = Role('role2') + role2 = Role(name='role2') db.session.add(role2) db.session.commit() data = f'''\ @@ -256,42 +256,42 @@ newuser12,newuser12@example.com,{role1.id};{role1.id} dump('user_csvimport', r) self.assertEqual(r.status_code, 200) user = User.ldap_get('uid=newuser1,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser1') self.assertEqual(user.displayname, 'newuser1') self.assertEqual(user.mail, 'newuser1@example.com') self.assertEqual(roles, ['base']) user = User.ldap_get('uid=newuser2,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser2') self.assertEqual(user.displayname, 'newuser2') self.assertEqual(user.mail, 'newuser2@example.com') self.assertEqual(roles, ['base', 'role1']) user = User.ldap_get('uid=newuser3,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser3') self.assertEqual(user.displayname, 'newuser3') self.assertEqual(user.mail, 'newuser3@example.com') self.assertEqual(roles, ['base', 'role1', 'role2']) user = User.ldap_get('uid=newuser4,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser4') self.assertEqual(user.displayname, 'newuser4') self.assertEqual(user.mail, 'newuser4@example.com') self.assertEqual(roles, ['base']) user = User.ldap_get('uid=newuser5,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser5') self.assertEqual(user.displayname, 'newuser5') self.assertEqual(user.mail, 'newuser5@example.com') self.assertEqual(roles, ['base']) user = User.ldap_get('uid=newuser6,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser6') self.assertEqual(user.displayname, 'newuser6') @@ -301,14 +301,14 @@ newuser12,newuser12@example.com,{role1.id};{role1.id} self.assertIsNone(User.ldap_get('uid=newuser8,ou=users,dc=example,dc=com')) self.assertIsNone(User.ldap_get('uid=newuser9,ou=users,dc=example,dc=com')) user = User.ldap_get('uid=newuser10,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser10') self.assertEqual(user.displayname, 'newuser10') self.assertEqual(user.mail, 'newuser10@example.com') self.assertEqual(roles, ['base']) user = User.ldap_get('uid=newuser11,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser11') self.assertEqual(user.displayname, 'newuser11') @@ -317,7 +317,7 @@ newuser12,newuser12@example.com,{role1.id};{role1.id} #self.assertEqual(roles, ['base', 'role1', 'role2']) self.assertEqual(roles, ['base', 'role2']) user = User.ldap_get('uid=newuser12,ou=users,dc=example,dc=com') - roles = sorted([r.name for r in Role.get_for_user(user)]) + roles = sorted([r.name for r in user.roles]) self.assertIsNotNone(user) self.assertEqual(user.loginname, 'newuser12') self.assertEqual(user.displayname, 'newuser12') diff --git a/uffd/ldap.py b/uffd/ldap.py index 5c3aed3e1a36d492fde11dd6955f48d3c094fe9a..cd653286b2ddc94a00ffc801fcc037e8caefc36b 100644 --- a/uffd/ldap.py +++ b/uffd/ldap.py @@ -148,6 +148,10 @@ class LDAPSet(MutableSet): def discard(self, value): self.__delitem(self.__encode(value)) + def update(self, values): + for value in values: + self.add(value) + class LDAPAttribute: def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None): self.name = name @@ -400,3 +404,68 @@ class LDAPModel: if not success: raise Exception() self.__ldap_attributes = {} + +class DB2LDAPBackref: + def __init__(self, baseattr, mapcls, backattr): + self.baseattr = baseattr + self.mapcls = mapcls + self.backattr = backattr + + def getitems(self, ldapobj): + return {getattr(mapobj, self.backattr) for mapobj in self.mapcls.query.filter_by(dn=ldapobj.dn)} + + def additem(self, ldapobj, dbobj): + if dbobj not in self.getitems(ldapobj): + getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) + + def delitem(self, ldapobj, dbobj): + for mapobj in list(getattr(dbobj, self.baseattr)): + if mapobj.dn == ldapobj.dn: + getattr(dbobj, self.baseattr).remove(mapobj) + + def __get__(self, ldapobj, objtype=None): + if ldapobj is None: + return self + return LDAPSet(getitems=lambda: self.getitems(ldapobj), + additem=lambda dbobj: self.additem(ldapobj, dbobj), + delitem=lambda dbobj: self.delitem(ldapobj, dbobj)) + + def __set__(self, ldapobj, dbobjs): + rel = self.__get__(ldapobj) + rel.clear() + for dbobj in dbobjs: + rel.add(dbobj) + +class DB2LDAPRelation: + def __init__(self, baseattr, mapcls, ldapcls, backattr=None, backref=None): + self.baseattr = baseattr + self.mapcls = mapcls + self.ldapcls = ldapcls + if backref is not None: + setattr(ldapcls, backref, DB2LDAPBackref(baseattr, mapcls, backattr)) + + def getitems(self, dbobj): + return {mapobj.dn for mapobj in getattr(dbobj, self.baseattr)} + + def additem(self, dbobj, dn): + if dn not in self.getitems(dbobj): + getattr(dbobj, self.baseattr).append(self.mapcls(dn=dn)) + + def delitem(self, dbobj, dn): + for mapobj in list(getattr(dbobj, self.baseattr)): + if mapobj.dn == dn: + getattr(dbobj, self.baseattr).remove(mapobj) + + def __get__(self, dbobj, objtype=None): + if dbobj is None: + return self + return LDAPSet(getitems=lambda: self.getitems(dbobj), + additem=lambda dn: self.additem(dbobj, dn), + delitem=lambda dn: self.delitem(dbobj, dn), + encode=lambda ldapobj: ldapobj.dn, + decode=self.ldapcls.ldap_get) + + def __set__(self, dbobj, ldapobjs): + getattr(dbobj, self.baseattr).clear() + for ldapobj in ldapobjs: + getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) diff --git a/uffd/role/models.py b/uffd/role/models.py index c00228b5a55d696803a5a5059c680baf82170a90..79445573e77c7448f236cfc6dcb2539cb7ec6c92 100644 --- a/uffd/role/models.py +++ b/uffd/role/models.py @@ -1,61 +1,12 @@ -from operator import attrgetter - from sqlalchemy import Column, String, Integer, Text, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declared_attr from uffd.database import db +from uffd.ldap import DB2LDAPRelation from uffd.user.models import User, Group -class Role(db.Model): - __tablename__ = 'role' - id = Column(Integer(), primary_key=True, autoincrement=True) - name = Column(String(32), unique=True) - description = Column(Text()) - members = relationship("RoleUser", backref="role", cascade="all, delete-orphan") - groups = relationship("RoleGroup", backref="role", cascade="all, delete-orphan") - - def __init__(self, name='', description=''): - self.name = name - self.description = description - - @classmethod - def get_for_user(cls, user): - return Role.query.join(Role.members, aliased=True).filter_by(dn=user.dn) - - def member_ldap(self): - result = [] - for dn in self.member_dns(): - result.append(User.ldap_get(dn)) - return result - - def member_dns(self): - return list(map(attrgetter('dn'), self.members)) - - def add_member(self, member): - newmapping = RoleUser(member.dn, self) - self.members.append(newmapping) - - def del_member(self, member): - for i in self.members: - if i.dn == member.dn: - self.members.remove(i) - break - - def group_dns(self): - return list(map(attrgetter('dn'), self.groups)) - - def add_group(self, group): - newmapping = RoleGroup(group.dn, self) - self.groups.append(newmapping) - - def del_group(self, group): - for i in self.groups: - if i.dn == group.dn: - self.groups.remove(i) - break - -class LdapMapping(): +class LdapMapping: id = Column(Integer(), primary_key=True, autoincrement=True) dn = Column(String(128)) __table_args__ = ( @@ -64,22 +15,32 @@ class LdapMapping(): @declared_attr def role_id(self): return Column(ForeignKey('role.id')) - ldapclass = None - def __init__(self, dn='', role=''): - self.dn = dn - self.role = role +class RoleGroup(LdapMapping, db.Model): + pass - def get_ldap(self): - return self.ldapclass.ldap_get(self.dn) +class RoleUser(LdapMapping, db.Model): + pass - def set_ldap(self, value): - self.dn = value['dn'] +def update_user_groups(user): + user.groups.clear() + for role in user.roles: + user.groups.update(role.groups) -class RoleGroup(LdapMapping, db.Model): - __tablename__ = 'role-group' - ldapclass = User +User.update_groups = update_user_groups -class RoleUser(LdapMapping, db.Model): - __tablename__ = 'role-user' - ldapclass = Group +class Role(db.Model): + __tablename__ = 'role' + id = Column(Integer(), primary_key=True, autoincrement=True) + name = Column(String(32), unique=True) + description = Column(Text(), default='') + + db_members = relationship("RoleUser", backref="role", cascade="all, delete-orphan") + members = DB2LDAPRelation('db_members', RoleUser, User, backattr='role', backref='roles') + + db_groups = relationship("RoleGroup", backref="role", cascade="all, delete-orphan") + groups = DB2LDAPRelation('db_groups', RoleGroup, Group, backattr='role', backref='roles') + + def update_member_groups(self): + for user in self.members: + user.update_groups() diff --git a/uffd/role/templates/role.html b/uffd/role/templates/role.html index 4f9094b02c25fe705c29bc3f16177d2fde878389..3b2df91fdd1ecaf3f2e9354a5efbc44d8c4397e9 100644 --- a/uffd/role/templates/role.html +++ b/uffd/role/templates/role.html @@ -30,7 +30,7 @@ <tr id="group-{{ group.gid }}"> <td> <div class="form-check"> - <input class="form-check-input" type="checkbox" id="group-{{ group.gid }}-checkbox" name="group-{{ group.gid }}" value="1" aria-label="enabled" {% if group.dn in role.group_dns() %}checked{% endif %}> + <input class="form-check-input" type="checkbox" id="group-{{ group.gid }}-checkbox" name="group-{{ group.gid }}" value="1" aria-label="enabled" {% if group in role.groups %}checked{% endif %}> </div> </td> <td> diff --git a/uffd/role/utils.py b/uffd/role/utils.py deleted file mode 100644 index de55f1ab695a2901a0e60f1adc83d7b483822dc8..0000000000000000000000000000000000000000 --- a/uffd/role/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -from uffd.user.models import Group -from uffd.role.models import Role - -def recalculate_user_groups(user): - newgroups = set() - for role in Role.get_for_user(user).all(): - # TODO: improve this after finding a solution for the Role<->Group relation - newgroups.update({Group.ldap_get(dn) for dn in role.group_dns()}) - user.groups = newgroups diff --git a/uffd/role/views.py b/uffd/role/views.py index 32929a276ffe2139e81bdd0796e089dd4fa62e66..7b16b02469a9fd7c7ae28659b4afe254b2ac9edd 100644 --- a/uffd/role/views.py +++ b/uffd/role/views.py @@ -3,10 +3,10 @@ from flask import Blueprint, render_template, request, url_for, redirect, flash, from uffd.navbar import register_navbar from uffd.csrf import csrf_protect from uffd.role.models import Role -from uffd.role.utils import recalculate_user_groups from uffd.user.models import Group from uffd.session import get_current_user, login_required, is_valid_session from uffd.database import db +from uffd.ldap import ldap bp = Blueprint("role", __name__, template_folder='templates', url_prefix='/role/') @bp.before_request @@ -38,45 +38,32 @@ def show(roleid=False): @csrf_protect(blueprint=bp) def update(roleid=False): is_newrole = bool(not roleid) - session = db.session if is_newrole: role = Role() - session.add(role) + db.session.add(role) else: role = Role.query.filter_by(id=roleid).one() role.name = request.values['name'] role.description = request.values['description'] - - groups = Group.ldap_all() - role_group_dns = role.group_dns() - for group in groups: + for group in Group.ldap_all(): if request.values.get('group-{}'.format(group.gid), False): - if group.dn in role_group_dns: - continue - role.add_group(group) - elif group.dn in role_group_dns: - role.del_group(group) - - members = role.member_ldap() - for user in members: - recalculate_user_groups(user) - if not user.to_ldap(): - flash('updating group membership for user {} failed'.format(user.loginname)) - - session.commit() + role.groups.add(group) + else: + role.groups.discard(group) + role.update_member_groups() + db.session.commit() + ldap.session.commit() return redirect(url_for('role.index')) @bp.route("/<int:roleid>/del") @csrf_protect(blueprint=bp) def delete(roleid): - session = db.session role = Role.query.filter_by(id=roleid).one() - members = role.member_ldap() - session.delete(role) - session.commit() - for user in members: - recalculate_user_groups(user) - if not user.to_ldap(): - flash('updating group membership for user {} failed'.format(user.loginname)) - session.commit() + oldmembers = list(role.members) + role.members.clear() + db.session.delete(role) + for user in oldmembers: + user.update_groups() + db.session.commit() + ldap.session.commit() return redirect(url_for('role.index')) diff --git a/uffd/user/models.py b/uffd/user/models.py index 18cd4d0306ac63a70621105091f8b15d79574469..d79a8bb70b2baf86e92ed6ab15b9e295ca56a467 100644 --- a/uffd/user/models.py +++ b/uffd/user/models.py @@ -29,7 +29,8 @@ class User(LDAPModel): mail = LDAPAttribute('mail') pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) - groups = [] # Shut up pylint, overwritten by LDAPBackref + groups = [] # Shuts up pylint, overwritten by back-reference + roles = [] # Shuts up pylint, overwritten by back-reference def dummy_attribute_defaults(self): if self.ldap_getattr('sn') == []: @@ -107,3 +108,5 @@ class Group(LDAPModel): name = LDAPAttribute('cn') description = LDAPAttribute('description', default='') members = LDAPRelation('uniqueMember', User, backref='groups') + + roles = [] # Shuts up pylint, overwritten by back-reference diff --git a/uffd/user/templates/user.html b/uffd/user/templates/user.html index 4fdcbc58df6bf3e26da09162135a0f265afc7a34..ac7f3bf72df756a3d4c1b658e6373a8b4afc2b57 100644 --- a/uffd/user/templates/user.html +++ b/uffd/user/templates/user.html @@ -84,7 +84,7 @@ <td> <div class="form-check"> <input class="form-check-input" type="checkbox" id="role-{{ role.id }}-checkbox" name="role-{{ role.id }}" value="1" aria-label="enabled" - {% if user.dn in role.member_dns() or role.name in config["ROLES_BASEROLES"] %}checked {% endif %} + {% if user in role.members or role.name in config["ROLES_BASEROLES"] %}checked {% endif %} {% if role.name in config["ROLES_BASEROLES"] %}disabled {% endif %}> </div> </td> diff --git a/uffd/user/views_user.py b/uffd/user/views_user.py index a293d554799b734818a6445f5a5eefa2fcfa25fb..0d43feabde53cab189888cafc4fb656b54a02d52 100644 --- a/uffd/user/views_user.py +++ b/uffd/user/views_user.py @@ -8,7 +8,6 @@ from uffd.csrf import csrf_protect from uffd.selfservice import send_passwordreset from uffd.session import login_required, is_valid_session, get_current_user from uffd.role.models import Role -from uffd.role.utils import recalculate_user_groups from uffd.database import db from uffd.ldap import ldap, LDAPCommitError @@ -57,15 +56,11 @@ def update(uid=None): new_password = request.form.get('password') if uid is not None and new_password: user.set_password(new_password) + user.roles.clear() for role in Role.query.all(): - role_member_dns = role.member_dns() if request.values.get('role-{}'.format(role.id), False) or role.name in current_app.config["ROLES_BASEROLES"]: - if user.dn in role_member_dns: - continue - role.add_member(user) - elif user.dn in role_member_dns: - role.del_member(user) - recalculate_user_groups(user) + user.roles.add(role) + user.update_groups() ldap.session.add(user) ldap.session.commit() db.session.commit() @@ -80,9 +75,7 @@ def update(uid=None): @csrf_protect(blueprint=bp) def delete(uid): user = User.ldap_filter_by(uid=uid)[0] - for role in Role.get_for_user(user).all(): - if user.dn in role.member_dns(): - role.del_member(user) + user.roles.clear() ldap.session.delete(user) ldap.session.commit() db.session.commit() @@ -113,12 +106,9 @@ def csvimport(): flash("invalid mail address, skipped : {}".format(row)) continue for role in roles: - role_member_dns = role.member_dns() if (str(role.id) in row[2].split(';')) or role.name in current_app.config["ROLES_BASEROLES"]: - if newuser.dn in role_member_dns: - continue - role.add_member(newuser) - recalculate_user_groups(newuser) + role.members.add(newuser) + newuser.update_groups() ldap.session.add(newuser) try: ldap.session.commit()