Skip to content
Snippets Groups Projects
Commit 46076016 authored by Julian's avatar Julian
Browse files

Refactored relations

parent 97804bb9
Branches
Tags
No related merge requests found
...@@ -40,6 +40,7 @@ class LDAPSession: ...@@ -40,6 +40,7 @@ class LDAPSession:
def __init__(self): def __init__(self):
self.__objects = {} # dn -> instance self.__objects = {} # dn -> instance
self.__to_delete = [] self.__to_delete = []
self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...}
def lookup(self, dn): def lookup(self, dn):
return self.__objects.get(dn) return self.__objects.get(dn)
...@@ -50,6 +51,20 @@ class LDAPSession: ...@@ -50,6 +51,20 @@ class LDAPSession:
self.__objects[obj.dn] = obj self.__objects[obj.dn] = obj
return obj return obj
def lookup_relations(self, srccls, srcattr, dn):
key = (srccls, srcattr, dn)
return self.__relations.get(key, set())
def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None):
for dn in (delete_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].discard(srcobj)
for dn in (add_dns or []):
key = (type(srcobj), srcattr, dn)
self.__relations[key] = self.__relations.get(key, set())
self.__relations[key].add(srcobj)
def add(self, obj): def add(self, obj):
if obj.ldap_created: if obj.ldap_created:
raise Exception() raise Exception()
...@@ -147,64 +162,46 @@ class LDAPAttribute: ...@@ -147,64 +162,46 @@ class LDAPAttribute:
values = [values] values = [values]
obj.ldap_setattr(self.name, [self.encode(value) for value in values]) obj.ldap_setattr(self.name, [self.encode(value) for value in values])
class LDAPRelationBackref: class LDAPBackref:
def __init__(self, src, srcattr): def __init__(self, srccls, srcattr):
self.src = src self.srccls = srccls
self.srcattr = srcattr self.srcattr = srcattr
self.key = (self.src, self.srcattr) if srccls.ldap_relations is None:
srccls.ldap_relations = set()
def fetch(self, obj): srccls.ldap_relations.add(srcattr)
if self.key not in obj._relation_data:
kwargs = {getattr(self.src, self.srcattr).name: obj.dn}
values = self.src.ldap_filter_by(**kwargs)
obj._relation_data[self.key] = set(values)
return obj._relation_data[self.key]
def add(self, obj, value): def init(self, obj):
self.fetch(obj) if self.srcattr in obj.ldap_relation_data:
obj._relation_data[self.key].add(value) return
# The query instanciates all related objects that in turn add their relations to session
def discard(self, obj, value): self.srccls.ldap_filter_by(**{self.srcattr: obj.dn})
self.fetch(obj) obj.ldap_relation_data.add(self.srcattr)
obj._relation_data[self.key].discard(value)
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
if obj is None: if obj is None:
return self return self
return LDAPSet(getitems=lambda: self.fetch(obj), self.init(obj)
additem=lambda value: getattr(value, self.srcattr).add(obj), return LDAPSet(getitems=lambda: ldap.session.lookup_relations(self.srccls, self.srcattr, obj.dn),
delitem=lambda value: getattr(value, self.srcattr).discard(obj)) additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn),
delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn))
class LDAPRelation: def __set__(self, obj, values):
current = self.__get__(obj)
current.clear()
for value in values:
current.add(value)
class LDAPRelation(LDAPAttribute):
def __init__(self, name, dest, backref=None): def __init__(self, name, dest, backref=None):
super().__init__(name, multi=True, encode=lambda value: value.dn,
decode=lambda value: dest.ldap_get(value))
self.name = name self.name = name
self.dest = dest self.dest = dest
self.backref = backref self.backref = backref
def __set_name__(self, cls, name): def __set_name__(self, cls, name):
if self.backref is not None: if self.backref is not None:
setattr(self.dest, self.backref, LDAPRelationBackref(cls, name)) setattr(self.dest, self.backref, LDAPBackref(cls, self.name))
def add(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attradd(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).add(destobj, obj)
def discard(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attrdel(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).discard(destobj, obj)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
additem=lambda value: self.add(obj, value),
delitem=lambda value: self.discard(obj, value),
encode=lambda value: value.dn,
decode=lambda value: self.dest.ldap_get(value))
class LDAPModel: class LDAPModel:
ldap_dn_attribute = None ldap_dn_attribute = None
...@@ -213,9 +210,10 @@ class LDAPModel: ...@@ -213,9 +210,10 @@ class LDAPModel:
ldap_object_classes = None ldap_object_classes = None
ldap_filter = None ldap_filter = None
ldap_defaults = None # Populated by LDAPAttribute ldap_defaults = None # Populated by LDAPAttribute
ldap_relations = None # Populated by LDAPBackref
def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs):
self._relation_data = {} self.ldap_relation_data = set()
self.__ldap_dn = _ldap_dn self.__ldap_dn = _ldap_dn
self.__ldap_attributes = {} self.__ldap_attributes = {}
for key, values in (_ldap_attributes or {}).items(): for key, values in (_ldap_attributes or {}).items():
...@@ -229,22 +227,32 @@ class LDAPModel: ...@@ -229,22 +227,32 @@ class LDAPModel:
if not hasattr(self, key): if not hasattr(self, key):
raise Exception() raise Exception()
setattr(self, key, value) setattr(self, key, value)
for name in (self.ldap_relations or []):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
def __update_relations(self, name, delete_dns=None, add_dns=None):
if name in (self.ldap_relations or []):
ldap.session.update_relations(self, name, delete_dns, add_dns)
def ldap_getattr(self, name): def ldap_getattr(self, name):
return self.__attributes.get(name, []) return self.__attributes.get(name, [])
def ldap_setattr(self, name, values): def ldap_setattr(self, name, values):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes[name] = [(MODIFY_REPLACE, values)] self.__changes[name] = [(MODIFY_REPLACE, values)]
self.__attributes[name] = values self.__attributes[name] = values
self.__update_relations(name, add_dns=values)
def ldap_attradd(self, name, item): def ldap_attradd(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [item])] self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])]
self.__attributes[name].append(item) self.__attributes[name].append(value)
self.__update_relations(name, add_dns=[value])
def ldap_attrdel(self, name, item): def ldap_attrdel(self, name, value):
self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [item])] self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])]
if item in self.__attributes.get(name, []): if value in self.__attributes.get(name, []):
self.__attributes[name].remove(item) self.__attributes[name].remove(value)
self.__update_relations(name, delete_dns=[value])
def __repr__(self): def __repr__(self):
name = '%s.%s'%(type(self).__module__, type(self).__name__) name = '%s.%s'%(type(self).__module__, type(self).__name__)
...@@ -302,8 +310,12 @@ class LDAPModel: ...@@ -302,8 +310,12 @@ class LDAPModel:
return res return res
def ldap_reset(self): def ldap_reset(self):
for name in (self.ldap_relations or []):
self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
self.__changes = {} self.__changes = {}
self.__attributes = deepcopy(self.__ldap_attributes) self.__attributes = deepcopy(self.__ldap_attributes)
for name in (self.ldap_relations or {}):
self.__update_relations(name, add_dns=self.__attributes.get(name, []))
@property @property
def ldap_dirty(self): def ldap_dirty(self):
...@@ -374,6 +386,7 @@ class Group(LDAPModel): ...@@ -374,6 +386,7 @@ class Group(LDAPModel):
gid = LDAPAttribute('gidNumber') gid = LDAPAttribute('gidNumber')
name = LDAPAttribute('cn') name = LDAPAttribute('cn')
description = LDAPAttribute('description', default='') description = LDAPAttribute('description', default='')
member_dns= LDAPAttribute('uniqueMember', multi=True)
members = LDAPRelation('uniqueMember', User, backref='groups') members = LDAPRelation('uniqueMember', User, backref='groups')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment