Skip to content
Snippets Groups Projects
Commit 97804bb9 authored by Julian's avatar Julian
Browse files

Implemented (incomplete) relation mechanism

parent c82ed372
Branches
Tags
No related merge requests found
...@@ -242,7 +242,6 @@ ...@@ -242,7 +242,6 @@
"cn=Subschema" "cn=Subschema"
], ],
"uniqueMember": [ "uniqueMember": [
"cn=dummy,ou=system,dc=example,dc=com",
"uid=testuser,ou=users,dc=example,dc=com", "uid=testuser,ou=users,dc=example,dc=com",
"uid=testadmin,ou=users,dc=example,dc=com" "uid=testadmin,ou=users,dc=example,dc=com"
] ]
......
from copy import deepcopy from copy import deepcopy
from collections.abc import MutableSet
from flask import current_app, request from flask import current_app, request
...@@ -84,44 +85,32 @@ class FlaskLDAPMapper: ...@@ -84,44 +85,32 @@ class FlaskLDAPMapper:
ldap = FlaskLDAPMapper() ldap = FlaskLDAPMapper()
class LDAPList: class LDAPSet(MutableSet):
def __init__(self, obj, name, encode=None, decode=None): def __init__(self, getitems, additem, delitem, encode=None, decode=None):
self.__obj = obj self.__getitems = getitems
self.__name = name self.__additem = additem
self.__delitem = delitem
self.__encode = encode or (lambda x: x) self.__encode = encode or (lambda x: x)
self.__decode = decode or (lambda x: x) self.__decode = decode or (lambda x: x)
@property def __repr__(self):
def __list(self): return repr(set(self))
return self.__obj.ldap_getattr(self.__name)
def __contains__(self, value): def __contains__(self, value):
return self.__encode(value) in self.__list return self.__encode(value) in self.__getitems()
def __iadd__(self, values):
self.extend(values)
def __iter__(self): def __iter__(self):
return iter(map(self.__decode, self.__list)) return iter(map(self.__decode, self.__getitems()))
def __len__(self): def __len__(self):
return len(self.__list) return len(self.__getitems())
def append(self, value):
self.__obj.ldap_attradd(self.__name, self.__encode(value))
def clear(self): def add(self, value):
self.__obj.ldap_setattr(self.__name, []) if value not in self:
self.__additem(self.__encode(value))
def count(self, value): def discard(self, value):
return __list.count(self.__encode(value)) self.__delitem(self.__encode(value))
def extend(self, values):
for value in values:
self.__obj.ldap_attradd(self.__name, self.__encode(value))
def remove(self, value):
self.__obj.ldap_attrdel(self.__name, self.__encode(value))
class LDAPAttribute: class LDAPAttribute:
def __init__(self, name, multi=False, default=None, encode=None, decode=None): def __init__(self, name, multi=False, default=None, encode=None, decode=None):
...@@ -136,27 +125,86 @@ class LDAPAttribute: ...@@ -136,27 +125,86 @@ class LDAPAttribute:
return [self.encode(value) for value in values] return [self.encode(value) for value in values]
self.default = default_wrapper self.default = default_wrapper
def __set_name__(self, obj, name): def __set_name__(self, cls, name):
if self.default is None: if self.default is None:
return return
if not obj.ldap_defaults: if not cls.ldap_defaults:
obj.ldap_defaults = {} cls.ldap_defaults = {}
obj.ldap_defaults[self.name] = self.default cls.ldap_defaults[self.name] = self.default
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
if obj is None: if obj is None:
return self return self
values = obj.ldap_getattr(self.name)
if self.multi: if self.multi:
return LDAPList(obj, self.name, encode=self.encode, decode=self.decode) return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
return self.decode((values or [None])[0]) additem=lambda value: obj.ldap_attradd(self.name, value),
delitem=lambda value: obj.ldap_attrdel(self.name, value),
encode=self.encode, decode=self.decode)
return self.decode((obj.ldap_getattr(self.name) or [None])[0])
def __set__(self, obj, values): def __set__(self, obj, values):
if not self.multi: if not self.multi:
values = [values] values = [values]
obj.ldap_setattr(self.name, [self.encode(value) for value in values]) obj.ldap_setattr(self.name, [self.encode(value) for value in values])
class LDAPRelationBackref:
def __init__(self, src, srcattr):
self.src = src
self.srcattr = srcattr
self.key = (self.src, self.srcattr)
def fetch(self, obj):
if self.key not in obj._relation_data:
kwargs = {getattr(self.src, self.srcattr).name: obj.dn}
values = self.src.ldap_filter_by(**kwargs)
obj._relation_data[self.key] = set(values)
return obj._relation_data[self.key]
def add(self, obj, value):
self.fetch(obj)
obj._relation_data[self.key].add(value)
def discard(self, obj, value):
self.fetch(obj)
obj._relation_data[self.key].discard(value)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: self.fetch(obj),
additem=lambda value: getattr(value, self.srcattr).add(obj),
delitem=lambda value: getattr(value, self.srcattr).discard(obj))
class LDAPRelation: class LDAPRelation:
def __init__(self, name, dest, backref=None):
self.name = name
self.dest = dest
self.backref = backref
def __set_name__(self, cls, name):
if self.backref is not None:
setattr(self.dest, self.backref, LDAPRelationBackref(cls, name))
def add(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attradd(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).add(destobj, obj)
def discard(self, obj, destdn):
destobj = self.dest.ldap_get(destdn) # Always from cache!
obj.ldap_attrdel(self.name, destdn)
if self.backref is not None:
getattr(self.dest, self.backref).discard(destobj, obj)
def __get__(self, obj, objtype=None):
if obj is None:
return self
return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
additem=lambda value: self.add(obj, value),
delitem=lambda value: self.discard(obj, value),
encode=lambda value: value.dn,
decode=lambda value: self.dest.ldap_get(value))
class LDAPModel: class LDAPModel:
ldap_dn_attribute = None ldap_dn_attribute = None
...@@ -167,6 +215,7 @@ class LDAPModel: ...@@ -167,6 +215,7 @@ class LDAPModel:
ldap_defaults = None # Populated by LDAPAttribute ldap_defaults = None # Populated by LDAPAttribute
def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs):
self._relation_data = {}
self.__ldap_dn = _ldap_dn self.__ldap_dn = _ldap_dn
self.__ldap_attributes = {} self.__ldap_attributes = {}
for key, values in (_ldap_attributes or {}).items(): for key, values in (_ldap_attributes or {}).items():
...@@ -197,6 +246,12 @@ class LDAPModel: ...@@ -197,6 +246,12 @@ class LDAPModel:
if item in self.__attributes.get(name, []): if item in self.__attributes.get(name, []):
self.__attributes[name].remove(item) self.__attributes[name].remove(item)
def __repr__(self):
name = '%s.%s'%(type(self).__module__, type(self).__name__)
if self.__ldap_dn is None:
return '<%s>'%name
return '<%s %s>'%(name, self.__ldap_dn)
def build_dn(self): def build_dn(self):
return '%s=%s,%s'%(self.ldap_dn_attribute, self.__attributes[self.ldap_dn_attribute][0], self.ldap_dn_base) return '%s=%s,%s'%(self.ldap_dn_attribute, self.__attributes[self.ldap_dn_attribute][0], self.ldap_dn_base)
...@@ -237,7 +292,7 @@ class LDAPModel: ...@@ -237,7 +292,7 @@ class LDAPModel:
for key, value in kwargs.items(): for key, value in kwargs.items():
filters.append('(%s=%s)'%(key, escape_filter_chars(value))) filters.append('(%s=%s)'%(key, escape_filter_chars(value)))
conn = get_conn() conn = get_conn()
conn.search(cls.ldap_dn_base, '(&%s)'%(''.join(filters))) conn.search(cls.ldap_base, '(&%s)'%(''.join(filters)))
res = [] res = []
for entry in conn.response: for entry in conn.response:
obj = ldap.session.lookup(entry['dn']) obj = ldap.session.lookup(entry['dn'])
...@@ -307,22 +362,21 @@ class User(LDAPModel): ...@@ -307,22 +362,21 @@ class User(LDAPModel):
displayname = LDAPAttribute('cn') displayname = LDAPAttribute('cn')
mail = LDAPAttribute('mail') mail = LDAPAttribute('mail')
pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128)))
groups = LDAPAttribute('memberOf', multi=True, default=[], encode=lambda obj: obj.dn, decode=lambda dn: Group.ldap_get(dn))
def password(self, value): def password(self, value):
self.pwhash = hashed(HASHED_SALTED_SHA512, value) self.pwhash = hashed(HASHED_SALTED_SHA512, value)
password = property(fset=password) password = property(fset=password)
class Group(LDAPModel): class Group(LDAPModel):
ldap_base = 'ou=groups,dc=example,dc=com' ldap_base = 'ou=groups,dc=example,dc=com'
ldap_filter = '(objectClass=groupOfUniqueNames)' ldap_filter = '(objectClass=groupOfUniqueNames)'
gid = LDAPAttribute('gidNumber') gid = LDAPAttribute('gidNumber')
name = LDAPAttribute('cn') name = LDAPAttribute('cn')
members = LDAPAttribute('uniqueMember', multi=True, default=[], encode=lambda obj: obj.dn, decode=lambda dn: User.ldap_get(dn))
description = LDAPAttribute('description', default='') description = LDAPAttribute('description', default='')
members = LDAPRelation('uniqueMember', User, backref='groups')
class Mail(LDAPModel): class Mail(LDAPModel):
ldap_base = 'ou=postfix,dc=example,dc=com' ldap_base = 'ou=postfix,dc=example,dc=com'
ldap_dn_attribute = 'uid' ldap_dn_attribute = 'uid'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment