diff --git a/ldapserver/dn.py b/ldapserver/dn.py index 0ede9a1db5ad4858eb6e840749b2bdbf002a7644..c6f6bb869d699dda23d8d21855c4e22284e63c9a 100644 --- a/ldapserver/dn.py +++ b/ldapserver/dn.py @@ -1,129 +1,44 @@ -from string import hexdigits as HEXDIGITS +from string import hexdigits as HEXDIGITS, ascii_letters as ASCII_LETTERS, digits as DIGITS from .util import encode_attribute DN_ESCAPED = ('"', '+', ',', ';', '<', '>') DN_SPECIAL = DN_ESCAPED + (' ', '#', '=') -def parse_assertion(expr, case_ignore_attrs=None): - case_ignore_attrs = case_ignore_attrs or [] - hexdigit = None - escaped = False - tokens = [] - token = b'' - for char in expr: - if hexdigit is not None: - if char not in HEXDIGITS: - raise ValueError('Invalid hexpair: \\%s%s'%(hexdigit, char)) - token += bytes.fromhex('%s%s'%(hexdigit, char)) - hexdigit = None - elif escaped: - escaped = False - if char in DN_SPECIAL or char == '\\': - token += char.encode() - elif char in HEXDIGITS: - hexdigit = char - else: - raise ValueError('Invalid escape: \\%s'%char) - elif char == '\\': - escaped = True - elif char == '=': - tokens.append(token) - token = b'' - else: - token += char.encode() - tokens.append(token) - if len(tokens) != 2: - raise ValueError('Invalid assertion in RDN: "%s"'%expr) - name = tokens[0].decode().lower() - value = tokens[1] - if not name or not value: - raise ValueError('Invalid assertion in RDN: "%s"'%expr) - if name in case_ignore_attrs: - value = value.lower() - return (name, value) - -def parse_rdn(rdn, case_ignore_attrs=None): - escaped = False - assertions = [] - token = '' - for char in rdn: - if escaped: - escaped = False - token += char - elif char == '+': - assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs)) - token = '' - else: - if char == '\\': - escaped = True - token += char - assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs)) - if not assertions: - raise ValueError('Invalid RDN "%s"'%rdn) - return tuple(sorted(assertions)) - -def parse_dn(dn, case_ignore_attrs=None): - if not dn: - return tuple() - escaped = False - rdns = [] - rdn = '' - for char in dn: - if escaped: - escaped = False - rdn += char - elif char == ',': - rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs)) - rdn = '' - else: - if char == '\\': - escaped = True - rdn += char - rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs)) - return tuple(rdns) - -def escape_dn_value(value): - if isinstance(value, int): - value = str(value) - if isinstance(value, str): - value = value.encode() - res = '' - for byte in value: - byte = bytes((byte,)) - try: - chars = byte.decode() - except UnicodeDecodeError: - chars = '\\'+byte.hex() - if chars in DN_SPECIAL: - chars = '\\'+chars - res += chars - return res - -def build_assertion(assertion): - name, value = assertion - return '%s=%s'%(escape_dn_value(name.encode()), escape_dn_value(value)) - -def build_rdn(assertions): - return '+'.join(map(build_assertion, assertions)) - -def build_dn(rdns): - return ','.join(map(build_rdn, rdns)) - class DN(tuple): def __new__(cls, *args): - if len(args) == 1 and isinstance(args[0], str): - return super().__new__(cls, [RDN(*rdn) for rdn in parse_dn(args[0])]) - elif len(args) == 1 and isinstance(args[0], DN): + if len(args) == 1 and isinstance(args[0], DN): return args[0] - else: - return super().__new__(cls, [RDN(rdn) for rdn in args]) + if len(args) == 1 and isinstance(args[0], str): + return cls.from_str(args[0]) + return super().__new__(cls, [RDN(rdn) for rdn in args]) def __repr__(self): return '<DN(%s)>'%repr(str(self)) + @classmethod + def from_str(cls, dn, case_ignore_attrs=None): + if not dn: + return tuple() + escaped = False + rdns = [] + rdn = '' + for char in dn: + if escaped: + escaped = False + rdn += char + elif char == ',': + rdns.append(RDN.from_str(rdn, case_ignore_attrs=case_ignore_attrs)) + rdn = '' + else: + if char == '\\': + escaped = True + rdn += char + rdns.append(RDN.from_str(rdn, case_ignore_attrs=case_ignore_attrs)) + return cls(*rdns) + def __str__(self): - return build_dn(self) + return ','.join(map(str, self)) def __bytes__(self): return str(self).encode() @@ -131,6 +46,8 @@ class DN(tuple): def __add__(self, value): if isinstance(value, DN): return DN(*(tuple(self) + tuple(value))) + elif isinstance(value, RDN): + return self + DN(value) else: raise ValueError() @@ -156,46 +73,130 @@ class DN(tuple): return not rbase class RDN(tuple): - def __new__(cls, *args, **kwargs): - if not kwargs and len(args) == 1 and isinstance(args[0], str): - return super().__new__(cls, [RDNAssertion(ava) for ava in parse_rdn(args[0])]) - elif not kwargs and len(args) == 1 and isinstance(args[0], cls): + '''Group of one or more `RDNAssertion` objects''' + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], cls): return args[0] + if len(args) == 1 and isinstance(args[0], str): + return cls.from_str(args[0]) + if len(args) == 1 and isinstance(args[0], dict): + args = args[0].items() + assertions = [] + for value in args: + assertion = RDNAssertion(value) + if assertion not in assertions: + assertions.append(assertion) + assertions.sort() + if not assertions: + raise ValueError('Invalid RDN "%s"'%repr(args)) + return super().__new__(cls, assertions) + + def __add__(self, value): + if isinstance(value, RDN): + return DN(self, value) + elif isinstance(value, DN): + return DN(self) + value else: - assertions = [] - for key, value in args + tuple(kwargs.items()): - assertion = RDNAssertion(key, value) - if assertion not in assertions: - assertions.append(assertion) - assertions.sort() - return super().__new__(cls, assertions) + raise ValueError() def __repr__(self): return '<RDN(%s)>'%repr(str(self)) + @classmethod + def from_str(cls, rdn, case_ignore_attrs=None): + escaped = False + assertions = [] + token = '' + for char in rdn: + if escaped: + escaped = False + token += char + elif char == '+': + assertions.append(RDNAssertion.from_str(token, case_ignore_attrs=case_ignore_attrs)) + token = '' + else: + if char == '\\': + escaped = True + token += char + assertions.append(RDNAssertion.from_str(token, case_ignore_attrs=case_ignore_attrs)) + return cls(*assertions) + def __str__(self): - return build_rdn(self) + return '+'.join(map(str, self)) class RDNAssertion(tuple): + '''A single assertion (attribute=value)''' def __new__(cls, *args): - if len(args) == 1 and isinstance(args[0], RDNAssertion): - return args[0] - elif len(args) == 1 and isinstance(args[0], str): - return super().__new__(cls, parse_assertion(args[0])) - else: - key, value = args - value = encode_attribute(value) - if not isinstance(key, str): - raise ValueError('Key in RDN assertion "%s=%s" has invalid type'%(repr(key), repr(value))) - if not isinstance(value, bytes): - raise ValueError('Value in RDN assertion "%s=%s" has invalid type'%(key, repr(value))) - return super().__new__(cls, (key.lower(), value)) + if len(args) == 1: + args = args[0] + if isinstance(args, RDNAssertion): + return args + if isinstance(args, str): + return cls.from_str(args) + attribute, value = args + if not isinstance(attribute, str): + raise TypeError('Attribute name in RDN assertion %s=%s must be str not %s'%(repr(attribute), repr(value), repr(type(attribute)))) + for index, char in enumerate(attribute): + if char not in ASCII_LETTERS+DIGITS+'-': + raise ValueError('Invalid character in attribute name %s at position %d'%(repr(attribute), index+1)) + attribute = attribute.lower() + value = encode_attribute(value) + return super().__new__(cls, (attribute, value)) def __repr__(self): return '<RDNAssertion(%s)>'%repr(str(self)) + @classmethod + def from_str(cls, expr, case_ignore_attrs=None): + case_ignore_attrs = case_ignore_attrs or [] + hexdigit = None + escaped = False + tokens = [] + token = b'' + for char in expr: + if hexdigit is not None: + if char not in HEXDIGITS: + raise ValueError('Invalid hexpair: \\%s%s'%(hexdigit, char)) + token += bytes.fromhex('%s%s'%(hexdigit, char)) + hexdigit = None + elif escaped: + escaped = False + if char in DN_SPECIAL or char == '\\': + token += char.encode() + elif char in HEXDIGITS: + hexdigit = char + else: + raise ValueError('Invalid escape: \\%s'%char) + elif char == '\\': + escaped = True + elif char == '=': + tokens.append(token) + token = b'' + else: + token += char.encode() + tokens.append(token) + if len(tokens) != 2: + raise ValueError('Invalid assertion in RDN: "%s"'%expr) + name = tokens[0].decode().lower() + value = tokens[1] + if not name or not value: + raise ValueError('Invalid assertion in RDN: "%s"'%expr) + if name in case_ignore_attrs: + value = value.lower() + return cls(name, value) + def __str__(self): - return build_assertion(self) + valuestr = '' + for byte in self.value: + byte = bytes((byte,)) + try: + chars = byte.decode() + except UnicodeDecodeError: + chars = '\\'+byte.hex() + if chars in DN_SPECIAL: + chars = '\\'+chars + valuestr += chars + return '%s=%s'%(self.attribute, valuestr) @property def attribute(self):