diff --git a/db.py b/db.py index 8c2791bdbd6540b4fe6cab44588162cb4e541899..c6a502424f9a57a03e4449596c749daeeaa3de12 100644 --- a/db.py +++ b/db.py @@ -260,7 +260,7 @@ def ldap_bind(name, password, conn): evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute, dn_base=User.ldap_dn_base) - res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one() + res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none() if res: return res.check_password(password) return False diff --git a/ldap.py b/ldap.py index 5aceb6deb6757643a08150bc0a89af8f6f0d991d..a60561dd229642cbf58c6eedc9bb31659c178f51 100644 --- a/ldap.py +++ b/ldap.py @@ -19,10 +19,11 @@ def decode_ber(data): index += 1 if not data[index] & 0x80: length = data[index] + index += 1 elif data[index] == 0x80: raise ValueError('Indefinite form not implemented') elif data[index] == 0xff: - return ValueError('BER length invalid') + raise ValueError('BER length invalid') else: num = data[index] & ~0x80 index += 1 @@ -32,10 +33,10 @@ def decode_ber(data): for octet in data[index:index + num]: length = length << 8 | octet index += num - if len(data) < index + length + 1: - raise IncompleteBERError(index + length + 1) - ber_content = data[index + 1: index + length + 1] - rest = data[index + length + 1:] + if len(data) < index + length: + raise IncompleteBERError(index + length) + ber_content = data[index: index + length] + rest = data[index + length:] return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest def decode_ber_integer(data): @@ -433,6 +434,9 @@ class BindRequest(Sequence, ProtocolOp): class BindResponse(LDAPResult, ProtocolOp): ber_tag = (1, True, 1) +class UnbindRequest(Sequence, ProtocolOp): + ber_tag = (1, False, 2) + class SearchRequest(Sequence, ProtocolOp): ber_tag = (1, True, 3) sequence_fields = [ @@ -460,8 +464,65 @@ class SearchResultEntry(Sequence, ProtocolOp): class SearchResultDone(LDAPResult, ProtocolOp): ber_tag = (1, True, 5) -class UnbindRequest(Sequence, ProtocolOp): - ber_tag = (1, False, 2) +class ModifyRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 6) + # stub + +class ModifyResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 7) + +class AddRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 8) + # stub + +class AddResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 9) + +class DelRequest(Wrapper, ProtocolOp): + ber_tag = (1, False, 10) + wrapped_attribute = 'dn' + wrapped_type = LDAPString + wrapped_default = None + +class DelResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 11) + +class ModifyDNRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 12) + # stub + +class ModifyDNResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 13) + +class CompareRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 14) + # stub + +class CompareResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 15) + +class AbandonRequest(Wrapper, ProtocolOp): + ber_tag = (1, False, 16) + wrapped_attribute = 'messageID' + wrapped_type = Integer + wrapped_default = None + +class ExtendedRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 23) + # stub + +class ExtendedResponse(Sequence, ProtocolOp): + ber_tag = (1, True, 24) + sequence_fields = [ + (LDAPResultCodeEnum, 'resultCode', None), + (LDAPString, 'matchedDN', ''), + (LDAPString, 'diagnosticMessage', ''), + ] + # stub + +class IntermediateResponse(Sequence, ProtocolOp): + ber_tag = (1, True, 25) + # stub class LDAPMessage(Sequence): sequence_fields = [ diff --git a/server.py b/server.py index 96c52d01ff93283ee6828d64ede26ae88f93a00c..69af473e3418a56b7e27bafe7f8e121c99e3b0ea 100644 --- a/server.py +++ b/server.py @@ -1,7 +1,7 @@ import traceback from socketserver import ForkingTCPServer, BaseRequestHandler -from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication +from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse class Handler(BaseRequestHandler): ldap_server = None @@ -19,6 +19,7 @@ class Handler(BaseRequestHandler): if func(name, password, self): self.bind_dn = name self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + return self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials))) def handle_search(self, req): @@ -37,8 +38,15 @@ class Handler(BaseRequestHandler): def handle_message(self, data): handlers = { BindRequest: (self.handle_bind, BindResponse), - SearchRequest: (self.handle_search, SearchResultDone), UnbindRequest: (self.handle_unbind, None), + SearchRequest: (self.handle_search, SearchResultDone), + ModifyRequest: (None, ModifyResponse), + AddRequest: (None, AddResponse), + DelRequest: (None, DelResponse), + ModifyDNRequest: (None, ModifyDNResponse), + CompareRequest: (None, CompareResponse), + AbandonRequest: (None, None), + ExtendedRequest: (None, ExtendedResponse), # TODO } shallowmsg, rest = ShallowLDAPMessage.from_ber(data) if shallowmsg.protocolOp is None: @@ -57,6 +65,8 @@ class Handler(BaseRequestHandler): try: if func: func(msg) + elif errfunc: + self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights))) except Exception as e: if errfunc: self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))