from crypt import crypt

from sqlalchemy import create_engine, or_, and_, Column, Integer, String
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent
from server import Server as LDAPServer
from dn import parse_dn, build_dn

Base = declarative_base()

class BaseSearchEvaluator:
	def __call__(self, base, scope, filter_expr):
		dn_res = self.filter_dn(base, scope)
		filter_res = self.filter_expr(filter_expr)
		return self.query(self.filter_and(dn_res, filter_res))

	def filter_expr(self, expr):
		if isinstance(expr, FilterAnd):
			return self.filter_and(*[self.filter_expr(subexpr) for subexpr in expr.filters])
		elif isinstance(expr, FilterOr):
			return self.filter_or(*[self.filter_expr(subexpr) for subexpr in expr.filters])
		elif isinstance(expr, FilterNot):
			return self.filter_not(self.filter_expr(expr.filter))
		elif isinstance(expr, FilterEqual):
			return self.filter_equal(expr.attribute.lower(), expr.value)
		elif isinstance(expr, FilterPresent):
			return self.filter_present(expr.attribute.lower())
		else:
			return False

	def filter_present(self, name):
		return False

	def filter_equal(self, name, value):
		return False

	def filter_and(self, *subresults):
		filtered = []
		for subres in subresults:
			if subres is True:
				continue
			if subres is False:
				return False
			filtered.append(subres)
		if not filtered:
			return True
		return self._filter_and(*filtered)

	def _filter_and(self, *subresults):
		return False

	def filter_or(self, *subresults):
		filtered = []
		for subres in subresults:
			if subres is True:
				return True
			if subres is False:
				continue
			filtered.append(subres)
		if not filtered:
			return False
		return self._filter_or(*filtered)

	def _filter_or(self, *subresults):
		return False

	def filter_not(self, subresult):
		if subresult is True:
			return False
		if subresult is False:
			return True
		return self._filter_not(subresult)

	def _filter_not(self, subresult):
		return False

	def filter_dn(self, base, scope):
		return False

	def query(self, filter_obj):
		return []

class SQLSearchEvaluator(BaseSearchEvaluator):
	def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
		self.model = model
		self.session = session
		self.attributes = {}
		for ldap_name, attr_name in (attributes or {}).items():
			self.attributes[ldap_name.lower()] = attr_name
		self.objectclasses = []
		for value in (objectclasses or []):
			value = value.lower()
			if isinstance(value, str):
				value = value.encode()
			self.objectclasses.append(value)
		self.rdn_attr = rdn_attr
		self.dn_base_path = parse_dn(dn_base)

	def filter_present(self, name):
		if name == 'objectclass':
			return True
		if name not in self.attributes:
			return False
		return getattr(self.model, self.attributes[name]).isnot(None)

	def filter_equal(self, name, value):
		if name == 'objectclass':
			return value in self.objectclasses
		if name not in self.attributes:
			return False
		attr = getattr(self.model, self.attributes[name])
		if hasattr(attr, 'type') and isinstance(attr.type, String):
			value = value.decode()
		elif hasattr(attr, 'type') and isinstance(attr.type, Integer):
			value = int(value)
		elif isinstance(attr, str):
			value = value.decode()
		elif isinstance(attr, int):
			value = int(value)
		return attr == value

	def _filter_and(self, *subresults):
		return and_(*subresults)

	def _filter_or(self, *subresults):
		return or_(*subresults)

	def _filter_not(self, subresult):
		return ~subresult

	def filter_dn(self, base, scope):
		search_path = list(parse_dn(base))
		base_path = list(self.dn_base_path)
		while search_path and base_path:
			if search_path.pop() != base_path.pop():
				return False
		if scope == SearchScope.baseObject:
			if base_path or len(search_path) != 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr:
				return False
			return self.filter_equal(self.rdn_attr, search_path[0][0][1])
		elif scope == SearchScope.singleLevel:
			return not search_path and not base_path
		elif scope == SearchScope.wholeSubtree:
			if not search_path:
				return True
			if len(search_path) > 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr:
				return False
			return self.filter_equal(self.rdn_attr, search_path[0][0][1])
		else:
			return False

	def query(self, filter_obj):
		if filter_obj is False:
			return []
		elif filter_obj is True:
			objs = self.session.query(self.model)
		else:
			objs = self.session.query(self.model).filter(filter_obj)
		results = []
		for obj in objs:
			attrs = {}
			for ldap_name, attr_name in self.attributes.items():
				value = getattr(obj, attr_name)
				if value is None:
					continue
				if isinstance(value, int):
					value = str(value)
				if isinstance(value, str):
					value = value.encode()
				attrs[ldap_name] = [value]
			attrs['objectClass'] = self.objectclasses
			dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path
			results.append((build_dn(dn_parts), attrs))
		return results

engine = create_engine('sqlite:///db.sqlite', echo=True)
Session = sessionmaker(bind=engine)
session = Session()

class LDAPViewMixin:
	ldap_attributes = {}
	ldap_objectclasses = [b'top']
	ldap_rdn_attribute = 'uid'
	ldap_dn_base = ''

	@classmethod
	def ldap_search(cls, base, scope, filter_expr, conn):
		evaluator = SQLSearchEvaluator(cls, session, attributes=cls.ldap_attributes,
			objectclasses=cls.ldap_objectclasses, rdn_attr=cls.ldap_rdn_attribute,
			dn_base=cls.ldap_dn_base)
		return evaluator(base, scope, filter_expr)

class User(Base, LDAPViewMixin):
	__tablename__ = 'users'
	ldap_attributes = {
		'cn': 'displayname',
		'displayname': 'displayname',
		'gidnumber': 'ldap_gid',
		'givenname': 'displayname',
		'homedirectory': 'homedirectory',
		'mail': 'email',
		'sn': 'ldap_sn',
		'uid': 'loginname',
		'uidnumber': 'id',
	}
	ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount']
	ldap_rdn_attribute = 'uid'
	ldap_dn_base = 'ou=users,dc=example,dc=com'

	id = Column(Integer, primary_key=True)
	loginname = Column(String, unique=True, nullable=False)
	displayname = Column(String, nullable=False, default='')
	email = Column(String)
	pwhash = Column(String)

	ldap_gid = 1
	ldap_sn = ' '

	@hybrid_property
	def homedirectory(self):
		return '/home/' + self.loginname

	# Write-only property
	def password(self, value):
		self.pwhash = crypt(value)
	password = property(fset=password)

	def check_password(self, password):
		return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash

class Group(Base, LDAPViewMixin):
	__tablename__ = 'groups'
	ldap_attributes = {
		'cn': 'name',
		'description': 'description',
		'gidnumber': 'id',
	}
	ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames']
	ldap_rdn_attribute = 'cn'
	ldap_dn_base = 'ou=groups,dc=example,dc=com'

	id = Column(Integer, primary_key=True)
	name = Column(String, unique=True, nullable=False)
	description = Column(String, nullable=False, default='')

Base.metadata.create_all(engine)

ldap_server = LDAPServer()
ldap_server.search_handler(User.ldap_search)
ldap_server.search_handler(Group.ldap_search)

@ldap_server.bind_handler
def ldap_bind(name, password, conn):
	try:
		password = password.decode()
	except UnicodeDecodeError:
		return False
	evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes,
		objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
		dn_base=User.ldap_dn_base)
	res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
	if res:
		return res.check_password(password)
	return False

ldap_server.run('127.0.0.1', 1337)