from collections.abc import MutableSet

from ldap3.utils.conv import escape_filter_chars

from . import base

class Session:
	def __init__(self, get_connection):
		self.ldap_session = base.Session(get_connection)

	def add(self, obj):
		self.ldap_session.add(obj.ldap_object, obj.dn, obj.object_classes)

	def delete(self, obj):
		self.ldap_session.delete(obj.ldap_object)

	def commit(self):
		self.ldap_session.commit()

	def rollback(self):
		self.ldap_session.rollback()

def make_modelobj(obj, model):
	if obj is None:
		return None
	if not hasattr(obj, 'model'):
		obj.model = model()
		obj.model.ldap_object = obj
	if not isinstance(obj.model, model):
		return None
	return obj.model

def make_modelobjs(objs, model):
	modelobjs = []
	for obj in objs:
		modelobj = make_modelobj(obj, model)
		if modelobj is not None:
			modelobjs.append(modelobj)
	return modelobjs

class ModelQuery:
	def __init__(self, model):
		self.model = model

	def get(self, dn):
		session = self.model.ldap_mapper.session.ldap_session
		return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)

	def all(self):
		session = self.model.ldap_mapper.session.ldap_session
		objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params)
		return make_modelobjs(objs, self.model)

	def filter_by(self, **kwargs):
		filter_params = self.model.ldap_filter_params + list(kwargs.items())
		session = self.model.ldap_mapper.session.ldap_session
		objs = session.filter(self.model.ldap_search_base, filter_params)
		return make_modelobjs(objs, self.model)

class ModelQueryWrapper:
	def __get__(self, obj, objtype=None):
		return ModelQuery(objtype)

class Model:
	# Overwritten by mapper
	ldap_mapper = None

	# Overwritten by models
	ldap_search_base = None
	ldap_filter_params = None
	ldap_dn_base = None
	ldap_dn_attribute = None

	query = ModelQueryWrapper()

	def __init__(self, **kwargs):
		self.ldap_object = base.Object()
		for key, value, in kwargs.items():
			if not hasattr(self, key):
				raise Exception()
			setattr(self, key, value)

	@property
	def dn(self):
		if self.ldap_object.dn is not None:
			return self.ldap_object.dn
		if self.ldap_dn_base is None or self.ldap_dn_attribute is None:
			return None
		values = self.ldap_object.getattr(self.ldap_dn_attribute)
		if not values:
			return None
		return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(values[0]), self.ldap_dn_base)

	def __repr__(self):
		cls_name = '%s.%s'%(type(self).__module__, type(self).__name__)
		if self.dn is not None:
			return '<%s %s>'%(cls_name, self.dn)
		return '<%s>'%cls_name

class SetView(MutableSet):
	def __init__(self, getitems, additem, delitem, encode=None, decode=None):
		self.__getitems = getitems
		self.__additem = additem
		self.__delitem = delitem
		self.__encode = encode or (lambda x: x)
		self.__decode = decode or (lambda x: x)

	def __repr__(self):
		return repr(set(self))

	def __contains__(self, value):
		return value is not None and self.__encode(value) in self.__getitems()

	def __iter__(self):
		return iter(filter(lambda obj: obj is not None, map(self.__decode, self.__getitems())))

	def __len__(self):
		return len(set(self))

	def add(self, value):
		if value not in self:
			self.__additem(self.__encode(value))

	def discard(self, value):
		self.__delitem(self.__encode(value))

	def update(self, values):
		for value in values:
			self.add(value)

class Attribute:
	def __init__(self, name, multi=False, encode=None, decode=None, aliases=None):
		self.name = name
		self.multi = multi
		self.encode = encode or (lambda x: x)
		self.decode = decode or (lambda x: x)
		self.aliases = [name] + (aliases or [])

	def additem(self, obj, value):
		for name in self.aliases:
			obj.ldap_object.attradd(name, value)

	def delitem(self, obj, value):
		for name in self.aliases:
			obj.ldap_object.attrdel(name, value)

	def __get__(self, obj, objtype=None):
		if obj is None:
			return self
		if self.multi:
			return SetView(getitems=lambda: obj.ldap_object.getattr(self.name),
			               additem=lambda value: self.additem(obj, value),
			               delitem=lambda value: self.delitem(obj, value),
			               encode=self.encode, decode=self.decode)
		return self.decode((obj.ldap_object.getattr(self.name) or [None])[0])

	def __set__(self, obj, values):
		if not self.multi:
			values = [values]
		values = [self.encode(value) for value in values]
		for name in self.aliases:
			obj.ldap_object.setattr(name, values)