import socket
from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute
from ldap3.utils.asn1 import decoder, encode
from ldap_decode import decode_message_fast, compute_ldap_message_size

class LDAPObject:
	@classmethod
	def _ldap_filter(cls, operator, *args):
		if operator == 'and':
			return cls._ldap_filter_and(*args)
		elif operator == 'or':
			return cls._ldap_filter_or(*args)
		elif operator == 'not':
			return cls._ldap_filter_not(*args)
		elif operator == 'equal':
			return cls._ldap_filter_equal(*args)
		elif operator == 'present':
			return cls._ldap_filter_present(*args)
		else:
			return None

	@classmethod
	def _ldap_filter_and(cls, *components):
		results = []
		for component in components:
			subres = cls._ldap_filter(*component)
			if subres is True:
				continue
			elif subres is False:
				return False
			elif subres is None:
				return None
			else:
				results.append(subres)
		return cls.ldap_filter_and(*results)

	@classmethod
	def _ldap_filter_or(cls, *components):
		results = []
		for component in components:
			subres = cls._ldap_filter(*component)
			if subres is True:
				return True
			elif subres is False:
				continue
			elif subres is None:
				return None
			else:
				results.append(subres)
		return cls.ldap_filter_or(*results)

	@classmethod
	def _ldap_filter_not(cls, subfilter):
		subres = cls._ldap_filter(*subfilter)
		if subres is True:
			return False
		elif subres is False:
			return True
		elif subres is None:
			return None
		else:
			return cls.ldap_filter_not(subres)

	@classmethod
	def _ldap_filter_present(cls, name):
		return cls.ldap_filter_present(name.decode().lower())

	@classmethod
	def _ldap_filter_equal(cls, name, value):
		return cls.ldap_filter_equal(name.decode().lower(), bytes(value))

	@classmethod
	def ldap_filter_present(cls, name):
		# return True if always true, False if always false, None if undefined, custom object else
		return None

	@classmethod
	def ldap_filter_equal(cls, name, value):
		return None

	@classmethod
	def ldap_filter_and(cls, *results):
		return None

	@classmethod
	def ldap_filter_or(cls, *results):
		return None
		
	@classmethod
	def ldap_filter_not(cls, result):
		return None

	@classmethod
	def ldap_search(cls, dn_base, filter_res):
		return []

def split_dn(dn):
	if not dn:
		return []
	return dn.split(',')

class StaticLDAPObject(LDAPObject):
	present_map = {} # name -> set of objs
	value_map = {} # (name, value) -> set of objs
	all_objects = set() # set of all objs
	dn_map = {tuple(): set()} # (dn part tuples) -> set of objs

	def __init__(self, dn, **kwargs):
		self.ldap_dn = dn
		cls = type(self)
		cls.dn_map[tuple()].add(self)
		dn_path = []
		for part in reversed(split_dn(dn)):
			dn_path.insert(0, part)
			key = tuple(dn_path)
			cls.dn_map[key] = cls.dn_map.get(key, set())
			cls.dn_map[key].add(self)
		cls.all_objects.add(self)
		self.ldap_attributes = {}
		for name, _values in kwargs.items():
			name = name.lower()
			if not isinstance(_values, list):
				_values = [_values]
			values = []
			for value in _values:
				if value is None or value == '' or value == b'':
					continue
				if isinstance(value, int):
					value = str(value)
				if isinstance(value, str):
					value = value.encode()
				values.append(value)
			if not values:
				continue
			self.ldap_attributes[name] = values
			cls.present_map[name] = cls.present_map.get(name, set())
			cls.present_map[name].add(self)
			for value in values:
				key = (name, value)
				cls.value_map[key] = cls.value_map.get(key, set())
				cls.value_map[key].add(self)

	@classmethod
	def ldap_filter_present(cls, name):
		return cls.present_map.get(name, set())

	@classmethod
	def ldap_filter_equal(cls, name, value):
		return cls.value_map.get((name, value), set())

	@classmethod
	def ldap_filter_and(cls, *results):
		objs = results[0]
		for subres in results[1:]:
			objs = objs.intersection(subres)
		return objs

	@classmethod
	def ldap_filter_or(cls, *results):
		objs = results[0]
		for subres in results[1:]:
			objs = objs.union(subres)
		return objs
		
	@classmethod
	def ldap_filter_not(cls, result):
		return cls.all_objects.difference(result)

	@classmethod
	def ldap_search(cls, dn_base, filter_res):
		key = tuple(split_dn(dn_base))
		objs = cls.dn_map.get(key, set())
		if filter_res is None or filter_res is False:
			return []
		if filter_res is not True:
			objs = objs.intersection(filter_res)
		return objs

StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser')
StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin')

class LDAPHandler:
	def __init__(self, conn, models=[]):
		self.models = models
		self.conn = conn
		self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn
		self.buf = b''

	def handle_bind(self, msgid, op):
		self.user = op['name']
		resp = BindResponse()
		resp['resultCode'] = ResultCode('success')
		resp['matchedDN'] = op['name']
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'bindResponse', resp)

	def handle_search(self, msgid, op):
		objs = []
		for model in self.models:
			filter_obj = model._ldap_filter(*op['filter'])
			objs += model.ldap_search(op['baseObject'].decode(), filter_obj)
		for obj in objs:
			resp = SearchResultEntry()
			resp['object'] = obj.ldap_dn
			for name, values in obj.ldap_attributes.items():
				attr = PartialAttribute()
				attr['type'] = name
				for value in values:
					attr['vals'].append(value)
				resp['attributes'].append(attr)
			self.send_response(msgid, 'searchResEntry', resp)
		resp = SearchResultDone()
		resp['resultCode'] = ResultCode('success')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'searchResDone', resp)

	def handle_modify(self, msgid, op):
		resp = ModifyResponse()
		resp['resultCode'] = ResultCode('other')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'modifyResponse', resp)

	def handle_add(self, msgid, op):
		resp = AddResponse()
		resp['resultCode'] = ResultCode('other')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'addResponse', resp)

	def handle_del(self, msgid, op):
		resp = DelResponse()
		resp['resultCode'] = ResultCode('other')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'delResponse', resp)

	def handle_moddn(self, msgid, op):
		resp = ModifyDNResponse()
		resp['resultCode'] = ResultCode('other')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'modDNResponse', resp)

	def handle_compare(self, msgid, op):
		resp = CompareResponse()
		resp['resultCode'] = ResultCode('other')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'compareResponse', resp)

	def handle_extended(self, msgid, op):
		resp = ExtendedResponse()
		resp['resultCode'] = ResultCode('protocolError')
		resp['matchedDN'] = ''
		resp['diagnosticMessage'] = ''
		self.send_response(msgid, 'extendedResp', resp)

	def recv_request(self):
		while True:
			size = compute_ldap_message_size(self.buf)
			if size != -1 and len(self.buf) >= size:
				break
			chunk = self.conn.recv(5)
			if not chunk:
				return None
			self.buf += chunk
		print('received ', self.buf)
		req = decode_message_fast(self.buf)
		print(req)
		self.buf = self.buf[size:]
		return req

	def send_response(self, message_id, response_type, response):
		msg = LDAPMessage()
		msg['messageID'] = message_id
		msg['protocolOp'] = ProtocolOp().setComponentByName(response_type, response)
		print('sent', msg)
		self.conn.sendall(encode(msg))

	def run(self):
		while True:
			msg = self.recv_request()
			if msg is None:
				break
			if msg['protocolOp'] == 0: # bindRequest
				self.handle_bind(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 2: # unbindRequest
				break
			elif msg['protocolOp'] == 3: # searchRequest
				self.handle_search(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 6: # modifyRequest
				self.handle_modify(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 8: # addRequest
				self.handle_add(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 10: # delRequest
				self.handle_del(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 12: # modDNRequest
				self.handle_moddn(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 14: # compareRequest
				self.handle_compare(msg['messageID'], msg['payload'])
			elif msg['protocolOp'] == 16: # abandonRequest
				pass
			elif msg['protocolOp'] == 23: # extendedReq
				self.handle_extended(msg['messageID'], msg['payload'])
			else:
				raise Exception()
		self.conn.close()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.bind(('127.0.0.1', 1337))
sock.listen(1)
while True:
	conn, addr = sock.accept()
	print('accepted connection from', addr)
	LDAPHandler(conn, models=[StaticLDAPObject]).run()
	print('connection closed')