diff --git a/server.py b/server.py index a4dc39a3b9adedad8c5ab12dbef6a6273aeb35fc..5e203e8f85d6483e0981e6391150aca989b41bb2 100644 --- a/server.py +++ b/server.py @@ -45,6 +45,10 @@ class LDAPNoSuchObject(LDAPError): def __init__(self, message=''): super().__init__(LDAPResultCode.noSuchObject, message) +class LDAPUnavailableCriticalExtension(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.unavailableCriticalExtension, message) + def decode_msg(shallowmsg): try: return shallowmsg.decode()[0] @@ -103,6 +107,13 @@ class RootDSE(AttributeDict): attrs[name] = [encode_attribute(value) for value in values] return [('', attrs)] +def check_controls(controls=None, supported_oids=[]): + for control in controls or []: + if not control.criticality: + continue + if control.controlType not in supported_oids: + raise LDAPUnavailableCriticalExtension() + class LDAPRequestHandler(BaseRequestHandler): ssl_context = None @@ -141,6 +152,7 @@ class LDAPRequestHandler(BaseRequestHandler): return res def handle_bind(self, req): + check_controls(req.controls) op = req.protocolOp if op.version != 3: raise LDAPProtocolError('Unsupported protocol version') @@ -401,6 +413,7 @@ class LDAPRequestHandler(BaseRequestHandler): return bind_obj, response def handle_search(self, req): + check_controls(req.controls) search = req.protocolOp for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter): attributes = [PartialAttribute(name, values) for name, values in attributes.items()] @@ -426,9 +439,11 @@ class LDAPRequestHandler(BaseRequestHandler): return self.rootdse.search(baseobj, scope, filter) def handle_unbind(self, req): + check_controls(req.controls) self.keep_running = False def handle_extended(self, req): + check_controls(req.controls) op = req.protocolOp if op.requestName == '1.3.6.1.4.1.1466.20037': # StartTLS (RFC 4511) @@ -482,22 +497,27 @@ class LDAPRequestHandler(BaseRequestHandler): raise LDAPUnwillingToPerform('Password change is not supported') def handle_modify(self, req): + check_controls(req.controls) raise LDAPInsufficientAccessRights() def handle_add(self, req): + check_controls(req.controls) raise LDAPInsufficientAccessRights() def handle_delete(self, req): + check_controls(req.controls) raise LDAPInsufficientAccessRights() def handle_modifydn(self, req): + check_controls(req.controls) raise LDAPInsufficientAccessRights() def handle_compare(self, req): + check_controls(req.controls) raise LDAPInsufficientAccessRights() def handle_abandon(self, req): - pass + check_controls(req.controls) def handle_message(self, shallowmsg): msgtypes = {