From 72e1d4226f7b795c719488d5e91fa4fdcecf172f Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Tue, 23 Feb 2021 21:43:05 +0100 Subject: [PATCH] Implemented sqlalchemy-like query class --- ldapalchemy/dbutils.py | 2 +- ldapalchemy/model.py | 63 +++++++++++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/ldapalchemy/dbutils.py b/ldapalchemy/dbutils.py index a9f399a..009cd79 100644 --- a/ldapalchemy/dbutils.py +++ b/ldapalchemy/dbutils.py @@ -86,7 +86,7 @@ class DBBackreferenceSet(MutableSet): def __get(self): if self.__mapcls is None: - return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}) + return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all() return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)} def __repr__(self): diff --git a/ldapalchemy/model.py b/ldapalchemy/model.py index 0e5484b..ddecd7f 100644 --- a/ldapalchemy/model.py +++ b/ldapalchemy/model.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + try: # Added in v2.5 from ldap3.utils.dn import escape_rdn @@ -54,34 +56,63 @@ def make_modelobjs(objs, model): modelobjs.append(modelobj) return modelobjs -class ModelQuery: - def __init__(self, model): - self.model = model +class Query(Sequence): + def __init__(self, model, filter_params=None): + self.__model = model + self.__filter_params = list(model.ldap_filter_params) + (filter_params or []) + + @property + def __session(self): + return self.__model.ldap_mapper.session.ldap_session def get(self, dn): - session = self.model.ldap_mapper.session.ldap_session - return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model) + return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model) def all(self): - session = self.model.ldap_mapper.session.ldap_session - objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params) - return make_modelobjs(objs, self.model) + objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params) + objs = sorted(objs, key=lambda obj: obj.dn) + return make_modelobjs(objs, self.__model) + + def first(self): + return (self.all() or [None])[0] + + def one(self): + modelobjs = self.all() + if len(modelobjs) != 1: + raise Exception() + return modelobjs[0] + + def one_or_none(self): + modelobjs = self.all() + if len(modelobjs) > 1: + raise Exception() + return (modelobjs or [None])[0] + + def __contains__(self, value): + return value in self.all() + + def __iter__(self): + return iter(self.all()) + + def __len__(self): + return len(self.all()) + + def __getitem__(self, index): + return self.all()[index] def filter_by(self, **kwargs): - filter_params = list(self.model.ldap_filter_params) - filter_params += [(getattr(self.model, attr).name, value) for attr, value in kwargs.items()] - session = self.model.ldap_mapper.session.ldap_session - objs = session.filter(self.model.ldap_search_base, filter_params) - return make_modelobjs(objs, self.model) + filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()] + return Query(self.__model, self.__filter_params + filter_params) -class ModelQueryWrapper: +class QueryWrapper: def __get__(self, obj, objtype=None): - return ModelQuery(objtype) + return objtype.query_class(objtype) class Model: # Overwritten by mapper ldap_mapper = None - query = ModelQueryWrapper() + query_class = Query + query = QueryWrapper() ldap_add_hooks = () # Overwritten by models -- GitLab