Skip to content
Snippets Groups Projects
Commit add1827b authored by Julian Rother's avatar Julian Rother
Browse files

Refactored server code

parent 1885d7a7
No related branches found
No related tags found
No related merge requests found
Pipeline #7068 passed
...@@ -5,18 +5,9 @@ import typing ...@@ -5,18 +5,9 @@ import typing
from . import asn1, exceptions, ldap, schema, directory from . import asn1, exceptions, ldap, schema, directory
def decode_msg(shallowmsg): def reject_critical_controls(controls=None):
try:
return shallowmsg.decode()[0]
except Exception as e:
traceback.print_exc()
raise exceptions.LDAPProtocolError() from e
def reject_critical_controls(controls=None, supported_oids=tuple()):
for control in controls or []: for control in controls or []:
if not control.criticality: if control.criticality:
continue
if control.controlType not in supported_oids:
raise exceptions.LDAPUnavailableCriticalExtension() raise exceptions.LDAPUnavailableCriticalExtension()
class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): class BaseLDAPRequestHandler(socketserver.BaseRequestHandler):
...@@ -25,7 +16,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): ...@@ -25,7 +16,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler):
self.keep_running = True self.keep_running = True
def handle(self): def handle(self):
self.on_connect()
buf = b'' buf = b''
while self.keep_running: while self.keep_running:
try: try:
...@@ -33,14 +23,12 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): ...@@ -33,14 +23,12 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler):
for respmsg in self.handle_message(shallowmsg): for respmsg in self.handle_message(shallowmsg):
self.request.sendall(ldap.LDAPMessage.to_ber(respmsg)) self.request.sendall(ldap.LDAPMessage.to_ber(respmsg))
except asn1.IncompleteBERError: except asn1.IncompleteBERError:
chunk = self.request.recv(5) chunk = self.request.recv(4096)
if not chunk: if not chunk:
self.keep_running = False self.keep_running = False
self.on_disconnect()
self.request.close() self.request.close()
else: else:
buf += chunk buf += chunk
self.on_disconnect()
self.request.close() self.request.close()
def handle_message(self, shallowmsg: ldap.ShallowLDAPMessage) -> typing.Iterable[ldap.LDAPMessage]: def handle_message(self, shallowmsg: ldap.ShallowLDAPMessage) -> typing.Iterable[ldap.LDAPMessage]:
...@@ -66,38 +54,23 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): ...@@ -66,38 +54,23 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler):
if handler is None: if handler is None:
raise exceptions.LDAPProtocolError() raise exceptions.LDAPProtocolError()
try: try:
msg = decode_msg(shallowmsg) msg = shallowmsg.decode()[0]
except ValueError as e: except ValueError as e:
self.on_recv_invalid(shallowmsg) self.on_recv_invalid(shallowmsg)
raise exceptions.LDAPProtocolError() from e raise exceptions.LDAPProtocolError() from e
self.on_recv(msg)
for args in handler(msg.protocolOp, msg.controls): for args in handler(msg.protocolOp, msg.controls):
response, controls = args if isinstance(args, tuple) else (args, None) response, controls = args if isinstance(args, tuple) else (args, None)
yield ldap.LDAPMessage(shallowmsg.messageID, response, controls) yield ldap.LDAPMessage(shallowmsg.messageID, response, controls)
except exceptions.LDAPError as e: except exceptions.LDAPError as e:
if response_type is not None: if response_type is not None:
respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message)) respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message))
self.on_send(respmsg)
yield respmsg yield respmsg
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
if response_type is not None: if response_type is not None:
respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(ldap.LDAPResultCode.other)) respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(ldap.LDAPResultCode.other))
self.on_send(respmsg)
yield respmsg yield respmsg
self.on_exception(e) self.on_exception(e)
def on_connect(self):
pass
def on_disconnect(self):
pass
def on_send(self, msg):
pass
def on_recv(self, msg):
pass
def on_recv_invalid(self, shallowmsg): def on_recv_invalid(self, shallowmsg):
pass pass
...@@ -105,7 +78,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler): ...@@ -105,7 +78,6 @@ class BaseLDAPRequestHandler(socketserver.BaseRequestHandler):
traceback.print_exc() traceback.print_exc()
def handle_bind(self, op: ldap.BindRequest, controls=None) -> typing.Iterable[ldap.ProtocolOp]: def handle_bind(self, op: ldap.BindRequest, controls=None) -> typing.Iterable[ldap.ProtocolOp]:
'''Handle bind bla bla bla'''
reject_critical_controls(controls) reject_critical_controls(controls)
raise exceptions.LDAPAuthMethodNotSupported() raise exceptions.LDAPAuthMethodNotSupported()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment