From 1306790eccb9b757f8e1a369fd6f7fba48d969ab Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Sun, 21 Feb 2021 01:55:28 +0100 Subject: [PATCH] Added reference tracking --- ldap3_mapper_new/base.py | 50 ++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py index db0188d2..bf80f408 100644 --- a/ldap3_mapper_new/base.py +++ b/ldap3_mapper_new/base.py @@ -1,12 +1,24 @@ from copy import deepcopy class SessionState: - def __init__(self, objects=None, deleted_objects=None): + def __init__(self, objects=None, deleted_objects=None, references=None): self.objects = objects or {} self.deleted_objects = deleted_objects or {} + self.references = references or {} # {(attr_name, value): {srcobj, ...}, ...} def copy(self): - return SessionState(objects=deepcopy(self.objects), deleted_objects=deepcopy(self.deleted_objects)) + return SessionState(deepcopy(self.objects), deepcopy(self.deleted_objects), deepcopy(self.references)) + + def ref(self, obj, attr, values): + for value in values: + if key not in self.references: + self.references[key] = {self.obj} + else: + self.references[key].add(self.obj) + + def unref(self, obj, attr, values): + for value in values: + self.references.get((name, value), set()).discard(obj) class ObjectState: def __init__(self, session=None, attributes=None, dn=None): @@ -31,6 +43,8 @@ class AddOperation: def apply_session(self, session_state): assert self.dn not in session_state.objects session_state.objects[self.dn] = self.obj + for name, values in self.attributes.items(): + session_state.ref(self.obj, name, values) def apply_ldap(self, conn): success = conn.add(self.dn, self.object_classes, self.attributes) @@ -41,6 +55,7 @@ class DeleteOperation: def __init__(self, obj): self.dn = obj.state.dn self.obj = obj + self.attributes = deepcopy(obj.state.attributes) def apply_object(self, obj_state): obj_state.dn = None @@ -49,6 +64,8 @@ class DeleteOperation: assert self.dn in session_state.objects del session_state.objects[self.dn] session_state.deleted_objects[self.dn] = self.obj + for name, values in self.attributes.items(): + session_state.unref(self.obj, name, values) def apply_ldap(self, conn): success = conn.delete(self.dn) @@ -58,6 +75,7 @@ class DeleteOperation: class ModifyOperation: def __init__(self, obj, changes): self.obj = obj + self.attributes = deepcopy(obj.state.attributes) self.changes = deepcopy(changes) def apply_object(self, obj_state): @@ -72,7 +90,15 @@ class ModifyOperation: obj_state.attributes[attr].remove(value) def apply_session(self, session_state): - pass + for attr, changes in self.changes.items(): + for action, values in changes: + if action == MODIFY_REPLACE: + session_state.unref(self.obj, attr, self.attributes.get(attr, []) + session_state.ref(self.obj, attr, values) + elif action == MODIFY_ADD: + session_state.ref(self.obj, attr, values) + elif action == MODIFY_DELETE: + session_state.unref(self.obj, attr, values) def apply_ldap(self, conn): success = conn.modify(self.obj.state.dn, self.changes) @@ -141,9 +167,12 @@ class Session: return None assert len(conn.response) == 1 assert conn.response[0]['dn'] == dn - self.state.objects[dn] = Object(self, conn.response[0]) - self.committed_state.objects[dn] = self.state.objects[dn] - return self.state.objects[dn] + obj = Object(self, conn.response[0]) + self.state.objects[dn] = obj + self.committed_state.objects[dn] = obj + for attr, values in obj.state.attributes.items(): + self.state.ref(obj, attr, values) + return obj def search(self, search_base, search_filter): conn = self.get_connection() @@ -156,9 +185,12 @@ class Session: elif dn in self.state.deleted_objects: continue else: - self.state.objects[dn] = Object(self, response) - self.committed_state.objects[dn] = self.state.objects[dn] - res.append(self.state.objects[dn]) + obj = Object(self, response) + self.state.objects[dn] = obj + self.committed_state.objects[dn] = obj + for attr, values in obj.state.attributes.items(): + self.state.ref(obj, attr, values) + res.append(obj) return res class Object: -- GitLab