import sys
import json
import socketserver
import requests
from cachecontrol import CacheControl
from cachecontrol.heuristics import ExpiresAfter

from ldapserver import SimpleLDAPRequestHandler
from ldapserver.dn import DN
from ldapserver.ldap import FilterEqual, FilterAnd
from ldapserver.directory import BaseDirectory, SimpleFilterMixin, StaticDirectory, eval_ldap_filter
from ldapserver.util import encode_attribute, CaseInsensitiveDict
from ldapserver.exceptions import LDAPInvalidCredentials
from ldapserver.schema import RFC2307BIS_SUBSCHEMA

class UffdAPI:
	def __init__(self, baseurl, key, cache_ttl=60):
		self.baseurl = baseurl
		self.key = key
		self.session = CacheControl(requests.Session(), heuristic=ExpiresAfter(seconds=cache_ttl))
		self.session.headers['Authorization'] = 'Bearer '+self.key

	def get(self, endpoint, **kwargs):
		resp = self.session.get(self.baseurl + '/' + endpoint, params=kwargs)
		assert(resp.ok)
		return resp.json()

	def post(self, endpoint, **kwargs):
		resp = self.session.post(self.baseurl + '/' + endpoint, data=kwargs)
		assert(resp.ok)
		return resp.json()

class UserDirectory(SimpleFilterMixin, BaseDirectory):
	def __init__(self, api, dn_base):
		self.api = api
		self.rdn_attr = 'uid'
		self.dn_base = DN('ou=users') + DN(dn_base)
		self.group_dn_base = DN('ou=groups') + DN(dn_base)
		self.structuralobjectclass = b'inetorgperson'
		self.objectclasses = [b'top', b'inetorgperson', b'organizationalperson', b'person', b'posixaccount']
		self.attributes = ['structuralobjectclass', 'objectclass', 'cn', 'displayname', 'givenname', 'homedirectory', 'mail', 'sn', 'uid', 'uidnumber', 'memberof']

	def generate_result(self, user):
		attributes = CaseInsensitiveDict(
			structuralObjectClass=[self.structuralobjectclass],
			objectClass=self.objectclasses,
			cn=[encode_attribute(user['displayname'])],
			displayname=[encode_attribute(user['displayname'])],
			givenname=[encode_attribute(user['displayname'])],
			homeDirectory=[encode_attribute('/home/'+user['loginname'])],
			mail=[encode_attribute(user['email'])],
			sn=[encode_attribute(' ')],
			uid=[encode_attribute(user['loginname'])],
			uidNumber=[encode_attribute(user['id'])],
			memberOf=[encode_attribute(DN(cn=group) + self.group_dn_base) for group in user['groups']],
		)
		dn = str(DN(uid=user['loginname']) + self.dn_base)
		return dn, attributes

	def get_best_api_param(self, expr):
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uid':
			return 'loginname', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uidnumber':
			return 'id', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'mail':
			return 'email', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'memberof':
			group_dn = DN.from_str(expr.value.decode())
			if group_dn.is_direct_child_of(self.group_dn_base) and len(group_dn[0]) == 1 and group_dn[0][0].attribute == 'cn':
				return 'group', group_dn[0][0].value
		if isinstance(expr, FilterAnd):
			params = dict([self.get_best_api_param(subexpr) for subexpr in expr.filters])
			for key in ['loginname', 'id', 'email', 'group']:
				if key in params:
					return key, params[key]
		return None, None

	def search_fetch(self, expr):
		if expr is False:
			return
		kwargs = {}
		key, value = self.get_best_api_param(expr)
		if key is not None:
			kwargs[key] = value
		for user in self.api.get('getusers', **kwargs):
			dn, obj = self.generate_result(user)
			if eval_ldap_filter(obj, expr):
				yield dn, obj

	def filter_equal(self, attribute, value):
		if attribute == 'memberof':
			value = str(DN.from_str(value.decode())).encode()
		return super().filter_equal(attribute, value)

	def filter_present(self, attribute):
		if attribute not in self.attributes:
			return False
		return super().filter_present(attribute)

