import traceback
from socketserver import ForkingTCPServer, BaseRequestHandler

from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse

class Handler(BaseRequestHandler):
	ldap_server = None

	def setup(self):
		self.bind_dn = b''
		self.keep_running = True

	def handle_bind(self, req):
		if not isinstance(req.protocolOp.authentication, SimpleAuthentication):
			self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.authMethodNotSupported)))
		name = req.protocolOp.name
		password = req.protocolOp.authentication.password
		for func in self.ldap_server.bind_handlers:
			if func(name, password, self):
				self.bind_dn = name
				self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
				return
		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials)))

	def handle_search(self, req):
		search = req.protocolOp
		entries = []
		for func in self.ldap_server.search_handlers:
			entries += func(search.baseObject, search.scope, search.filter, self)
		for dn, attributes in entries:
			attributes = [PartialAttribute(name, values) for name, values in attributes.items()]
			self.send_msg(LDAPMessage(req.messageID, SearchResultEntry(dn, attributes)))
		self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success)))

	def handle_unbind(self, req):
		self.keep_running = False

	def handle_message(self, data):
		handlers = {
			BindRequest: (self.handle_bind, BindResponse),
			UnbindRequest: (self.handle_unbind, None),
			SearchRequest: (self.handle_search, SearchResultDone),
			ModifyRequest: (None, ModifyResponse),
			AddRequest: (None,  AddResponse),
			DelRequest: (None, DelResponse),
			ModifyDNRequest: (None, ModifyDNResponse),
			CompareRequest: (None, CompareResponse),
			AbandonRequest: (None, None),
			ExtendedRequest: (None, ExtendedResponse), # TODO
		}
		shallowmsg, rest = ShallowLDAPMessage.from_ber(data)
		if shallowmsg.protocolOp is None:
			print('Ignoring unknown message')
			return rest
		func, errfunc = handlers[shallowmsg.protocolOp]
		msg = None
		try:
			msg, _ = LDAPMessage.from_ber(data)
		except Exception as e:
			if errfunc:
				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError)))
			traceback.print_exc()
			return rest
		print('received', msg)
		try:
			if func:
				func(msg)
			elif errfunc:
				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights)))
		except Exception as e:
			if errfunc:
				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))
			traceback.print_exc()
			return rest
		return rest

	def send_msg(self, msg):
		print('sending', msg)
		self.request.sendall(LDAPMessage.to_ber(msg))

	def handle(self):
		data = b''
		while self.keep_running:
			try:
				data = self.handle_message(data)
			except IncompleteBERError:
				chunk = self.request.recv(5)
				if not chunk:
					return
				data += chunk

class Server:
	def __init__(self):
		self.bind_handlers = []
		self.search_handlers = []

	def bind_handler(self, func):
		self.bind_handlers.append(func)

	def search_handler(self, func):
		self.search_handlers.append(func)

	def run(self, host='127.0.0.1', port=1337):
		class BoundHandler(Handler):
			ldap_server = self
		ForkingTCPServer((host, port), BoundHandler).serve_forever()