From bb9c49cc37c011307d44ef0011bc40fd1c50d537 Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Sun, 21 Feb 2021 00:55:50 +0100
Subject: [PATCH] Rewrote everything

---
 ldap3_mapper_new/base.py | 298 ++++++++++++++++++---------------------
 1 file changed, 134 insertions(+), 164 deletions(-)

diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py
index ff3ad0e5..db0188d2 100644
--- a/ldap3_mapper_new/base.py
+++ b/ldap3_mapper_new/base.py
@@ -1,225 +1,195 @@
-from enum import Enum
 from copy import deepcopy
 
-class Status(Enum):
-	NEW
-	ADDED
-	DELETED
+class SessionState:
+	def __init__(self, objects=None, deleted_objects=None):
+		self.objects = objects or {}
+		self.deleted_objects = deleted_objects or {}
 
-class State:
-	def __init__(self, status=Status.NEW, attributes=None):
-		self.status = status
+	def copy(self):
+		return SessionState(objects=deepcopy(self.objects), deleted_objects=deepcopy(self.deleted_objects))
+
+class ObjectState:
+	def __init__(self, session=None, attributes=None, dn=None):
+		self.session = session
 		self.attributes = attributes or {}
+		self.dn = dn
 
 	def copy(self):
-		return State(self.status, deepcopy(self.attributes))
+		return ObjectState(attributes=deepcopy(self.attributes), dn=self.dn, session=self.session)
 
-class Operation:
-	def __init__(self, obj):
+class AddOperation:
+	def __init__(self, obj, dn, object_classes):
 		self.obj = obj
+		self.dn = dn
+		self.object_classes = object_classes
+		self.attributes = deepcopy(obj.state.attributes)
 
-	def apply(self, state):
-		raise NotImplemented()
-
-	def execute(self, conn):
-		raise NotImplemented()
-
-	def extend(self, oper):
-		return False
-
-class AddOperation(Operation):
-	def __init__(self, obj, attributes, ldap_object_classes):
-		super().__init__(obj)
-		self.attributes = deepcopy(attributes)
-		self.ldap_object_classes = ldap_object_classes
+	def apply_object(self, obj_state):
+		obj_state.dn = self.dn
+		obj_state.attributes = deepcopy(self.attributes)
 
-	def apply(self, state):
-		state.status = Status.ADDED
-		state.attributes = self.attributes
+	def apply_session(self, session_state):
+		assert self.dn not in session_state.objects
+		session_state.objects[self.dn] = self.obj
 
-	def execute(self, conn):
-		success = conn.add(self.obj.dn, self.ldap_object_classes, self.attributes)
+	def apply_ldap(self, conn):
+		success = conn.add(self.dn, self.object_classes, self.attributes)
 		if not success:
 			raise LDAPCommitError()
 
-class DeleteOperation(Operation):
-	def apply(self, state):
-		state.status = Status.DELETED
+class DeleteOperation:
+	def __init__(self, obj):
+		self.dn = obj.state.dn
+		self.obj = obj
+
+	def apply_object(self, obj_state):
+		obj_state.dn = None
+
+	def apply_session(self, session_state):
+		assert self.dn in session_state.objects
+		del session_state.objects[self.dn]
+		session_state.deleted_objects[self.dn] = self.obj
 
-	def execute(self, conn):
-		success = conn.delete(self.obj.dn)
+	def apply_ldap(self, conn):
+		success = conn.delete(self.dn)
 		if not success:
 			raise LDAPCommitError()
 
-class ModifyOperation(Operation):
+class ModifyOperation:
 	def __init__(self, obj, changes):
-		super().__init__(obj)
+		self.obj = obj
 		self.changes = deepcopy(changes)
 
-	def apply(self, state):
+	def apply_object(self, obj_state):
 		for attr, changes in self.changes.items():
 			for action, values in changes:
 				if action == MODIFY_REPLACE:
-					state.attributes[attr] = values
+					obj_state.attributes[attr] = values
 				elif action == MODIFY_ADD:
-					state.attributes[attr] += values
+					obj_state.attributes[attr] += values
 				elif action == MODIFY_DELETE:
 					for value in values:
-						state.attributes[attr].remove(value)
+						obj_state.attributes[attr].remove(value)
+
+	def apply_session(self, session_state):
+		pass
 
