From 58540c20361638652688786bc2ddb67c07cc248a Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Mon, 8 Mar 2021 22:43:26 +0100 Subject: [PATCH] Added SASL DIGEST-MD5 support --- ldap.py | 51 ++++++++++--- sasl.py | 48 ++++++++++++ server.py | 223 ++++++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 272 insertions(+), 50 deletions(-) create mode 100644 sasl.py diff --git a/ldap.py b/ldap.py index eadc75e..a72fefa 100644 --- a/ldap.py +++ b/ldap.py @@ -442,6 +442,13 @@ class SimpleAuthentication(Wrapper, AuthenticationChoice): return '<%s(EMPTY PASSWORD)>'%(type(self).__name__) return '<%s(PASSWORD HIDDEN)>'%(type(self).__name__) +class SaslCredentials(Sequence, AuthenticationChoice): + ber_tag = (2, True, 3) + sequence_fields = [ + (LDAPString, 'mechanism', None, False), + (OctetString, 'credentials', None, True), + ] + class AttributeValueSet(Set): set_type = OctetString @@ -475,8 +482,14 @@ class BindRequest(Sequence, ProtocolOp): (AuthenticationChoice, 'authentication', lambda: SimpleAuthentication(), False) ] -class BindResponse(LDAPResult, ProtocolOp): +class BindResponse(Sequence, ProtocolOp): ber_tag = (1, True, 1) + sequence_fields = [ + (wrapenum(LDAPResultCode), 'resultCode', None, False), + (LDAPString, 'matchedDN', '', False), + (LDAPString, 'diagnosticMessage', '', False), + (retag(OctetString, (2, False, 7)), 'serverSaslCreds', None, True) + ] class UnbindRequest(Sequence, ProtocolOp): ber_tag = (1, False, 2) @@ -621,20 +634,36 @@ class LDAPMessage(Sequence): (Controls, 'controls', None, True) ] -class ShallowProtocolOp: +class ShallowLDAPMessage(BERType): + ber_tag = (0, True, 16) + + def __init__(self, messageID=None, protocolOpType=None, data=None): + self.messageID = messageID + self.protocolOpType = protocolOpType + self.data = data + + def decode(self): + return LDAPMessage.from_ber(self.data) + @classmethod def from_ber(cls, data): - obj, rest = decode_ber(data) + seq, rest = decode_ber(data) + data = data[:len(data)-len(rest)] + if seq.tag != cls.ber_tag: + raise ValueError() + content = seq.content + messageID, content = Integer.from_ber(content) + op, content = decode_ber(content) for subcls in ProtocolOp.__subclasses__(): - if subcls.ber_tag == obj.tag: - return subcls, rest - return None, rest + if subcls.ber_tag == op.tag: + return cls(messageID, subcls, data), rest + return cls(messageID, None, data), rest -class ShallowLDAPMessage(Sequence): - sequence_fields = [ - (Integer, 'messageID', None, False), - (ShallowProtocolOp, 'protocolOp', None, False) - ] + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, cls): + raise TypeError() + return obj.data # Extended Operation Values diff --git a/sasl.py b/sasl.py new file mode 100644 index 0000000..48febc7 --- /dev/null +++ b/sasl.py @@ -0,0 +1,48 @@ +SEP = [b'(', b')', b'<', b'>', b'@', b',', b';', b':', b'\\', b'\'', b'/', b'[', b']', b'?', b'=', b'{', b'}', b' ', b'\t'] +CTL = [bytes([c]) for c in range(0, 31)] + [b'127'] + +def parse_token(s): + for index in range(len(s)): + c = bytes([s[index]]) + if c in SEP + CTL: + return bytes(s[:index]), bytes(s[index:]) + return s, b'' + +def parse_qstr(s): + if s[0] != b'"'[0]: + raise ValueError() + res = b'' + escaped = False + for index in range(1, len(s)): + c = bytes([s[index]]) + if escaped: + res += c + escaped = False + elif c == b'\\': + escaped = True + elif c == b'"': + return res, bytes(s[index+1:]) + else: + res += c + raise ValueError() + +def parse_token_qstr(s): + if s[0] == b'"'[0]: + return parse_qstr(s) + return parse_token(s) + +def parse_kwargs(s): + res = [] + while True: + key, s = parse_token(s) + if s[0] != b'='[0]: + raise ValueError() + value, s = parse_token_qstr(bytes(s[1:])) + res.append((key, value)) + if not s: + return res + if s[0] != b','[0]: + raise ValueError() + s = bytes(s[1:]) + return res + diff --git a/server.py b/server.py index 00a2762..2a65f50 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,10 @@ import traceback +import hashlib +import secrets from socketserver import BaseRequestHandler -from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue +import sasl +from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue, SaslCredentials class LDAPError(Exception): def __init__(self, code=LDAPResultCode.other, message=''): @@ -32,20 +35,129 @@ class LDAPAuthMethodNotSupported(LDAPError): def __init__(self, message=''): super().__init__(LDAPResultCode.authMethodNotSupported, message) +class LDAPConfidentialityRequired(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.confidentialityRequired, message) + +def decode_msg(shallowmsg): + try: + return shallowmsg.decode()[0] + except: + traceback.print_exc() + raise LDAPProtocolError() + class LDAPRequestHandler(BaseRequestHandler): + ssl_context = None + + def setup(self): + super().setup() + self.keep_running = True + self.bind_object = None + self.bind_sasl_state = None + def handle_bind(self, req): op = req.protocolOp if op.version != 3: raise LDAPProtocolError('Unsupported protocol version') - if isinstance(op.authentication, SimpleAuthentication): - self.bind_object = self.do_bind(op.name, op.authentication.password) + auth = op.authentication + # Resume ongoing SASL dialog + if self.bind_sasl_state and isinstance(auth, SaslCredentials) \ + and auth.mechanism == self.bind_sasl_state[0]: + iterator = self.bind_sasl_state[1] + resp_code = LDAPResultCode.saslBindInProgress + try: + resp = iterator.send(auth.credentials) + except StopIteration as e: + resp_code = LDAPResultCode.success + self.bind_sasl_state = None + self.bind_object, resp = e.value + self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp))) + return + # If auth type or SASL method changed, abort SASL dialog + self.bind_sasl_state = None + if isinstance(auth, SimpleAuthentication): + self.bind_object = self.do_bind(op.name, auth.password) + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + elif isinstance(auth, SaslCredentials): + ret = self.do_bind_sasl(auth.mechanism, auth.credentials) + if isinstance(ret, tuple): + self.bind_object, resp = ret + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success, serverSaslCreds=resp))) + return + iterator = iter(ret) + resp_code = LDAPResultCode.saslBindInProgress + try: + resp = next(iterator) + except StopIteration as e: + resp_code = LDAPResultCode.success + self.bind_sasl_state = None + self.bind_object, resp = e.value + self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp))) + self.bind_sasl_state = (auth.mechanism, iterator) else: raise LDAPAuthMethodNotSupported() - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) def do_bind(self, user, password): raise LDAPInvalidCredentials() + def do_bind_sasl(self, mechanism, credentials): + if mechanism == 'DIGEST-MD5': + return self.do_bind_sasl_digest_md5(mechanism, credentials) + raise LDAPAuthMethodNotSupported() + + def do_bind_sasl_digest_md5_password(self, username, realm): + return None, 'foo' + # should return (bind_obj, password string) for username + raise LDAPAuthMethodNotSupported() + + def do_bind_sasl_digest_md5_pwdigest(self, username, realm, charset): + # charset is either 'utf-8' or 'latin_1', it should only affect username and password + bind_obj, password = self.do_bind_sasl_digest_md5_password(username, realm) + ctx = hashlib.md5() + ctx.update(username.encode(charset) + b':' + realm + b':' + password.encode(charset)) + return bind_obj, ctx.digest() + + def do_bind_sasl_digest_md5(self, mechanism, credentials): + nonce = secrets.token_urlsafe(1024).encode() + challenge = b'nonce="%s",charset="utf-8",algorithm="md5-sess"'%(nonce) + resp = yield challenge + args = {key: value for key, value in sasl.parse_kwargs(resp)} + if args[b'nonce'] != nonce: + raise LDAPProtocolError() + try: + charset = 'utf-8' if args.get(b'charset', b'utf-8') == b'utf-8' else 'latin_1' + username = args[b'username'].decode(charset) + realm = args.get(b'realm', b'') + cnonce = args[b'cnonce'] + nc = args.get(b'nc', b'00000001') + qop = args.get(b'qop', b'auth') + digest_uri = args[b'digest-uri'] + response = args[b'response'] + except KeyError: + raise LDAPProtocolError() + except UnicodeError: + raise LDAPProtocolError() + bind_obj, pwdigest = self.do_bind_sasl_digest_md5_pwdigest(username, realm, charset) + def md5digest(data): + ctx = hashlib.md5() + ctx.update(data) + return ctx.hexdigest().lower().encode() + a1 = pwdigest + b':' + nonce + b':' + cnonce + a2 = b'AUTHENTICATE:' + digest_uri + key = md5digest(a1) + data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2) + expected_response = md5digest(key + b':' + data) + if expected_response != response: + raise LDAPInvalidCredentials() + # We don't support subsequent authentication so according to RFC 2829 the + # serverSaslCreds field in our response should be absent and we should + # return (bind_obj, None). But this seems to confuse some clients (e.g. + # openldap's ldapsearch) so we return serverSaslCreds with rspauth instead. + a2 = b':' + digest_uri + data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2) + response = b'rspauth=%s'%md5digest(key + b':' + data) + return bind_obj, response + def handle_search(self, req): search = req.protocolOp for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter): @@ -54,7 +166,7 @@ class LDAPRequestHandler(BaseRequestHandler): self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success))) def do_search(self, baseobj, scope, filter): - yield from [] + return [] def handle_unbind(self, req): self.keep_running = False @@ -63,7 +175,13 @@ class LDAPRequestHandler(BaseRequestHandler): op = req.protocolOp if op.requestName == '1.3.6.1.4.1.1466.20037': # StartTLS (RFC 4511) - raise LDAPProtocolError() + sent_response = False + for _ in self.do_starttls(): + if not sent_response: + self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseName='1.3.6.1.4.1.1466.20037'))) + sent_response = True + if not sent_response: + raise LDAPProtocolError() elif op.requestName == '1.3.6.1.4.1.4203.1.11.1': # Password Modify Extended Operation (RFC 3062) newpw = None @@ -80,11 +198,39 @@ class LDAPRequestHandler(BaseRequestHandler): else: raise LDAPProtocolError() + def do_starttls(self): + if self.ssl_context is None: + raise LDAPProtocolError() + yield None + try: + self.request = self.ssl_context.wrap_socket(self.request, server_side=True) + except Exception as e: + traceback.print_exc() + self.keep_running = False + def do_passwd(self, user=None, oldpasswd=None, newpasswd=None): raise LDAPUnwillingToPerform('Password change is not supported') - def handle_message(self, data): - handlers = { + def handle_modify(self, req): + raise LDAPInsufficientAccessRights() + + def handle_add(self, req): + raise LDAPInsufficientAccessRights() + + def handle_delete(self, req): + raise LDAPInsufficientAccessRights() + + def handle_modifydn(self, req): + raise LDAPInsufficientAccessRights() + + def handle_compare(self, req): + raise LDAPInsufficientAccessRights() + + def handle_abandon(self, req): + pass + + def handle_message(self, shallowmsg): + msgtypes = { BindRequest: (self.handle_bind, BindResponse), UnbindRequest: (self.handle_unbind, None), SearchRequest: (self.handle_search, SearchResultDone), @@ -96,50 +242,49 @@ class LDAPRequestHandler(BaseRequestHandler): AbandonRequest: (None, None), ExtendedRequest: (self.handle_extended, ExtendedResponse), } - shallowmsg, rest = ShallowLDAPMessage.from_ber(data) - if shallowmsg.protocolOp is None: - print('Ignoring unknown message') - return rest - func, errfunc = handlers[shallowmsg.protocolOp] - msg = None - try: - msg, _ = LDAPMessage.from_ber(data) - except Exception as e: - if errfunc: - self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError))) - traceback.print_exc() - return rest - print('received', msg) + + responses = { + BindRequest: BindResponse, + SearchRequest: SearchResultDone, + ModifyRequest: ModifyResponse, + AddRequest: AddResponse, + DelRequest: DelResponse, + ModifyDNRequest: ModifyDNResponse, + CompareRequest: CompareResponse, + ExtendedRequest: ExtendedResponse, + } + handler, response = msgtypes.get(shallowmsg.protocolOpType, (None, None)) try: - if func: - func(msg) - elif errfunc: - self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights))) + if handler is None: + raise LDAPProtocolError() + try: + msg = decode_msg(shallowmsg) + except ValueError: + raise LDAPProtocolError() + print('recved', msg) + handler(msg) except LDAPError as e: - if errfunc: - self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(e.code, diagnosticMessage=e.message))) - return rest + if response is not None: + self.send_msg(LDAPMessage(shallowmsg.messageID, response(e.code, diagnosticMessage=e.message))) except Exception as e: - if errfunc: - self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other))) + if response is not None: + self.send_msg(LDAPMessage(shallowmsg.messageID, response(LDAPResultCode.other))) traceback.print_exc() - return rest - return rest def send_msg(self, msg): print('sending', msg) self.request.sendall(LDAPMessage.to_ber(msg)) def handle(self): - self.keep_running = True - self.bind_object = None - data = b'' + buf = b'' while self.keep_running: try: - data = self.handle_message(data) + shallowmsg, buf = ShallowLDAPMessage.from_ber(buf) + self.handle_message(shallowmsg) except IncompleteBERError: chunk = self.request.recv(5) if not chunk: - return - data += chunk + self.keep_running = False + return None + buf += chunk -- GitLab