from copy import deepcopy

from ldap3.utils.conv import escape_filter_chars
from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES

from .types import LDAPSet, LDAPCommitError

class Session:
	def __init__(self):
		self.__objects = {} # dn -> instance
		self.__to_delete = []
		self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...}

	def lookup(self, dn):
		return self.__objects.get(dn)

	def register(self, obj):
		if obj.dn in self.__objects and self.__objects[obj.dn] != obj:
			raise Exception()
		self.__objects[obj.dn] = obj
		return obj

	def lookup_relations(self, srccls, srcattr, dn):
		key = (srccls, srcattr, dn)
		return self.__relations.get(key, set())

	def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None):
		for dn in (delete_dns or []):
			key = (type(srcobj), srcattr, dn)
			self.__relations[key] = self.__relations.get(key, set())
			self.__relations[key].discard(srcobj)
		for dn in (add_dns or []):
			key = (type(srcobj), srcattr, dn)
			self.__relations[key] = self.__relations.get(key, set())
			self.__relations[key].add(srcobj)

	def add(self, obj):
		self.register(obj)

	def delete(self, obj):
		if obj.dn in self.__objects:
			del self.__objects[obj.dn]
		self.__to_delete.append(obj)

	def commit(self):
		while self.__to_delete:
			self.__to_delete.pop(0).ldap_delete()
		for obj in list(self.__objects.values()):
			if not obj.ldap_created:
				obj.ldap_create()
			elif obj.ldap_dirty:
				obj.ldap_modify()

	def rollback(self):
		self.__to_delete.clear()
		self.__objects = {dn: obj for dn, obj in self.__objects.items() if obj.ldap_created}
		for obj in self.__objects.values():
			if obj.ldap_dirty:
				obj.ldap_reset()

class Model:
	ldap_mapper = None # Overwritten by LDAP3Mapper

	ldap_dn_attribute = None
	ldap_dn_base = None
	ldap_base = None
	ldap_object_classes = None
	ldap_filter = None
	# Caution: Never mutate ldap_pre_create_hooks and ldap_relations, always reassign!
	ldap_pre_create_hooks = []
	ldap_relations = []

	def __init__(self, _ldap_response=None, **kwargs):
		self.ldap_session = self.ldap_mapper.session
		self.ldap_relation_data = set()
		self.__ldap_dn = None if _ldap_response is None else _ldap_response['dn']
		self.__ldap_attributes = {}
		for key, values in (_ldap_response or {}).get('attributes', {}).items():
			if isinstance(values, list):
				self.__ldap_attributes[key] = values
			else:
				self.__ldap_attributes[key] = [values]
		self.__attributes = deepcopy(self.__ldap_attributes)
		self.__changes = {}
		for key, value, in kwargs.items():
			if not hasattr(self, key):
				raise Exception()
			setattr(self, key, value)
		for name in self.ldap_relations:
			self.__update_relations(name, add_dns=self.__attributes.get(name, []))

	def __update_relations(self, name, delete_dns=None, add_dns=None):
		if name in self.ldap_relations:
			self.ldap_session.update_relations(self, name, delete_dns, add_dns)

	def ldap_getattr(self, name):
		return self.__attributes.get(name, [])

	def ldap_setattr(self, name, values):
		self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
		self.__changes[name] = [(MODIFY_REPLACE, values)]
		self.__attributes[name] = values
		self.__update_relations(name, add_dns=values)

	def ldap_attradd(self, name, value):
		self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])]
		self.__attributes[name].append(value)
		self.__update_relations(name, add_dns=[value])

	def ldap_attrdel(self, name, value):
		self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])]
		if value in self.__attributes.get(name, []):
			self.__attributes[name].remove(value)
		self.__update_relations(name, delete_dns=[value])

	def __repr__(self):
		name = '%s.%s'%(type(self).__module__, type(self).__name__)
		if self.__ldap_dn is None:
			return '<%s>'%name
		return '<%s %s>'%(name, self.__ldap_dn)

	def build_dn(self):
		if self.ldap_dn_attribute is None:
			return None
		if self.ldap_dn_base is None:
			return None
		if self.__attributes.get(self.ldap_dn_attribute) is None:
			return None
		return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(self.__attributes[self.ldap_dn_attribute][0]), self.ldap_dn_base)

	@property
	def dn(self):
		if self.__ldap_dn is not None:
			return self.__ldap_dn
		return self.build_dn()

	@classmethod
	def ldap_get(cls, dn):
		obj = cls.ldap_mapper.session.lookup(dn)
		if obj is None:
			conn = cls.ldap_mapper.connect()
			conn.search(dn, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
			if not conn.response:
				return None
			if len(conn.response) != 1:
				raise Exception()
			obj = cls.ldap_mapper.session.register(cls(_ldap_response=conn.response[0]))
		return obj

	@classmethod
	def ldap_all(cls):
		conn = cls.ldap_mapper.connect()
		conn.search(cls.ldap_base, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
		res = []
		for entry in conn.response:
			obj = cls.ldap_mapper.session.lookup(entry['dn'])
			if obj is None:
				obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry))
			res.append(obj)
		return res

	@classmethod
	def ldap_filter_by_raw(cls, **kwargs):
		filters = [cls.ldap_filter]
		for key, value in kwargs.items():
			filters.append('(%s=%s)'%(key, escape_filter_chars(value)))
		conn = cls.ldap_mapper.connect()
		conn.search(cls.ldap_base, '(&%s)'%(''.join(filters)), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
		res = []
		for entry in conn.response:
			obj = cls.ldap_mapper.session.lookup(entry['dn'])
			if obj is None:
				obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry))
			res.append(obj)
		return res

	@classmethod
	def ldap_filter_by(cls, **kwargs):
		_kwargs = {}
		for key, value in kwargs.items():
			attr = getattr(cls, key)
			_kwargs[attr.name] = attr.encode(value)
		return cls.ldap_filter_by_raw(**_kwargs)

	def ldap_reset(self):
		for name in self.ldap_relations:
			self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
		self.__changes = {}
		self.__attributes = deepcopy(self.__ldap_attributes)
		for name in self.ldap_relations:
			self.__update_relations(name, add_dns=self.__attributes.get(name, []))

	@property
	def ldap_dirty(self):
		return bool(self.__changes)

	@property
	def ldap_created(self):
		return bool(self.__ldap_attributes)

	def ldap_modify(self):
		if not self.ldap_created:
			raise Exception()
		if not self.ldap_dirty:
			return
		conn = self.ldap_mapper.connect()
		success = conn.modify(self.dn, self.__changes)
		if not success:
			raise Exception()
		self.__changes = {}
		self.__ldap_attributes = deepcopy(self.__attributes)

	def ldap_create(self):
		if self.ldap_created:
			raise Exception()
		conn = self.ldap_mapper.connect()
		for func in self.ldap_pre_create_hooks:
			func(self)
		success = conn.add(self.dn, self.ldap_object_classes, self.__attributes)
		if not success:
			raise LDAPCommitError()
		self.__changes = {}
		self.__ldap_attributes = deepcopy(self.__attributes)

	def ldap_delete(self):
		conn = self.ldap_mapper.connect()
		success = conn.delete(self.dn)
		if not success:
			raise Exception()
		self.__ldap_attributes = {}