-	def execute(self, conn):
-		success = conn.modify(self.obj.dn, self.changes)
+	def apply_ldap(self, conn):
+		success = conn.modify(self.obj.state.dn, self.changes)
 		if not success:
 			raise LDAPCommitError()
 
 class Session:
-	ldap_mapper = None
+	def __init__(self, get_connection):
+		self.get_connection = get_connection
+		self.committed_state = SessionState()
+		self.state = SessionState()
+		self.changes = []
+
+	def add(self, obj, dn, object_classes):
+		if self.state.objects.get(dn) == obj:
+			return
+		assert obj.state.session is None
+		oper = AddOperation(obj, dn, object_classes)
+		oper.apply_object(obj.state)
+		oper.apply_session(self.state)
+		self.changes.append(oper)
 
-	def __init__(self):
-		self.__objects = {}
-		self.__deleted_objects = {}
-		self.__operations = []
+	def delete(self, obj):
+		if obj.state.dn not in self.state.objects:
+			return
+		assert obj.state.session == self
+		oper = DeleteOperation(obj)
+		oper.apply_object(obj.state)
+		obj.state.session = None
+		oper.apply_session(self.state)
+		self.changes.append(oper)
 
-	# Never called directly!
 	def record(self, oper):
-		if isinstance(oper, AddOperation):
-			if oper.obj.ldap_state.session == self:
-				return
-			if oper.obj.ldap_state.session is not None:
-				raise Exception()
-			if oper.obj.dn in self.__objects:
-				raise Exception()
-			self.__objects[oper.obj.dn] = oper.obj
-		elif isinstance(oper, DeleteOperation):
-			if oper.obj.ldap_state.session is None:
-				return
-			if oper.obj.ldap_state.session != self:
-				raise Exception()
-			if oper.obj.dn not in self.__objects:
-				raise Exception()
-			if oper.obj.dn in self.__deleted_objects:
-				raise Exception()
-			self.__deleted_objects[oper.obj.dn] = oper.obj
-			del self.__objects[oper.obj.dn]
-		else:
-			if oper.obj.ldap_state.session is None:
-				return
-		if not self.__operations or not self.__operations[-1].extend(oper):
-			self.__operations.append(oper)
-
-	def add(self, obj):
-		obj.ldap_state.add(self)
-
-	def delete(self, obj):
-		obj.ldap_state.delete()
+		assert oper.obj.state.session == self
+		self.changes.append(oper)
 
 	def commit(self):
-		conn = self.mapper.connect()
-		while self.__operations:
-			obj, oper = self.__operations.pop(0)
+		conn = self.get_connection()
+		while self.changes:
+			oper = self.changes.pop(0)
 			try:
-				oper.execute(obj.dn, conn)
+				oper.apply_ldap(conn)
 			except e:
-				self.__operations.insert(0, (obj, oper))
+				self.changes.insert(0, oper)
 				raise e
+			oper.apply_object(oper.obj.committed_state)
+			oper.apply_session(self.committed_state)
+		self.committed_state = self.state.copy()
 
 	def rollback(self):
-		while self.__operations:
-			obj, oper = self.__operations.pop(0)
-			obj.ldap_state.current = obj.ldap_state.committed.copy()
-
-	def query_get(self, cls, dn):
-		if dn in self.__objects:
-			return self.__objects[dn]
-		if dn in self.__deleted_objects:
+		for obj in self.state.objects.values():
+			obj.state = obj.committed_state.copy()
+		for obj in self.state.deleted_objects.values():
+			obj.state = obj.committed_state.copy()
+		self.state = self.committed_state.copy()
+		self.changes.clear()
+
+	def get(self, dn, search_filter):
+		if dn in self.state.objects:
+			return self.state.objects[dn]
+		if dn in self.state.deleted_objects:
 			return None
