From d0fd1d23bb01b77366ab4191fe4ed1b138a18122 Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Sun, 21 Feb 2021 04:58:31 +0100
Subject: [PATCH] Implemented high-level Session, Model and Attribute

---
 ldap3_mapper_new/__init__.py |  20 ++---
 ldap3_mapper_new/model.py    | 153 +++++++++++++++++++++++++++++++++++
 2 files changed, 160 insertions(+), 13 deletions(-)
 create mode 100644 ldap3_mapper_new/model.py

diff --git a/ldap3_mapper_new/__init__.py b/ldap3_mapper_new/__init__.py
index 7bb8e9be..a071ffe4 100644
--- a/ldap3_mapper_new/__init__.py
+++ b/ldap3_mapper_new/__init__.py
@@ -1,27 +1,21 @@
 import ldap3
 
-from .types import LDAPCommitError
-from . import base
+from . import model
 
-class BaseModel(base.SessionObject)::
-	def __init__(self, _ldap_response=None, **kwargs):
-		super().__init__(_ldap_response)
-		for key, value, in kwargs.items():
-			if not hasattr(type(self), key):
-				raise Exception()
-			setattr(self, key, value)
+__all__ = ['LDAP3Mapper']
 
 class LDAP3Mapper:
 	def __init__(self, server=None, bind_dn=None, bind_password=None):
 
-		class Session(base.Session):
+		class Session(model.Session):
 			ldap_mapper = self
 
-		class Model(BaseModel):
+		class Model(model.Model):
 			ldap_mapper = self
 
 		self.Session = Session # pylint: disable=invalid-name
 		self.Model = Model # pylint: disable=invalid-name
+		self.Attribute = model.Attribute # pylint: disable=invalid-name
 
 		if not hasattr(type(self), 'server'):
 			self.server = server
@@ -30,7 +24,7 @@ class LDAP3Mapper:
 		if not hasattr(type(self), 'bind_password'):
 			self.bind_password = bind_password
 		if not hasattr(type(self), 'session'):
-			self.session = self.Session()
+			self.session = self.Session(self.get_connection)
 
-	def connect(self):
+	def get_connection(self):
 		return ldap3.Connection(self.server, self.bind_dn, self.bind_password, auto_bind=True)
diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py
new file mode 100644
index 00000000..29c33e96
--- /dev/null
+++ b/ldap3_mapper_new/model.py
@@ -0,0 +1,153 @@
+from collections.abc import MutableSet
+
+from ldap3.utils.conv import escape_filter_chars
+
+from . import base
+
+class Session:
+	def __init__(self, get_connection):
+		self.ldap_session = base.Session(get_connection)
+
+	def add(self, obj):
+		self.ldap_session.add(obj.ldap_object, obj.dn, obj.object_classes)
+
+	def delete(self, obj):
+		self.ldap_session.delete(obj.ldap_object)
+
+	def commit(self):
+		self.ldap_session.commit()
+
+	def rollback(self):
+		self.ldap_session.rollback()
+
+def make_modelobj(obj, model):
+	if obj is None:
+		return None
+	if not hasattr(obj, 'model'):
+		obj.model = model()
+		obj.model.ldap_object = obj
+	if not isinstance(obj.model, model):
+		return None
+	return obj.model
+
+class ModelQuery:
+	def __init__(self, model):
+		self.model = model
+
+	def get(self, dn):
+		session = self.model.ldap_mapper.session.ldap_session
+		return make_modelobj(session.get(dn, self.model.ldap_search_filter), self.model)
+
+	def all(self):
+		session = self.model.ldap_mapper.session.ldap_session
+		objs = session.search(self.model.ldap_search_base, self.model.ldap_search_filter)
+		# TODO: check cached objects for non-committed objects
+		objs = [make_modelobj(obj, self.model) for obj in objs]
+		return [obj for obj in objs if obj is not None]
+
+	def filter_by(self, dn):
+		pass # TODO
+
+class ModelQueryWrapper:
+	def __get__(self, obj, objtype=None):
+		return ModelQuery(objtype)
+
+class Model:
+	# Overwritten by mapper
+	ldap_mapper = None
+
+	# Overwritten by models
+	ldap_search_base = None
+	ldap_search_filter = None
+	ldap_dn_base = None
+	ldap_dn_attribute = None
+
+	query = ModelQueryWrapper()
+
+	def __init__(self, **kwargs):
+		self.ldap_object = base.Object()
+		for key, value, in kwargs.items():
+			if not hasattr(self, key):
+				raise Exception()
+			setattr(self, key, value)
+
+	@property
+	def dn(self):
+		if self.ldap_object.state.dn is not None:
+			return self.ldap_object.state.dn
+		if self.ldap_dn_base is None or self.ldap_dn_attribute is None:
+			return None
+		values = self.ldap_object.getattr(self.ldap_dn_attribute)
+		if not values:
+			return None
+		return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(values[0]), self.ldap_dn_base)
+
+	def __repr__(self):
+		cls_name = '%s.%s'%(type(self).__module__, type(self).__name__)
+		if self.dn is not None:
+			return '<%s %s>'%(cls_name, self.dn)
+		return '<%s>'%cls_name
+
+class SetView(MutableSet):
+	def __init__(self, getitems, additem, delitem, encode=None, decode=None):
+		self.__getitems = getitems
+		self.__additem = additem
+		self.__delitem = delitem
+		self.__encode = encode or (lambda x: x)
+		self.__decode = decode or (lambda x: x)
+
+	def __repr__(self):
+		return repr(set(self))
+
+	def __contains__(self, value):
+		return value is not None and self.__encode(value) in self.__getitems()
+
+	def __iter__(self):
+		return iter(filter(lambda obj: obj is not None, map(self.__decode, self.__getitems())))
+
+	def __len__(self):
+		return len(set(self))
+
+	def add(self, value):
+		if value not in self:
+			self.__additem(self.__encode(value))
+
+	def discard(self, value):
+		self.__delitem(self.__encode(value))
+
+	def update(self, values):
+		for value in values:
+			self.add(value)
+
+class Attribute:
+	def __init__(self, name, multi=False, encode=None, decode=None, aliases=None):
+		self.name = name
+		self.multi = multi
+		self.encode = encode or (lambda x: x)
+		self.decode = decode or (lambda x: x)
+		self.aliases = [name] + (aliases or [])
+
+	def additem(self, obj, value):
+		for name in self.aliases:
+			obj.ldap_object.attradd(name, value)
+
+	def delitem(self, obj, value):
+		for name in self.aliases:
+			obj.ldap_object.attrdel(name, value)
+
+	def __get__(self, obj, objtype=None):
+		if obj is None:
+			return self
+		if self.multi:
+			return SetView(getitems=lambda: obj.ldap_object.getattr(self.name),
+			               additem=lambda value: self.additem(obj, value),
+			               delitem=lambda value: self.delitem(obj, value),
+			               encode=self.encode, decode=self.decode)
+		return self.decode((obj.ldap_object.getattr(self.name) or [None])[0])
+
+	def __set__(self, obj, values):
+		if not self.multi:
+			values = [values]
+		values = [self.encode(value) for value in values]
+		for name in self.aliases:
+			obj.ldap_object.setattr(name, values)
-- 
GitLab