diff --git a/ldap_decode.py b/ldap_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..52b5edca4c53c21323e8095063d96c3111d9d053 --- /dev/null +++ b/ldap_decode.py @@ -0,0 +1,200 @@ +# Decoding code based on python-ldap3 + +def compute_ldap_message_size(data): # from ldap3 + ret_value = -1 + if len(data) > 2: + if data[1] <= 127: # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long + ret_value = data[1] + 2 + else: # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length + bytes_length = data[1] - 128 + if len(data) >= bytes_length + 2: + value_length = 0 + cont = bytes_length + for byte in data[2:2 + bytes_length]: + cont -= 1 + value_length += byte * (256 ** cont) + ret_value = value_length + 2 + bytes_length + return ret_value + +CLASSES = {(False, False): 0, # Universal + (False, True): 1, # Application + (True, False): 2, # Context + (True, True): 3} # Private + +# a fast BER decoder for LDAP responses only +def compute_ber_size(data): + """ + Compute size according to BER definite length rules + Returns size of value and value offset + """ + + if data[1] <= 127: # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long + return data[1], 2 + else: # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length + bytes_length = data[1] - 128 + value_length = 0 + cont = bytes_length + for byte in data[2: 2 + bytes_length]: + cont -= 1 + value_length += byte * (256 ** cont) + return value_length, bytes_length + 2 + + +def decode_message_fast(message): + ber_len, ber_value_offset = compute_ber_size(get_bytes(message[:10])) # get start of sequence, at maximum 3 bytes for length + decoded = decode_sequence(message, ber_value_offset, ber_len + ber_value_offset, LDAP_MESSAGE_CONTEXT) + return { + 'messageID': decoded[0][3], + 'protocolOp': decoded[1][2], + 'payload': decoded[1][3], + 'controls': decoded[2][3] if len(decoded) == 3 else None + } + +def decode_sequence(message, start, stop, context_decoders=None): + decoded = [] + while start < stop: + octet = get_byte(message[start]) + ber_class = CLASSES[(bool(octet & 0b10000000), bool(octet & 0b01000000))] + ber_constructed = bool(octet & 0b00100000) + ber_type = octet & 0b00011111 + ber_decoder = DECODERS[(ber_class, octet & 0b00011111)] if ber_class < 2 else None + ber_len, ber_value_offset = compute_ber_size(get_bytes(message[start: start + 10])) + start += ber_value_offset + if ber_decoder: + value = ber_decoder(message, start, start + ber_len, context_decoders) # call value decode function + else: + # try: + value = context_decoders[ber_type](message, start, start + ber_len) # call value decode function for context class + # except KeyError: + # if ber_type == 3: # Referral in result + # value = decode_sequence(message, start, start + ber_len) + # else: + # raise # re-raise, should never happen + decoded.append((ber_class, ber_constructed, ber_type, value)) + start += ber_len + + return decoded + + +def decode_integer(message, start, stop, context_decoders=None): + first = message[start] + value = -1 if get_byte(first) & 0x80 else 0 + for octet in message[start: stop]: + value = value << 8 | get_byte(octet) + + return value + + +def decode_octet_string(message, start, stop, context_decoders=None): + return message[start: stop] + + +def decode_boolean(message, start, stop, context_decoders=None): + return False if message[start: stop] == 0 else True + +def decode_search_request(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + return { + 'baseObject': decoded[0][3], + 'scope': decoded[1][3], + 'filter': decoded[6][3], + 'attributes': [item[3] for item in decoded[7][3]], + } + +def decode_bind_request(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, BIND_REQUEST_CONTEXT) + return { + 'version': decoded[0][3], + 'name': decoded[1][3], + 'authentication': decoded[2][3], + } + +def decode_filter_and(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + return ('and',) + tuple(item[3] for item in decoded) + +def decode_filter_or(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + return ('or',) + tuple(item[3] for item in decoded) + +def decode_filter_not(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + return ('not', decoded[0][3]) + +def decode_filter_equal(message, start, stop, context_decoders=None): + decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + return ('equal', decoded[0][3], decoded[1][3]) + +def decode_filter_present(message, start, stop, context_decoders=None): + #decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT) + #return ('present', ecoded[0][3]) + return ('present', decode_octet_string(message, start, stop)) + +def decode_filter_unknown(message, start, stop, context_decoders=None): + return ('unknown',) + +def decode_controls(message, start, stop, context_decoders=None): + return decode_sequence(message, start, stop, CONTROLS_CONTEXT) + +###### + +if str is not bytes: # Python 3 + def get_byte(x): + return x + + def get_bytes(x): + return x +else: # Python 2 + def get_byte(x): + return ord(x) + + def get_bytes(x): + return bytearray(x) + +DECODERS = { + # Universal + (0, 1): decode_boolean, # Boolean + (0, 2): decode_integer, # Integer + (0, 4): decode_octet_string, # Octet String + (0, 10): decode_integer, # Enumerated + (0, 16): decode_sequence, # Sequence + (0, 17): decode_sequence, # Set + # Application + (1, 0): decode_bind_request, # Bind + (1, 2): decode_sequence, # Unbind + (1, 3): decode_search_request, # Search request + (1, 6): decode_octet_string, # Modify request + (1, 8): decode_octet_string, # Add request + (1, 10): decode_octet_string, # Del request + (1, 12): decode_octet_string, # ModDN request + (1, 14): decode_octet_string, # Compare request + (1, 16): decode_sequence, # Abandon + (1, 23): decode_octet_string, # Extended Request +} + +BIND_REQUEST_CONTEXT = { + 0: decode_octet_string, # simple + 3: decode_octet_string, # SaslCredentials +} + +SEARCH_FILTER_CONTEXT = { + 0: decode_filter_and, # and + 1: decode_filter_or, # or + 2: decode_filter_not, # not + 3: decode_filter_equal, # equalityMatch + 4: decode_filter_unknown, # substrings + 5: decode_filter_unknown, # greaterOrEqual + 6: decode_filter_unknown, # lessOrEqual + 7: decode_filter_present, # present + 8: decode_filter_unknown, # approxMatch + 9: decode_filter_unknown, # extensibleMatch +} + +LDAP_MESSAGE_CONTEXT = { + 0: decode_controls, # Controls + 3: decode_sequence # Referral +} + +CONTROLS_CONTEXT = { + 0: decode_sequence # Control +} diff --git a/test.py b/test.py index 72275b09e78070b175c78eb5a811f77ec5e84509..5584b39076f4220ab2524014b08f7c4b0854aff3 100644 --- a/test.py +++ b/test.py @@ -1,25 +1,187 @@ import socket -from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse +from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute from ldap3.utils.asn1 import decoder, encode +from ldap_decode import decode_message_fast, compute_ldap_message_size -def compute_ldap_message_size(data): # from ldap3 - ret_value = -1 - if len(data) > 2: - if data[1] <= 127: # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long - ret_value = data[1] + 2 - else: # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length - bytes_length = data[1] - 128 - if len(data) >= bytes_length + 2: - value_length = 0 - cont = bytes_length - for byte in data[2:2 + bytes_length]: - cont -= 1 - value_length += byte * (256 ** cont) - ret_value = value_length + 2 + bytes_length - return ret_value +class LDAPObject: + @classmethod + def _ldap_filter(cls, operator, *args): + if operator == 'and': + return cls._ldap_filter_and(*args) + elif operator == 'or': + return cls._ldap_filter_or(*args) + elif operator == 'not': + return cls._ldap_filter_not(*args) + elif operator == 'equal': + return cls._ldap_filter_equal(*args) + elif operator == 'present': + return cls._ldap_filter_present(*args) + else: + return None + + @classmethod + def _ldap_filter_and(cls, *components): + results = [] + for component in components: + subres = cls._ldap_filter(*component) + if subres is True: + continue + elif subres is False: + return False + elif subres is None: + return None + else: + results.append(subres) + return cls.ldap_filter_and(*results) + + @classmethod + def _ldap_filter_or(cls, *components): + results = [] + for component in components: + subres = cls._ldap_filter(*component) + if subres is True: + return True + elif subres is False: + continue + elif subres is None: + return None + else: + results.append(subres) + return cls.ldap_filter_or(*results) + + @classmethod + def _ldap_filter_not(cls, subfilter): + subres = cls._ldap_filter(*subfilter) + if subres is True: + return False + elif subres is False: + return True + elif subres is None: + return None + else: + return cls.ldap_filter_not(subres) + + @classmethod + def _ldap_filter_present(cls, name): + return cls.ldap_filter_present(name.decode().lower()) + + @classmethod + def _ldap_filter_equal(cls, name, value): + return cls.ldap_filter_equal(name.decode().lower(), bytes(value)) + + @classmethod + def ldap_filter_present(cls, name): + # return True if always true, False if always false, None if undefined, custom object else + return None + + @classmethod + def ldap_filter_equal(cls, name, value): + return None + + @classmethod + def ldap_filter_and(cls, *results): + return None + + @classmethod + def ldap_filter_or(cls, *results): + return None + + @classmethod + def ldap_filter_not(cls, result): + return None + + @classmethod + def ldap_search(cls, dn_base, filter_res): + return [] + +def split_dn(dn): + if not dn: + return [] + return dn.split(',') + +class StaticLDAPObject(LDAPObject): + present_map = {} # name -> set of objs + value_map = {} # (name, value) -> set of objs + all_objects = set() # set of all objs + dn_map = {tuple(): set()} # (dn part tuples) -> set of objs + + def __init__(self, dn, **kwargs): + self.ldap_dn = dn + cls = type(self) + cls.dn_map[tuple()].add(self) + dn_path = [] + for part in reversed(split_dn(dn)): + dn_path.insert(0, part) + key = tuple(dn_path) + cls.dn_map[key] = cls.dn_map.get(key, set()) + cls.dn_map[key].add(self) + cls.all_objects.add(self) + self.ldap_attributes = {} + for name, _values in kwargs.items(): + name = name.lower() + if not isinstance(_values, list): + _values = [_values] + values = [] + for value in _values: + if value is None or value == '' or value == b'': + continue + if isinstance(value, int): + value = str(value) + if isinstance(value, str): + value = value.encode() + values.append(value) + if not values: + continue + self.ldap_attributes[name] = values + cls.present_map[name] = cls.present_map.get(name, set()) + cls.present_map[name].add(self) + for value in values: + key = (name, value) + cls.value_map[key] = cls.value_map.get(key, set()) + cls.value_map[key].add(self) + + @classmethod + def ldap_filter_present(cls, name): + return cls.present_map.get(name, set()) + + @classmethod + def ldap_filter_equal(cls, name, value): + return cls.value_map.get((name, value), set()) + + @classmethod + def ldap_filter_and(cls, *results): + objs = results[0] + for subres in results[1:]: + objs = objs.intersection(subres) + return objs + + @classmethod + def ldap_filter_or(cls, *results): + objs = results[0] + for subres in results[1:]: + objs = objs.union(subres) + return objs + + @classmethod + def ldap_filter_not(cls, result): + return cls.all_objects.difference(result) + + @classmethod + def ldap_search(cls, dn_base, filter_res): + key = tuple(split_dn(dn_base)) + objs = cls.dn_map.get(key, set()) + if filter_res is None or filter_res is False: + return [] + if filter_res is not True: + objs = objs.intersection(filter_res) + return objs + +StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser') +StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin') class LDAPHandler: - def __init__(self, conn): + def __init__(self, conn, models=[]): + self.models = models self.conn = conn self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn self.buf = b'' @@ -33,8 +195,22 @@ class LDAPHandler: self.send_response(msgid, 'bindResponse', resp) def handle_search(self, msgid, op): + objs = [] + for model in self.models: + filter_obj = model._ldap_filter(*op['filter']) + objs += model.ldap_search(op['baseObject'].decode(), filter_obj) + for obj in objs: + resp = SearchResultEntry() + resp['object'] = obj.ldap_dn + for name, values in obj.ldap_attributes.items(): + attr = PartialAttribute() + attr['type'] = name + for value in values: + attr['vals'].append(value) + resp['attributes'].append(attr) + self.send_response(msgid, 'searchResEntry', resp) resp = SearchResultDone() - resp['resultCode'] = ResultCode('other') + resp['resultCode'] = ResultCode('success') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'searchResDone', resp) @@ -90,7 +266,10 @@ class LDAPHandler: if not chunk: return None self.buf += chunk - req, self.buf = decoder.decode(self.buf, asn1Spec=LDAPMessage()) + print('received ', self.buf) + req = decode_message_fast(self.buf) + print(req) + self.buf = self.buf[size:] return req def send_response(self, message_id, response_type, response): @@ -105,30 +284,26 @@ class LDAPHandler: msg = self.recv_request() if msg is None: break - print('received', msg) - msgid = msg['messageID'] - message_type = msg.getComponentByName('protocolOp').getName() - op = msg['protocolOp'].getComponent() - if message_type == 'bindRequest': - self.handle_bind(msgid, op) - elif message_type == 'unbindRequest': + if msg['protocolOp'] == 0: # bindRequest + self.handle_bind(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 2: # unbindRequest break - elif message_type == 'searchRequest': - self.handle_search(msgid, op) - elif message_type == 'modifyRequest': - self.handle_modify(msgid, op) - elif message_type == 'addRequest': - self.handle_add(msgid, op) - elif message_type == 'delRequest': - self.handle_del(msgid, op) - elif message_type == 'modDNRequest': - self.handle_moddn(msgid, op) - elif message_type == 'compareRequest': - self.handle_compare(msgid, op) - elif message_type == 'abandonRequest': + elif msg['protocolOp'] == 3: # searchRequest + self.handle_search(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 6: # modifyRequest + self.handle_modify(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 8: # addRequest + self.handle_add(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 10: # delRequest + self.handle_del(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 12: # modDNRequest + self.handle_moddn(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 14: # compareRequest + self.handle_compare(msg['messageID'], msg['payload']) + elif msg['protocolOp'] == 16: # abandonRequest pass - elif message_type == 'extendedReq': - self.handle_extended(msgid, op) + elif msg['protocolOp'] == 23: # extendedReq + self.handle_extended(msg['messageID'], msg['payload']) else: raise Exception() self.conn.close() @@ -139,6 +314,6 @@ sock.listen(1) while True: conn, addr = sock.accept() print('accepted connection from', addr) - LDAPHandler(conn).run() + LDAPHandler(conn, models=[StaticLDAPObject]).run() print('connection closed')