class GroupDirectory(SimpleFilterMixin, BaseDirectory):
	def __init__(self, api, dn_base):
		self.api = api
		self.rdn_attr = 'cn'
		self.dn_base = DN('ou=groups') + DN(dn_base)
		self.user_dn_base = DN('ou=users') + DN(dn_base)
		self.structuralobjectclass = b'groupOfUniqueNames'
		self.objectclasses = [b'top', b'groupOfUniqueNames', b'posixGroup']
		self.attributes = ['structuralobjectclass', 'objectclass', 'cn', 'description', 'gidnumber', 'uniquemember']

	def generate_result(self, group):
		attributes = CaseInsensitiveDict(
			structuralObjectClass=[self.structuralobjectclass],
			objectClass=self.objectclasses,
			cn=[encode_attribute(group['name'])],
			description=[encode_attribute(' ')],
			gidNumber=[encode_attribute(group['id'])],
			uniqueMember=[encode_attribute(DN(uid=user) + self.user_dn_base) for user in group['members']],
		)
		dn = str(DN(cn=group['name']) + self.dn_base)
		return dn, attributes

	def get_best_api_param(self, expr):
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'cn':
			return 'name', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'gidnumber':
			return 'id', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uniquemember':
			user_dn = DN.from_str(expr.value.decode())
			if user_dn.is_direct_child_of(self.user_dn_base) and len(user_dn[0]) == 1 and user_dn[0][0].attribute == 'uid':
				return 'member', user_dn[0][0].value
		if isinstance(expr, FilterAnd):
			params = dict([self.get_best_api_param(subexpr) for subexpr in expr.filters])
			for key in ['name', 'id', 'member']:
				if key in params:
					return key, params[key]
		return None, None

	def search_fetch(self, expr):
		if expr is False:
			return
		kwargs = {}
		key, value = self.get_best_api_param(expr)
		if key is not None:
			kwargs[key] = value
		for group in self.api.get('getgroups', **kwargs):
			dn, obj = self.generate_result(group)
			if eval_ldap_filter(obj, expr):
				yield dn, obj

	def filter_equal(self, attribute, value):
		if attribute == 'uniquemember':
			value = str(DN.from_str(value.decode())).encode()
		return super().filter_equal(attribute, value)

	def filter_present(self, attribute):
		if attribute not in self.attributes:
			return False
		return super().filter_present(attribute)

class MailDirectory(SimpleFilterMixin, BaseDirectory):
	def __init__(self, api, dn_base):
		self.api = api
		self.rdn_attr = 'uid'
		self.dn_base = DN('ou=postfix') + DN(dn_base)
		self.structuralobjectclass = b'postfixVirtual'
		self.objectclasses = [b'top', b'postfixVirtual']
		self.attributes = ['structuralobjectclass', 'objectclass', 'uid', 'mailacceptinggeneralid', 'maildrop']

	def generate_result(self, mail):
		attributes = CaseInsensitiveDict(
			structuralObjectClass=[self.structuralobjectclass],
			objectClass=self.objectclasses,
			uid=[encode_attribute(mail['name'])],
			mailacceptinggeneralid=[encode_attribute(address) for address in mail['receive_addresses']],
			maildrop=[encode_attribute(address) for address in mail['destination_addresses']],
		)
		dn = str(DN(uid=mail['name']) + self.dn_base)
		return dn, attributes

	def get_best_api_param(self, expr):
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uid':
			return 'name', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'mailacceptinggeneralid':
			return 'receive_address', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'maildrop':
			return 'destination_address', expr.value
		if isinstance(expr, FilterAnd):
			params = dict([self.get_best_api_param(subexpr) for subexpr in expr.filters])
			for key in ['name', 'receive_address', 'destination_address']:
				if key in params:
					return key, params[key]
		return None, None

	def search_fetch(self, expr):
		if expr is False:
			return
		kwargs = {}
		key, value = self.get_best_api_param(expr)
		if key is not None:
			kwargs[key] = value
		for mail in self.api.get('getmails', **kwargs):
			dn, obj = self.generate_result(mail)
			if eval_ldap_filter(obj, expr):
				yield dn, obj

	def filter_present(self, attribute):
		if attribute not in self.attributes:
			return False
		return super().filter_present(attribute)

