diff --git a/ldapserver/ldap.py b/ldapserver/ldap.py index 96262ac668f4422bb33f07362eeb502de1b57716..639d715abc76f9c74f81dfd7bee3176e570573cc 100644 --- a/ldapserver/ldap.py +++ b/ldapserver/ldap.py @@ -683,11 +683,14 @@ class ShallowLDAPMessage(asn1.BERType): raise TypeError() return obj.data -# Extended Operation Values +# StartTLS Extended Operation (RFC4511) +STARTTLS_OID = '1.3.6.1.4.1.1466.20037' -EXT_STARTTLS_OID = '1.3.6.1.4.1.1466.20037' -EXT_WHOAMI_OID = '1.3.6.1.4.1.4203.1.11.3' -EXT_PASSWORD_MODIFY_OID = '1.3.6.1.4.1.4203.1.11.1' +# "Who am I?" Extended Operation (RFC4532) +WHOAMI_OID = '1.3.6.1.4.1.4203.1.11.3' + +# Password Modify Extended Operation (RFC3062) +PASSWORD_MODIFY_OID = '1.3.6.1.4.1.4203.1.11.1' class PasswdModifyRequestValue(asn1.Sequence): SEQUENCE_FIELDS = [ @@ -706,3 +709,15 @@ class PasswdModifyResponseValue(asn1.Sequence): ] genPasswd: bytes + +# LDAP Control Extension for Simple Paged Results Manipulation (RFC2696) +PAGED_RESULTS_OID = '1.2.840.113556.1.4.319' + +class PagedResultsValue(asn1.Sequence): + SEQUENCE_FIELDS = [ + (asn1.Integer, 'size', 0, False), + (asn1.OctetString, 'cookie', b'', False) + ] + + size: int + cookie: bytes diff --git a/ldapserver/server.py b/ldapserver/server.py index a629a5c5929c990397e1eb65b8d93063b34cd202..4c1235bc2455d5cbee0fc25a52bf95a57c31c021 100644 --- a/ldapserver/server.py +++ b/ldapserver/server.py @@ -6,15 +6,38 @@ import logging import time import random import string +import itertools from . import asn1, exceptions, ldap, schema, objects from .dn import DN +def pop_control(controls, oid): + result = None + remaining_controls = [] + for control in controls or []: + if control.controlType == oid: + result = control + break + remaining_controls.append(control) + return result, remaining_controls + def reject_critical_controls(controls=None): for control in controls or []: if control.criticality: raise exceptions.LDAPUnavailableCriticalExtension() +def mark_last(iterable): + '''Yield (item, is_last) for all items in iterable + + is_last is True for the last items and False for other items.''' + prev_item = None + for item in iterable: + if prev_item is not None: + yield prev_item, False + prev_item = item + if prev_item is not None: + yield prev_item, True + class RequestLogAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return self.extra['trace_id'] + ': ' + msg, kwargs @@ -169,31 +192,46 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): self.rootdse['objectClass'] = ['top'] self.rootdse['supportedSASLMechanisms'] = self.get_sasl_mechanisms self.rootdse['supportedExtension'] = self.get_extentions + self.rootdse['supportedControl'] = self.get_controls self.rootdse['supportedLDAPVersion'] = ['3'] self.bind_object = None self.bind_sasl_state = None + self.__paged_searches = {} # pagination cookie -> (iterator, orig_op) + self.__paged_cookie_counter = 0 def get_extentions(self): '''Get supported LDAP extentions :returns: OIDs of supported LDAP extentions - :rtype: list of bytes objects + :rtype: list of strings Called whenever the root DSE attribute "supportedExtension" is queried.''' res = [] if self.supports_starttls: - res.append(ldap.EXT_STARTTLS_OID) + res.append(ldap.STARTTLS_OID) if self.supports_whoami: - res.append(ldap.EXT_WHOAMI_OID) + res.append(ldap.WHOAMI_OID) if self.supports_password_modify: - res.append(ldap.EXT_PASSWORD_MODIFY_OID) + res.append(ldap.PASSWORD_MODIFY_OID) + return res + + def get_controls(self): + '''Get supported LDAP controls + + :returns: OIDs of supported LDAP controls + :rtype: list of strings + + Called whenever the root DSE attribute "supportedControl" is queried.''' + res = [] + if self.supports_paged_results: + res.append(ldap.PAGED_RESULTS_OID) return res def get_sasl_mechanisms(self): '''Get supported SASL mechanisms :returns: Names of supported SASL mechanisms - :rtype: list of bytes objects + :rtype: list of strings SASL mechanism name are typically all-caps, like "EXTERNAL". @@ -418,11 +456,57 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): :any:`LDAPAuthMethodNotSupported` exception.''' raise exceptions.LDAPAuthMethodNotSupported() + supports_paged_results = True + + def __handle_search_paged(self, op, paged_control, controls=None): + def build_control(size=0, cookie=b''): + value = ldap.PagedResultsValue(size=size, cookie=cookie) + return ldap.Control(controlType=ldap.PAGED_RESULTS_OID, + criticality=True, controlValue=bytes(value)) + + # pylint: disable=no-member + paged_control = ldap.PagedResultsValue.from_ber(paged_control.controlValue)[0] + if not paged_control.cookie: # New paged search request + results = self.do_search(op.baseObject, op.scope, op.filter) + results = filter(lambda obj: obj.match_search(op.baseObject, op.scope, op.filter), results) + iterator = iter(mark_last(results)) + else: # Continue existing paged search + try: + iterator, orig_op = self.__paged_searches.pop(paged_control.cookie) + except KeyError as exc: + raise exceptions.LDAPUnwillingToPerform('Invalid pagination cookie') from exc + if ldap.ProtocolOp.to_ber(orig_op) != ldap.ProtocolOp.to_ber(op): + raise exceptions.LDAPUnwillingToPerform('Search parameter mismatch') + if not paged_control.size: # Cancel paged search + yield ldap.SearchResultDone(ldap.LDAPResultCode.success), [build_control()] + return + is_last = True + entries = 0 + time_start = time.perf_counter() + for obj, is_last in itertools.islice(iterator, 0, paged_control.size): + self.logger.debug('SEARCH entry %r', obj) + yield obj.get_search_result_entry(op.attributes, op.typesOnly) + entries += 1 + cookie = b'' + if not is_last: + cookie = str(self.__paged_cookie_counter).encode() + self.__paged_cookie_counter += 1 + self.__paged_searches[cookie] = iterator, op + yield ldap.SearchResultDone(ldap.LDAPResultCode.success), [build_control(cookie=cookie)] + time_end = time.perf_counter() + self.logger.info('SEARCH done page cookie=%r entries=%d duration_seconds=%.3f', cookie, entries, time_end - time_start) + def handle_search(self, op, controls=None): self.logger.info('SEARCH request dn=%r dn_scope=%s filter=%r, attributes=%r', op.baseObject, op.scope.name, op.filter.get_filter_string(), ' '.join(op.attributes)) + paged_control = None + if self.supports_paged_results: + paged_control, controls = pop_control(controls, ldap.PAGED_RESULTS_OID) reject_critical_controls(controls) + if paged_control: + yield from self.__handle_search_paged(op, paged_control, controls) + return entries = 0 time_start = time.perf_counter() for obj in self.do_search(op.baseObject, op.scope, op.filter): @@ -499,21 +583,21 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): def handle_extended(self, op, controls=None): reject_critical_controls(controls) - if op.requestName == ldap.EXT_STARTTLS_OID and self.supports_starttls: + if op.requestName == ldap.STARTTLS_OID and self.supports_starttls: self.logger.info('EXTENDED STARTTLS') # StartTLS (RFC 4511) - yield ldap.ExtendedResponse(ldap.LDAPResultCode.success, responseName=ldap.EXT_STARTTLS_OID) + yield ldap.ExtendedResponse(ldap.LDAPResultCode.success, responseName=ldap.STARTTLS_OID) try: self.do_starttls() except Exception: # pylint: disable=broad-except traceback.print_exc() self.keep_running = False - elif op.requestName == ldap.EXT_WHOAMI_OID and self.supports_whoami: + elif op.requestName == ldap.WHOAMI_OID and self.supports_whoami: self.logger.info('EXTENDED WHOAMI') # "Who am I?" Operation (RFC 4532) identity = (self.do_whoami() or '').encode() yield ldap.ExtendedResponse(ldap.LDAPResultCode.success, responseValue=identity) - elif op.requestName == ldap.EXT_PASSWORD_MODIFY_OID and self.supports_password_modify: + elif op.requestName == ldap.PASSWORD_MODIFY_OID and self.supports_password_modify: self.logger.info('EXTENDED PASSWORD_MODIFY') # Password Modify Extended Operation (RFC 3062) newpw = None