From f062794f73bec944921bf0799ff4f6a6b16dab15 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Thu, 4 Mar 2021 01:30:39 +0100
Subject: [PATCH] decode_message_fast and partial filter support

---
 ldap_decode.py | 200 ++++++++++++++++++++++++++++++++++++++
 test.py        | 259 +++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 417 insertions(+), 42 deletions(-)
 create mode 100644 ldap_decode.py

diff --git a/ldap_decode.py b/ldap_decode.py
new file mode 100644
index 0000000..52b5edc
--- /dev/null
+++ b/ldap_decode.py
@@ -0,0 +1,200 @@
+# Decoding code based on python-ldap3
+
+def compute_ldap_message_size(data): # from ldap3
+	ret_value = -1
+	if len(data) > 2:
+		if data[1] <= 127:  # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long
+			ret_value = data[1] + 2
+		else:  # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length
+			bytes_length = data[1] - 128
+			if len(data) >= bytes_length + 2:
+				value_length = 0
+				cont = bytes_length
+				for byte in data[2:2 + bytes_length]:
+					cont -= 1
+					value_length += byte * (256 ** cont)
+				ret_value = value_length + 2 + bytes_length
+	return ret_value
+
+CLASSES = {(False, False): 0,  # Universal
+		   (False, True): 1,  # Application
+		   (True, False): 2,  # Context
+		   (True, True): 3}  # Private
+
+# a fast BER decoder for LDAP responses only
+def compute_ber_size(data):
+	"""
+	Compute size according to BER definite length rules
+	Returns size of value and value offset
+	"""
+
+	if data[1] <= 127:  # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long
+		return data[1], 2
+	else:  # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length
+		bytes_length = data[1] - 128
+		value_length = 0
+		cont = bytes_length
+		for byte in data[2: 2 + bytes_length]:
+			cont -= 1
+			value_length += byte * (256 ** cont)
+		return value_length, bytes_length + 2
+
+
+def decode_message_fast(message):
+	ber_len, ber_value_offset = compute_ber_size(get_bytes(message[:10]))  # get start of sequence, at maximum 3 bytes for length
+	decoded = decode_sequence(message, ber_value_offset, ber_len + ber_value_offset, LDAP_MESSAGE_CONTEXT)
+	return {
+		'messageID': decoded[0][3],
+		'protocolOp': decoded[1][2],
+		'payload': decoded[1][3],
+		'controls': decoded[2][3] if len(decoded) == 3 else None
+	}
+
+def decode_sequence(message, start, stop, context_decoders=None):
+	decoded = []
+	while start < stop:
+		octet = get_byte(message[start])
+		ber_class = CLASSES[(bool(octet & 0b10000000), bool(octet & 0b01000000))]
+		ber_constructed = bool(octet & 0b00100000)
+		ber_type = octet & 0b00011111
+		ber_decoder = DECODERS[(ber_class, octet & 0b00011111)] if ber_class < 2 else None
+		ber_len, ber_value_offset = compute_ber_size(get_bytes(message[start: start + 10]))
+		start += ber_value_offset
+		if ber_decoder:
+			value = ber_decoder(message, start, start + ber_len, context_decoders)  # call value decode function
+		else:
+			# try:
+			value = context_decoders[ber_type](message, start, start + ber_len)  # call value decode function for context class
+			# except KeyError:
+			#	 if ber_type == 3:  # Referral in result
+			#		 value = decode_sequence(message, start, start + ber_len)
+			#	 else:
+			#		 raise  # re-raise, should never happen
+		decoded.append((ber_class, ber_constructed, ber_type, value))
+		start += ber_len
+
+	return decoded
+
+
+def decode_integer(message, start, stop, context_decoders=None):
+	first = message[start]
+	value = -1 if get_byte(first) & 0x80 else 0
+	for octet in message[start: stop]:
+		value = value << 8 | get_byte(octet)
+
+	return value
+
+
+def decode_octet_string(message, start, stop, context_decoders=None):
+	return message[start: stop]
+
+
+def decode_boolean(message, start, stop, context_decoders=None):
+	return False if message[start: stop] == 0 else True
+
+def decode_search_request(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	return {
+		'baseObject': decoded[0][3],
+		'scope': decoded[1][3],
+		'filter': decoded[6][3],
+		'attributes': [item[3] for item in decoded[7][3]],
+	}
+
+def decode_bind_request(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, BIND_REQUEST_CONTEXT)
+	return {
+		'version': decoded[0][3],
+		'name': decoded[1][3],
+		'authentication': decoded[2][3],
+	}
+
+def decode_filter_and(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	return ('and',) + tuple(item[3] for item in decoded)
+
+def decode_filter_or(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	return ('or',) + tuple(item[3] for item in decoded)
+
+def decode_filter_not(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	return ('not', decoded[0][3])
+
+def decode_filter_equal(message, start, stop, context_decoders=None):
+	decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	return ('equal', decoded[0][3], decoded[1][3])
+
+def decode_filter_present(message, start, stop, context_decoders=None):
+	#decoded = decode_sequence(message, start, stop, SEARCH_FILTER_CONTEXT)
+	#return ('present', ecoded[0][3])
+	return ('present', decode_octet_string(message, start, stop))
+
+def decode_filter_unknown(message, start, stop, context_decoders=None):
+	return ('unknown',)
+
+def decode_controls(message, start, stop, context_decoders=None):
+	return decode_sequence(message, start, stop, CONTROLS_CONTEXT)
+
+######
+
+if str is not bytes:  # Python 3
+	def get_byte(x):
+		return x
+
+	def get_bytes(x):
+		return x
+else:  # Python 2
+	def get_byte(x):
+		return ord(x)
+
+	def get_bytes(x):
+		return bytearray(x)
+
+DECODERS = {
+	# Universal
+	(0, 1): decode_boolean,  # Boolean
+	(0, 2): decode_integer,  # Integer
+	(0, 4): decode_octet_string,  # Octet String
+	(0, 10): decode_integer,  # Enumerated
+	(0, 16): decode_sequence,  # Sequence
+	(0, 17): decode_sequence,  # Set
+	# Application
+	(1, 0): decode_bind_request,   # Bind
+	(1, 2): decode_sequence,       # Unbind
+	(1, 3): decode_search_request, # Search request
+	(1, 6): decode_octet_string,   # Modify request
+	(1, 8): decode_octet_string,   # Add request
+	(1, 10): decode_octet_string,  # Del request
+	(1, 12): decode_octet_string,  # ModDN request
+	(1, 14): decode_octet_string,  # Compare request
+	(1, 16): decode_sequence,      # Abandon
+	(1, 23): decode_octet_string,  # Extended Request
+}
+
+BIND_REQUEST_CONTEXT = {
+	0: decode_octet_string, # simple
+	3: decode_octet_string, # SaslCredentials
+}
+
+SEARCH_FILTER_CONTEXT = {
+	0: decode_filter_and,      # and
+	1: decode_filter_or,       # or
+	2: decode_filter_not,      # not
+	3: decode_filter_equal,    # equalityMatch
+	4: decode_filter_unknown,  # substrings
+	5: decode_filter_unknown,  # greaterOrEqual
+	6: decode_filter_unknown,  # lessOrEqual
+	7: decode_filter_present,  # present
+	8: decode_filter_unknown,  # approxMatch
+	9: decode_filter_unknown,  # extensibleMatch
+}
+
+LDAP_MESSAGE_CONTEXT = {
+	0: decode_controls,  # Controls
+	3: decode_sequence  # Referral
+}
+
+CONTROLS_CONTEXT = {
+	0: decode_sequence  # Control
+}
diff --git a/test.py b/test.py
index 72275b0..5584b39 100644
--- a/test.py
+++ b/test.py
@@ -1,25 +1,187 @@
 import socket
-from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse
+from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute
 from ldap3.utils.asn1 import decoder, encode
+from ldap_decode import decode_message_fast, compute_ldap_message_size
 
-def compute_ldap_message_size(data): # from ldap3
-	ret_value = -1
-	if len(data) > 2:
-		if data[1] <= 127:  # BER definite length - short form. Highest bit of byte 1 is 0, message length is in the last 7 bits - Value can be up to 127 bytes long
-			ret_value = data[1] + 2
-		else:  # BER definite length - long form. Highest bit of byte 1 is 1, last 7 bits counts the number of following octets containing the value length
-			bytes_length = data[1] - 128
-			if len(data) >= bytes_length + 2:
-				value_length = 0
-				cont = bytes_length
-				for byte in data[2:2 + bytes_length]:
-					cont -= 1
-					value_length += byte * (256 ** cont)
-				ret_value = value_length + 2 + bytes_length
-	return ret_value
+class LDAPObject:
+	@classmethod
+	def _ldap_filter(cls, operator, *args):
+		if operator == 'and':
+			return cls._ldap_filter_and(*args)
+		elif operator == 'or':
+			return cls._ldap_filter_or(*args)
+		elif operator == 'not':
+			return cls._ldap_filter_not(*args)
+		elif operator == 'equal':
+			return cls._ldap_filter_equal(*args)
+		elif operator == 'present':
+			return cls._ldap_filter_present(*args)
+		else:
+			return None
+
+	@classmethod
+	def _ldap_filter_and(cls, *components):
+		results = []
+		for component in components:
+			subres = cls._ldap_filter(*component)
+			if subres is True:
+				continue
+			elif subres is False:
+				return False
+			elif subres is None:
+				return None
+			else:
+				results.append(subres)
+		return cls.ldap_filter_and(*results)
+
+	@classmethod
+	def _ldap_filter_or(cls, *components):
+		results = []
+		for component in components:
+			subres = cls._ldap_filter(*component)
+			if subres is True:
+				return True
+			elif subres is False:
+				continue
+			elif subres is None:
+				return None
+			else:
+				results.append(subres)
+		return cls.ldap_filter_or(*results)
+
+	@classmethod
+	def _ldap_filter_not(cls, subfilter):
+		subres = cls._ldap_filter(*subfilter)
+		if subres is True:
+			return False
+		elif subres is False:
+			return True
+		elif subres is None:
+			return None
+		else:
+			return cls.ldap_filter_not(subres)
+
+	@classmethod
+	def _ldap_filter_present(cls, name):
+		return cls.ldap_filter_present(name.decode().lower())
+
+	@classmethod
+	def _ldap_filter_equal(cls, name, value):
+		return cls.ldap_filter_equal(name.decode().lower(), bytes(value))
+
+	@classmethod
+	def ldap_filter_present(cls, name):
+		# return True if always true, False if always false, None if undefined, custom object else
+		return None
+
+	@classmethod
+	def ldap_filter_equal(cls, name, value):
+		return None
+
+	@classmethod
+	def ldap_filter_and(cls, *results):
+		return None
+
+	@classmethod
+	def ldap_filter_or(cls, *results):
+		return None
+		
+	@classmethod
+	def ldap_filter_not(cls, result):
+		return None
+
+	@classmethod
+	def ldap_search(cls, dn_base, filter_res):
+		return []
+
+def split_dn(dn):
+	if not dn:
+		return []
+	return dn.split(',')
+
+class StaticLDAPObject(LDAPObject):
+	present_map = {} # name -> set of objs
+	value_map = {} # (name, value) -> set of objs
+	all_objects = set() # set of all objs
+	dn_map = {tuple(): set()} # (dn part tuples) -> set of objs
+
+	def __init__(self, dn, **kwargs):
+		self.ldap_dn = dn
+		cls = type(self)
+		cls.dn_map[tuple()].add(self)
+		dn_path = []
+		for part in reversed(split_dn(dn)):
+			dn_path.insert(0, part)
+			key = tuple(dn_path)
+			cls.dn_map[key] = cls.dn_map.get(key, set())
+			cls.dn_map[key].add(self)
+		cls.all_objects.add(self)
+		self.ldap_attributes = {}
+		for name, _values in kwargs.items():
+			name = name.lower()
+			if not isinstance(_values, list):
+				_values = [_values]
+			values = []
+			for value in _values:
+				if value is None or value == '' or value == b'':
+					continue
+				if isinstance(value, int):
+					value = str(value)
+				if isinstance(value, str):
+					value = value.encode()
+				values.append(value)
+			if not values:
+				continue
+			self.ldap_attributes[name] = values
+			cls.present_map[name] = cls.present_map.get(name, set())
+			cls.present_map[name].add(self)
+			for value in values:
+				key = (name, value)
+				cls.value_map[key] = cls.value_map.get(key, set())
+				cls.value_map[key].add(self)
+
+	@classmethod
+	def ldap_filter_present(cls, name):
+		return cls.present_map.get(name, set())
+
+	@classmethod
+	def ldap_filter_equal(cls, name, value):
+		return cls.value_map.get((name, value), set())
+
+	@classmethod
+	def ldap_filter_and(cls, *results):
+		objs = results[0]
+		for subres in results[1:]:
+			objs = objs.intersection(subres)
+		return objs
+
+	@classmethod
+	def ldap_filter_or(cls, *results):
+		objs = results[0]
+		for subres in results[1:]:
+			objs = objs.union(subres)
+		return objs
+		
+	@classmethod
+	def ldap_filter_not(cls, result):
+		return cls.all_objects.difference(result)
+
+	@classmethod
+	def ldap_search(cls, dn_base, filter_res):
+		key = tuple(split_dn(dn_base))
+		objs = cls.dn_map.get(key, set())
+		if filter_res is None or filter_res is False:
+			return []
+		if filter_res is not True:
+			objs = objs.intersection(filter_res)
+		return objs
+
+StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser')
+StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin')
 
 class LDAPHandler:
-	def __init__(self, conn):
+	def __init__(self, conn, models=[]):
+		self.models = models
 		self.conn = conn
 		self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn
 		self.buf = b''
@@ -33,8 +195,22 @@ class LDAPHandler:
 		self.send_response(msgid, 'bindResponse', resp)
 
 	def handle_search(self, msgid, op):
+		objs = []
+		for model in self.models:
+			filter_obj = model._ldap_filter(*op['filter'])
+			objs += model.ldap_search(op['baseObject'].decode(), filter_obj)
+		for obj in objs:
+			resp = SearchResultEntry()
+			resp['object'] = obj.ldap_dn
+			for name, values in obj.ldap_attributes.items():
+				attr = PartialAttribute()
+				attr['type'] = name
+				for value in values:
+					attr['vals'].append(value)
+				resp['attributes'].append(attr)
+			self.send_response(msgid, 'searchResEntry', resp)
 		resp = SearchResultDone()
-		resp['resultCode'] = ResultCode('other')
+		resp['resultCode'] = ResultCode('success')
 		resp['matchedDN'] = ''
 		resp['diagnosticMessage'] = ''
 		self.send_response(msgid, 'searchResDone', resp)
@@ -90,7 +266,10 @@ class LDAPHandler:
 			if not chunk:
 				return None
 			self.buf += chunk
-		req, self.buf = decoder.decode(self.buf, asn1Spec=LDAPMessage())
+		print('received ', self.buf)
+		req = decode_message_fast(self.buf)
+		print(req)
+		self.buf = self.buf[size:]
 		return req
 
 	def send_response(self, message_id, response_type, response):
@@ -105,30 +284,26 @@ class LDAPHandler:
 			msg = self.recv_request()
 			if msg is None:
 				break
-			print('received', msg)
-			msgid = msg['messageID']
-			message_type = msg.getComponentByName('protocolOp').getName()
-			op = msg['protocolOp'].getComponent()
-			if message_type == 'bindRequest':
-				self.handle_bind(msgid, op)
-			elif message_type == 'unbindRequest':
+			if msg['protocolOp'] == 0: # bindRequest
+				self.handle_bind(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 2: # unbindRequest
 				break
-			elif message_type == 'searchRequest':
-				self.handle_search(msgid, op)
-			elif message_type == 'modifyRequest':
-				self.handle_modify(msgid, op)
-			elif message_type == 'addRequest':
-				self.handle_add(msgid, op)
-			elif message_type == 'delRequest':
-				self.handle_del(msgid, op)
-			elif message_type == 'modDNRequest':
-				self.handle_moddn(msgid, op)
-			elif message_type == 'compareRequest':
-				self.handle_compare(msgid, op)
-			elif message_type == 'abandonRequest':
+			elif msg['protocolOp'] == 3: # searchRequest
+				self.handle_search(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 6: # modifyRequest
+				self.handle_modify(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 8: # addRequest
+				self.handle_add(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 10: # delRequest
+				self.handle_del(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 12: # modDNRequest
+				self.handle_moddn(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 14: # compareRequest
+				self.handle_compare(msg['messageID'], msg['payload'])
+			elif msg['protocolOp'] == 16: # abandonRequest
 				pass
-			elif message_type == 'extendedReq':
-				self.handle_extended(msgid, op)
+			elif msg['protocolOp'] == 23: # extendedReq
+				self.handle_extended(msg['messageID'], msg['payload'])
 			else:
 				raise Exception()
 		self.conn.close()
@@ -139,6 +314,6 @@ sock.listen(1)
 while True:
 	conn, addr = sock.accept()
 	print('accepted connection from', addr)
-	LDAPHandler(conn).run()
+	LDAPHandler(conn, models=[StaticLDAPObject]).run()
 	print('connection closed')
 
-- 
GitLab