from copy import deepcopy

from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES

class LDAPCommitError(Exception):
	pass

class SessionState:
	def __init__(self, objects=None, deleted_objects=None, references=None):
		self.objects = objects or {}
		self.deleted_objects = deleted_objects or {}
		self.references = references or {} # {(attr_name, value): {srcobj, ...}, ...}

	def copy(self):
		return SessionState(deepcopy(self.objects), deepcopy(self.deleted_objects), deepcopy(self.references))

	def ref(self, obj, attr, values):
		for value in values:
			key = (attr, value)
			if key not in self.references:
				self.references[key] = {obj}
			else:
				self.references[key].add(obj)

	def unref(self, obj, attr, values):
		for value in values:
			self.references.get((attr, value), set()).discard(obj)

class ObjectState:
	def __init__(self, session=None, attributes=None, dn=None):
		self.session = session
		self.attributes = attributes or {}
		self.dn = dn

	def copy(self):
		return ObjectState(attributes=deepcopy(self.attributes), dn=self.dn, session=self.session)

class AddOperation:
	def __init__(self, obj, dn, object_classes):
		self.obj = obj
		self.dn = dn
		self.object_classes = object_classes
		self.attributes = deepcopy(obj.state.attributes)

	def apply_object(self, obj_state):
		obj_state.dn = self.dn
		obj_state.attributes = deepcopy(self.attributes)

	def apply_session(self, session_state):
		assert self.dn not in session_state.objects
		session_state.objects[self.dn] = self.obj
		for name, values in self.attributes.items():
			session_state.ref(self.obj, name, values)

	def apply_ldap(self, conn):
		success = conn.add(self.dn, self.object_classes, self.attributes)
		if not success:
			raise LDAPCommitError()

class DeleteOperation:
	def __init__(self, obj):
		self.dn = obj.state.dn
		self.obj = obj
		self.attributes = deepcopy(obj.state.attributes)

	def apply_object(self, obj_state):
		obj_state.dn = None

	def apply_session(self, session_state):
		assert self.dn in session_state.objects
		del session_state.objects[self.dn]
		session_state.deleted_objects[self.dn] = self.obj
		for name, values in self.attributes.items():
			session_state.unref(self.obj, name, values)

	def apply_ldap(self, conn):
		success = conn.delete(self.dn)
		if not success:
			raise LDAPCommitError()

class ModifyOperation:
	def __init__(self, obj, changes):
		self.obj = obj
		self.attributes = deepcopy(obj.state.attributes)
		self.changes = deepcopy(changes)

	def apply_object(self, obj_state):
		for attr, changes in self.changes.items():
			for action, values in changes:
				if action == MODIFY_REPLACE:
					obj_state.attributes[attr] = values
				elif action == MODIFY_ADD:
					obj_state.attributes[attr] += values
				elif action == MODIFY_DELETE:
					for value in values:
						if value in obj_state.attributes[attr]:
							obj_state.attributes[attr].remove(value)

	def apply_session(self, session_state):
		for attr, changes in self.changes.items():
			for action, values in changes:
				if action == MODIFY_REPLACE:
					session_state.unref(self.obj, attr, self.attributes.get(attr, []))
					session_state.ref(self.obj, attr, values)
				elif action == MODIFY_ADD:
					session_state.ref(self.obj, attr, values)
				elif action == MODIFY_DELETE:
					session_state.unref(self.obj, attr, values)

	def apply_ldap(self, conn):
		success = conn.modify(self.obj.state.dn, self.changes)
		if not success:
			raise LDAPCommitError()

class Session:
	def __init__(self, get_connection):
		self.get_connection = get_connection
		self.committed_state = SessionState()
		self.state = SessionState()
		self.changes = []

	def add(self, obj, dn, object_classes):
		if self.state.objects.get(dn) == obj:
			return
		assert obj.state.session is None
		oper = AddOperation(obj, dn, object_classes)
		oper.apply_object(obj.state)
		obj.state.session = self
		oper.apply_session(self.state)
		self.changes.append(oper)

	def delete(self, obj):
		if obj.state.dn not in self.state.objects:
			return
		assert obj.state.session == self
		oper = DeleteOperation(obj)
		oper.apply_object(obj.state)
		obj.state.session = None
		oper.apply_session(self.state)
		self.changes.append(oper)

	def record(self, oper):
		assert oper.obj.state.session == self
		self.changes.append(oper)

	def commit(self):
		conn = self.get_connection()
		while self.changes:
			oper = self.changes.pop(0)
			try:
				oper.apply_ldap(conn)
			except Exception as err:
				self.changes.insert(0, oper)
				raise err
			oper.apply_object(oper.obj.committed_state)
			oper.apply_session(self.committed_state)
		self.committed_state = self.state.copy()

	def rollback(self):
		for obj in self.state.objects.values():
			obj.state = obj.committed_state.copy()
		for obj in self.state.deleted_objects.values():
			obj.state = obj.committed_state.copy()
		self.state = self.committed_state.copy()
		self.changes.clear()

	def get(self, dn, search_filter):
		if dn in self.state.objects:
			return self.state.objects[dn]
		if dn in self.state.deleted_objects:
			return None
		conn = self.get_connection()
		conn.search(dn, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
		if not conn.response:
			return None
		assert len(conn.response) == 1
		assert conn.response[0]['dn'] == dn
		obj = Object(self, conn.response[0])
		self.state.objects[dn] = obj
		self.committed_state.objects[dn] = obj
		for attr, values in obj.state.attributes.items():
			self.state.ref(obj, attr, values)
		return obj

	def search(self, search_base, search_filter):
		conn = self.get_connection()
		conn.search(search_base, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
		res = []
		for response in conn.response:
			dn = response['dn']
			if dn in self.state.objects:
				res.append(self.state.objects[dn])
			elif dn in self.state.deleted_objects:
				continue
			else:
				obj = Object(self, response)
				self.state.objects[dn] = obj
				self.committed_state.objects[dn] = obj
				for attr, values in obj.state.attributes.items():
					self.state.ref(obj, attr, values)
				res.append(obj)
		return res

class Object:
	def __init__(self, session=None, response=None):
		if response is None:
			self.committed_state = ObjectState()
		else:
			assert session is not None
			attrs = {attr: value if isinstance(value, list) else [value] for attr, value in response['attributes'].items()}
			self.committed_state = ObjectState(session, attrs, response['dn'])
		self.state = self.committed_state.copy()

	def getattr(self, name):
		return self.state.attributes.get(name, [])

	def setattr(self, name, values):
		oper = ModifyOperation(self, {name: [(MODIFY_REPLACE, values)]})
		oper.apply_object(self.state)
		if self.state.session:
			oper.apply_session(self.state.session.state)
			self.state.session.changes.append(oper)

	def attr_append(self, name, value):
		oper = ModifyOperation(self, {name: [(MODIFY_ADD, [value])]})
		oper.apply_object(self.state)
		if self.state.session:
			oper.apply_session(self.state.session.state)
			self.state.session.changes.append(oper)

	def attr_remove(self, name, value):
		oper = ModifyOperation(self, {name: [(MODIFY_DELETE, [value])]})
		oper.apply_object(self.state)
		if self.state.session:
			oper.apply_session(self.state.session.state)
			self.state.session.changes.append(oper)