From 39f0d293187293a64ef11bcd26f17335bab947c9 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Fri, 5 Mar 2021 21:02:35 +0100 Subject: [PATCH] Implemented custom asn.1/ldap parser --- dn.py | 3 + ldap.py | 488 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ server.py | 65 ++++++++ 3 files changed, 556 insertions(+) create mode 100644 ldap.py create mode 100644 server.py diff --git a/dn.py b/dn.py index 1ce21a3..020e7d0 100644 --- a/dn.py +++ b/dn.py @@ -119,3 +119,6 @@ class DNScope(Enum): singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject. wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates. + @classmethod + def from_bytes(cls, data): + return self(data) diff --git a/ldap.py b/ldap.py new file mode 100644 index 0000000..48d3091 --- /dev/null +++ b/ldap.py @@ -0,0 +1,488 @@ +from collections import namedtuple +import enum + +BERObject = namedtuple('BERObject', ['tag', 'content']) + +class IncompleteBERError(ValueError): + def __init__(self, expected_length=-1): + super().__init__() + self.expected_length = expected_length + +def decode_ber(data): + index = 0 + if len(data) < 2: + raise IncompleteBERError(2) + identifier = data[index] + ber_class = identifier >> 6 + ber_constructed = bool(identifier & 0x20) + ber_type = identifier & 0x1f + index += 1 + if not data[index] & 0x80: + length = data[index] + elif data[index] == 0x80: + raise ValueError('Indefinite form not implemented') + elif data[index] == 0xff: + return ValueError('BER length invalid') + else: + num = data[index] & ~0x80 + index += 1 + if len(data) < index + num: + raise IncompleteBERError(index + num) + length = 0 + for octet in data[index:index + num]: + length = length << 8 | octet + index += num + if len(data) < index + length + 1: + raise IncompleteBERError(index + length + 1) + ber_content = data[index + 1: index + length + 1] + rest = data[index + length + 1:] + return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest + +def decode_ber_integer(data): + if not data: + return 0 + value = -1 if data[0] & 0x80 else 0 + for octet in data: + value = value << 8 | octet + return value + +def encode_ber(obj): + tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111) + length = len(obj.content) + if length >= 127: + raise NotImplementedError('Long form length encoding not implemented') + return bytes([tag, length]) + obj.content + +def encode_ber_integer(value): + if value < 0 or value > 255: + raise NotImplementedError('Encoding of integers greater than 255 is not implemented') + return bytes([value]) + +class BERType: + @classmethod + def from_ber(cls, data): + raise NotImplementedError() + + @classmethod + def to_ber(cls, obj): + raise NotImplementedError() + + def __bytes__(self): + return type(self).to_ber(self) + +class OctetString(BERType): + ber_tag = (0, False, 4) + + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + if obj.tag != cls.ber_tag: + raise ValueError('Expected tag %s but found %s'%(cls.ber_tag, obj.tag)) + return obj.content, rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, bytes): + raise TypeError() + return encode_ber(BERObject(cls.ber_tag, obj)) + +class Integer(BERType): + ber_tag = (0, False, 2) + + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + if obj.tag != cls.ber_tag: + raise ValueError() + return decode_ber_integer(obj.content), rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, int): + raise TypeError() + return encode_ber(BERObject(cls.ber_tag, encode_ber_integer(obj))) + +class Boolean(BERType): + ber_tag = (0, False, 1) + + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + if obj.tag != cls.ber_tag: + raise ValueError() + return bool(decode_ber_integer(obj.content)), rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, bool): + raise TypeError() + content = b'\xff' if obj else b'\x00' + return encode_ber(BERObject(cls.ber_tag, content)) + +class LDAPString(OctetString): + @classmethod + def from_ber(cls, data): + raw, rest = super().from_ber(data) + return raw.decode(), rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, str): + raise TypeError() + return super().to_ber(obj.encode()) + +class Set(BERType): + ber_tag = (0, True, 17) + set_type = OctetString + + @classmethod + def from_ber(cls, data): + setobj, rest = decode_ber(data) + if setobj.tag != cls.ber_tag: + raise ValueError() + objs = [] + data = setobj.content + while data: + obj, data = cls.set_type.from_ber(data) + objs.append(obj) + return list(objs), rest + + @classmethod + def to_ber(cls, obj): + content = b'' + for item in obj: + content += cls.set_type.to_ber(item) + return encode_ber(BERObject(cls.ber_tag, content)) + +class SequenceOf(Set): + ber_tag = (0, True, 16) + +class Sequence(BERType): + ber_tag = (0, True, 16) + sequence_fields = [ + #(Type, attr_name, default_value), + ] + + def __init__(self, *args, **kwargs): + for index, spec in enumerate(type(self).sequence_fields): + field_type, name, default = spec + if index < len(args): + value = args[index] + elif name in kwargs: + value = kwargs[name] + else: + value = default() if callable(default) else default + setattr(self, name, value) + + def __repr__(self): + args = [] + for field_type, name, default in type(self).sequence_fields: + args.append('%s=%s'%(name, repr(getattr(self, name)))) + return '<%s(%s)>'%(type(self).__name__, ', '.join(args)) + + @classmethod + def from_ber(cls, data): + seqobj, rest = decode_ber(data) + if seqobj.tag != cls.ber_tag: + raise ValueError() + args = [] + data = seqobj.content + for field_type, name, default in cls.sequence_fields: + obj, data = field_type.from_ber(data) + args.append(obj) + return cls(*args), rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, cls): + raise TypeError() + content = b'' + for field_type, name, default in cls.sequence_fields: + content += field_type.to_ber(getattr(obj, name)) + return encode_ber(BERObject(cls.ber_tag, content)) + +class Choice(BERType): + ber_tag = None + + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + for subcls in cls.__subclasses__(): + if subcls.ber_tag == obj.tag: + return subcls.from_ber(data) + return None, rest + + @classmethod + def to_ber(cls, obj): + for subcls in cls.__subclasses__(): + if isinstance(obj, subcls): + return subcls.to_ber(obj) + raise TypeError() + +class Wrapper(BERType): + ber_tag = None + wrapped_attribute = None + wrapped_type = None + wrapped_default = None + wrapped_clsattrs = {} + + def __init__(self, *args, **kwargs): + cls = type(self) + attribute = cls.wrapped_attribute + if args: + setattr(self, attribute, args[0]) + elif kwargs: + setattr(self, attribute, kwargs[attribute]) + else: + setattr(self, attribute, cls.wrapped_default() if callable(cls.wrapped_default) else cls.wrapped_default) + + def __repr__(self): + return '<%s(%s)>'%(type(self).__name__, repr(getattr(self, type(self).wrapped_attribute))) + + @classmethod + def from_ber(cls, data): + class WrappedType(cls.wrapped_type): + ber_tag = cls.ber_tag + for key, value in cls.wrapped_clsattrs.items(): + setattr(WrappedType, key, value) + value, rest = WrappedType.from_ber(data) + return cls(value), rest + + @classmethod + def to_ber(cls, obj): + class WrappedType(cls.wrapped_type): + ber_tag = cls.ber_tag + for key, value in cls.wrapped_clsattrs.items(): + setattr(WrappedType, key, value) + if not isinstance(obj, cls): + raise TypeError() + return WrappedType.to_ber(getattr(obj, cls.wrapped_attribute)) + +class Filter(Choice): + pass + +class FilterAnd(Wrapper, Filter): + ber_tag = (2, True, 0) + wrapped_attribute = 'filters' + wrapped_type = Set + wrapped_clsattrs = {'set_type': Filter} + +class FilterOr(Wrapper, Filter): + ber_tag = (2, True, 1) + wrapped_attribute = 'filters' + wrapped_type = Set + wrapped_clsattrs = {'set_type': Filter} + +class FilterNot(Sequence, Filter): + ber_tag = (2, True, 2) + sequence_fields = [ + (Filter, 'filter', None) + ] + +class FilterEqual(Sequence, Filter): + ber_tag = (2, True, 3) + sequence_fields = [ + (LDAPString, 'attribute', None), + (OctetString, 'value', None) + ] + +class FilterPresent(Wrapper, Filter): + ber_tag = (2, False, 7) + wrapped_attribute = 'attribute' + wrapped_type = LDAPString + wrapped_default = None + +class Enum(BERType): + ber_tag = (0, False, 10) + enum_type = None + + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + if obj.tag != cls.ber_tag: + raise ValueError() + value = decode_ber_integer(obj.content) + return cls.enum_type(value), rest + + @classmethod + def to_ber(cls, obj): + if not isinstance(obj, cls.enum_type): + raise TypeError() + return encode_ber(BERObject(cls.ber_tag, encode_ber_integer(obj.value))) + +class SearchScope(enum.Enum): + baseObject = 0 # The scope is constrained to the entry named by baseObject. + singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject. + wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates. + +class SearchScopeEnum(Enum): + enum_type = SearchScope + +class DerefAliases(enum.Enum): + neverDerefAliases = 0 + derefInSearching = 1 + derefFindingBaseObj = 2 + derefAlways = 3 + +class DerefAliasesEnum(Enum): + enum_type = DerefAliases + +class LDAPResultCode(enum.Enum): + success = 0 + operationsError = 1 + protocolError = 2 + timeLimitExceeded = 3 + sizeLimitExceeded = 4 + compareFalse = 5 + compareTrue = 6 + authMethodNotSupported = 7 + strongerAuthRequired = 8 + # -- 9 reserved -- + referral = 10 + adminLimitExceeded = 11 + unavailableCriticalExtension = 12 + confidentialityRequired = 13 + saslBindInProgress = 14 + noSuchAttribute = 16 + undefinedAttributeType = 17 + inappropriateMatching = 18 + constraintViolation = 19 + attributeOrValueExists = 20 + invalidAttributeSyntax = 21 + # -- 22-31 unused -- + noSuchObject = 32 + aliasProblem = 33 + invalidDNSyntax = 34 + # -- 35 reserved for undefined isLeaf -- + aliasDereferencingProblem = 36 + # -- 37-47 unused -- + inappropriateAuthentication = 48 + invalidCredentials = 49 + insufficientAccessRights = 50 + busy = 51 + unavailable = 52 + unwillingToPerform = 53 + loopDetect = 54 + # -- 55-63 unused -- + namingViolation = 64 + objectClassViolation = 65 + notAllowedOnNonLeaf = 66 + notAllowedOnRDN = 67 + entryAlreadyExists = 68 + objectClassModsProhibited = 69 + # -- 70 reserved for CLDAP -- + affectsMultipleDSAs = 71 + # -- 72-79 unused -- + other = 80 + +class LDAPResultCodeEnum(Enum): + enum_type = LDAPResultCode + +class LDAPResult(Sequence): + ber_tag = (5, True, 1) + sequence_fields = [ + (LDAPResultCodeEnum, 'resultCode', None), + (LDAPString, 'matchedDN', ''), + (LDAPString, 'diagnosticMessage', ''), + ] + +class AttributeSelection(SequenceOf): + set_type = LDAPString + +class AuthenticationChoice(Choice): + pass + +class SimpleAuthentication(Wrapper, AuthenticationChoice): + ber_tag = (2, False, 0) + wrapped_attribute = 'password' + wrapped_type = OctetString + wrapped_default = b'' + + def __repr__(self): + if not self.password: + return '<%s(EMPTY PASSWORD)>'%(type(self).__name__) + return '<%s(PASSWORD HIDDEN)>'%(type(self).__name__) + +class AttributeValueSet(Set): + set_type = OctetString + +class PartialAttribute(Sequence): + sequence_fields = [ + (LDAPString, 'type', None), + (AttributeValueSet, 'vals', lambda: []), + ] + +class PartialAttributeList(SequenceOf): + set_type = PartialAttribute + +class ProtocolOp(Choice): + pass + +class BindRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 0) + sequence_fields = [ + (Integer, 'version', 3), + (LDAPString, 'name', ''), + (AuthenticationChoice, 'authentication', lambda: SimpleAuthentication()) + ] + +class BindResponse(LDAPResult, ProtocolOp): + ber_tag = (1, True, 1) + +class SearchRequest(Sequence, ProtocolOp): + ber_tag = (1, True, 3) + sequence_fields = [ + (LDAPString, 'baseObject', ''), + (SearchScopeEnum, 'scope', SearchScope.wholeSubtree), + (DerefAliasesEnum, 'derefAliases', DerefAliases.neverDerefAliases), + (Integer, 'sizeLimit', 0), + (Integer, 'timeLimit', 0), + (Boolean, 'typesOnly', False), + (Filter, 'filter', lambda: FilterPresent('objectClass')), + (AttributeSelection, 'attributes', lambda: []) + ] + + @classmethod + def from_ber(cls, data): + return super().from_ber(data) + +class SearchResultEntry(Sequence, ProtocolOp): + ber_tag = (1, True, 4) + sequence_fields = [ + (LDAPString, 'objectName', ''), + (PartialAttributeList, 'attributes', lambda: []), + ] + +class SearchResultDone(LDAPResult, ProtocolOp): + ber_tag = (1, True, 5) + +class UnbindRequest(Sequence, ProtocolOp): + ber_tag = (1, False, 2) + +class LDAPMessage(Sequence): + sequence_fields = [ + (Integer, 'messageID', None), + (ProtocolOp, 'protocolOp', None) + ] + +class ShallowProtocolOp: + @classmethod + def from_ber(cls, data): + obj, rest = decode_ber(data) + for subcls in ProtocolOp.__subclasses__(): + if subcls.ber_tag == obj.tag: + return subcls, rest + return None, rest + +class ShallowLDAPMessage(Sequence): + sequence_fields = [ + (Integer, 'messageID', None), + (ShallowProtocolOp, 'protocolOp', None) + ] + +bind1 = b'0\x0c\x02\x01\x01`\x07\x02\x01\x03\x04\x00\x80\x00' +bind2 = b'0\x1c\x02\x01\x01`\x17\x02\x01\x03\x04\nuid=foobar\x80\x06foobar' +search1 = b'0?\x02\x01\x02c:\x04\x1aou=users,dc=example,dc=com\n\x01\x01\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\x87\x0bobjectclass0\x00' +search2 = b'0@\x02\x01\x02c;\x04\x00\n\x01\x02\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\xa0&\xa3\x12\x04\x0bobjectClass\x04\x03top\xa3\x10\x04\x03uid\x04\ttestadmin0\x00' +unbind1 = b'0\x05\x02\x01\x03B\x00' + diff --git a/server.py b/server.py new file mode 100644 index 0000000..10abdbe --- /dev/null +++ b/server.py @@ -0,0 +1,65 @@ +import traceback +from socketserver import ForkingTCPServer, BaseRequestHandler + +from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError + +class Handler(BaseRequestHandler): + def setup(self): + self.bind_dn = b'' + self.keep_running = True + + def handle_bind(self, req): + self.bind_dn = req.protocolOp.name + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + + def handle_search(self, req): + self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success))) + + def handle_unbind(self, req): + self.keep_running = False + + def handle_message(self, data): + handlers = { + BindRequest: (self.handle_bind, BindResponse), + SearchRequest: (self.handle_search, SearchResultDone), + UnbindRequest: (self.handle_unbind, None), + } + shallowmsg, rest = ShallowLDAPMessage.from_ber(data) + if shallowmsg.protocolOp is None: + print('Ignoring unknown message') + return rest + func, errfunc = handlers[shallowmsg.protocolOp] + msg = None + try: + msg, _ = LDAPMessage.from_ber(data) + except Exception as e: + if errfunc: + self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError))) + traceback.print_exc() + return rest + print(msg) + try: + if func: + func(msg) + except Exception as e: + if errfunc: + self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other))) + traceback.print_exc() + return rest + return rest + + def send_msg(self, msg): + self.request.sendall(LDAPMessage.to_ber(msg)) + + def handle(self): + data = b'' + while self.keep_running: + try: + data = self.handle_message(data) + except IncompleteBERError: + chunk = self.request.recv(5) + if not chunk: + return + data += chunk + +ForkingTCPServer(('127.0.0.1', 1338), Handler).serve_forever() -- GitLab