diff --git a/db.py b/db.py index ddace7a9ad01dbaefef7237543a193c5177c65d3..1f35511d7ce9172a452706fcde7246774da8da46 100644 --- a/db.py +++ b/db.py @@ -284,7 +284,7 @@ sqlstore.register_model( ssl_context = SSLContext() ssl_context.load_cert_chain('devcert.crt', 'devcert.key') -class RequestHandler(LDAPRequestHandler): +class RequestHandler(SimpleLDAPRequestHandler): ssl_context = ssl_context supports_sasl_digest_md5 = True diff --git a/docs/api.rst b/docs/api.rst index 9aad5b4d182198824b93a49a19b32b9fe0f47884..456b9fdaaa285423e9b4047cd4c9800340fac211 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,10 @@ API Reference Request Handler --------------- -.. autoclass:: ldapserver.LDAPRequestHandler +.. autoclass:: ldapserver.BaseLDAPRequestHandler + :members: + +.. autoclass:: ldapserver.SimpleLDAPRequestHandler :members: Object Stores @@ -41,7 +44,7 @@ LDAP response messages carry a result code and an optional diagnostic message. The subclasses of :any:`ldapserver.LDAPError` represent the possible (non-success) result codes. Raising a :any:`ldapserver.LDAPError` instance in a handler method of -:any:`ldapserver.LDAPRequestHandler` cases the appropriate response message to be +:any:`ldapserver.BaseLDAPRequestHandler` cases the appropriate response message to be sent with the corresponding result code and diagnostic message. .. autoexception:: ldapserver.LDAPError diff --git a/ldapserver/__init__.py b/ldapserver/__init__.py index 81a829d58c92a7fc20dd9a76035b826d7f22faaa..c2db8d9866ecc5fd8b2b319e8d5e893a6bda2e00 100644 --- a/ldapserver/__init__.py +++ b/ldapserver/__init__.py @@ -3,4 +3,4 @@ from . import dn from . import sasl from .exceptions import * -from .server import LDAPRequestHandler +from .server import BaseLDAPRequestHandler, SimpleLDAPRequestHandler diff --git a/ldapserver/server.py b/ldapserver/server.py index 83a85c08677c90dda33f88f002f8577301b028a6..464f2e27053bca88e59ffad82aae19ae29123b85 100644 --- a/ldapserver/server.py +++ b/ldapserver/server.py @@ -77,14 +77,140 @@ class RootDSE(AttributeDict): attrs[name] = [encode_attribute(value) for value in values] return [('', attrs)] -def check_controls(controls=None, supported_oids=[]): +def reject_critical_controls(controls=None, supported_oids=[]): for control in controls or []: if not control.criticality: continue if control.controlType not in supported_oids: raise LDAPUnavailableCriticalExtension() -class LDAPRequestHandler(BaseRequestHandler): +class BaseLDAPRequestHandler(BaseRequestHandler): + def setup(self): + super().setup() + self.keep_running = True + + def handle(self): + self.on_connect() + buf = b'' + while self.keep_running: + try: + shallowmsg, buf = ShallowLDAPMessage.from_ber(buf) + for respmsg in self.handle_message(shallowmsg): + self.request.sendall(LDAPMessage.to_ber(respmsg)) + except IncompleteBERError: + chunk = self.request.recv(5) + if not chunk: + self.keep_running = False + self.on_disconnect() + self.request.close() + else: + buf += chunk + self.on_disconnect() + self.request.close() + + def handle_message(self, shallowmsg): + '''Handle an LDAP request + + :param shallowmsg: Half-decoded LDAP message to handle + :type shallowmsg: ShallowLDAPMessage + :returns: Response messages + :rtype: iterable of LDAPMessage objects + ''' + msgtypes = { + BindRequest: (self.handle_bind, BindResponse), + UnbindRequest: (self.handle_unbind, None), + SearchRequest: (self.handle_search, SearchResultDone), + ModifyRequest: (self.handle_modify, ModifyResponse), + AddRequest: (self.handle_add, AddResponse), + DelRequest: (self.handle_delete, DelResponse), + ModifyDNRequest: (self.handle_modifydn, ModifyDNResponse), + CompareRequest: (self.handle_compare, CompareResponse), + AbandonRequest: (self.handle_abandon, None), + ExtendedRequest: (self.handle_extended, ExtendedResponse), + } + handler, response_type = msgtypes.get(shallowmsg.protocolOpType, (None, None)) + try: + if handler is None: + raise LDAPProtocolError() + try: + msg = decode_msg(shallowmsg) + except ValueError: + self.on_recv_invalid(shallowmsg) + raise LDAPProtocolError() + self.on_recv(msg) + for args in handler(msg.protocolOp, msg.controls): + response, controls = args if isinstance(args, tuple) else (args, None) + yield LDAPMessage(shallowmsg.messageID, response, controls) + except LDAPError as e: + if response_type is not None: + respmsg = LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message)) + self.on_send(respmsg) + yield respmsg + except Exception as e: + if response_type is not None: + respmsg = LDAPMessage(shallowmsg.messageID, response_type(LDAPResultCode.other)) + self.on_send(respmsg) + yield respmsg + self.on_exception(e) + + def on_connect(self): + print('connected') + + def on_disconnect(self): + print('disconnected') + + def on_send(self, msg): + print('sending', msg) + + def on_recv(self, msg): + print('received', msg) + + def on_recv_invalid(self, shallowmsg): + print('received invalid', shallowmsg) + + def on_exception(self, e): + traceback.print_exc() + + def handle_bind(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPAuthMethodNotSupported() + + def handle_unbind(self, op, controls=None): + reject_critical_controls(controls) + self.keep_running = False + + def handle_search(self, op, controls=None): + reject_critical_controls(controls) + yield SearchResultDone(LDAPResultCode.success) + + def handle_modify(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPInsufficientAccessRights() + + def handle_add(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPInsufficientAccessRights() + + def handle_delete(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPInsufficientAccessRights() + + def handle_modifydn(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPInsufficientAccessRights() + + def handle_compare(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPInsufficientAccessRights() + + def handle_abandon(self, op, controls=None): + reject_critical_controls(controls) + + def handle_extended(self, op, controls=None): + reject_critical_controls(controls) + raise LDAPProtocolError() + +class SimpleLDAPRequestHandler(BaseLDAPRequestHandler): ''' .. py:attribute:: rootdse :type: ldapserver.server.RootDSE @@ -100,7 +226,6 @@ class LDAPRequestHandler(BaseRequestHandler): self.rootdse['supportedSASLMechanisms'] = self.get_sasl_mechanisms self.rootdse['supportedExtension'] = self.get_extentions self.rootdse['supportedLDAPVersion'] = [b'3'] - self.keep_running = True self.bind_object = None self.bind_sasl_state = None @@ -140,9 +265,8 @@ class LDAPRequestHandler(BaseRequestHandler): res.append(b'DIGEST-MD5') return res - def handle_bind(self, req): - check_controls(req.controls) - op = req.protocolOp + def handle_bind(self, op, controls=None): + reject_critical_controls(controls) if op.version != 3: raise LDAPProtocolError('Unsupported protocol version') auth = op.authentication @@ -158,18 +282,18 @@ class LDAPRequestHandler(BaseRequestHandler): except StopIteration as e: resp_code = LDAPResultCode.success self.bind_object, resp = e.value - self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp))) + yield BindResponse(resp_code, serverSaslCreds=resp) return # If auth type or SASL method changed, abort SASL dialog self.bind_sasl_state = None if isinstance(auth, SimpleAuthentication): self.bind_object = self.do_bind_simple(op.name, auth.password) - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + yield BindResponse(LDAPResultCode.success) elif isinstance(auth, SaslCredentials): ret = self.do_bind_sasl(auth.mechanism, auth.credentials) if isinstance(ret, tuple): self.bind_object, resp = ret - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success, serverSaslCreds=resp))) + yield BindResponse(LDAPResultCode.success, serverSaslCreds=resp) return iterator = iter(ret) resp_code = LDAPResultCode.saslBindInProgress @@ -179,9 +303,9 @@ class LDAPRequestHandler(BaseRequestHandler): except StopIteration as e: resp_code = LDAPResultCode.success self.bind_object, resp = e.value - self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp))) + yield BindResponse(resp_code, serverSaslCreds=resp) else: - raise LDAPAuthMethodNotSupported() + yield from super().handle_bind(op, controls) def do_bind_simple(self, dn='', password=b''): '''Do LDAP BIND with simple authentication @@ -392,13 +516,11 @@ class LDAPRequestHandler(BaseRequestHandler): :any:`LDAPInvalidCredentials` exception.''' raise LDAPInvalidCredentials() - def handle_search(self, req): - check_controls(req.controls) - search = req.protocolOp - for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter): + def handle_search(self, op, controls=None): + for dn, attributes in self.do_search(op.baseObject, op.scope, op.filter): attributes = [PartialAttribute(name, values) for name, values in attributes.items()] - self.send_msg(LDAPMessage(req.messageID, SearchResultEntry(dn, attributes))) - self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success))) + yield SearchResultEntry(dn, attributes) + yield SearchResultDone(LDAPResultCode.success) def do_search(self, baseobj, scope, filter): '''Do LDAP SEARCH operation @@ -418,16 +540,16 @@ class LDAPRequestHandler(BaseRequestHandler): The default implementation always returns an empty list.''' return self.rootdse.search(baseobj, scope, filter) - def handle_unbind(self, req): - check_controls(req.controls) + def handle_unbind(self, op, controls=None): + reject_critical_controls(controls) self.keep_running = False + return [] - def handle_extended(self, req): - check_controls(req.controls) - op = req.protocolOp + def handle_extended(self, op, controls=None): + reject_critical_controls(controls) if op.requestName == EXT_STARTTLS_OID and self.supports_starttls: # StartTLS (RFC 4511) - self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseName=EXT_STARTTLS_OID))) + yield ExtendedResponse(LDAPResultCode.success, responseName=EXT_STARTTLS_OID) try: self.do_starttls() except Exception as e: @@ -436,7 +558,7 @@ class LDAPRequestHandler(BaseRequestHandler): elif op.requestName == EXT_WHOAMI_OID and self.supports_whoami: # "Who am I?" Operation (RFC 4532) identity = (self.do_whoami() or '').encode() - self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseValue=identity))) + yield ExtendedResponse(LDAPResultCode.success, responseValue=identity) elif op.requestName == EXT_PASSWORD_MODIFY_OID and self.supports_password_modify: # Password Modify Extended Operation (RFC 3062) newpw = None @@ -446,12 +568,12 @@ class LDAPRequestHandler(BaseRequestHandler): decoded, _ = PasswdModifyRequestValue.from_ber(op.requestValue) newpw = self.do_passwd(decoded.userIdentity, decoded.oldPasswd, decoded.newPasswd) if newpw is None: - self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success))) + yield ExtendedResponse(LDAPResultCode.success) else: encoded = PasswdModifyResponseValue.to_ber(PasswdModifyResponseValue(newpw)) - self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseValue=encoded))) + yield ExtendedResponse(LDAPResultCode.success, responseValue=encoded) else: - raise LDAPProtocolError() + yield from super().handle_extended(op, controls) #: :any:`ssl.SSLContext` for StartTLS ssl_context = None @@ -504,75 +626,3 @@ class LDAPRequestHandler(BaseRequestHandler): Called by `handle_extended()` if :any:`supports_password_modify` is True. The default implementation always raises an :any:`LDAPUnwillingToPerform` error.''' raise LDAPUnwillingToPerform() - - def handle_modify(self, req): - check_controls(req.controls) - raise LDAPInsufficientAccessRights() - - def handle_add(self, req): - check_controls(req.controls) - raise LDAPInsufficientAccessRights() - - def handle_delete(self, req): - check_controls(req.controls) - raise LDAPInsufficientAccessRights() - - def handle_modifydn(self, req): - check_controls(req.controls) - raise LDAPInsufficientAccessRights() - - def handle_compare(self, req): - check_controls(req.controls) - raise LDAPInsufficientAccessRights() - - def handle_abandon(self, req): - check_controls(req.controls) - - def handle_message(self, shallowmsg): - msgtypes = { - BindRequest: (self.handle_bind, BindResponse), - UnbindRequest: (self.handle_unbind, None), - SearchRequest: (self.handle_search, SearchResultDone), - ModifyRequest: (None, ModifyResponse), - AddRequest: (None, AddResponse), - DelRequest: (None, DelResponse), - ModifyDNRequest: (None, ModifyDNResponse), - CompareRequest: (None, CompareResponse), - AbandonRequest: (None, None), - ExtendedRequest: (self.handle_extended, ExtendedResponse), - } - handler, response = msgtypes.get(shallowmsg.protocolOpType, (None, None)) - try: - if handler is None: - raise LDAPProtocolError() - try: - msg = decode_msg(shallowmsg) - except ValueError: - raise LDAPProtocolError() - print('recved', msg) - handler(msg) - except LDAPError as e: - if response is not None: - self.send_msg(LDAPMessage(shallowmsg.messageID, response(e.code, diagnosticMessage=e.message))) - except Exception as e: - if response is not None: - self.send_msg(LDAPMessage(shallowmsg.messageID, response(LDAPResultCode.other))) - traceback.print_exc() - - def send_msg(self, msg): - print('sending', msg) - self.request.sendall(LDAPMessage.to_ber(msg)) - - def handle(self): - buf = b'' - while self.keep_running: - try: - shallowmsg, buf = ShallowLDAPMessage.from_ber(buf) - self.handle_message(shallowmsg) - except IncompleteBERError: - chunk = self.request.recv(5) - if not chunk: - self.keep_running = False - return None - buf += chunk -