diff --git a/examples/new_api.py b/examples/new_api.py new file mode 100644 index 0000000000000000000000000000000000000000..776e97a709d1066cff24e58e9e73c6e24707b571 --- /dev/null +++ b/examples/new_api.py @@ -0,0 +1,81 @@ +from ldapserver import Subtree, Server +from ldapserver.exceptions import LDAPInvalidCredentials + +api = None # Overwritten later +bind_password = None # if None anonymous reads are allowed + +server = Server(schema) + +service = server.add_entry('cn=service,ou=system,dc=example,dc=com', + objectclass=['top', 'organizationalRole', 'simpleSecurityObject'], + structuralObjectClass=['organizationalRole']) + +@service.bind +def service_bind(password): + if password == bind_password: + return + raise LDAPInvalidCredentials() + +@server.template('uid={loginname},ou=users,dc=example,dc=com', + structuralObjectClass=['inetorgperson'], + objectClass=['top', 'inetorgperson', 'organizationalperson', 'person', 'posixaccount'], + sn=[' '], + cn=['{displayname}'], + displayname=['{displayname}'], + givenname=['{displayname}'], + homeDirectory=['/home/{loginname}'], + mail=['{email}'], + uid=['{loginname}'], + uidNumber=['{uid}'], + memberOf=['cn={*group_names},ou=groups,dc=example,dc=com']) +def users(loginname=None, uid=None, group_names=None, **kwargs): + if not connection.bind_state: + return + request_params = {} + if not request_params and loginname is not None: + request_params = {'loginname': normalize_user_loginname(loginname)} + if not request_params and uid is not None: + request_params = {'id': uid} + if not request_params and group_names: + request_params = {'group': normalize_group_name(group_names[0])} + for user in api.get_users(**request_params): + yield dict(loginname=user['loginname'], displayname=user['displayname'], + email=user['email'], uid=user['id'], groups=user['groups']) + +@users.bind +def users_bind(entry, password): + if api.check_password(loginname=entry['uid'][0], password=password): + return + raise LDAPInvalidCredentials() + +@server.bind_sasl_plain +def bind_sasl_plain(identity, password, authzid=None): + if authzid is not None and identity != authzid: + raise LDAPInvalidCredentials() + user = api.check_password(loginname=identity, password=password) + if user is None: + raise LDAPInvalidCredentials() + return user + +@server.template('cn={name},ou=groups', + structuralObjectClass=['groupOfUniqueNames'], + objectClass=['top', 'groupOfUniqueNames', 'posixGroup'], + description=[' '], + cn=['{name}'], + gidNumber=['{gid}'], + uniqueMember=['uid={*member_names},ou=users,dc=example,dc=com']) +def groups(name=None, gid=None, member_names=None): + if not connection.bind_state: + return + request_params = {} + if not request_params and name is not None: + request_params = {'name': normalize_group_name(name)} + if not request_params and gid is not None: + request_params = {'id': gid} + if not request_params and member_names: + request_params = {'member': normalize_user_loginname(member_names[0])} + for group in api.get_groups(**request_params): + yield groups.create(name=group['name'], gid=group['id'], member_names=group['members']) + +if __name__ == '__main__': + server.run() diff --git a/examples/passwd.py b/examples/passwd.py index ce024a344de426a3b3073b717095b43659fbe702..1d4590fc14af8b273a88cc20512f3e90747a332e 100644 --- a/examples/passwd.py +++ b/examples/passwd.py @@ -4,38 +4,49 @@ import pwd import grp import ldapserver +import ldapserver.future -logging.basicConfig(level=logging.INFO) +server = ldapserver.future.Server(ldapserver.schema.RFC2307BIS_SCHEMA) +server.add_entry('dc=example,dc=com', + objectClass=['top', 'dcObject', 'organization'], + structuralObjectClass=['organization'], +) +server.add_entry('ou=users,dc=example,dc=com', + objectClass=['top', 'organizationalUnit'], + structuralObjectClass=['organizationalUnit'], + ou=['users'], +) +server.add_entry('ou=groups,dc=example,dc=com', + objectClass=['top', 'organizationalUnit'], + structuralObjectClass=['organizationalUnit'], + ou=['groups'], +) -class RequestHandler(ldapserver.LDAPRequestHandler): - subschema = ldapserver.SubschemaSubentry(ldapserver.schema.RFC2307BIS_SCHEMA, 'cn=Subschema') +@server.template('uid={name},ou=users,dc=example,dc=com', + objectClass=['top', 'organizationalperson', 'person', 'posixaccount'], + structuralObjectClass=['organizationalperson'], + uid=['{name}'], + uidNumber=['{uid}'], + gidNumber=['{gid}'], + cn=['{gecos}'], + homeDirectory=['{home}'] +) +def users(name=None, gecos=None, uid=None, gid=None, home=None): + print('users', repr(name), repr(gecos), repr(uid), repr(gid), repr(home)) + for user in pwd.getpwall(): + yield {'name': user.pw_name, 'gecos': user.pw_gecos or ' ', 'uid': user.pw_uid, 'gid': user.pw_gid, 'home': user.pw_dir} - def do_search(self, basedn, scope, filterobj): - yield from super().do_search(basedn, scope, filterobj) - yield self.subschema.ObjectEntry('dc=example,dc=com', **{ - 'objectClass': ['top', 'dcObject', 'organization'], - 'structuralObjectClass': ['organization'], - }) - user_gids = {} - for user in pwd.getpwall(): - user_gids[user.pw_gid] = user_gids.get(user.pw_gid, set()) | {user.pw_name} - yield self.subschema.ObjectEntry(self.subschema.DN('ou=users,dc=example,dc=com', uid=user.pw_name), **{ - 'objectClass': ['top', 'organizationalperson', 'person', 'posixaccount'], - 'structuralObjectClass': ['organizationalperson'], - 'uid': [user.pw_name], - 'uidNumber': [user.pw_uid], - 'gidNumber': [user.pw_gid], - 'cn': [user.pw_gecos], - }) - for group in grp.getgrall(): - members = set(group.gr_mem) | user_gids.get(group.gr_gid, set()) - yield self.subschema.ObjectEntry(self.subschema.DN('ou=groups,dc=example,dc=com', cn=group.gr_name), **{ - 'objectClass': ['top', 'groupOfUniqueNames', 'posixGroup'], - 'structuralObjectClass': ['groupOfUniqueNames'], - 'cn': [group.gr_name], - 'gidNumber': [group.gr_gid], - 'uniqueMember': [self.subschema.DN('ou=user,dc=example,dc=com', uid=name) for name in members], - }) +@server.template('cn={name},ou=groups,dc=example,dc=com', + objectClass=['top', 'groupOfUniqueNames', 'posixGroup'], + structuralObjectClass=['groupOfUniqueNames'], + cn=['{name}'], + gidNumber=['{gid}'], + uniqueMember=['uid={*member_names},ou=users,dc=example,dc=com'], +) +def groups(name=None, gid=None, member_names=None): + print('groups', repr(name), repr(gid), repr(member_names)) + for group in grp.getgrall(): + yield {'name': group.gr_name, 'gid': group.gr_gid, 'member_names': group.gr_mem} if __name__ == '__main__': - socketserver.ThreadingTCPServer(('127.0.0.1', 3890), RequestHandler).serve_forever() + socketserver.ThreadingTCPServer(('127.0.0.1', 3890), server.socketserver_handler).serve_forever() diff --git a/ldapserver/future.py b/ldapserver/future.py new file mode 100644 index 0000000000000000000000000000000000000000..93c6bb37f4133705337ede15cd30e45df85a1d0c --- /dev/null +++ b/ldapserver/future.py @@ -0,0 +1,478 @@ +import traceback +import re +import ssl +import socketserver +import typing +import logging +import time +import random +import string +import itertools +import collections.abc +import enum +import contextvars + +from . import ldap, exceptions, entries, asn1, schema, value_templates +from .dn import DN, RDN, RDNAssertion + +''' +server-level +- simple anonymous bind +- SASL PLAIN bind +- SASL EXTERNAL bind +- SASL ANONYMOUS bind +- other SASL bind +- unbind +- add +- abandon +- extended + +special +- search + +object-level +- simple authenticated bind (with dn) +- simple unauthenticated bind (with dn) +- modify +- modifydn +- compare +- del +''' + +class Server: + def __init__(self, schema): + self.schema = schema + self.lookup_handlers = [] + self.search_handlers = [] + self.sasl_handlers = {} # 'METHOD': func + self.extended_handlers = {} # 'OID': func + + class ServerRequestHandler(RequestHandler): + ldap_server = self + + self.socketserver_handler = ServerRequestHandler + + def sasl_handler(self, method): + def decorator(func): + assert method not in self.sasl_bind_handler + self.sasl_bind_handler[method] = func + return func + return decorator + + def sasl_plain_handler(self, func): + @self.bind_sasl('PLAIN') + def wrapper(credentials, dn): + if credentials is None: + raise exceptions.LDAPProtocolError('Unsupported protocol version') + authzid, authcid, password = credentials.split(b'\0', 2) + return func(authcid.decode(), password.decode(), authzid.decode() or None), None + return func + + def extended_handler(self, oid): + def decorator(func): + assert oid not in self.extended_handlers + self.extended_handlers[oid] = func + return func + return decorator + + def add_entry(self, dn, **attributes): + entry = Entry(self.schema, dn, attributes) + self.lookup_handlers.append(entry.lookup) + self.search_handlers.append(entry.search) + return entry + + def template(self, dn_template, **attribute_templates): + def decorator(func): + templ = Template(self.schema, dn_template, func, attribute_templates) + self.lookup_handlers.append(templ.lookup) + self.search_handlers.append(templ.search) + return templ + return decorator + + def lookup_entry(self, dn): + for func in self.lookup_handlers: + entry = func(dn) + if entry: + return entry + raise exceptions.LDAPNoSuchObject() + + def process_connect(self): + connection.bind_object = None + + def process_message(self, shallowmsg: ldap.ShallowLDAPMessage) -> typing.Iterable[ldap.LDAPMessage]: + msgtypes = { + ldap.BindRequest: (self.process_bind, ldap.BindResponse), + #ldap.UnbindRequest: (self.process_unbind, None), + ldap.SearchRequest: (self.process_search, ldap.SearchResultDone), + #ldap.ModifyRequest: (self.process_modify, ldap.ModifyResponse), + #ldap.AddRequest: (self.process_add, ldap.AddResponse), + #ldap.DelRequest: (self.process_delete, ldap.DelResponse), + #ldap.ModifyDNRequest: (self.process_modifydn, ldap.ModifyDNResponse), + ldap.CompareRequest: (self.process_compare, ldap.CompareResponse), + #ldap.AbandonRequest: (self.process_abandon, None), + ldap.ExtendedRequest: (self.process_extended, ldap.ExtendedResponse), + } + handler, response_type = msgtypes.get(shallowmsg.protocolOpType, (None, None)) + try: + if handler is None: + raise exceptions.LDAPProtocolError() + try: + msg = shallowmsg.decode()[0] + except ValueError as e: + raise exceptions.LDAPProtocolError() from e + for args in handler(msg.protocolOp, msg.controls): + response, controls = args if isinstance(args, tuple) else (args, None) + yield ldap.LDAPMessage(shallowmsg.messageID, response, controls) + except exceptions.LDAPError as e: + if response_type is not None: + respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message)) + yield respmsg + except Exception as e: # pylint: disable=broad-except + traceback.print_exc() + if response_type is not None: + respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(ldap.LDAPResultCode.other)) + yield respmsg + + def process_bind(self, op, controls=None): + if op.version != 3: + raise exceptions.LDAPProtocolError('Unsupported protocol version') + auth = op.authentication + if isinstance(auth, ldap.SimpleAuthentication): + if not op.name: + connection.bind_object = None + else: + try: + entry = self.lookup_entry(op.name) + except exceptions.LDAPNoSuchObject as exc: + raise exceptions.LDAPInvalidCredentials() from exc + if auth.password: + connection.bind_object = entry.bind(auth.password) + else: + connection.bind_object = entry.bind_unauthenticated() + yield ldap.BindResponse(ldap.LDAPResultCode.success) + elif isinstance(auth, ldap.SaslCredentials): + if auth.mechanism not in self.sasl_handlers: + raise exceptions.LDAPAuthMethodNotSupported() + connection.bind_object, resp = self.sasl_handlers[auth.mechanism](auth.credentials) + yield ldap.BindResponse(ldap.LDAPResultCode.success, serverSaslCreds=resp) + else: + raise exceptions.LDAPAuthMethodNotSupported() + + def process_search(self, op, controls=None): + for func in self.search_handlers: + yield from func(op.baseObject, op.scope, op.filter, op.attributes, op.typesOnly) + yield ldap.SearchResultDone(ldap.LDAPResultCode.success) + + def process_compare(self, op, controls=None): + if self.lookup_entry(op.entry).compare(op.ava.attributeDesc, op.ava.assertionValue): + return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)] + else: + return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)] + + def process_extended(self, op, controls=None): + if op.requestName not in self.extended_handlers: + exceptions.LDAPProtocolError('Unknown extended request') + resp = self.extended_handlers[op.requestName](op.requestValue) + yield ldap.ExtendedResponse(ldap.LDAPResultCode.success, responseValue=resp) + +class FilterResult(enum.Enum): + TRUE = enum.auto() + FALSE = enum.auto() + UNDEFINED = enum.auto() + +class Entry(entries.AttributeDict): + def __init__(self, schema, dn, attributes): + super().__init__(schema, **attributes) + #: Entry's distinguished name (:class:`DN`) + self.dn = DN(schema, dn) + + def __search_match_dn(self, basedn, scope): + if scope == ldap.SearchScope.baseObject: + return self.dn == basedn + elif scope == ldap.SearchScope.singleLevel: + return self.dn.is_direct_child_of(basedn) + elif scope == ldap.SearchScope.wholeSubtree: + return self.dn.in_subtree_of(basedn) + else: + return False + + def __search_match_filter(self, filter_obj): + # pylint: disable=too-many-branches,too-many-return-statements,too-many-statements,too-many-nested-blocks + if isinstance(filter_obj, ldap.FilterAnd): + result = FilterResult.TRUE + for subfilter in filter_obj.filters: + subresult = self.__search_match_filter(subfilter) + if subresult == FilterResult.FALSE: + return FilterResult.FALSE + elif subresult == FilterResult.UNDEFINED: + result = FilterResult.UNDEFINED + return result + elif isinstance(filter_obj, ldap.FilterOr): + result = FilterResult.FALSE + for subfilter in filter_obj.filters: + subresult = self.__search_match_filter(subfilter) + if subresult == FilterResult.TRUE: + return FilterResult.TRUE + elif subresult == FilterResult.UNDEFINED: + result = FilterResult.UNDEFINED + return result + elif isinstance(filter_obj, ldap.FilterNot): + subresult = self.__search_match_filter(filter_obj.filter) + if subresult == FilterResult.TRUE: + return FilterResult.FALSE + elif subresult == FilterResult.FALSE: + return FilterResult.TRUE + else: + return subresult + elif isinstance(filter_obj, (ldap.FilterPresent, + ldap.FilterEqual, + ldap.FilterSubstrings, + ldap.FilterApproxMatch, + ldap.FilterGreaterOrEqual, + ldap.FilterLessOrEqual)): + try: + attribute_type = self.schema.attribute_types[filter_obj.attribute] + except KeyError: + return FilterResult.UNDEFINED + values = self.get(filter_obj.attribute, subtypes=True) + try: + if isinstance(filter_obj, ldap.FilterPresent): + result = values != [] + elif isinstance(filter_obj, ldap.FilterEqual): + result = attribute_type.match_equal(values, filter_obj.value) + elif isinstance(filter_obj, ldap.FilterSubstrings): + result = attribute_type.match_substr(values, + filter_obj.initial_substring, filter_obj.any_substrings, filter_obj.final_substring) + elif isinstance(filter_obj, ldap.FilterApproxMatch): + result = attribute_type.match_approx(values, filter_obj.value) + elif isinstance(filter_obj, ldap.FilterGreaterOrEqual): + result = attribute_type.match_greater_or_equal(values, filter_obj.value) + elif isinstance(filter_obj, ldap.FilterLessOrEqual): + result = attribute_type.match_less_or_equal(values, filter_obj.value) + else: + return FilterResult.UNDEFINED + return FilterResult.TRUE if result else FilterResult.FALSE + except exceptions.LDAPError: + return FilterResult.UNDEFINED + elif isinstance(filter_obj, ldap.FilterExtensibleMatch): + attribute_types = [] + matching_rule = None + try: + if filter_obj.type is not None and filter_obj.matchingRule is not None: + attribute_types = [self.schema.attribute_types[filter_obj.type]] + matching_rule = self.schema.matching_rules[filter_obj.matchingRule] + elif filter_obj.type is not None: + attribute_types = [self.schema.attribute_types[filter_obj.type]] + elif filter_obj.matchingRule is not None: + matching_rule = self.schema.matching_rules[filter_obj.matchingRule] + attribute_types = matching_rule.compatible_attribute_types + except KeyError: + pass + result = FilterResult.FALSE + for attribute_type in attribute_types: + values = self.get(attribute_type.oid, subtypes=True) + if filter_obj.dnAttributes: + for rdn in self.dn: + for assertion in rdn: + if assertion.attribute.lower() == attribute_type.ref.lower(): + values.append(assertion.value) + try: + if attribute_type.match_extensible(values, filter_obj.matchValue, matching_rule): + return FilterResult.TRUE + except exceptions.LDAPError: + result = FilterResult.UNDEFINED + return result + else: + return FilterResult.UNDEFINED + + def search(self, base_obj, scope, filter_obj, attributes, types_only): + if not self.__search_match_dn(DN.from_str(self.schema, base_obj), scope): + return [] + if not self.__search_match_filter(filter_obj) == FilterResult.TRUE: + return [] + selected_attributes = set() + for selector in attributes or ['*']: + if selector == '*': + selected_attributes |= self.schema.user_attribute_types + elif selector == '1.1': + continue + elif selector in self.schema.attribute_types: + selected_attributes.add(self.schema.attribute_types[selector]) + partial_attributes = [] + for attribute_type, values in self.items(types=True): + if attribute_type in selected_attributes: + if types_only: + values = [] + encoded_values = [attribute_type.encode(value) for value in values] + partial_attributes.append(ldap.PartialAttribute(attribute_type.ref, encoded_values)) + return [ldap.SearchResultEntry(str(self.dn), partial_attributes)] + + def lookup(self, dn): + return self if self.dn == DN.from_str(self.schema, dn) else None + + def bind(self, password): + raise exceptions.LDAPInvalidCredentials() + + def bind_handler(self, func): + self.bind = func + return func + + def compare(self, attribute, value): + try: + attribute_type = self.schema.attribute_types[attribute] + except KeyError as exc: + raise exceptions.LDAPUndefinedAttributeType() from exc + return attribute_type.match_equal(self.get(attribute_type, subtypes=True), value) + +class AttributeTemplate: + def __init__(self, attribute_type, exprs): + self.attribute_type = attribute_type + self.value_templates = [] + for expr in exprs: + value_template = attribute_type.syntax.definition.parse_template(attribute_type.schema, expr) + self.value_templates.append(value_template) + + def render(self, params): + results = [] + for template in self.value_templates: + if any([params.get(param) is None for param in template.optional_params]): + continue + if not template.starred_param: + results.append(template.render(params)) + continue + for value in params.get(template.starred_param, []): + results.append(template.render(params | {template.starred_param: value})) + return results + + def match_equal(self, assertion_value): + if not self.attribute_type.equality: + return {} + return self.attribute_type.equality.definition.match_equal_template(self.attribute_type.schema, self.value_templates, assertion_value) + +class Template: + def __init__(self, schema, dn_template, lookup_func, attribute_templates): + self.schema = schema + self.dn_template = value_templates.DNTemplate.from_str(schema, dn_template) + if self.dn_template.starred_param is not None: + raise ValueError() + if self.dn_template.optional_params: + raise ValueError() + self.lookup_func = lookup_func + self.attribute_templates = {schema.attribute_types[key].ref: AttributeTemplate(schema.attribute_types[key], values) for key, values in attribute_templates.items()} + self.bind_handler_func = None + + def bind_handler(self, func): + self.bind_handler_func = func + + def __lookup_entries(self, input_params): + for output_params in self.lookup_func(**input_params): + entry = Entry(self.schema, self.dn_template.render(output_params), + {key: value.render(output_params) for key, value in self.attribute_templates.items()}) + if self.bind_handler_func: + entry.bind = lambda password: self.bind_handler_func(entry, password) + yield entry + + def __extract_filter_constraints(self, filter_obj): + if isinstance(filter_obj, ldap.FilterEqual): + try: + attribute_type = self.schema.attribute_types[filter_obj.attribute] + except KeyError: + return entries.AttributeDict(self.schema) + if attribute_type.equality is None: + return entries.AttributeDict(self.schema) + assertion_value = attribute_type.equality.syntax.decode(filter_obj.value) + if assertion_value is None: + return entries.AttributeDict(self.schema) + return entries.AttributeDict(self.schema, **{filter_obj.attribute: [assertion_value]}) + if isinstance(filter_obj, ldap.FilterAnd): + result = entries.AttributeDict(self.schema) + for subfilter in filter_obj.filters: + for name, values in self.__extract_filter_constraints(subfilter).items(): + result[name] += values + return result + return entries.AttributeDict(self.schema) + + def search(self, base_obj, scope, filter_obj, attributes, types_only): + constraints = self.dn_template.match(DN.from_str(self.schema, base_obj), scope=scope) + if constraints is False: + return + if constraints is True: + constraints = {} + constraints = {key: {value} for key, value in constraints.items()} + filter_constraints = self.__extract_filter_constraints(filter_obj) + for attribute_type, values in filter_constraints.items(types=True): + if attribute_type.ref not in self.attribute_templates: + continue + templates = self.attribute_templates[attribute_type.ref].value_templates + for value in values: + subresult = attribute_type.equality.definition.match_equal_template(self.schema, templates, value) + if subresult is False: + return + if subresult is True: + continue + for key, value in subresult.items(): + constraints.setdefault(key, set()).add(value) + constraints = {key: list(values)[0] for key, values in constraints.items() if len(values) == 1} + for entry in self.__lookup_entries(constraints): + yield from entry.search(base_obj, scope, filter_obj, attributes, types_only) + + def lookup(self, dn): + dn = DN.from_str(self.schema, dn) + constraints = self.dn_template.match(dn) + if constraints is False: + return None + if constraints is True: + constraints = {} + for entry in self.__lookup_entries(constraints): + if entry.dn == dn: + return entry + +class Connection: + pass + +_connection: contextvars.ContextVar[Connection] = contextvars.ContextVar('connection') + +internal_conn = _connection + +class Proxy: + def __init__(self, var): + super().__setattr__('__var', var) + + def __getattr__(self, key): + return getattr(super().__getattribute__('__var').get(), key) + + def __setattr__(self, key, value): + return setattr(super().__getattribute__('__var').get(), key, value) + + def __delattr__(self, key): + return delattr(super().__getattribute__('__var').get(), key) + +connection = Proxy(_connection) + +class RequestHandler(socketserver.BaseRequestHandler): + ldap_server = None + + def setup(self): + super().setup() + self.keep_running = True + conn = Connection() + self.ctxtoken = _connection.set(conn) + + def handle(self): + buf = b'' + while self.keep_running: + try: + shallowmsg, buf = ldap.ShallowLDAPMessage.from_ber(buf) + for respmsg in self.ldap_server.process_message(shallowmsg): + self.request.sendall(ldap.LDAPMessage.to_ber(respmsg)) + except asn1.IncompleteBERError: + chunk = self.request.recv(4096) + if not chunk: + self.keep_running = False + self.request.close() + else: + buf += chunk + self.request.close() + + def finish(self): + _connection.reset(self.ctxtoken) diff --git a/ldapserver/schema/definitions.py b/ldapserver/schema/definitions.py index 3943cda98db423ae94b21c969a7b493967b3f2f0..abda2011546d56514c76d25703395bd299b1d91e 100644 --- a/ldapserver/schema/definitions.py +++ b/ldapserver/schema/definitions.py @@ -1,7 +1,7 @@ import enum import re -from .. import exceptions +from .. import exceptions, value_templates __all__ = [ 'SyntaxDefinition', @@ -207,6 +207,9 @@ class SyntaxDefinition: :raises exceptions.LDAPError: if raw_value is invalid''' raise exceptions.LDAPInvalidAttributeSyntax() + def parse_template(self, schema, expr): + return value_templates.SimpleValueTemplate(expr, decode=lambda expr: self.decode(schema, expr.encode())) + class MatchingRuleKind(enum.Enum): '''Values for :any:`MatchingRuleDefinition.kind`''' #: @@ -329,6 +332,48 @@ class MatchingRuleDefinition: :raises exceptions.LDAPError: if the result is undefined''' raise exceptions.LDAPInappropriateMatching() + def match_equal_template(self, schema, templates, assertion_value): + static_values = [] + params = set() + starred_params = set() + for template in templates: + if isinstance(template, value_templates.SimpleValueTemplate): + if not template.param: + static_values.append(template.static_value) + else: + params.add(template.param) + if template.starred_param: + starred_params.add(template.starred_param) + elif isinstance(template, value_templates.SubstrValueTemplate): + if len(template.tokens) != 1: + return {} + text, key = template.tokens[0] + if not text and key: + params.add(key) + if template.starred_param: + starred_params.add(template.starred_param) + elif text and not key: + static_values.append(text) + else: + return {} + else: + return {} + try: + result = self.match_equal(schema, static_values, assertion_value) + if result: + return True + except exceptions.LDAPInappropriateMatching: + pass + if starred_params: + if len(starred_params) == 1: + return {params.pop(): [assertion_value]} + return {} + elif params: + if len(params) == 1: + return {params.pop(): assertion_value} + else: + return False + class MatchingRuleUseDefinition: def __init__(self, oid, name=None, desc='', obsolete=False, applies=None, extensions=None): #: Numeric OID (string) diff --git a/ldapserver/schema/matching_rules.py b/ldapserver/schema/matching_rules.py index 851eb6f759e7d15ad82777711b0e85be7ee6eb6b..7716061ecd6ab0c01114e3165ba99ff96cf9d8db 100644 --- a/ldapserver/schema/matching_rules.py +++ b/ldapserver/schema/matching_rules.py @@ -23,6 +23,11 @@ class GenericMatchingRuleDefinition(MatchingRuleDefinition): return True return False +class DistinguishedNameMatchingRuleDefinition(GenericMatchingRuleDefinition): + def match_equal_template(self, schema, templates, assertion_value): + for template in templates: + pass + def _substr_match(attribute_value, inital_substring, any_substrings, final_substring): if inital_substring: if not attribute_value.startswith(inital_substring): diff --git a/ldapserver/schema/syntaxes.py b/ldapserver/schema/syntaxes.py index 2660a03f25fbd990d7e951d691ff9ad17ed831b0..6eae0af3a7ae294cc95d6d65d66072cbb54a245a 100644 --- a/ldapserver/schema/syntaxes.py +++ b/ldapserver/schema/syntaxes.py @@ -2,7 +2,7 @@ import re import datetime from .definitions import SyntaxDefinition -from .. import dn, exceptions +from .. import dn, exceptions, value_templates class BytesSyntaxDefinition(SyntaxDefinition): def encode(self, schema, value): @@ -29,6 +29,9 @@ class StringSyntaxDefinition(SyntaxDefinition): raise exceptions.LDAPInvalidAttributeSyntax() return value + def parse_template(self, schema, expr): + return value_templates.SubstrValueTemplate(expr) + class IntegerSyntaxDefinition(StringSyntaxDefinition): def __init__(self, oid, **kwargs): super().__init__(oid, encoding='ascii', re_pattern='([0-9]|-?[1-9][0-9]+)', **kwargs) @@ -42,6 +45,10 @@ class IntegerSyntaxDefinition(StringSyntaxDefinition): except ValueError as exc: raise exceptions.LDAPInvalidAttributeSyntax() from exc + def parse_template(self, schema, expr): + return value_templates.SimpleValueTemplate(expr, decode=lambda expr: self.decode(schema, expr.encode())) + + class SchemaElementSyntaxDefinition(SyntaxDefinition): def encode(self, schema, value): return str(value).encode('utf8') @@ -68,6 +75,9 @@ class DNSyntaxDefinition(SyntaxDefinition): except (UnicodeDecodeError, TypeError, ValueError) as exc: raise exceptions.LDAPInvalidAttributeSyntax() from exc + def parse_template(self, schema, expr): + return value_templates.DNTemplate.from_str(schema, expr) + class NameAndOptionalUIDSyntaxDefinition(StringSyntaxDefinition): def encode(self, schema, value): return str(value).encode('utf8') @@ -78,6 +88,9 @@ class NameAndOptionalUIDSyntaxDefinition(StringSyntaxDefinition): except (UnicodeDecodeError, TypeError, ValueError) as exc: raise exceptions.LDAPInvalidAttributeSyntax() from exc + def parse_template(self, schema, expr): + return value_templates.DNTemplate.from_str(schema, expr) + class GeneralizedTimeSyntaxDefinition(SyntaxDefinition): def encode(self, schema, value): if value.microsecond: diff --git a/ldapserver/value_templates.py b/ldapserver/value_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..159fb795bb925c615e970f22b84a507576b25638 --- /dev/null +++ b/ldapserver/value_templates.py @@ -0,0 +1,244 @@ +import re + +from . import ldap +from .dn import DN, RDN, RDNAssertion + +class SimpleValueTemplate: + def __init__(self, expr, decode): + self.param = None + self.starred_param = None + self.optional_params = set() + self.static_value = None + if not isinstance(expr, str): + self.static_value = decode(expr) + return + match = re.fullmatch(r'{([*?]?)([A-Za-z0-9_]+)}', expr) + if not match: + self.static_value = decode(expr.replace('{{', '{').replace('}}', '}')) + return + option, key = match.groups() + option = option or None + self.param = key + if option is None: + pass + elif option == '*': + self.starred_param = key + elif option == '?': + self.optional_params.add(key) + else: + raise ValueError(f'Invalid template option {option!r}') + + def render(self, params): + if self.param is None: + return self.static_value + return params[self.param] + +class SubstrValueTemplate: + def __init__(self, expr): + self.tokens = [] + self.starred_param = None + self.optional_params = set() + while expr: + match = re.match(r'^(([^{}]|{{|}})*)({([*?]?)([A-Za-z0-9_]+)})?', expr) + if not match or not match.end(): + raise ValueError('Invalid token') + text, _, _, option, key, = match.groups() + text = text.replace('{{', '{').replace('}}', '}') + option = option or None + expr = expr[match.end():] + self.tokens.append((text, key)) + if key is None: + pass + elif option is None: + pass + elif option == '?': + self.optional_params.add(key) + elif option == '*': + if self.starred_param is not None and self.starred_param != key: + raise ValueError('Only one parameter may be starred') + self.starred_param = key + else: + raise ValueError(f'Invalid template option {option!r}') + + def render(self, params): + result = '' + for text, key in self.tokens: + result += text + if key is not None: + value = params[key] + result += str(value) + return result + +class DNTemplate(tuple): + def __new__(cls, schema, *args): + for rdn in args: + if not isinstance(rdn, RDNTemplate): + raise TypeError(f'Argument {repr(rdn)} is of type {repr(type(rdn))}, expected ldapserver.RDNTemplate object') + dn = super().__new__(cls, args) + dn.schema = schema + dn.starred_param = None + dn.optional_params = set() + for rdn in args: + dn.optional_params |= rdn.optional_params + if rdn.starred_param: + if dn.starred_param and rdn.starred_param != dn.starred_param: + raise ValueError() + dn.starred_param = rdn.starred_param + return dn + + @classmethod + def from_str(cls, schema, expr): + escaped = False + rdns = [] + token = '' + for char in expr: + if escaped: + escaped = False + token += char + elif char == ',': + rdns.append(RDNTemplate.from_str(schema, token)) + token = '' + else: + if char == '\\': + escaped = True + token += char + if token: + rdns.append(RDNTemplate.from_str(schema, token)) + return cls(schema, *rdns) + + def render(self, params): + return DN(self.schema, *[rdn.render(params) for rdn in self]) + + def match(self, dn, scope=ldap.SearchScope.baseObject): + if scope == ldap.SearchScope.baseObject: + if len(self) != len(dn): + return False + elif scope == ldap.SearchScope.singleLevel: + if len(self) != len(dn) + 1: + return False + elif scope == ldap.SearchScope.wholeSubtree: + if len(self) < len(dn): + return False + result = True + for rdn_template, rdn in zip(self[::-1], dn[::-1]): + subresult = rdn_template.match(rdn) + if subresult is False: + return False + if subresult is True: + continue + if result is True: + result = {} + for key, value in subresult.items(): + result.setdefault(key, []).append(value) + if result is True: + return True + return {key: values[0] for key, values in result.items() if len(values) == 1} + +class RDNTemplate(tuple): + def __new__(cls, schema, *assertions): + for assertion in assertions: + if not isinstance(assertion, RDNAssertionTemplate): + raise TypeError(f'Argument {repr(assertion)} is of type {repr(type(assertion))}, expected ldapserver.RDNAssertionTemplate') + if assertion.attribute_type.schema is not schema: + raise ValueError('RDNAssertion has different schema') + assertions = list(assertions) + if not assertions: + raise ValueError('RDN must have at least one assertion') + rdn = super().__new__(cls, assertions) + rdn.schema = schema + rdn.starred_param = None + rdn.optional_params = set() + for assertion in assertions: + rdn.optional_params |= assertion.optional_params + if assertion.starred_param: + if rdn.starred_param and assertion.starred_param != rdn.starred_param: + raise ValueError() + rdn.starred_param = assertion.starred_param + return rdn + + @classmethod + def from_str(cls, schema, expr): + escaped = False + assertions = [] + token = '' + for char in expr: + if escaped: + escaped = False + token += char + elif char == '+': + assertions.append(RDNAssertionTemplate.from_str(schema, token)) + token = '' + else: + if char == '\\': + escaped = True + token += char + if token: + assertions.append(RDNAssertionTemplate.from_str(schema, token)) + return cls(schema, *assertions) + + def render(self, params): + return RDN(self.schema, *[assertion.render(params) for assertion in self]) + + def match(self, rdn): + if len(self) != len(rdn): + return False + if len(self) != 1: + return {} + return self[0].match(rdn[0]) + +class RDNAssertionTemplate: + def __init__(self, schema, attribute, value_template): + self.schema = schema + try: + self.attribute_type = schema.attribute_types[attribute] + except KeyError as exc: + raise ValueError('Invalid RDN attribute type: Attribute type undefined in schema') from exc + self.attribute = attribute + if not self.attribute_type.equality: + raise ValueError('Invalid RDN attribute type: Attribute type has no EQUALITY matching rule') + self.value_template = value_template + self.starred_param = value_template.starred_param + self.optional_params = value_template.optional_params + + @classmethod + def from_str(cls, schema, expr): + attribute, escaped_value = expr.split('=', 1) + if escaped_value.startswith('#'): + # The "#..." form is used for unknown attribute types and those without + # an LDAP string encoding. Supporting it would require us to somehow + # handle the hex-encoded BER encoding of the data. We'll stay away from + # this mess for now. + raise ValueError('Hex-encoded RDN assertion values are not supported') + escaped = False + hexdigit = None + encoded_value = b'' + for char in escaped_value: + if hexdigit is not None: + encoded_value += bytes.fromhex('%s%s'%(hexdigit, char)) + hexdigit = None + elif escaped: + if ord(char) in DN_SPECIAL + (b'\\'[0],): + encoded_value += char.encode('utf8') + elif ord(char) in HEXDIGITS: + hexdigit = char + else: + raise ValueError('Invalid escape: \\%s'%char) + escaped = False + elif char == '\\': + escaped = True + else: + encoded_value += char.encode('utf8') + try: + attribute_type = schema.attribute_types[attribute] + except KeyError as exc: + raise ValueError('Invalid RDN attribute type: Attribute type undefined in schema') from exc + value_template = attribute_type.syntax.definition.parse_template(schema, encoded_value.decode()) + return cls(schema, attribute, value_template) + + def render(self, params): + return RDNAssertion(self.schema, self.attribute, self.value_template.render(params)) + + def match(self, rdnassertion): + if self.attribute_type != rdnassertion.attribute_type: + return False + return self.attribute_type.equality.definition.match_equal_template(self.schema, [self.value_template], rdnassertion.value)