diff --git a/tests/utils.py b/tests/utils.py index d66ce40408fca465b3e93b08b3e64e3ea726c32f..c1456ef9dacdc9df2fb82aaec741375f93ad601b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -40,7 +40,7 @@ class UffdTestCase(unittest.TestCase): os.system("ldapdelete -c -D 'cn=uffd,ou=system,dc=example,dc=com' -w 'uffd-ldap-password' -H 'ldap://localhost' -f ldap_server_entries_cleanup.ldif > /dev/null 2>&1") os.system("ldapadd -c -D 'cn=uffd,ou=system,dc=example,dc=com' -w 'uffd-ldap-password' -H 'ldap://localhost' -f ldap_server_entries_add.ldif") os.system("ldapmodify -c -D 'cn=uffd,ou=system,dc=example,dc=com' -w 'uffd-ldap-password' -H 'ldap://localhost' -f ldap_server_entries_modify.ldif") - os.system("/usr/sbin/slapcat -n 1 -l /dev/stdout") + #os.system("/usr/sbin/slapcat -n 1 -l /dev/stdout") self.app = create_app(config) self.setUpApp() self.client = self.app.test_client() diff --git a/uffd/ldap.py b/uffd/ldap.py index 4ae9a330da5a18035d9a627fecece0eb97cc9dd5..d4bc071e968d1355746b1294390bc9f01388c2cc 100644 --- a/uffd/ldap.py +++ b/uffd/ldap.py @@ -149,24 +149,27 @@ class LDAPSet(MutableSet): self.__delitem(self.__encode(value)) class LDAPAttribute: - def __init__(self, name, multi=False, default=None, encode=None, decode=None): + def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None): self.name = name self.multi = multi self.encode = encode or (lambda x: x) self.decode = decode or (lambda x: x) - def default_wrapper(): - values = default() if callable(default) else default - if not isinstance(values, list): - values = [values] - return [self.encode(value) for value in values] - self.default = default_wrapper + self.default_values = default + self.aliases = aliases or [] + + def default(self, obj): + if obj.ldap_getattr(self.name) == []: + values = self.default_values + if callable(values): + values = values() + self.__set__(obj, values) + for name in self.aliases: + obj.ldap_setattr(name, obj.ldap_getattr(self.name)) def __set_name__(self, cls, name): if self.default is None: return - if not cls.ldap_defaults: - cls.ldap_defaults = {} - cls.ldap_defaults[self.name] = self.default + cls.ldap_defaults = cls.ldap_defaults + [self.default] def __get__(self, obj, objtype=None): if obj is None: @@ -187,9 +190,7 @@ class LDAPBackref: def __init__(self, srccls, srcattr): self.srccls = srccls self.srcattr = srcattr - if srccls.ldap_relations is None: - srccls.ldap_relations = set() - srccls.ldap_relations.add(srcattr) + srccls.ldap_relations = srccls.ldap_relations + [srcattr] def init(self, obj): if self.srcattr not in obj.ldap_relation_data and obj.ldap_created: @@ -229,8 +230,9 @@ class LDAPModel: ldap_base = None ldap_object_classes = None ldap_filter = None - ldap_defaults = None # Populated by LDAPAttribute - ldap_relations = None # Populated by LDAPBackref + # Caution: Never mutate ldap_defaults and ldap_relations, always reassign! + ldap_defaults = [] + ldap_relations = [] def __init__(self, _ldap_dn=None, _ldap_attributes=None, **kwargs): self.ldap_relation_data = set() @@ -247,11 +249,11 @@ class LDAPModel: if not hasattr(self, key): raise Exception() setattr(self, key, value) - for name in (self.ldap_relations or []): + for name in self.ldap_relations: self.__update_relations(name, add_dns=self.__attributes.get(name, [])) def __update_relations(self, name, delete_dns=None, add_dns=None): - if name in (self.ldap_relations or []): + if name in self.ldap_relations: ldap.session.update_relations(self, name, delete_dns, add_dns) def ldap_getattr(self, name): @@ -344,11 +346,11 @@ class LDAPModel: return cls.ldap_filter_by_raw(**_kwargs) def ldap_reset(self): - for name in (self.ldap_relations or []): + for name in self.ldap_relations: self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) self.__changes = {} self.__attributes = deepcopy(self.__ldap_attributes) - for name in (self.ldap_relations or {}): + for name in self.ldap_relations: self.__update_relations(name, add_dns=self.__attributes.get(name, [])) @property @@ -375,13 +377,11 @@ class LDAPModel: if self.ldap_created: raise Exception() conn = get_conn() - for key, func in (self.ldap_defaults or {}).items(): - if key not in self.__attributes: - values = func() - self.__attributes[key] = values - self.__changes[key] = [(MODIFY_REPLACE, values)] + for func in self.ldap_defaults: + func(self) success = conn.add(self.dn, self.ldap_object_classes, self.__attributes) if not success: + print('commit error', success, conn.result) raise LDAPCommitError() self.__changes = {} self.__ldap_attributes = deepcopy(self.__attributes) diff --git a/uffd/user/models.py b/uffd/user/models.py index 96cc8c65667b81c690c2a199183e6aed1b9982d6..a56e687ab9e46a81081ef1f7869b613ba878653e 100644 --- a/uffd/user/models.py +++ b/uffd/user/models.py @@ -25,10 +25,20 @@ class User(LDAPModel): uid = LDAPAttribute('uidNumber', default=get_next_uid) loginname = LDAPAttribute('uid') - displayname = LDAPAttribute('cn') + displayname = LDAPAttribute('cn', aliases=['givenName', 'displayName']) mail = LDAPAttribute('mail') pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) + def dummy_attribute_defaults(self): + if self.ldap_getattr('sn') == []: + self.ldap_setattr('sn', [' ']) + if self.ldap_getattr('homeDirectory') == []: + self.ldap_setattr('homeDirectory', ['/home/%s'%self.loginname]) + if self.ldap_getattr('gidNumber') == []: + self.ldap_setattr('gidNumber', [current_app.config['LDAP_USER_GID']]) + + ldap_defaults = LDAPModel.ldap_defaults + [dummy_attribute_defaults] + # Write-only property def password(self, value): self.pwhash = hashed(HASHED_SALTED_SHA512, value)