From 4607601620ebfd6c0fb6f77deecea5d32bc7e460 Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Fri, 19 Feb 2021 04:09:49 +0100 Subject: [PATCH] Refactored relations --- uffd/ldaporm.py | 117 +++++++++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 52 deletions(-) diff --git a/uffd/ldaporm.py b/uffd/ldaporm.py index f9805fbd..448b561c 100644 --- a/uffd/ldaporm.py +++ b/uffd/ldaporm.py @@ -40,6 +40,7 @@ class LDAPSession: def __init__(self): self.__objects = {} # dn -> instance self.__to_delete = [] + self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...} def lookup(self, dn): return self.__objects.get(dn) @@ -50,6 +51,20 @@ class LDAPSession: self.__objects[obj.dn] = obj return obj + def lookup_relations(self, srccls, srcattr, dn): + key = (srccls, srcattr, dn) + return self.__relations.get(key, set()) + + def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None): + for dn in (delete_dns or []): + key = (type(srcobj), srcattr, dn) + self.__relations[key] = self.__relations.get(key, set()) + self.__relations[key].discard(srcobj) + for dn in (add_dns or []): + key = (type(srcobj), srcattr, dn) + self.__relations[key] = self.__relations.get(key, set()) + self.__relations[key].add(srcobj) + def add(self, obj): if obj.ldap_created: raise Exception() @@ -147,64 +162,46 @@ class LDAPAttribute: values = [values] obj.ldap_setattr(self.name, [self.encode(value) for value in values]) -class LDAPRelationBackref: - def __init__(self, src, srcattr): - self.src = src +class LDAPBackref: + def __init__(self, srccls, srcattr): + self.srccls = srccls self.srcattr = srcattr - self.key = (self.src, self.srcattr) - - def fetch(self, obj): - if self.key not in obj._relation_data: - kwargs = {getattr(self.src, self.srcattr).name: obj.dn} - values = self.src.ldap_filter_by(**kwargs) - obj._relation_data[self.key] = set(values) - return obj._relation_data[self.key] - - def add(self, obj, value): - self.fetch(obj) - obj._relation_data[self.key].add(value) + if srccls.ldap_relations is None: + srccls.ldap_relations = set() + srccls.ldap_relations.add(srcattr) - def discard(self, obj, value): - self.fetch(obj) - obj._relation_data[self.key].discard(value) + def init(self, obj): + if self.srcattr in obj.ldap_relation_data: + return + # The query instanciates all related objects that in turn add their relations to session + self.srccls.ldap_filter_by(**{self.srcattr: obj.dn}) + obj.ldap_relation_data.add(self.srcattr) def __get__(self, obj, objtype=None): if obj is None: return self - return LDAPSet(getitems=lambda: self.fetch(obj), - additem=lambda value: getattr(value, self.srcattr).add(obj), - delitem=lambda value: getattr(value, self.srcattr).discard(obj)) + self.init(obj) + return LDAPSet(getitems=lambda: ldap.session.lookup_relations(self.srccls, self.srcattr, obj.dn), + additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn), + delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn)) + + def __set__(self, obj, values): + current = self.__get__(obj) + current.clear() + for value in values: + current.add(value) -class LDAPRelation: +class LDAPRelation(LDAPAttribute): def __init__(self, name, dest, backref=None): + super().__init__(name, multi=True, encode=lambda value: value.dn, + decode=lambda value: dest.ldap_get(value)) self.name = name self.dest = dest self.backref = backref def __set_name__(self, cls, name): if self.backref is not None: - setattr(self.dest, self.backref, LDAPRelationBackref(cls, name)) - - def add(self, obj, destdn): - destobj = self.dest.ldap_get(destdn) # Always from cache! - obj.ldap_attradd(self.name, destdn) - if self.backref is not None: - getattr(self.dest, self.backref).add(destobj, obj) - - def discard(self, obj, destdn): - destobj = self.dest.ldap_get(destdn) # Always from cache! - obj.ldap_attrdel(self.name, destdn) - if self.backref is not None: - getattr(self.dest, self.backref).discard(destobj, obj) - - def __get__(self, obj, objtype=None): - if obj is None: - return self - return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name), - additem=lambda value: self.add(obj, value), - delitem=lambda value: self.discard(obj, value), - encode=lambda value: value.dn, - decode=lambda value: self.dest.ldap_get(value)) + setattr(self.dest, self.backref, LDAPBackref(cls, self.name)) class LDAPModel: ldap_dn_attribute = None @@ -213,9 +210,10 @@ class LDAPModel: ldap_object_classes = None ldap_filter = None ldap_defaults = None # Populated by LDAPAttribute + ldap_relations = None # Populated by LDAPBackref def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): - self._relation_data = {} + self.ldap_relation_data = set() self.__ldap_dn = _ldap_dn self.__ldap_attributes = {} for key, values in (_ldap_attributes or {}).items(): @@ -229,22 +227,32 @@ class LDAPModel: if not hasattr(self, key): raise Exception() setattr(self, key, value) + for name in (self.ldap_relations or []): + self.__update_relations(name, add_dns=self.__attributes.get(name, [])) + + def __update_relations(self, name, delete_dns=None, add_dns=None): + if name in (self.ldap_relations or []): + ldap.session.update_relations(self, name, delete_dns, add_dns) def ldap_getattr(self, name): return self.__attributes.get(name, []) def ldap_setattr(self, name, values): + self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) self.__changes[name] = [(MODIFY_REPLACE, values)] self.__attributes[name] = values + self.__update_relations(name, add_dns=values) - def ldap_attradd(self, name, item): - self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [item])] - self.__attributes[name].append(item) + def ldap_attradd(self, name, value): + self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])] + self.__attributes[name].append(value) + self.__update_relations(name, add_dns=[value]) - def ldap_attrdel(self, name, item): - self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [item])] - if item in self.__attributes.get(name, []): - self.__attributes[name].remove(item) + def ldap_attrdel(self, name, value): + self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])] + if value in self.__attributes.get(name, []): + self.__attributes[name].remove(value) + self.__update_relations(name, delete_dns=[value]) def __repr__(self): name = '%s.%s'%(type(self).__module__, type(self).__name__) @@ -302,8 +310,12 @@ class LDAPModel: return res def ldap_reset(self): + for name in (self.ldap_relations or []): + self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) self.__changes = {} self.__attributes = deepcopy(self.__ldap_attributes) + for name in (self.ldap_relations or {}): + self.__update_relations(name, add_dns=self.__attributes.get(name, [])) @property def ldap_dirty(self): @@ -374,6 +386,7 @@ class Group(LDAPModel): gid = LDAPAttribute('gidNumber') name = LDAPAttribute('cn') description = LDAPAttribute('description', default='') + member_dns= LDAPAttribute('uniqueMember', multi=True) members = LDAPRelation('uniqueMember', User, backref='groups') -- GitLab