From 19a484ffa8027ce540a0ff3120cd2379389dc35a Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Sat, 20 Feb 2021 17:54:15 +0100
Subject: [PATCH] Some restructuring

---
 ldap3_mapper_new/base.py | 91 +++++++++++++++++++++++++++-------------
 1 file changed, 63 insertions(+), 28 deletions(-)

diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py
index 80a38a1f..ff3ad0e5 100644
--- a/ldap3_mapper_new/base.py
+++ b/ldap3_mapper_new/base.py
@@ -15,6 +15,9 @@ class State:
 		return State(self.status, deepcopy(self.attributes))
 
 class Operation:
+	def __init__(self, obj):
+		self.obj = obj
+
 	def apply(self, state):
 		raise NotImplemented()
 
@@ -25,7 +28,8 @@ class Operation:
 		return False
 
 class AddOperation(Operation):
-	def __init__(self, attributes, ldap_object_classes):
+	def __init__(self, obj, attributes, ldap_object_classes):
+		super().__init__(obj)
 		self.attributes = deepcopy(attributes)
 		self.ldap_object_classes = ldap_object_classes
 
@@ -33,8 +37,8 @@ class AddOperation(Operation):
 		state.status = Status.ADDED
 		state.attributes = self.attributes
 
-	def execute(self, dn, conn):
-		success = conn.add(dn, self.ldap_object_classes, self.attributes)
+	def execute(self, conn):
+		success = conn.add(self.obj.dn, self.ldap_object_classes, self.attributes)
 		if not success:
 			raise LDAPCommitError()
 
@@ -42,13 +46,14 @@ class DeleteOperation(Operation):
 	def apply(self, state):
 		state.status = Status.DELETED
 
-	def execute(self, dn, conn):
-		success = conn.delete(dn)
+	def execute(self, conn):
+		success = conn.delete(self.obj.dn)
 		if not success:
 			raise LDAPCommitError()
 
 class ModifyOperation(Operation):
-	def __init__(self, changes):
+	def __init__(self, obj, changes):
+		super().__init__(obj)
 		self.changes = deepcopy(changes)
 
 	def apply(self, state):
@@ -62,8 +67,8 @@ class ModifyOperation(Operation):
 					for value in values:
 						state.attributes[attr].remove(value)
 
-	def execute(self, dn, conn):
-		success = conn.modify(dn, self.changes)
+	def execute(self, conn):
+		success = conn.modify(self.obj.dn, self.changes)
 		if not success:
 			raise LDAPCommitError()
 
@@ -72,27 +77,41 @@ class Session:
 
 	def __init__(self):
 		self.__objects = {}
+		self.__deleted_objects = {}
 		self.__operations = []
 
-	def record(self, obj, oper):
-		if not self.__operations or self.__operations[0][0] != obj or not self.__operations[0][1].extend(oper):
-			self.__operations.append((obj, oper))
+	# Never called directly!
+	def record(self, oper):
+		if isinstance(oper, AddOperation):
+			if oper.obj.ldap_state.session == self:
+				return
+			if oper.obj.ldap_state.session is not None:
+				raise Exception()
+			if oper.obj.dn in self.__objects:
+				raise Exception()
+			self.__objects[oper.obj.dn] = oper.obj
+		elif isinstance(oper, DeleteOperation):
+			if oper.obj.ldap_state.session is None:
+				return
+			if oper.obj.ldap_state.session != self:
+				raise Exception()
+			if oper.obj.dn not in self.__objects:
+				raise Exception()
+			if oper.obj.dn in self.__deleted_objects:
+				raise Exception()
+			self.__deleted_objects[oper.obj.dn] = oper.obj
+			del self.__objects[oper.obj.dn]
+		else:
+			if oper.obj.ldap_state.session is None:
+				return
+		if not self.__operations or not self.__operations[-1].extend(oper):
+			self.__operations.append(oper)
 
-	# TODO: maybe move the implementation to SessionObjectState?
 	def add(self, obj):
-		if obj.ldap_state.current.status != Status.NEW:
-			return
-		oper = AddOperation(obj.ldap_state.current.attributes, obj.ldap_object_classes)
-		oper.apply(obj.ldap_state.current)
-		self.__operations.append((obj, oper))
+		obj.ldap_state.add(self)
 
-	# TODO: maybe move the implementation to SessionObjectState?
 	def delete(self, obj):
-		if obj.ldap_state.current.status != Status.ADDED:
-			return
-		oper = DeleteOperation()
-		oper.apply(obj.ldap_state.current)
-		self.__operations.append((obj, oper))
+		obj.ldap_state.delete()
 
 	def commit(self):
 		conn = self.mapper.connect()
@@ -103,7 +122,6 @@ class Session:
 			except e:
 				self.__operations.insert(0, (obj, oper))
 				raise e
-			oper.apply(obj.ldap_state.committed)
 
 	def rollback(self):
 		while self.__operations:
@@ -146,32 +164,49 @@ class Session:
 class SessionObjectState:
 	def __init__(self, obj, response=None):
 		self.obj = obj
-		self.session = obj.ldap_mapper.session
+		self.session = None
 		if response is not None:
 			self.commited = State()
 		else:
 			self.commited = State(Status.ADDED, response['attributes'])
 		self.current = self.commited.copy()
 
+	def add(self, session):
+		if self.session is not None:
+			return
+		# TODO: call hook functions
+		oper = AddOperation(self.current.attributes, self.obj.ldap_object_classes)
+		self.session.record(self.obj, oper)
+		self.session = session
+		oper.apply(self.current)
+
+	def delete(self):
+		if self.session is None:
+			return
+		oper = DeleteOperation()
+		self.session.record(self.obj, oper)
+		self.session = None
+		oper.apply(self.current)
+
 	def getattr(self, name):
 		return self.current.attributes.get(name, [])
 
 	def setattr(self, name, values):
 		oper = ModifyOperation({name: [(MODIFY_REPLACE, [values])]})
-		if self.current.status == Status.ADDED:
+		if self.session is not None:
 			self.session.record(self.obj, oper)
 		oper.apply(self.current)
 
 	def attr_append(self, name, value):
 		oper = ModifyOperation({name: [(MODIFY_ADD, [value])]})
-		if self.current.status == Status.ADDED:
+		if self.session is not None:
 			self.session.record(self.obj, oper)
 		oper.apply(self.current)
 
 	def attr_remove(self, name, value):
 		# TODO: how does LDAP handle MODIFY_DELETE ops with non-existant values?
 		oper = ModifyOperation({name: [(MODIFY_DELETE, [value])]})
-		if self.current.status == Status.ADDED:
+		if self.session is not None:
 			self.session.record(self.obj, oper)
 		oper.apply(self.current)
 
-- 
GitLab