From 97804bb980b48ecbc8c26082f8d0a4d159903fb7 Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Fri, 19 Feb 2021 01:41:15 +0100 Subject: [PATCH] Implemented (incomplete) relation mechanism --- ldap_server_entries.json | 1 - uffd/ldaporm.py | 132 +++++++++++++++++++++++++++------------ 2 files changed, 93 insertions(+), 40 deletions(-) diff --git a/ldap_server_entries.json b/ldap_server_entries.json index a6b5d8c5..4c76279f 100644 --- a/ldap_server_entries.json +++ b/ldap_server_entries.json @@ -242,7 +242,6 @@ "cn=Subschema" ], "uniqueMember": [ - "cn=dummy,ou=system,dc=example,dc=com", "uid=testuser,ou=users,dc=example,dc=com", "uid=testadmin,ou=users,dc=example,dc=com" ] diff --git a/uffd/ldaporm.py b/uffd/ldaporm.py index 01e72929..f9805fbd 100644 --- a/uffd/ldaporm.py +++ b/uffd/ldaporm.py @@ -1,4 +1,5 @@ from copy import deepcopy +from collections.abc import MutableSet from flask import current_app, request @@ -84,44 +85,32 @@ class FlaskLDAPMapper: ldap = FlaskLDAPMapper() -class LDAPList: - def __init__(self, obj, name, encode=None, decode=None): - self.__obj = obj - self.__name = name +class LDAPSet(MutableSet): + def __init__(self, getitems, additem, delitem, encode=None, decode=None): + self.__getitems = getitems + self.__additem = additem + self.__delitem = delitem self.__encode = encode or (lambda x: x) self.__decode = decode or (lambda x: x) - @property - def __list(self): - return self.__obj.ldap_getattr(self.__name) + def __repr__(self): + return repr(set(self)) def __contains__(self, value): - return self.__encode(value) in self.__list - - def __iadd__(self, values): - self.extend(values) - + return self.__encode(value) in self.__getitems() + def __iter__(self): - return iter(map(self.__decode, self.__list)) + return iter(map(self.__decode, self.__getitems())) def __len__(self): - return len(self.__list) - - def append(self, value): - self.__obj.ldap_attradd(self.__name, self.__encode(value)) - - def clear(self): - self.__obj.ldap_setattr(self.__name, []) - - def count(self, value): - return __list.count(self.__encode(value)) + return len(self.__getitems()) - def extend(self, values): - for value in values: - self.__obj.ldap_attradd(self.__name, self.__encode(value)) + def add(self, value): + if value not in self: + self.__additem(self.__encode(value)) - def remove(self, value): - self.__obj.ldap_attrdel(self.__name, self.__encode(value)) + def discard(self, value): + self.__delitem(self.__encode(value)) class LDAPAttribute: def __init__(self, name, multi=False, default=None, encode=None, decode=None): @@ -136,27 +125,86 @@ class LDAPAttribute: return [self.encode(value) for value in values] self.default = default_wrapper - def __set_name__(self, obj, name): + def __set_name__(self, cls, name): if self.default is None: return - if not obj.ldap_defaults: - obj.ldap_defaults = {} - obj.ldap_defaults[self.name] = self.default + if not cls.ldap_defaults: + cls.ldap_defaults = {} + cls.ldap_defaults[self.name] = self.default def __get__(self, obj, objtype=None): if obj is None: return self - values = obj.ldap_getattr(self.name) if self.multi: - return LDAPList(obj, self.name, encode=self.encode, decode=self.decode) - return self.decode((values or [None])[0]) + return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name), + additem=lambda value: obj.ldap_attradd(self.name, value), + delitem=lambda value: obj.ldap_attrdel(self.name, value), + encode=self.encode, decode=self.decode) + return self.decode((obj.ldap_getattr(self.name) or [None])[0]) def __set__(self, obj, values): if not self.multi: values = [values] obj.ldap_setattr(self.name, [self.encode(value) for value in values]) +class LDAPRelationBackref: + def __init__(self, src, srcattr): + self.src = src + self.srcattr = srcattr + self.key = (self.src, self.srcattr) + + def fetch(self, obj): + if self.key not in obj._relation_data: + kwargs = {getattr(self.src, self.srcattr).name: obj.dn} + values = self.src.ldap_filter_by(**kwargs) + obj._relation_data[self.key] = set(values) + return obj._relation_data[self.key] + + def add(self, obj, value): + self.fetch(obj) + obj._relation_data[self.key].add(value) + + def discard(self, obj, value): + self.fetch(obj) + obj._relation_data[self.key].discard(value) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return LDAPSet(getitems=lambda: self.fetch(obj), + additem=lambda value: getattr(value, self.srcattr).add(obj), + delitem=lambda value: getattr(value, self.srcattr).discard(obj)) + class LDAPRelation: + def __init__(self, name, dest, backref=None): + self.name = name + self.dest = dest + self.backref = backref + + def __set_name__(self, cls, name): + if self.backref is not None: + setattr(self.dest, self.backref, LDAPRelationBackref(cls, name)) + + def add(self, obj, destdn): + destobj = self.dest.ldap_get(destdn) # Always from cache! + obj.ldap_attradd(self.name, destdn) + if self.backref is not None: + getattr(self.dest, self.backref).add(destobj, obj) + + def discard(self, obj, destdn): + destobj = self.dest.ldap_get(destdn) # Always from cache! + obj.ldap_attrdel(self.name, destdn) + if self.backref is not None: + getattr(self.dest, self.backref).discard(destobj, obj) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name), + additem=lambda value: self.add(obj, value), + delitem=lambda value: self.discard(obj, value), + encode=lambda value: value.dn, + decode=lambda value: self.dest.ldap_get(value)) class LDAPModel: ldap_dn_attribute = None @@ -167,6 +215,7 @@ class LDAPModel: ldap_defaults = None # Populated by LDAPAttribute def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): + self._relation_data = {} self.__ldap_dn = _ldap_dn self.__ldap_attributes = {} for key, values in (_ldap_attributes or {}).items(): @@ -197,6 +246,12 @@ class LDAPModel: if item in self.__attributes.get(name, []): self.__attributes[name].remove(item) + def __repr__(self): + name = '%s.%s'%(type(self).__module__, type(self).__name__) + if self.__ldap_dn is None: + return '<%s>'%name + return '<%s %s>'%(name, self.__ldap_dn) + def build_dn(self): return '%s=%s,%s'%(self.ldap_dn_attribute, self.__attributes[self.ldap_dn_attribute][0], self.ldap_dn_base) @@ -237,7 +292,7 @@ class LDAPModel: for key, value in kwargs.items(): filters.append('(%s=%s)'%(key, escape_filter_chars(value))) conn = get_conn() - conn.search(cls.ldap_dn_base, '(&%s)'%(''.join(filters))) + conn.search(cls.ldap_base, '(&%s)'%(''.join(filters))) res = [] for entry in conn.response: obj = ldap.session.lookup(entry['dn']) @@ -307,22 +362,21 @@ class User(LDAPModel): displayname = LDAPAttribute('cn') mail = LDAPAttribute('mail') pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) - groups = LDAPAttribute('memberOf', multi=True, default=[], encode=lambda obj: obj.dn, decode=lambda dn: Group.ldap_get(dn)) def password(self, value): self.pwhash = hashed(HASHED_SALTED_SHA512, value) password = property(fset=password) - class Group(LDAPModel): ldap_base = 'ou=groups,dc=example,dc=com' ldap_filter = '(objectClass=groupOfUniqueNames)' gid = LDAPAttribute('gidNumber') name = LDAPAttribute('cn') - members = LDAPAttribute('uniqueMember', multi=True, default=[], encode=lambda obj: obj.dn, decode=lambda dn: User.ldap_get(dn)) description = LDAPAttribute('description', default='') + members = LDAPRelation('uniqueMember', User, backref='groups') + class Mail(LDAPModel): ldap_base = 'ou=postfix,dc=example,dc=com' ldap_dn_attribute = 'uid' -- GitLab