import sys
import json
import socketserver
import requests

from ldapserver import SimpleLDAPRequestHandler
from ldapserver.dn import DN, RDN
from ldapserver.ldap import FilterEqual, FilterAnd
from ldapserver.directory import BaseDirectory, SimpleFilterMixin, StaticDirectory, eval_ldap_filter
from ldapserver.schema.rfc2307bis import rfc2307bis_subschema
from ldapserver.util import encode_attribute, CaseInsensitiveDict
from ldapserver.exceptions import LDAPInvalidCredentials

class UffdAPI:
	def __init__(self, baseurl, key):
		self.baseurl = baseurl
		self.key = key

	def get(self, endpoint, **kwargs):
		resp = requests.get(self.baseurl + '/' + endpoint, params=kwargs,
		                    headers={'Authorization': 'Bearer '+self.key})
		assert(resp.ok)
		return resp.json()

	def post(self, endpoint, **kwargs):
		resp = requests.post(self.baseurl + '/' + endpoint, data=kwargs,
		                     headers={'Authorization': 'Bearer '+self.key})
		assert(resp.ok)
		return resp.json()

class UserDirectory(SimpleFilterMixin, BaseDirectory):
	def __init__(self, api, dn_base):
		self.api = api
		self.rdn_attr = 'uid'
		self.dn_base = DN('ou=users') + DN(dn_base)
		self.group_dn_base = DN('ou=groups') + DN(dn_base)
		self.structuralobjectclass = b'inetorgperson'
		self.objectclasses = [b'top', b'inetorgperson', b'organizationalperson', b'person', b'posixaccount']

	def generate_result(self, user):
		attributes = CaseInsensitiveDict(
			structuralObjectClass=[self.structuralobjectclass],
			objectClass=self.objectclasses,
			subschemaSubentry=[b'cn=Subschema'],
			cn=[encode_attribute(user['displayname'])],
			displayname=[encode_attribute(user['displayname'])],
			givenname=[encode_attribute(user['displayname'])],
			homeDirectory=[encode_attribute('/home/'+user['loginname'])],
			mail=[encode_attribute(user['email'])],
			sn=[encode_attribute(' ')],
			uid=[encode_attribute(user['loginname'])],
			uidNumber=[encode_attribute(user['id'])],
			memberOf=[encode_attribute(DN(RDN(cn=group)) + self.group_dn_base) for group in user['groups']],
		)
		dn = str(DN(RDN(uid=user['loginname'])) + self.dn_base)
		return dn, attributes

	def get_best_api_param(self, expr):
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uid':
			return 'loginname', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uidnumber':
			return 'id', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'mail':
			return 'email', expr.value
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'memberof':
			group_dn = DN(expr.value.decode())
			if group_dn.is_direct_child_of(self.group_dn_base) and len(group_dn[0]) == 1 and group_dn[0][0].attribute == 'cn':
				return 'group', group_dn[0][0].value
		if isinstance(expr, FilterAnd):
			params = dict([self.get_best_api_param(subexpr) for subexpr in expr.filters])
			for key in ['loginname', 'id', 'email', 'group']:
				if key in params:
					return key, params[key]
		return None, None

	def search_fetch(self, expr):
		if expr is False:
			return
		kwargs = {}
		key, value = self.get_best_api_param(expr)
		if key is not None:
			kwargs[key] = value
		for user in self.api.get('getusers', **kwargs):
			dn, obj = self.generate_result(user)
			if eval_ldap_filter(obj, expr):
				yield dn, obj

	def filter_equal(self, attribute, value):
		if attribute == 'memberof':
			value = str(DN(value.decode())).encode()
		return super().filter_equal(attribute, value)

class GroupDirectory(SimpleFilterMixin, BaseDirectory):
	def __init__(self, api, dn_base):
		self.api = api
		self.rdn_attr = 'cn'
		self.dn_base = DN('ou=groups') + DN(dn_base)
		self.user_dn_base = DN('ou=users') + DN(dn_base)
		self.structuralobjectclass = b'groupOfUniqueNames'
		self.objectclasses = [b'top', b'groupOfUniqueNames', b'posixGroup']

	def generate_result(self, group):
		attributes = CaseInsensitiveDict(
			structuralObjectClass=[self.structuralobjectclass],
			objectClass=self.objectclasses,
			subschemaSubentry=[b'cn=Subschema'],
			cn=[encode_attribute(group['name'])],
			description=[encode_attribute(' ')],
			gidNumber=[encode_attribute(group['id'])],
			uniqueMember=[encode_attribute(DN(RDN(uid=user)) + self.user_dn_base) for user in group['members']],
		)
		dn = str(DN(RDN(cn=group['name'])) + self.dn_base)
		return dn, attributes

	def get_best_api_param(self, expr):
		if isinstance(expr, FilterEqual) and expr.attribute.lower() == 'cn':
			return 'name', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'gidnumber':
			return 'id', expr.value
		elif isinstance(expr, FilterEqual) and expr.attribute.lower() == 'uniquemember':
			user_dn = DN(expr.value.decode())
			if user_dn.is_direct_child_of(self.user_dn_base) and len(user_dn[0]) == 1 and user_dn[0][0].attribute == 'uid':
				return 'member', user_dn[0][0].value
		if isinstance(expr, FilterAnd):
			params = dict([self.get_best_api_param(subexpr) for subexpr in expr.filters])
			for key in ['name', 'id', 'member']:
				if key in params:
					return key, params[key]
		return None, None

	def search_fetch(self, expr):
		if expr is False:
			return
		kwargs = {}
		key, value = self.get_best_api_param(expr)
		if key is not None:
			kwargs[key] = value
		for group in self.api.get('getgroups', **kwargs):
			dn, obj = self.generate_result(group)
			if eval_ldap_filter(obj, expr):
				yield dn, obj

	def filter_equal(self, attribute, value):
		if attribute == 'uniquemember':
			value = str(DN(value.decode())).encode()
		return super().filter_equal(attribute, value)