-		conn = self.mapper.connect()
-		conn.search(dn, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
+		conn = self.get_connection()
+		conn.search(dn, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
 		if not conn.response:
 			return None
-		self.__objects[dn] = cls(__ldap_response=conn.response[0])
-		return self.__objects[dn]
-
-	def query_search(self, cls, filters=None):
-		filters = [cls.ldap_filter] + (filters or [])
-		if len(filters) == 1:
-			expr = filters[0]
-		else:
-			expr = '(&%s)'%(''.join(filters))
-		conn = self.mapper.connect()
-		conn.search(cls.ldap_base, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
+		assert len(conn.response) == 1
+		assert conn.response[0]['dn'] == dn
+		self.state.objects[dn] = Object(self, conn.response[0])
+		self.committed_state.objects[dn] = self.state.objects[dn]
+		return self.state.objects[dn]
+
+	def search(self, search_base, search_filter):
+		conn = self.get_connection()
+		conn.search(search_base, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
 		res = []
 		for response in conn.response:
 			dn = response['dn']
-			if dn in self.__objects:
-				res.append(self.__objects[dn])
-			elif dn in self.__deleted_objects:
+			if dn in self.state.objects:
+				res.append(self.state.objects[dn])
+			elif dn in self.state.deleted_objects:
 				continue
 			else:
-				self.__objects[dn] = cls(__ldap_response=response)
-				res.append(self.__objects[dn])
+				self.state.objects[dn] = Object(self, response)
+				self.committed_state.objects[dn] = self.state.objects[dn]
+				res.append(self.state.objects[dn])
 		return res
 
-# This is only a seperate class to keep SessionObject's namespace cleaner
-class SessionObjectState:
-	def __init__(self, obj, response=None):
-		self.obj = obj
-		self.session = None
-		if response is not None:
-			self.commited = State()
+class Object:
+	def __init__(self, session=None, response=None):
+		if response is None:
+			self.committed_state = ObjectState()
 		else:
-			self.commited = State(Status.ADDED, response['attributes'])
-		self.current = self.commited.copy()
-
-	def add(self, session):
-		if self.session is not None:
-			return
-		# TODO: call hook functions
-		oper = AddOperation(self.current.attributes, self.obj.ldap_object_classes)
-		self.session.record(self.obj, oper)
-		self.session = session
-		oper.apply(self.current)
-
-	def delete(self):
-		if self.session is None:
-			return
-		oper = DeleteOperation()
-		self.session.record(self.obj, oper)
-		self.session = None
-		oper.apply(self.current)
+			assert session is not None
+			self.committed_state = ObjectState(session, response['attributes'], response['dn'])
+		self.state = self.committed_state.copy()
 
 	def getattr(self, name):
-		return self.current.attributes.get(name, [])
+		return self.state.attributes.get(name, [])
 
 	def setattr(self, name, values):
-		oper = ModifyOperation({name: [(MODIFY_REPLACE, [values])]})
-		if self.session is not None:
-			self.session.record(self.obj, oper)
-		oper.apply(self.current)
+		oper = ModifyOperation(self, {name: [(MODIFY_REPLACE, [values])]})
+		oper.apply_object(obj.state)
+		if self.state.session:
+			oper.apply_session(self.state.session.state)
+			self.state.session.changes.append(oper)
 
 	def attr_append(self, name, value):
-		oper = ModifyOperation({name: [(MODIFY_ADD, [value])]})
-		if self.session is not None:
-			self.session.record(self.obj, oper)
-		oper.apply(self.current)
+		oper = ModifyOperation(self, {name: [(MODIFY_ADD, [value])]})
+		oper.apply_object(obj.state)
+		if self.state.session:
+			oper.apply_session(self.state.session.state)
+			self.state.session.changes.append(oper)
 
 	def attr_remove(self, name, value):
-		# TODO: how does LDAP handle MODIFY_DELETE ops with non-existant values?
-		oper = ModifyOperation({name: [(MODIFY_DELETE, [value])]})
-		if self.session is not None:
-			self.session.record(self.obj, oper)
-		oper.apply(self.current)
-
-# This is only a seperate class to keep SessionObject's namespace cleaner
-class SessionObject:
-	ldap_mapper = None
-	ldap_object_classes = None
-	ldap_base = None
-	ldap_filter = None
-
-	def __init__(self, __ldap_response=None):
-		self.ldap_state = SessionObjectState(self, __ldap_response)
-
-	@property
-	def dn(self):
-		raise NotImplemented()
+		oper = ModifyOperation(self, {name: [(MODIFY_DELETE, [value])]})
+		oper.apply_object(obj.state)
+		if self.state.session:
+			oper.apply_session(self.state.session.state)
+			self.state.session.changes.append(oper)
-- 
GitLab