diff --git a/ldapserver/server.py b/ldapserver/server.py index 9c342801b33ec928a40524fbdb732ef3d9cd4e80..a69721382e144156432bad3a9c299cb394577fab 100644 --- a/ldapserver/server.py +++ b/ldapserver/server.py @@ -5,18 +5,9 @@ import typing from . import asn1, exceptions, ldap, schema, directory -def decode_msg(shallowmsg): - try: - return shallowmsg.decode()[0] - except Exception as e: - traceback.print_exc() - raise exceptions.LDAPProtocolError() from e - -def reject_critical_controls(controls=None, supported_oids=tuple()): +def reject_critical_controls(controls=None): for control in controls or []: - if not control.criticality: - continue - if control.controlType not in supported_oids: + if control.criticality: raise exceptions.LDAPUnavailableCriticalExtension() class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): @@ -25,7 +16,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): self.keep_running = True def handle(self): - self.on_connect() buf = b'' while self.keep_running: try: @@ -33,14 +23,12 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): for respmsg in self.handle_message(shallowmsg): self.request.sendall(ldap.LDAPMessage.to_ber(respmsg)) except asn1.IncompleteBERError: - chunk = self.request.recv(5) + chunk = self.request.recv(4096) if not chunk: self.keep_running = False - self.on_disconnect() self.request.close() else: buf += chunk - self.on_disconnect() self.request.close() def handle_message(self, shallowmsg: ldap.ShallowLDAPMessage) -> typing.Iterable[ldap.LDAPMessage]: @@ -66,38 +54,23 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): if handler is None: raise exceptions.LDAPProtocolError() try: - msg = decode_msg(shallowmsg) + msg = shallowmsg.decode()[0] except ValueError as e: self.on_recv_invalid(shallowmsg) raise exceptions.LDAPProtocolError() from e - self.on_recv(msg) for args in handler(msg.protocolOp, msg.controls): response, controls = args if isinstance(args, tuple) else (args, None) yield ldap.LDAPMessage(shallowmsg.messageID, response, controls) except exceptions.LDAPError as e: if response_type is not None: respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message)) - self.on_send(respmsg) yield respmsg except Exception as e: # pylint: disable=broad-except if response_type is not None: respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(ldap.LDAPResultCode.other)) - self.on_send(respmsg) yield respmsg self.on_exception(e) - def on_connect(self): - pass - - def on_disconnect(self): - pass - - def on_send(self, msg): - pass - - def on_recv(self, msg): - pass - def on_recv_invalid(self, shallowmsg): pass @@ -105,7 +78,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): traceback.print_exc() def handle_bind(self, op: ldap.BindRequest, controls=None) -> typing.Iterable[ldap.ProtocolOp]: - '''Handle bind bla bla bla''' reject_critical_controls(controls) raise exceptions.LDAPAuthMethodNotSupported()