diff --git a/ldapserver/server.py b/ldapserver/server.py index 03789fc7bbb5d26dd26a680de2c59621dbbbe80e..d20ed66f6ec495155e17588f9e2ac553d66893be 100644 --- a/ldapserver/server.py +++ b/ldapserver/server.py @@ -39,6 +39,12 @@ def mark_last(iterable): if prev_item is not None: yield prev_item, True +def enforce_size_limit(iterable, limit): + for index, item in enumerate(iterable): + if index >= limit: + raise exceptions.LDAPSizeLimitExceeded() + yield item + class RequestLogAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return self.extra['trace_id'] + ': ' + msg, kwargs @@ -427,7 +433,10 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): results = self.do_search(op.baseObject, op.scope, op.filter) results = map(lambda obj: obj.search(op.baseObject, op.scope, op.filter, op.attributes, op.typesOnly), results) results = filter(None, results) - iterator = iter(mark_last(results)) + results = mark_last(results) + if op.sizeLimit: + results = enforce_size_limit(results, op.sizeLimit) + iterator = iter(results) else: # Continue existing paged search try: iterator, orig_op = self.__paged_searches.pop(paged_control.cookie) @@ -470,6 +479,8 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): for obj in self.do_search(op.baseObject, op.scope, op.filter): entry = obj.search(op.baseObject, op.scope, op.filter, op.attributes, op.typesOnly) if entry: + if op.sizeLimit and result_count >= op.sizeLimit: + raise exceptions.LDAPSizeLimitExceeded() self.logger.debug('SEARCH entry %r', entry) result_count += 1 yield entry