class Attribute:
	ldap_mapper = None # Overwritten by LDAP3Mapper

	def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None):
		self.name = name
		self.multi = multi
		self.encode = encode or (lambda x: x)
		self.decode = decode or (lambda x: x)
		self.default_values = default
		self.aliases = aliases or []

	def default(self, obj):
		if obj.ldap_getattr(self.name) == []:
			values = self.default_values
			if callable(values):
				values = values()
			self.__set__(obj, values)

	def additem(self, obj, value):
		obj.ldap_attradd(self.name, value)
		for name in self.aliases:
			obj.ldap_attradd(name, value)

	def delitem(self, obj, value):
		obj.ldap_attradd(self.name, value)
		for name in self.aliases:
			obj.ldap_attradd(name, value)

	def __set_name__(self, cls, name):
		if self.default_values is not None:
			cls.ldap_pre_create_hooks = cls.ldap_pre_create_hooks + [self.default]

	def __get__(self, obj, objtype=None):
		if obj is None:
			return self
		if self.multi:
			return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
			               additem=lambda value: self.additem(obj, value),
			               delitem=lambda value: self.delitem(obj, value),
			               encode=self.encode, decode=self.decode)
		return self.decode((obj.ldap_getattr(self.name) or [None])[0])

	def __set__(self, obj, values):
		if not self.multi:
			values = [values]
		obj.ldap_setattr(self.name, [self.encode(value) for value in values])
		for name in self.aliases:
			obj.ldap_setattr(name, [self.encode(value) for value in values])

class Backref:
	ldap_mapper = None # Overwritten by LDAP3Mapper

	def __init__(self, srccls, srcattr):
		self.srccls = srccls
		self.srcattr = srcattr
		srccls.ldap_relations = srccls.ldap_relations + [srcattr]

	def init(self, obj):
		if self.srcattr not in obj.ldap_relation_data and obj.ldap_created:
			# The query instanciates all related objects that in turn add their relations to session
			self.srccls.ldap_filter_by_raw(**{self.srcattr: obj.dn})
		obj.ldap_relation_data.add(self.srcattr)

	def __get__(self, obj, objtype=None):
		if obj is None:
			return self
		self.init(obj)
		return LDAPSet(getitems=lambda: obj.ldap_session.lookup_relations(self.srccls, self.srcattr, obj.dn),
									 additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn),
									 delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn))

	def __set__(self, obj, values):
		current = self.__get__(obj)
		current.clear()
		for value in values:
			current.add(value)

class Relation(Attribute):
	ldap_mapper = None # Overwritten by LDAP3Mapper

	def __init__(self, name, dest, backref=None):
		super().__init__(name, multi=True, encode=lambda value: value.dn, decode=dest.ldap_get)
		self.name = name
		self.dest = dest
		self.backref = backref

	def __set_name__(self, cls, name):
		if self.backref is not None:
			setattr(self.dest, self.backref, Backref(cls, self.name))