class RequestHandler(SimpleLDAPRequestHandler):
	subschema = RFC2307BIS_SUBSCHEMA

	# Overwritten before use
	api = None
	dn_base = None
	static_directory = None
	user_directory = None
	group_directory = None
	mail_directory = None
	bind_dn = None
	bind_password = None

	def setup(self):
		super().setup()

	def handle(self):
		print('CONNECT')
		super().handle()
		print('DISCONNECT')

	def handle_message(self, shallowmsg):
		print('MSG', shallowmsg.data)
		return super().handle_message(shallowmsg)

	def do_bind_simple_authenticated(self, dn, password):
		print('BIND plain', dn)
		dn = DN.from_str(dn)
		if dn == self.bind_dn and password == self.bind_password:
			return True
		if not dn.is_direct_child_of(DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid':
			raise LDAPInvalidCredentials()
		if self.api.post('checkpassword', loginname=dn[0][0].value, password=password):
			return True
		raise LDAPInvalidCredentials()

	supports_sasl_plain = True

	def do_bind_sasl_plain(self, identity, password, authzid=None):
		print('BIND sasl', identity, authzid)
		if authzid is not None and identity != authzid:
			raise LDAPInvalidCredentials()
		user = self.api.post('checkpassword', loginname=identity, password=password)
		if user is None:
			raise LDAPInvalidCredentials()
		return user

	def do_search(self, baseobj, scope, filter):
		print('SEARCH %s "%s" %s'%(scope.name, baseobj, filter.get_filter_string()))
		yield from super().do_search(baseobj, scope, filter)
		if self.bind_object:
			yield from self.static_directory.search(baseobj, scope, filter)
			yield from self.user_directory.search(baseobj, scope, filter)
			yield from self.group_directory.search(baseobj, scope, filter)
			if self.mail_directory is not None:
				yield from self.mail_directory.search(baseobj, scope, filter)

def main(config):
	dn_base = DN.from_str(config['dn_base'])
	api = UffdAPI(config['api_baseurl'], config['api_key'], config.get('cache_ttl', 60))
	user_directory = UserDirectory(api, dn_base)
	group_directory = GroupDirectory(api, dn_base)
	mail_directory = MailDirectory(api, dn_base)

	static_directory = StaticDirectory()
	base_attrs = {
		'objectClass': ['top', 'dcObject', 'organization'],
		'structuralObjectClass': ['organization'],
	}
	for rdnassertion in dn_base[0]:
		base_attrs[rdnassertion.attribute] = [rdnassertion.value]
	static_directory.add(dn_base, base_attrs)
	static_directory.add(DN('ou=users') + dn_base, {
		'ou': ['users'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
	})
	static_directory.add(DN('ou=groups') + dn_base, {
		'ou': ['groups'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
	})
	if config.get('enable_mail'):
		static_directory.add(DN('ou=postfix') + dn_base, {
			'ou': ['postfix'],
			'objectClass': ['top', 'organizationalUnit'],
			'structuralObjectClass': ['organizationalUnit'],
		})
	static_directory.add(DN('ou=system') + dn_base, {
		'ou': ['system'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
	})
	static_directory.add(DN('cn=service,ou=system') + dn_base, {
		'cn': ['service'],
		'objectClass': ['top', 'organizationalRole', 'simpleSecurityObject'],
		'structuralObjectClass': ['organizationalRole'],
	})

	class CustomRequestHandler(RequestHandler):
		pass

	CustomRequestHandler.api = api
	CustomRequestHandler.dn_base = dn_base
	CustomRequestHandler.bind_dn = DN('cn=service,ou=system') + dn_base
	CustomRequestHandler.bind_password = config['bind_password'].encode()
	CustomRequestHandler.static_directory = static_directory
	CustomRequestHandler.user_directory = user_directory
	CustomRequestHandler.group_directory = group_directory
	if config.get('enable_mail'):
		CustomRequestHandler.mail_directory = mail_directory

	if config['listen_addr'].startswith('unix:'):
		socketserver.ThreadingUnixStreamServer(config['listen_addr'][5:], CustomRequestHandler).serve_forever()
	else:
		addr = config['listen_addr']
		port = '389'
		if addr.startswith('['):
			addr, remainder = addr[1:].split(']', 1)
			if remainder.startswith(':'):
				port = remainder[1:]
		elif ':' in addr:
			addr, port = addr.split(':')
		socketserver.ThreadingTCPServer((addr, int(port)), CustomRequestHandler).serve_forever()

if __name__ == '__main__':
	if len(sys.argv) != 2:
		print('usage: server.py CONFIG_PATH')
		exit(1)
	with open(sys.argv[1], 'r') as f:
		config = json.load(f)
	main(config)