diff --git a/ldapserver/objects.py b/ldapserver/objects.py index dc87113d380615405e3e00fded0e77aa5adcd0d7..ae267bea67b2f790af47b2297ffbcbcd54434bef 100644 --- a/ldapserver/objects.py +++ b/ldapserver/objects.py @@ -51,6 +51,13 @@ def all_3value(iterable): return result class AttributeDict(dict): + '''Special dictionary holding LDAP attribute values + + Attribute values can be set and accessed by their attribute type's numeric + OID or short descriptive name. Attribute types must be defined within the + schema to be used. Attribute values are always lists. Accessing an unset + attribute behaves the same as accessing an empty attribute. List items must + conform to the attribute's syntax.''' def __init__(self, schema, **attributes): super().__init__() self.schema = schema @@ -58,38 +65,38 @@ class AttributeDict(dict): self[key] = value def __contains__(self, key): - return super().__contains__(self.schema.lookup_attribute(key)) + try: + return super().__contains__(self.schema.get_attribute_type(key).name) + except KeyError: + return False - def __setitem__(self, key, value): - super().__setitem__(self.schema.lookup_attribute(key, fail_if_not_found=True), value) + def __setitem__(self, key, values): + super().__setitem__(self.schema.get_attribute_type(key).name, values) def __getitem__(self, key): - key = self.schema.lookup_attribute(key, fail_if_not_found=True) - if key not in self: - super().__setitem__(key, []) - result = super().__getitem__(key) + canonical_oid = self.schema.get_attribute_type(key).name + if not super().__contains__(canonical_oid): + super().__setitem__(canonical_oid, []) + result = super().__getitem__(canonical_oid) if callable(result): return result() return result def setdefault(self, key, default=None): - key = self.schema.lookup_attribute(key, fail_if_not_found=True) - return super().setdefault(key, default) + canonical_oid = self.schema.get_attribute_type(key).name + return super().setdefault(canonical_oid, default) def get(self, key, default=None): - key = self.schema.lookup_attribute(key, fail_if_not_found=True) - if key in self: - return self[key] - return default + return self[key] or default - def get_all(self, key): + def get_with_subtypes(self, key): result = [] - for attr in self.schema.lookup_attribute_list(key): - result += self[attr] + for attribute_type in self.schema.get_attribute_type_and_subtypes(key): + result += self[attribute_type.name] return result def match_present(self, key): - attribute_type = self.schema.lookup_attribute(key) + attribute_type = self.schema.get_attribute_type(key) if attribute_type is None: return FilterResult.UNDEFINED if self[attribute_type] != []: @@ -98,57 +105,72 @@ class AttributeDict(dict): return FilterResult.FALSE def match_equal(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: return FilterResult.UNDEFINED - assertion_value = attribute_type.equality.syntax.decode(assertion_value) + if attribute_type.equality is None: + return FilterResult.UNDEFINED + assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(attrval, assertion_value)), self.get_all(key))) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(self.schema, attrval, assertion_value)), self.get_with_subtypes(key))) def match_substr(self, key, inital_substring, any_substrings, final_substring): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None or attribute_type.substr is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.equality is None or attribute_type.substr is None: return FilterResult.UNDEFINED if inital_substring: - inital_substring = attribute_type.equality.syntax.decode(inital_substring) + inital_substring = attribute_type.equality.syntax.decode(self.schema, inital_substring) if inital_substring is None: return FilterResult.UNDEFINED - any_substrings = [attribute_type.equality.syntax.decode(substring) for substring in any_substrings] + any_substrings = [attribute_type.equality.syntax.decode(self.schema, substring) for substring in any_substrings] if None in any_substrings: return FilterResult.UNDEFINED if final_substring: - final_substring = attribute_type.equality.syntax.decode(final_substring) + final_substring = attribute_type.equality.syntax.decode(self.schema, final_substring) if final_substring is None: return FilterResult.UNDEFINED - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(attrval, inital_substring, any_substrings, final_substring)), self.get_all(key))) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(self.schema, attrval, inital_substring, any_substrings, final_substring)), self.get_with_subtypes(key))) def match_approx(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: return FilterResult.UNDEFINED - assertion_value = attribute_type.equality.syntax.decode(assertion_value) + if attribute_type.equality is None: + return FilterResult.UNDEFINED + assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(attrval, assertion_value)), self.get_all(key))) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(self.schema, attrval, assertion_value)), self.get_with_subtypes(key))) def match_greater_or_equal(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.ordering is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.ordering is None: return FilterResult.UNDEFINED - assertion_value = attribute_type.ordering.syntax.decode(assertion_value) + assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(attrval, assertion_value)), self.get_all(key))) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(self.schema, attrval, assertion_value)), self.get_with_subtypes(key))) def match_less(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.ordering is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: return FilterResult.UNDEFINED - assertion_value = attribute_type.ordering.syntax.decode(assertion_value) + if attribute_type.ordering is None: + return FilterResult.UNDEFINED + assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(attrval, assertion_value)), self.get_all(key))) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(self.schema, attrval, assertion_value)), self.get_with_subtypes(key))) def match_less_or_equal(self, key, assertion_value): return any_3value((self.match_equal(key, assertion_value), @@ -205,22 +227,21 @@ class Object(AttributeDict): selected_attributes = set() for selector in attributes or ['*']: if selector == '*': - selected_attributes |= self.schema.user_attribute_types + selected_attributes |= set(self.schema.user_attribute_types) elif selector == '1.1': continue else: - attribute = self.schema.lookup_attribute(selector) - if attribute is not None: - selected_attributes.add(attribute) + try: + selected_attributes.add(self.schema.get_attribute_type(selector)) + except KeyError: + pass partial_attributes = [] - for attribute in self: - if attribute not in selected_attributes: - continue - values = self[attribute] + for attribute_type in selected_attributes: + values = self[attribute_type.name] if values != []: if types_only: values = [] - partial_attributes.append(ldap.PartialAttribute(attribute.name, [attribute.syntax.encode(value) for value in values])) + partial_attributes.append(ldap.PartialAttribute(attribute_type.name, [attribute_type.syntax.encode(self.schema, value) for value in values])) return ldap.SearchResultEntry(str(self.dn), partial_attributes) class RootDSE(Object): @@ -244,8 +265,9 @@ class ObjectTemplate(AttributeDict): self.rdn_attribute = rdn_attribute def match_present(self, key): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: return FilterResult.UNDEFINED values = self[attribute_type] if values == []: @@ -256,72 +278,87 @@ class ObjectTemplate(AttributeDict): return FilterResult.TRUE def match_equal(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.equality is None: return FilterResult.UNDEFINED - assertion_value = attribute_type.equality.syntax.decode(assertion_value) + assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - values = self.get_all(key) + values = self.get_with_subtypes(key) if WILDCARD_VALUE in values: return FilterResult.MAYBE_TRUE - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(attrval, assertion_value)), values)) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(self.schema, attrval, assertion_value)), values)) def match_substr(self, key, inital_substring, any_substrings, final_substring): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None or attribute_type.substr is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.equality is None or attribute_type.substr is None: return FilterResult.UNDEFINED if inital_substring: - inital_substring = attribute_type.equality.syntax.decode(inital_substring) + inital_substring = attribute_type.equality.syntax.decode(self.schema, inital_substring) if inital_substring is None: return FilterResult.UNDEFINED - any_substrings = [attribute_type.equality.syntax.decode(substring) for substring in any_substrings] + any_substrings = [attribute_type.equality.syntax.decode(self.schema, substring) for substring in any_substrings] if None in any_substrings: return FilterResult.UNDEFINED if final_substring: - final_substring = attribute_type.equality.syntax.decode(final_substring) + final_substring = attribute_type.equality.syntax.decode(self.schema, final_substring) if final_substring is None: return FilterResult.UNDEFINED - values = self.get_all(key) + values = self.get_with_subtypes(key) if WILDCARD_VALUE in values: return FilterResult.MAYBE_TRUE - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(attrval, inital_substring, any_substrings, final_substring)), values)) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(self.schema, attrval, inital_substring, any_substrings, final_substring)), values)) def match_approx(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.equality is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.equality is None: return FilterResult.UNDEFINED - assertion_value = attribute_type.equality.syntax.decode(assertion_value) + assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - values = self.get_all(key) + values = self.get_with_subtypes(key) if WILDCARD_VALUE in values: return FilterResult.MAYBE_TRUE - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(attrval, assertion_value)), values)) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(self.schema, attrval, assertion_value)), values)) def match_greater_or_equal(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.ordering is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: return FilterResult.UNDEFINED - assertion_value = attribute_type.ordering.syntax.decode(assertion_value) + if attribute_type.ordering is None: + return FilterResult.UNDEFINED + assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - values = self.get_all(key) + values = self.get_with_subtypes(key) if WILDCARD_VALUE in values: return FilterResult.MAYBE_TRUE - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(attrval, assertion_value)), values)) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(self.schema, attrval, assertion_value)), values)) def match_less(self, key, assertion_value): - attribute_type = self.schema.lookup_attribute(key) - if attribute_type is None or attribute_type.ordering is None: + try: + attribute_type = self.schema.get_attribute_type(key) + except KeyError: + return FilterResult.UNDEFINED + if attribute_type.ordering is None: return FilterResult.UNDEFINED - assertion_value = attribute_type.ordering.syntax.decode(assertion_value) + assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value) if assertion_value is None: return FilterResult.UNDEFINED - values = self.get_all(key) + values = self.get_with_subtypes(key) if WILDCARD_VALUE in values: return FilterResult.MAYBE_TRUE - return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(attrval, assertion_value)), values)) + return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(self.schema, attrval, assertion_value)), values)) def __extract_dn_constraints(self, basedn, scope): if scope == ldap.SearchScope.baseObject: @@ -348,10 +385,13 @@ class ObjectTemplate(AttributeDict): def extract_filter_constraints(self, filter_obj): if isinstance(filter_obj, ldap.FilterEqual): - attribute_type = self.schema.lookup_attribute(filter_obj.attribute) - if attribute_type is None or attribute_type.equality is None: + try: + attribute_type = self.schema.get_attribute_type(filter_obj.attribute) + except KeyError: + return AttributeDict(self.schema) + if attribute_type.equality is None: return AttributeDict(self.schema) - assertion_value = attribute_type.equality.syntax.decode(filter_obj.value) + assertion_value = attribute_type.equality.syntax.decode(self.schema, filter_obj.value) if assertion_value is None: return AttributeDict(self.schema) return AttributeDict(self.schema, **{filter_obj.attribute: [assertion_value]}) @@ -391,10 +431,10 @@ class SubschemaSubentry(Object): self['subschemaSubentry'] = [self.dn] self['structuralObjectClass'] = ['subtree'] self['objectClass'] = ['top', 'subtree', 'subschema'] - self['objectClasses'] = schema.object_class_definitions - self['ldapSyntaxes'] = schema.syntax_definitions - self['matchingRules'] = schema.matching_rule_definitions - self['attributeTypes'] = schema.attribute_type_definitions + self['objectClasses'] = [item.to_definition() for item in schema.object_classes] + self['ldapSyntaxes'] = [item.to_definition() for item in schema.syntaxes] + self['matchingRules'] = [item.to_definition() for item in schema.matching_rules] + self['attributeTypes'] = [item.to_definition() for item in schema.attribute_types] # pylint: disable=invalid-name self.AttributeDict = lambda **attributes: AttributeDict(schema, **attributes) self.Object = lambda *args, **attributes: Object(schema, dn, subschemaSubentry=[dn], **attributes) diff --git a/ldapserver/schema/__init__.py b/ldapserver/schema/__init__.py index 11c676514af1561b5f402229dfc8244f14698ef4..0e4298dd7fa8beddcbdee95426b25d1408d14798 100644 --- a/ldapserver/schema/__init__.py +++ b/ldapserver/schema/__init__.py @@ -2,13 +2,13 @@ from .types import * from . import rfc4517, rfc4512, rfc4519, rfc4524, rfc3112, rfc2307bis, rfc2079, rfc2252, rfc2798, rfc4523, rfc1274 # Core LDAP Schema -RFC4519_SUBSCHEMA = Schema( rfc4519.object_classes.ALL, rfc4519.attribute_types.ALL, rfc4519.matching_rules.ALL, rfc4519.matching_rules.ALL) +RFC4519_SUBSCHEMA = Schema( rfc4519.object_classes.ALL, rfc4519.attribute_types.ALL, rfc4519.matching_rules.ALL, rfc4519.syntaxes.ALL) # COSINE LDAP/X.500 Schema -RFC4524_SUBSCHEMA = Schema(rfc4524.object_classes.ALL, rfc4524.attribute_types.ALL, rfc4524.matching_rules.ALL, rfc4524.matching_rules.ALL) +RFC4524_SUBSCHEMA = Schema(rfc4524.object_classes.ALL, rfc4524.attribute_types.ALL, rfc4524.matching_rules.ALL, rfc4524.syntaxes.ALL) # inetOrgPerson Schema -RFC2798_SUBSCHEMA = Schema(rfc2798.object_classes.ALL, rfc2798.attribute_types.ALL, rfc2798.matching_rules.ALL, rfc2798.matching_rules.ALL) +RFC2798_SUBSCHEMA = Schema(rfc2798.object_classes.ALL, rfc2798.attribute_types.ALL, rfc2798.matching_rules.ALL, rfc2798.syntaxes.ALL) # Extended RFC2307 (NIS) Schema -RFC2307BIS_SUBSCHEMA = Schema(rfc2307bis.object_classes.ALL, rfc2307bis.attribute_types.ALL, rfc2307bis.matching_rules.ALL, rfc2307bis.matching_rules.ALL) +RFC2307BIS_SUBSCHEMA = Schema(rfc2307bis.object_classes.ALL, rfc2307bis.attribute_types.ALL, rfc2307bis.matching_rules.ALL, rfc2307bis.syntaxes.ALL) diff --git a/ldapserver/schema/rfc2252/syntaxes.py b/ldapserver/schema/rfc2252/syntaxes.py index d47d85f2c5d72a7c9d0e11b757b22576c70b7160..cf2d2492dd021dfe34c5d69453aa0862db34de15 100644 --- a/ldapserver/schema/rfc2252/syntaxes.py +++ b/ldapserver/schema/rfc2252/syntaxes.py @@ -7,11 +7,11 @@ class Binary(Syntax): desc = 'Binary' @staticmethod - def encode(value): + def encode(schema, value): return value @staticmethod - def decode(raw_value): + def decode(schema, raw_value): return raw_value ALL = ( diff --git a/ldapserver/schema/rfc4517/matching_rules.py b/ldapserver/schema/rfc4517/matching_rules.py index ea0d98927b9344f073011ddde156c709f4ea1584..ea211434fcf32f889b97ed9c3261ffce6c3314f3 100644 --- a/ldapserver/schema/rfc4517/matching_rules.py +++ b/ldapserver/schema/rfc4517/matching_rules.py @@ -3,13 +3,13 @@ from ... import rfc4518_stringprep from . import syntaxes class GenericMatchingRule(MatchingRule): - def match_equal(self, attribute_value, assertion_value): + def match_equal(self, schema, attribute_value, assertion_value): return attribute_value == assertion_value - def match_less(self, attribute_value, assertion_value): + def match_less(self, schema, attribute_value, assertion_value): return attribute_value < assertion_value - def match_greater_or_equal(self, attribute_value, assertion_value): + def match_greater_or_equal(self, schema, attribute_value, assertion_value): return attribute_value >= assertion_value class StringMatchingRule(MatchingRule): @@ -17,7 +17,7 @@ class StringMatchingRule(MatchingRule): super().__init__(oid, name, syntax) self.matching_type = matching_type - def match_equal(self, attribute_value, assertion_value): + def match_equal(self, schema, attribute_value, assertion_value): try: attribute_value = rfc4518_stringprep.prepare(attribute_value, self.matching_type) assertion_value = rfc4518_stringprep.prepare(assertion_value, self.matching_type) @@ -25,7 +25,7 @@ class StringMatchingRule(MatchingRule): return None return attribute_value == assertion_value - def match_less(self, attribute_value, assertion_value): + def match_less(self, schema, attribute_value, assertion_value): try: attribute_value = rfc4518_stringprep.prepare(attribute_value, self.matching_type) assertion_value = rfc4518_stringprep.prepare(assertion_value, self.matching_type) @@ -33,7 +33,7 @@ class StringMatchingRule(MatchingRule): return None return attribute_value < assertion_value - def match_greater_or_equal(self, attribute_value, assertion_value): + def match_greater_or_equal(self, schema, attribute_value, assertion_value): try: attribute_value = rfc4518_stringprep.prepare(attribute_value, self.matching_type) assertion_value = rfc4518_stringprep.prepare(assertion_value, self.matching_type) @@ -41,7 +41,7 @@ class StringMatchingRule(MatchingRule): return None return attribute_value >= assertion_value - def match_substr(self, attribute_value, inital_substring, any_substrings, final_substring): + def match_substr(self, schema, attribute_value, inital_substring, any_substrings, final_substring): try: attribute_value = rfc4518_stringprep.prepare(attribute_value, self.matching_type) if inital_substring: @@ -72,7 +72,7 @@ class StringListMatchingRule(MatchingRule): self.matching_type = matching_type # Values are both lists of str - def match_equal(self, attribute_value, assertion_value): + def match_equal(self, schema, attribute_value, assertion_value): try: attribute_value = [rfc4518_stringprep.prepare(line, self.matching_type) for line in attribute_value] assertion_value = [rfc4518_stringprep.prepare(line, self.matching_type) for line in assertion_value] @@ -81,14 +81,23 @@ class StringListMatchingRule(MatchingRule): return attribute_value == assertion_value class FirstComponentMatchingRule(MatchingRule): - def __init__(self, oid, name, syntax, attribute_name): + def __init__(self, oid, name, syntax, attribute_name, matching_rule): super().__init__(oid, name, syntax) self.attribute_name = attribute_name + self.matching_rule = matching_rule - def match_equal(self, attribute_value, assertion_value): + def match_equal(self, schema, attribute_value, assertion_value): if not hasattr(attribute_value, self.attribute_name): return None - return getattr(attribute_value, self.attribute_name)() == assertion_value + return self.matching_rule.match_equal(schema, getattr(attribute_value, self.attribute_name)(), assertion_value) + +class OIDMatchingRule(MatchingRule): + def match_equal(self, schema, attribute_value, assertion_value): + attribute_value = schema.get_numeric_oid(attribute_value) + assertion_value = schema.get_numeric_oid(assertion_value) + if assertion_value is None: + return None + return attribute_value == assertion_value bitStringMatch = GenericMatchingRule('2.5.13.16', name='bitStringMatch', syntax=syntaxes.BitString()) booleanMatch = GenericMatchingRule('2.5.13.13', name='booleanMatch', syntax=syntaxes.Boolean()) @@ -103,20 +112,20 @@ caseIgnoreListSubstringsMatch = StringListMatchingRule('2.5.13.12', name='caseIg caseIgnoreMatch = StringMatchingRule('2.5.13.2', name='caseIgnoreMatch', syntax=syntaxes.DirectoryString(), matching_type=rfc4518_stringprep.MatchingType.CASE_IGNORE_STRING) caseIgnoreOrderingMatch = StringMatchingRule('2.5.13.3', name='caseIgnoreOrderingMatch', syntax=syntaxes.DirectoryString(), matching_type=rfc4518_stringprep.MatchingType.CASE_IGNORE_STRING) caseIgnoreSubstringsMatch = StringMatchingRule('2.5.13.4', name='caseIgnoreSubstringsMatch', syntax=syntaxes.SubstringAssertion(), matching_type=rfc4518_stringprep.MatchingType.CASE_IGNORE_STRING) -directoryStringFirstComponentMatch = FirstComponentMatchingRule('2.5.13.31', name='directoryStringFirstComponentMatch', syntax=syntaxes.DirectoryString(), attribute_name='get_first_component_string') +directoryStringFirstComponentMatch = FirstComponentMatchingRule('2.5.13.31', name='directoryStringFirstComponentMatch', syntax=syntaxes.DirectoryString(), attribute_name='get_first_component_string', matching_rule=caseIgnoreMatch) distinguishedNameMatch = GenericMatchingRule('2.5.13.1', name='distinguishedNameMatch', syntax=syntaxes.DN()) generalizedTimeMatch = GenericMatchingRule('2.5.13.27', name='generalizedTimeMatch', syntax=syntaxes.GeneralizedTime()) generalizedTimeOrderingMatch = GenericMatchingRule('2.5.13.28', name='generalizedTimeOrderingMatch', syntax=syntaxes.GeneralizedTime()) -integerFirstComponentMatch = FirstComponentMatchingRule('2.5.13.29', name='integerFirstComponentMatch', syntax=syntaxes.INTEGER(), attribute_name='get_first_component_integer') integerMatch = GenericMatchingRule('2.5.13.14', name='integerMatch', syntax=syntaxes.INTEGER()) +integerFirstComponentMatch = FirstComponentMatchingRule('2.5.13.29', name='integerFirstComponentMatch', syntax=syntaxes.INTEGER(), attribute_name='get_first_component_integer', matching_rule=integerMatch) integerOrderingMatch = GenericMatchingRule('2.5.13.15', name='integerOrderingMatch', syntax=syntaxes.INTEGER()) # Optional and implementation-specific, we simply never match keywordMatch = MatchingRule('2.5.13.33', name='keywordMatch', syntax=syntaxes.DirectoryString()) numericStringMatch = StringMatchingRule('2.5.13.8', name='numericStringMatch', syntax=syntaxes.NumericString(), matching_type=rfc4518_stringprep.MatchingType.NUMERIC_STRING) numericStringOrderingMatch = StringMatchingRule('2.5.13.9', name='numericStringOrderingMatch', syntax=syntaxes.NumericString(), matching_type=rfc4518_stringprep.MatchingType.NUMERIC_STRING) numericStringSubstringsMatch = StringMatchingRule('2.5.13.10', name='numericStringSubstringsMatch', syntax=syntaxes.SubstringAssertion(), matching_type=rfc4518_stringprep.MatchingType.NUMERIC_STRING) -objectIdentifierFirstComponentMatch = FirstComponentMatchingRule('2.5.13.30', name='objectIdentifierFirstComponentMatch', syntax=syntaxes.OID(), attribute_name='get_first_component_oid') -objectIdentifierMatch = StringMatchingRule('2.5.13.0', name='objectIdentifierMatch', syntax=syntaxes.OID(), matching_type=rfc4518_stringprep.MatchingType.CASE_IGNORE_STRING) +objectIdentifierMatch = OIDMatchingRule('2.5.13.0', name='objectIdentifierMatch', syntax=syntaxes.OID()) +objectIdentifierFirstComponentMatch = FirstComponentMatchingRule('2.5.13.30', name='objectIdentifierFirstComponentMatch', syntax=syntaxes.OID(), attribute_name='get_first_component_oid', matching_rule=objectIdentifierMatch) octetStringMatch = GenericMatchingRule('2.5.13.17', name='octetStringMatch', syntax=syntaxes.OctetString()) octetStringOrderingMatch = GenericMatchingRule('2.5.13.18', name='octetStringOrderingMatch', syntax=syntaxes.OctetString()) telephoneNumberMatch = StringMatchingRule('2.5.13.20', name='telephoneNumberMatch', syntax=syntaxes.TelephoneNumber(), matching_type=rfc4518_stringprep.MatchingType.TELEPHONE_NUMBER) diff --git a/ldapserver/schema/rfc4517/syntaxes.py b/ldapserver/schema/rfc4517/syntaxes.py index a25c4a0aa3909f5937d5221e7c5b2ac51d0c1cb3..cf63fd1e3be5648b1aea679486b7bfb9b5197c83 100644 --- a/ldapserver/schema/rfc4517/syntaxes.py +++ b/ldapserver/schema/rfc4517/syntaxes.py @@ -7,20 +7,20 @@ from ... import dn # Base classes class StringSyntax(Syntax): @staticmethod - def encode(value): + def encode(schema, value): return value.encode('utf8') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): return raw_value.decode('utf8') class BytesSyntax(Syntax): @staticmethod - def encode(value): + def encode(schema, value): return value @staticmethod - def decode(raw_value): + def decode(schema, raw_value): return raw_value # Syntax definitions @@ -37,11 +37,11 @@ class Boolean(Syntax): desc = 'Boolean' @staticmethod - def encode(value): + def encode(schema, value): return b'TRUE' if value else b'FALSE' @staticmethod - def decode(raw_value): + def decode(schema, raw_value): if raw_value == b'TRUE': return True elif raw_value == b'FALSE': @@ -74,11 +74,11 @@ class DN(Syntax): desc = 'DN' @staticmethod - def encode(value): + def encode(schema, value): return str(value).encode('utf8') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): try: return dn.DN.from_str(raw_value.decode('utf8')) except (UnicodeDecodeError, TypeError, ValueError): @@ -101,7 +101,7 @@ class GeneralizedTime(Syntax): desc = 'Generalized Time' @staticmethod - def encode(value): + def encode(schema, value): str_value = value.strftime('%Y%m%d%H%M%S.%f') if value.tzinfo == datetime.timezone.utc: str_value += 'Z' @@ -118,7 +118,7 @@ class GeneralizedTime(Syntax): return str_value.encode('ascii') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): try: raw_value = raw_value.decode('utf8') except UnicodeDecodeError: @@ -161,11 +161,11 @@ class INTEGER(Syntax): desc = 'INTEGER' @staticmethod - def encode(value): + def encode(schema, value): return str(value).encode('utf8') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): if not raw_value or not raw_value.split(b'-', 1)[-1].isdigit(): return None return int(raw_value.decode('utf8')) @@ -191,11 +191,11 @@ class NameAndOptionalUID(StringSyntax): desc = 'Name And Optional UID' @staticmethod - def encode(value): - return DN.encode(value) + def encode(schema, value): + return DN.encode(schema, value) @staticmethod - def decode(raw_value): + def decode(schema, raw_value): escaped = False dn_part = raw_value bitstr_part = b'' # pylint: disable=unused-variable @@ -213,7 +213,7 @@ class NameAndOptionalUID(StringSyntax): # of dn.DN that carries the bitstring part as an attribute. #if bitstr_part: # return DN.decode(dn_part), BitString.decode(bitstr_part) - return DN.decode(dn_part) + return DN.decode(schema, dn_part) class NameFormDescription(StringSyntax): oid = '1.3.6.1.4.1.1466.115.121.1.35' @@ -290,11 +290,11 @@ class PostalAddress(Syntax): # Native values are lists of str @staticmethod - def encode(value): + def encode(schema, value): return '$'.join([line.replace('\\', '\\5C').replace('$', '\\24') for line in value]).encode('utf8') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): return [line.replace('\\24', '$').replace('\\5C', '\\') for line in raw_value.decode('utf8').split('$')] class PrintableString(StringSyntax): @@ -322,7 +322,7 @@ class UTCTime(Syntax): desc = 'UTC Time' @staticmethod - def encode(value): + def encode(schema, value): str_value = value.strftime('%y%m%d%H%M%S') if value.tzinfo == datetime.timezone.utc: str_value += 'Z' @@ -339,7 +339,7 @@ class UTCTime(Syntax): return str_value.encode('ascii') @staticmethod - def decode(raw_value): + def decode(schema, raw_value): try: raw_value = raw_value.decode('utf8') except UnicodeDecodeError: diff --git a/ldapserver/schema/types.py b/ldapserver/schema/types.py index 628f1a8c969de85698d4f0a99b0d38fca86cf1aa..cac17ae3660050a50e50ea468fd93da690a28269 100644 --- a/ldapserver/schema/types.py +++ b/ldapserver/schema/types.py @@ -1,4 +1,5 @@ import enum +import re def escape(string): result = '' @@ -27,11 +28,10 @@ class Syntax: return cls.oid @classmethod - def encode_syntax_definition(cls): + def to_definition(cls): return f"( {cls.oid} DESC '{escape(cls.desc)}' )" - @staticmethod - def decode(raw_value): + def decode(self, schema, raw_value): '''Decode LDAP-specific encoding of a value to a native value :param raw_value: LDAP-specific encoding of the value @@ -41,8 +41,7 @@ class Syntax: :rtype: any or None''' return None - @staticmethod - def encode(value): + def encode(self, schema, value): '''Encode native value to its LDAP-specific encoding :param value: native value (depends on syntax) @@ -56,29 +55,30 @@ class MatchingRule: def __init__(self, oid, name, syntax, **kwargs): self.oid = oid self.name = name + self.names = [name] self.syntax = syntax for key, value in kwargs.items(): setattr(self, key, value) - def encode_syntax_definition(self): + def to_definition(self): return f"( {self.oid} NAME '{escape(self.name)}' SYNTAX {self.syntax.ref} )" def __repr__(self): - return f'<ldapserver.schema.MatchingRule {self.encode_syntax_definition()}>' + return f'<ldapserver.schema.MatchingRule {self.oid}>' - def match_equal(self, attribute_value, assertion_value): + def match_equal(self, schema, attribute_value, assertion_value): return None - def match_approx(self, attribute_value, assertion_value): - return self.match_equal(attribute_value, assertion_value) + def match_approx(self, schema, attribute_value, assertion_value): + return self.match_equal(schema, attribute_value, assertion_value) - def match_less(self, attribute_value, assertion_value): + def match_less(self, schema, attribute_value, assertion_value): return None - def match_greater_or_equal(self, attribute_value, assertion_value): + def match_greater_or_equal(self, schema, attribute_value, assertion_value): return None - def match_substr(self, attribute_value, inital_substring, any_substrings, final_substring): + def match_substr(self, schema, attribute_value, inital_substring, any_substrings, final_substring): return None class AttributeTypeUsage(enum.Enum): @@ -129,13 +129,12 @@ class AttributeType: self.schema_encoding = ' '.join(tokens) self.oid = oid self.name = name - self.names = set() - if name is not None: - self.names.add(name) + self.names = [name] if name is not None else [] + self.inherited_names = set(self.names) self.obsolete = obsolete or False self.sup = sup if self.sup is not None: - self.names |= self.sup.names + self.inherited_names |= self.sup.inherited_names self.equality = equality if self.equality is None and self.sup is not None: self.equality = self.sup.equality @@ -153,11 +152,14 @@ class AttributeType: self.no_user_modification = no_user_modification or False self.usage = usage or AttributeTypeUsage.userApplications + def to_definition(self): + return self.schema_encoding + def get_first_component_oid(self): return self.oid def __repr__(self): - return f'<ldapserver.schema.AttributeType {self.schema_encoding}>' + return f'<ldapserver.schema.AttributeType {self.oid}>' class ObjectClassKind(enum.Enum): ABSTRACT = enum.auto() @@ -201,6 +203,7 @@ class ObjectClass: self.schema_encoding = ' '.join(tokens) self.oid = oid self.name = name + self.names = [name] self.desc = desc self.obsolete = obsolete or False self.sup = sup @@ -208,83 +211,161 @@ class ObjectClass: self.must = must or [] self.may = may or [] + def to_definition(self): + return self.schema_encoding + def get_first_component_oid(self): return self.oid def __repr__(self): return f'<ldapserver.schema.ObjectClass {self.schema_encoding}>' +DOTTED_DECIMAL_RE = re.compile(r'[0-9]+(\.[0-9]+)*') + +def normalize_oid(oid): + return oid.lower().strip() + class Schema: '''Collection of LDAP syntaxes, matching rules, attribute types and object classes forming an LDAP schema.''' def __init__(self, object_classes=None, attribute_types=None, matching_rules=None, syntaxes=None): - attribute_types = list(attribute_types or []) - matching_rules = list(matching_rules or []) - syntaxes = list(syntaxes or []) - self.object_classes = {} - for objectclass in object_classes or []: - self.object_classes[objectclass.oid] = objectclass - attribute_types += objectclass.must + objectclass.may - self.attribute_types = {} - self.attribute_types_by_name = {} - self.attribute_types_by_unique_name = {} - self.user_attribute_types = set() - for attribute_type in attribute_types: - self.attribute_types[attribute_type.oid] = attribute_type - for name in attribute_type.names: - name = name.lower() - self.attribute_types_by_name[name] = \ - self.attribute_types_by_name.get(name, set()) | {attribute_type} - self.attribute_types_by_unique_name[attribute_type.name.lower()] = attribute_type - if attribute_type.usage == AttributeTypeUsage.userApplications: - self.user_attribute_types.add(attribute_type) - if attribute_type.equality is not None: - matching_rules += [attribute_type.equality] - if attribute_type.ordering is not None: - matching_rules += [attribute_type.ordering] - if attribute_type.substr is not None: - matching_rules += [attribute_type.substr] - syntaxes += [type(attribute_type.syntax)] - self.matching_rules = {} - for matching_rule in matching_rules: - self.matching_rules[matching_rule.oid] = matching_rule - syntaxes += [type(matching_rule.syntax)] - self.syntaxes = {} + self.syntaxes = [] + self.matching_rules = [] + self.attribute_types = [] + self.user_attribute_types = [] + self.object_classes = [] + self.__attribute_type_by_oid = {} + self.__attribute_type_and_subtypes_by_oid = {} + self.__oid_names = {} for syntax in syntaxes: - self.syntaxes[syntax.oid] = syntax - self.object_class_definitions = [objectclass.schema_encoding for objectclass in self.object_classes.values()] - self.syntax_definitions = [syntax.encode_syntax_definition() for syntax in self.syntaxes.values()] - self.matching_rule_definitions = [matching_rule.encode_syntax_definition() for matching_rule in self.matching_rules.values()] - self.attribute_type_definitions = [attribute_type.schema_encoding for attribute_type in self.attribute_types.values()] + self.register_syntax(syntax) + for matching_rule in matching_rules: + self.register_matching_rule(matching_rule) + for attribute_type in attribute_types: + self.register_attribute_type(attribute_type) + for object_class in object_classes: + self.register_object_class(object_class) def extend(self, *schemas, object_classes=None, attribute_types=None, matching_rules=None, syntaxes=None): - object_classes = list(self.object_classes.values()) + list(object_classes or []) - attribute_types = list(self.attribute_types.values()) + list(attribute_types or []) - matching_rules = list(self.matching_rules.values()) + list(matching_rules or []) - syntaxes = list(self.syntaxes.values()) + list(syntaxes or []) + object_classes = self.object_classes + (object_classes or []) + attribute_types = self.attribute_types + (attribute_types or []) + matching_rules = self.matching_rules + (matching_rules or []) + syntaxes = self.syntaxes + (syntaxes or []) for schema in schemas: - object_classes += list(schema.object_classes.values()) - attribute_types += list(schema.attribute_types.values()) - matching_rules += list(schema.matching_rules.values()) - syntaxes += list(schema.syntaxes.values()) + object_classes += schema.object_classes + attribute_types += schema.attribute_types + matching_rules += schema.matching_rules + syntaxes += schema.syntaxes return Schema(object_classes, attribute_types, matching_rules, syntaxes) - def lookup_attribute(self, oid_or_name, fail_if_not_found=False): - if isinstance(oid_or_name, AttributeType): - if self.attribute_types.get(oid_or_name.oid) != oid_or_name: - raise Exception() - return oid_or_name - if oid_or_name in self.attribute_types: - return self.attribute_types[oid_or_name] - result = self.attribute_types_by_unique_name.get(oid_or_name.lower()) - if result is None and fail_if_not_found: - raise Exception(f'Attribute "{oid_or_name}" not in schema') - return result - - def lookup_attribute_list(self, oid_or_name): - if oid_or_name in self.attribute_types_by_name: - return list(self.attribute_types_by_name[oid_or_name]) - result = self.lookup_attribute(oid_or_name) - if result is None: - return [] - return [result] + def register_syntax(self, syntax): + '''Add syntax to schema + + :param syntax: Syntax (subclass, not instance!) to add + :type syntax: Syntax subclass''' + if syntax in self.syntaxes: + return + self.register_oid(syntax.oid) + self.syntaxes.append(syntax) + + def register_matching_rule(self, matching_rule): + '''Add matching rule and the referenced syntax to schema + + :param matching_rule: Matching rule to add + :type matching_rule: MatchingRule''' + if matching_rule in self.matching_rules: + return + self.register_syntax(matching_rule.syntax) + self.register_oid(matching_rule.oid, *matching_rule.names) + self.matching_rules.append(matching_rule) + + def register_attribute_type(self, attribute_type): + '''Add attribute type and all referenced matching rules and syntaxes to schema + + :param attribute_type: Attribute type to add + :type attribute_type: AttributeType''' + if attribute_type in self.attribute_types: + return + self.register_syntax(attribute_type.syntax) + if attribute_type.equality: + self.register_matching_rule(attribute_type.equality) + if attribute_type.ordering: + self.register_matching_rule(attribute_type.ordering) + if attribute_type.substr: + self.register_matching_rule(attribute_type.substr) + self.register_oid(attribute_type.oid, *attribute_type.names) + for oid in [attribute_type.oid] + attribute_type.names: + self.__attribute_type_by_oid[normalize_oid(oid)] = attribute_type + self.__attribute_type_and_subtypes_by_oid.setdefault(normalize_oid(oid), []) + self.__attribute_type_and_subtypes_by_oid[normalize_oid(oid)].append(attribute_type) + for oid in attribute_type.inherited_names: + self.__attribute_type_and_subtypes_by_oid.setdefault(normalize_oid(oid), []) + self.__attribute_type_and_subtypes_by_oid[normalize_oid(oid)].append(attribute_type) + self.attribute_types.append(attribute_type) + if attribute_type.usage == AttributeTypeUsage.userApplications: + self.user_attribute_types.append(attribute_type) + + def register_object_class(self, object_class): + '''Add object class and all referenced attribute types, matching rules and syntaxes to schema + + :param object_class: Object class to add + :type object_class: ObjectClass''' + if object_class in self.object_classes: + return + for attribute_type in object_class.may + object_class.must: + self.register_attribute_type(attribute_type) + self.register_oid(object_class.oid, *object_class.names) + self.object_classes.append(object_class) + + def get_attribute_type(self, oid): + '''Get attribute type by its OID + + :param oid: Numeric OID or short descriptive name of attribute type + :type oid: str + :return: Attribute type identified by OID + :rtype: AttributeType + :raises KeyError: if attribute type is not found''' + attribute_type = self.__attribute_type_by_oid.get(normalize_oid(oid)) + if attribute_type is None: + raise KeyError(f'Attribute "{oid}" not in schema') + return attribute_type + + def get_attribute_type_and_subtypes(self, oid): + '''Get attribute type by its OID with all subtypes (if any) + + :param oid: Numeric OID or short descriptive name of attribute type + :type oid: str + :return: List containing the attribute type and its subtypes or empty list + if attribute type is not found + :rtype: List[AttributeType]''' + return self.__attribute_type_and_subtypes_by_oid.get(normalize_oid(oid), []) + + def register_oid(self, numeric_oid, *names): + '''Register numeric OID and optionally short descriptive names with the schema + + Both the numeric OID and all names must be unique within the schema. + + :param numeric_oid: Numeric OID + :type numeric_oid: str + :param names: Short descriptive names for OID + :type names: str''' + for name in names: + if self.__oid_names.get(normalize_oid(name), numeric_oid) != numeric_oid: + raise Exception(f'OID short descriptive name "{name}" is already used in schema') + for name in names: + self.__oid_names[normalize_oid(name)] = numeric_oid + + def get_numeric_oid(self, oid): + '''Return numeric OID for a given OID short descriptive name + + If `oid` is a numeric OID it is returned normalized. + + :param oid: OID short descriptive name or numeric OID + :type oid: str + :return: Numeric OID in dotted-decimal form or None if `oid` is not + recognized and is not a numeric OID. + :rtype: str or None''' + oid = normalize_oid(oid) + if DOTTED_DECIMAL_RE.fullmatch(oid): + return oid + return self.__oid_names.get(oid)