Skip to content
Snippets Groups Projects
Commit 46076016 authored by Julian's avatar Julian
Browse files

Refactored relations

parent 97804bb9
No related branches found
No related tags found
No related merge requests found
...@@ -40,6 +40,7 @@ class LDAPSession: ...@@ -40,6 +40,7 @@ class LDAPSession:
def __init__(self): def __init__(self):
self.__objects = {} # dn -> instance self.__objects = {} # dn -> instance
self.__to_delete = [] self.__to_delete = []
self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...}
def lookup(self, dn): def lookup(self, dn):
return self.__objects.get(dn) return self.__objects.get(dn)
...@@ -50,6 +51,20 @@ class LDAPSession: ...@@ -50,6 +51,20 @@ class LDAPSession:
self.__objects[obj.dn] = obj self.__objects[obj.dn] = obj
return obj return obj
def lookup_relations(self, srccls, srcattr, dn):
key = (srccls, srcattr, dn)
return self.__relations.get(key, set())
def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None):
for dn in (delete_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].discard(srcobj)
for dn in (add_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].add(srcobj)
def add(self, obj): def add(self, obj):
if obj.ldap_created: if obj.ldap_created:
raise Exception() raise Exception()
...@@ -147,64 +162,46 @@ class LDAPAttribute: ...@@ -147,64 +162,46 @@ class LDAPAttribute:
values = [values] values = [values]
obj.ldap_setattr(self.name, [self.encode(value) for value in values]) obj.ldap_setattr(self.name, [self.encode(value) for value in values])
class LDAPRelationBackref: class LDAPBackref:
def __init__(self, src, srcattr): def __init__(self, srccls, srcattr):
self.src = src self.srccls = srccls
self.srcattr = srcattr self.srcattr = srcattr
self.key = (self.src, self.srcattr) if srccls.ldap_relations is None:
srccls.ldap_relations = set()
def fetch(self, obj): srccls.ldap_relations.add(srcattr)
if self.key not in obj._relation_data:
kwargs = {getattr(self.src, self.srcattr).name: obj.dn}
values = self.src.ldap_filter_by(**kwargs)
obj._relation_data[self.key] = set(values)
return obj._relation_data[self.key]
def add(self, obj, value):
self.fetch(obj)
obj._relation_data[self.key].add(value)
def discard(self, obj, value): def init(self, obj):
self.fetch(obj) if self.srcattr in obj.ldap_relation_data:
obj._relation_data[self.key].discard(value) return
# The query instanciates all related objects that in turn add their relations to session
self.srccls.ldap_filter_by(**{self.srcattr: obj.dn})
obj.ldap_relation_data.add(self.srcattr)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
if obj is None: if obj is None:
return self return self
return LDAPSet(getitems=lambda: self.fetch(obj), self.init(obj)
additem=lambda value: getattr(value, self.srcattr).add(obj), return LDAPSet(getitems=lambda: ldap.session.lookup_relations(self.srccls, self.srcattr, obj.dn),
delitem=lambda value: getattr(value, self.srcattr).discard(obj)) additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn),
delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn))
def __set__(self, obj, values):
current = self.__get__(obj)
current.clear()
for value in values:
current.add(value)
class LDAPRelation: class LDAPRelation(LDAPAttribute):
def __init__(self, name, dest, backref=None): def __init__(self, name, dest, backref=None):
super().__init__(name, multi=True, encode=lambda value: value.dn,
decode=lambda value: dest.ldap_get(value))
self.name = name self.name = name
self.dest = dest self.dest = dest
self.backref = backref self.backref = backref
def __set_name__(self, cls, name): def __set_name__(self, cls, name):
if self.backref is not None: if self.backref is not None:
setattr(self.dest, self.backref, LDAPRelationBackref(cls, name)) setattr(self.dest, self.backref, LDAPBackref(cls, self.name))
def add(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attradd(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).add(destobj, obj)
def discard(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attrdel(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).discard(destobj, obj)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
additem=lambda value: self.add(obj, value),
delitem=lambda value: self.discard(obj, value),
encode=lambda value: value.dn,
decode=lambda value: self.dest.ldap_get(value))
class LDAPModel: class LDAPModel:
ldap_dn_attribute = None ldap_dn_attribute = None
...@@ -213,9 +210,10 @@ class LDAPModel: ...@@ -213,9 +210,10 @@ class LDAPModel:
ldap_object_classes = None ldap_object_classes = None
ldap_filter = None ldap_filter = None
ldap_defaults = None # Populated by LDAPAttribute ldap_defaults = None # Populated by LDAPAttribute
ldap_relations = None # Populated by LDAPBackref
def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs):
self._relation_data = {} self.ldap_relation_data = set()
self.__ldap_dn = _ldap_dn self.__ldap_dn = _ldap_dn
self.__ldap_attributes = {} self.__ldap_attributes = {}
for key, values in (_ldap_attributes or {}).items(): for key, values in (_ldap_attributes or {}).items():
...@@ -229,22 +227,32 @@ class LDAPModel: ...@@ -229,22 +227,32 @@ class LDAPModel:
if not hasattr(self, key): if not hasattr(self, key):
raise Exception() raise Exception()
setattr(self, key, value) setattr(self, key, value)
for name in (self.ldap_relations or []):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
def __update_relations(self, name, delete_dns=None, add_dns=None):
if name in (self.ldap_relations or []):
ldap.session.update_relations(self, name, delete_dns, add_dns)
def ldap_getattr(self, name): def ldap_getattr(self, name):
return self.__attributes.get(name, []) return self.__attributes.get(name, [])
def ldap_setattr(self, name, values): def ldap_setattr(self, name, values):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes[name] = [(MODIFY_REPLACE, values)] self.__changes[name] = [(MODIFY_REPLACE, values)]
self.__attributes[name] = values self.__attributes[name] = values
self.__update_relations(name, add_dns=values)
def ldap_attradd(self, name, item): def ldap_attradd(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [item])] self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])]
self.__attributes[name].append(item) self.__attributes[name].append(value)
self.__update_relations(name, add_dns=[value])
def ldap_attrdel(self, name, item): def ldap_attrdel(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [item])] self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])]
if item in self.__attributes.get(name, []): if value in self.__attributes.get(name, []):
self.__attributes[name].remove(item) self.__attributes[name].remove(value)
self.__update_relations(name, delete_dns=[value])
def __repr__(self): def __repr__(self):
name = '%s.%s'%(type(self).__module__, type(self).__name__) name = '%s.%s'%(type(self).__module__, type(self).__name__)
...@@ -302,8 +310,12 @@ class LDAPModel: ...@@ -302,8 +310,12 @@ class LDAPModel:
return res return res
def ldap_reset(self): def ldap_reset(self):
for name in (self.ldap_relations or []):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes = {} self.__changes = {}
self.__attributes = deepcopy(self.__ldap_attributes) self.__attributes = deepcopy(self.__ldap_attributes)
for name in (self.ldap_relations or {}):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
@property @property
def ldap_dirty(self): def ldap_dirty(self):
...@@ -374,6 +386,7 @@ class Group(LDAPModel): ...@@ -374,6 +386,7 @@ class Group(LDAPModel):
gid = LDAPAttribute('gidNumber') gid = LDAPAttribute('gidNumber')
name = LDAPAttribute('cn') name = LDAPAttribute('cn')
description = LDAPAttribute('description', default='') description = LDAPAttribute('description', default='')
member_dns= LDAPAttribute('uniqueMember', multi=True)
members = LDAPRelation('uniqueMember', User, backref='groups') members = LDAPRelation('uniqueMember', User, backref='groups')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment