Skip to content
Snippets Groups Projects
Commit 46076016 authored by Julian's avatar Julian
Browse files

Refactored relations

parent 97804bb9
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,7 @@ class LDAPSession:
def __init__(self):
self.__objects = {} # dn -> instance
self.__to_delete = []
self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...}
def lookup(self, dn):
return self.__objects.get(dn)
......@@ -50,6 +51,20 @@ class LDAPSession:
self.__objects[obj.dn] = obj
return obj
def lookup_relations(self, srccls, srcattr, dn):
key = (srccls, srcattr, dn)
return self.__relations.get(key, set())
def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None):
for dn in (delete_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].discard(srcobj)
for dn in (add_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].add(srcobj)
def add(self, obj):
if obj.ldap_created:
raise Exception()
......@@ -147,64 +162,46 @@ class LDAPAttribute:
values = [values]
obj.ldap_setattr(self.name, [self.encode(value) for value in values])
class LDAPRelationBackref:
def __init__(self, src, srcattr):
self.src = src
class LDAPBackref:
def __init__(self, srccls, srcattr):
self.srccls = srccls
self.srcattr = srcattr
self.key = (self.src, self.srcattr)
def fetch(self, obj):
if self.key not in obj._relation_data:
kwargs = {getattr(self.src, self.srcattr).name: obj.dn}
values = self.src.ldap_filter_by(**kwargs)
obj._relation_data[self.key] = set(values)
return obj._relation_data[self.key]
def add(self, obj, value):
self.fetch(obj)
obj._relation_data[self.key].add(value)
if srccls.ldap_relations is None:
srccls.ldap_relations = set()
srccls.ldap_relations.add(srcattr)
def discard(self, obj, value):
self.fetch(obj)
obj._relation_data[self.key].discard(value)
def init(self, obj):
if self.srcattr in obj.ldap_relation_data:
return
# The query instanciates all related objects that in turn add their relations to session
self.srccls.ldap_filter_by(**{self.srcattr: obj.dn})
obj.ldap_relation_data.add(self.srcattr)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: self.fetch(obj),
additem=lambda value: getattr(value, self.srcattr).add(obj),
delitem=lambda value: getattr(value, self.srcattr).discard(obj))
self.init(obj)
return LDAPSet(getitems=lambda: ldap.session.lookup_relations(self.srccls, self.srcattr, obj.dn),
additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn),
delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn))
def __set__(self, obj, values):
current = self.__get__(obj)
current.clear()
for value in values:
current.add(value)
class LDAPRelation:
class LDAPRelation(LDAPAttribute):
def __init__(self, name, dest, backref=None):
super().__init__(name, multi=True, encode=lambda value: value.dn,
decode=lambda value: dest.ldap_get(value))
self.name = name
self.dest = dest
self.backref = backref
def __set_name__(self, cls, name):
if self.backref is not None:
setattr(self.dest, self.backref, LDAPRelationBackref(cls, name))
def add(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attradd(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).add(destobj, obj)
def discard(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attrdel(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).discard(destobj, obj)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
additem=lambda value: self.add(obj, value),
delitem=lambda value: self.discard(obj, value),
encode=lambda value: value.dn,
decode=lambda value: self.dest.ldap_get(value))
setattr(self.dest, self.backref, LDAPBackref(cls, self.name))
class LDAPModel:
ldap_dn_attribute = None
......@@ -213,9 +210,10 @@ class LDAPModel:
ldap_object_classes = None
ldap_filter = None
ldap_defaults = None # Populated by LDAPAttribute
ldap_relations = None # Populated by LDAPBackref
def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs):
self._relation_data = {}
self.ldap_relation_data = set()
self.__ldap_dn = _ldap_dn
self.__ldap_attributes = {}
for key, values in (_ldap_attributes or {}).items():
......@@ -229,22 +227,32 @@ class LDAPModel:
if not hasattr(self, key):
raise Exception()
setattr(self, key, value)
for name in (self.ldap_relations or []):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
def __update_relations(self, name, delete_dns=None, add_dns=None):
if name in (self.ldap_relations or []):
ldap.session.update_relations(self, name, delete_dns, add_dns)
def ldap_getattr(self, name):
return self.__attributes.get(name, [])
def ldap_setattr(self, name, values):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes[name] = [(MODIFY_REPLACE, values)]
self.__attributes[name] = values
self.__update_relations(name, add_dns=values)
def ldap_attradd(self, name, item):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [item])]
self.__attributes[name].append(item)
def ldap_attradd(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])]
self.__attributes[name].append(value)
self.__update_relations(name, add_dns=[value])
def ldap_attrdel(self, name, item):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [item])]
if item in self.__attributes.get(name, []):
self.__attributes[name].remove(item)
def ldap_attrdel(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])]
if value in self.__attributes.get(name, []):
self.__attributes[name].remove(value)
self.__update_relations(name, delete_dns=[value])
def __repr__(self):
name = '%s.%s'%(type(self).__module__, type(self).__name__)
......@@ -302,8 +310,12 @@ class LDAPModel:
return res
def ldap_reset(self):
for name in (self.ldap_relations or []):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes = {}
self.__attributes = deepcopy(self.__ldap_attributes)
for name in (self.ldap_relations or {}):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
@property
def ldap_dirty(self):
......@@ -374,6 +386,7 @@ class Group(LDAPModel):
gid = LDAPAttribute('gidNumber')
name = LDAPAttribute('cn')
description = LDAPAttribute('description', default='')
member_dns= LDAPAttribute('uniqueMember', multi=True)
members = LDAPRelation('uniqueMember', User, backref='groups')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment