From 50a326d69e533eadc63c9357ff5282a5de32076a Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Mon, 22 Feb 2021 15:38:14 +0100
Subject: [PATCH] Implemented relationships

---
 ldap3_mapper_new/base.py         |  24 ++++--
 ldap3_mapper_new/model.py        |   3 +-
 ldap3_mapper_new/relationship.py | 128 +++++++++++++++++++++++++++++++
 3 files changed, 146 insertions(+), 9 deletions(-)
 create mode 100644 ldap3_mapper_new/relationship.py

diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py
index ff474a55..11399220 100644
--- a/ldap3_mapper_new/base.py
+++ b/ldap3_mapper_new/base.py
@@ -3,12 +3,18 @@ from copy import deepcopy
 from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
 from ldap3.utils.conv import escape_filter_chars
 
-def encode_filter(params):
-	return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params]))
+def encode_filter(filter_params):
+	return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in filter_params]))
 
 def match_dn(dn, base):
 	return dn.endswith(base) # Probably good enougth for all valid dns
 
+def make_cache_key(search_base, filter_params):
+	res = [search_base]
+	for attr, value in sorted(filter_params):
+		res.append((attr, value))
+	return res
+
 class LDAPCommitError(Exception):
 	pass
 
@@ -125,6 +131,7 @@ class Session:
 		self.committed_state = SessionState()
 		self.state = SessionState()
 		self.changes = []
+		self.cached_searches = set()
 
 	def add(self, obj, dn, object_classes):
 		if self.state.objects.get(dn) == obj:
@@ -190,7 +197,7 @@ class Session:
 			self.state.ref(obj, attr, values)
 		return obj
 
-	def filter_local(self, search_base, filter_params):
+	def filter(self, search_base, filter_params):
 		if not filter_params:
 			matches = self.state.objects.values()
 		else:
@@ -198,12 +205,12 @@ class Session:
 			matches = submatches.pop(0)
 			while submatches:
 				matches = matches.intersection(submatches.pop(0))
-		return [obj for obj in matches if match_dn(obj.state.dn, search_base)]
-
-	def filter(self, search_base, filter_params):
+		res = [obj for obj in matches if match_dn(obj.state.dn, search_base)]
+		cache_key = make_cache_key(search_base, filter_params)
+		if cache_key in self.cached_searches:
+			return res
 		conn = self.get_connection()
 		conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
-		res = []
 		for response in conn.response:
 			dn = response['dn']
 			if dn in self.state.objects or dn in self.state.deleted_objects:
@@ -214,7 +221,8 @@ class Session:
 			for attr, values in obj.state.attributes.items():
 				self.state.ref(obj, attr, values)
 			res.append(obj)
-		return res + self.filter_local(search_base, filter_params)
+		self.cached_searches.add(cache_key)
+		return res
 
 class Object:
 	def __init__(self, session=None, response=None):
diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py
index 8981899f..7f65466d 100644
--- a/ldap3_mapper_new/model.py
+++ b/ldap3_mapper_new/model.py
@@ -21,7 +21,7 @@ class Session:
 		self.ldap_session = base.Session(get_connection)
 
 	def add(self, obj):
-		self.ldap_session.add(obj.ldap_object, obj.dn, obj.object_classes)
+		self.ldap_session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes)
 
 	def delete(self, obj):
 		self.ldap_session.delete(obj.ldap_object)
@@ -80,6 +80,7 @@ class Model:
 	# Overwritten by models
 	ldap_search_base = None
 	ldap_filter_params = None
+	ldap_object_classes = None
 	ldap_dn_base = None
 	ldap_dn_attribute = None
 
diff --git a/ldap3_mapper_new/relationship.py b/ldap3_mapper_new/relationship.py
new file mode 100644
index 00000000..486d7276
--- /dev/null
+++ b/ldap3_mapper_new/relationship.py
@@ -0,0 +1,128 @@
+from collections.abc import MutableSet
+
+from .model import make_modelobj, make_modelobjs
+
+class UnboundObjectError(Exception):
+	pass
+
+class RelationshipSet(MutableSet):
+	def __init__(self, ldap_object, name, model, destmodel):
+		self.__ldap_object = ldap_object
+		self.__name = name
+		self.__model = model
+		self.__destmodel = destmodel
+
+	def __modify_check(self, value):
+		if self.__ldap_object.session is None:
+			raise UnboundObjectError()
+		if not isinstance(value, self.__destmodel):
+			raise TypeError()
+
+	def __repr__(self):
+		return repr(set(self))
+
+	def __contains__(self, value):
+		if value is None or not isinstance(value, self.__destmodel):
+			return False
+		return value.ldap_object.dn in self.__ldap_object.getattr(self.__name)
+
+	def __iter__(self):
+		def get(dn):
+			return make_modelobj(self.__ldap_object.session.get(dn, self.__model.ldap_filter_params), self.__destmodel)
+		dns = set(self.__ldap_object.getattr(self.__name))
+		return iter(filter(lambda obj: obj is not None, map(get, dns)))
+
+	def __len__(self):
+		return len(set(self))
+
+	def add(self, value):
+		self.__modify_check(value)
+		if value.ldap_object.session is None:
+			self.__ldap_object.session.add(value.ldap_object)
+		assert value.ldap_object.session == self.__ldap_object.session
+		self.__ldap_object.attradd(self.__name, value.dn)
+
+	def discard(self, value):
+		self.__modify_check(value)
+		self.__ldap_object.attrdel(self.__name, value.dn)
+
+class Relationship:
+	def __init__(self, name, destmodel, backref=None):
+		self.name = name
+		self.destmodel = destmodel
+		self.backref = backref
+
+	def __set_name__(self, cls, name):
+		if self.backref is not None:
+			setattr(self.destmodel, self.backref, Backreference(self.name, cls))
+
+	def __get__(self, obj, objtype=None):
+		if obj is None:
+			return self
+		return RelationshipSet(obj, self.name, type(obj), self.destmodel)
+
+	def __set__(self, obj, values):
+		tmp = self.__get__(obj)
+		tmp.clear()
+		for value in values:
+			tmp.add(value)
+
+class BackreferenceSet(MutableSet):
+	def __init__(self, ldap_object, name, model, srcmodel):
+		self.__ldap_object = ldap_object
+		self.__name = name
+		self.__model = model
+		self.__srcmodel = srcmodel
+
+	def __modify_check(self, value):
+		if self.__ldap_object.session is None:
+			raise UnboundObjectError()
+		if not isinstance(value, self.__srcmodel):
+			raise TypeError()
+
+	def __get(self):
+		if self.__ldap_object.session is None:
+			return set()
+		filter_params = self.__srcmodel.filter_params + [(self.__name, self.__ldap_object.dn)]
+		objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params)
+		return set(make_modelobjs(objs, self.__srcmodel))
+
+	def __repr__(self):
+		return repr(self.__get())
+
+	def __contains__(self, value):
+		return value in self.__get()
+
+	def __iter__(self):
+		return iter(self.__get())
+
+	def __len__(self):
+		return len(self.__get())
+
+	def add(self, value):
+		self.__modify_check(value)
+		if value.ldap_object.session is None:
+			self.__ldap_object.session.add(value.ldap_object)
+		assert value.ldap_object.session == self.__ldap_object.session
+		if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name):
+			value.ldap_object.attradd(self.__name, self.__ldap_object.dn)
+
+	def discard(self, value):
+		self.__modify_check(value)
+		value.ldap_object.attrdel(self.__name, self.__ldap_object.dn)
+
+class Backreference:
+	def __init__(self, name, srcmodel):
+		self.name = name
+		self.srcmodel = srcmodel
+
+	def __get__(self, obj, objtype=None):
+		if obj is None:
+			return self
+		return BackreferenceSet(obj, self.name, type(obj), self.srcmodel)
+
+	def __set__(self, obj, values):
+		tmp = self.__get__(obj)
+		tmp.clear()
+		for value in values:
+			tmp.add(value)
-- 
GitLab