From 72e1d4226f7b795c719488d5e91fa4fdcecf172f Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Tue, 23 Feb 2021 21:43:05 +0100
Subject: [PATCH] Implemented sqlalchemy-like query class

---
 ldapalchemy/dbutils.py |  2 +-
 ldapalchemy/model.py   | 63 +++++++++++++++++++++++++++++++-----------
 2 files changed, 48 insertions(+), 17 deletions(-)

diff --git a/ldapalchemy/dbutils.py b/ldapalchemy/dbutils.py
index a9f399a..009cd79 100644
--- a/ldapalchemy/dbutils.py
+++ b/ldapalchemy/dbutils.py
@@ -86,7 +86,7 @@ class DBBackreferenceSet(MutableSet):
 
 	def __get(self):
 		if self.__mapcls is None:
-			return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn})
+			return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all()
 		return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
 
 	def __repr__(self):
diff --git a/ldapalchemy/model.py b/ldapalchemy/model.py
index 0e5484b..ddecd7f 100644
--- a/ldapalchemy/model.py
+++ b/ldapalchemy/model.py
@@ -1,3 +1,5 @@
+from collections.abc import Sequence
+
 try:
 	# Added in v2.5
 	from ldap3.utils.dn import escape_rdn
@@ -54,34 +56,63 @@ def make_modelobjs(objs, model):
 			modelobjs.append(modelobj)
 	return modelobjs
 
-class ModelQuery:
-	def __init__(self, model):
-		self.model = model
+class Query(Sequence):
+	def __init__(self, model, filter_params=None):
+		self.__model = model
+		self.__filter_params = list(model.ldap_filter_params) + (filter_params or [])
+
+	@property
+	def __session(self):
+		return self.__model.ldap_mapper.session.ldap_session
 
 	def get(self, dn):
-		session = self.model.ldap_mapper.session.ldap_session
-		return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)
+		return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model)
 
 	def all(self):
-		session = self.model.ldap_mapper.session.ldap_session
-		objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params)
-		return make_modelobjs(objs, self.model)
+		objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params)
+		objs = sorted(objs, key=lambda obj: obj.dn)
+		return make_modelobjs(objs, self.__model)
+
+	def first(self):
+		return (self.all() or [None])[0]
+
+	def one(self):
+		modelobjs = self.all()
+		if len(modelobjs) != 1:
+			raise Exception()
+		return modelobjs[0]
+
+	def one_or_none(self):
+		modelobjs = self.all()
+		if len(modelobjs) > 1:
+			raise Exception()
+		return (modelobjs or [None])[0]
+
+	def __contains__(self, value):
+		return value in self.all()
+
+	def __iter__(self):
+		return iter(self.all())
+
+	def __len__(self):
+		return len(self.all())
+
+	def __getitem__(self, index):
+		return self.all()[index]
 
 	def filter_by(self, **kwargs):
-		filter_params = list(self.model.ldap_filter_params)
-		filter_params += [(getattr(self.model, attr).name, value) for attr, value in kwargs.items()]
-		session = self.model.ldap_mapper.session.ldap_session
-		objs = session.filter(self.model.ldap_search_base, filter_params)
-		return make_modelobjs(objs, self.model)
+		filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()]
+		return Query(self.__model, self.__filter_params + filter_params)
 
-class ModelQueryWrapper:
+class QueryWrapper:
 	def __get__(self, obj, objtype=None):
-		return ModelQuery(objtype)
+		return objtype.query_class(objtype)
 
 class Model:
 	# Overwritten by mapper
 	ldap_mapper = None
-	query = ModelQueryWrapper()
+	query_class = Query
+	query = QueryWrapper()
 	ldap_add_hooks = ()
 
 	# Overwritten by models
-- 
GitLab