class RequestHandler(SimpleLDAPRequestHandler):
	# Overwritten before use
	api = None
	dn_base = None
	static_directory = None
	user_directory = None
	group_directory = None
	bind_dn = None
	bind_password = None

	def setup(self):
		super().setup()
		self.rootdse['subschemaSubentry'] = [b'cn=Subschema']
		print('CONNECT')

	def do_bind_simple_authenticated(self, dn, password):
		print('BIND plain', dn)
		dn = DN(dn)
		if dn == self.bind_dn and password == self.bind_password:
			return True
		if not dn.is_direct_child_of(DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid':
			raise LDAPInvalidCredentials()
		if self.api.post('checkpassword', loginname=dn[0][0].value, password=password):
			return True
		raise LDAPInvalidCredentials()

	supports_sasl_plain = True

	def do_bind_sasl_plain(self, identity, password, authzid=None):
		print('BIND sasl', dn)
		if authzid is not None and identity != authzid:
			raise LDAPInvalidCredentials()
		user = self.api.post('checkpassword', loginname=identity, password=password)
		if user is None:
			raise LDAPInvalidCredentials()
		return user

	def do_search(self, baseobj, scope, filter):
		print('SEARCH %s "%s" %s'%(scope.name, baseobj, filter.get_filter_string()))
		yield from self.rootdse.search(baseobj, scope, filter)
		yield from rfc2307bis_subschema.search(baseobj, scope, filter)
		if self.bind_object:
			yield from self.static_directory.search(baseobj, scope, filter)
			yield from self.user_directory.search(baseobj, scope, filter)
			yield from self.group_directory.search(baseobj, scope, filter)

def main(config):
	dn_base = DN(config['dn_base'])
	api = UffdAPI(config['api_baseurl'], config['api_key'])
	user_directory = UserDirectory(api, dn_base)
	group_directory = GroupDirectory(api, dn_base)

	static_directory = StaticDirectory()
	base_attrs = {
		'objectClass': ['top', 'dcObject', 'organization'],
		'structuralObjectClass': ['organization'],
		'subschemaSubentry': ['cn=Subschema'],
	}
	for attr, value in dn_base[0]:
		base_attrs[attr] = [value]
	static_directory.add(dn_base, base_attrs)
	static_directory.add(DN('ou=users') + dn_base, {
		'ou': ['users'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
		'subschemaSubentry': ['cn=Subschema'],
	})
	static_directory.add(DN('ou=groups') + dn_base, {
		'ou': ['groups'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
		'subschemaSubentry': ['cn=Subschema'],
	})
	static_directory.add(DN('ou=system') + dn_base, {
		'ou': ['system'],
		'objectClass': ['top', 'organizationalUnit'],
		'structuralObjectClass': ['organizationalUnit'],
		'subschemaSubentry': ['cn=Subschema'],
	})
	static_directory.add(DN('cn=service,ou=system') + dn_base, {
		'cn': ['service'],
		'objectClass': ['top', 'organizationalRole', 'simpleSecurityObject'],
		'structuralObjectClass': ['organizationalRole'],
		'subschemaSubentry': ['cn=Subschema'],
	})

	class CustomRequestHandler(RequestHandler):
		pass

	CustomRequestHandler.api = api
	CustomRequestHandler.dn_base = dn_base
	CustomRequestHandler.bind_dn = DN('cn=service,ou=system') + dn_base
	CustomRequestHandler.bind_password = config['bind_password'].encode()
	CustomRequestHandler.static_directory = static_directory
	CustomRequestHandler.user_directory = user_directory
	CustomRequestHandler.group_directory = group_directory

	if config['listen_addr'].startswith('unix:'):
		class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer):
			pass
		ForkingUnixStreamServer(config['listen_addr'][5:], RequestHandler).serve_forever()
	else:
		addr = config['listen_addr']
		port = '389'
		if addr.startswith('['):
			addr, remainder = addr[1:].split(']', 1)
			if remainder.startswith(':'):
				port = remainder[1:]
		elif ':' in addr:
			addr, port = addr.split(':')
		socketserver.ForkingTCPServer((addr, int(port)), CustomRequestHandler).serve_forever()

if __name__ == '__main__':
	if len(sys.argv) != 2:
		print('usage: server.py CONFIG_PATH')
		exit(1)
	with open(sys.argv[1], 'r') as f:
		config = json.load(f)
	main(config)