Skip to content
Snippets Groups Projects
Commit 72e1d422 authored by Julian Rother's avatar Julian Rother
Browse files

Implemented sqlalchemy-like query class

parent db135ec7
No related branches found
No related tags found
No related merge requests found
...@@ -86,7 +86,7 @@ class DBBackreferenceSet(MutableSet): ...@@ -86,7 +86,7 @@ class DBBackreferenceSet(MutableSet):
def __get(self): def __get(self):
if self.__mapcls is None: if self.__mapcls is None:
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}) return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all()
return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)} return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
def __repr__(self): def __repr__(self):
......
from collections.abc import Sequence
try: try:
# Added in v2.5 # Added in v2.5
from ldap3.utils.dn import escape_rdn from ldap3.utils.dn import escape_rdn
...@@ -54,34 +56,63 @@ def make_modelobjs(objs, model): ...@@ -54,34 +56,63 @@ def make_modelobjs(objs, model):
modelobjs.append(modelobj) modelobjs.append(modelobj)
return modelobjs return modelobjs
class ModelQuery: class Query(Sequence):
def __init__(self, model): def __init__(self, model, filter_params=None):
self.model = model self.__model = model
self.__filter_params = list(model.ldap_filter_params) + (filter_params or [])
@property
def __session(self):
return self.__model.ldap_mapper.session.ldap_session
def get(self, dn): def get(self, dn):
session = self.model.ldap_mapper.session.ldap_session return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model)
return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)
def all(self): def all(self):
session = self.model.ldap_mapper.session.ldap_session objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params)
objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params) objs = sorted(objs, key=lambda obj: obj.dn)
return make_modelobjs(objs, self.model) return make_modelobjs(objs, self.__model)
def first(self):
return (self.all() or [None])[0]
def one(self):
modelobjs = self.all()
if len(modelobjs) != 1:
raise Exception()
return modelobjs[0]
def one_or_none(self):
modelobjs = self.all()
if len(modelobjs) > 1:
raise Exception()
return (modelobjs or [None])[0]
def __contains__(self, value):
return value in self.all()
def __iter__(self):
return iter(self.all())
def __len__(self):
return len(self.all())
def __getitem__(self, index):
return self.all()[index]
def filter_by(self, **kwargs): def filter_by(self, **kwargs):
filter_params = list(self.model.ldap_filter_params) filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()]
filter_params += [(getattr(self.model, attr).name, value) for attr, value in kwargs.items()] return Query(self.__model, self.__filter_params + filter_params)
session = self.model.ldap_mapper.session.ldap_session
objs = session.filter(self.model.ldap_search_base, filter_params)
return make_modelobjs(objs, self.model)
class ModelQueryWrapper: class QueryWrapper:
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
return ModelQuery(objtype) return objtype.query_class(objtype)
class Model: class Model:
# Overwritten by mapper # Overwritten by mapper
ldap_mapper = None ldap_mapper = None
query = ModelQueryWrapper() query_class = Query
query = QueryWrapper()
ldap_add_hooks = () ldap_add_hooks = ()
# Overwritten by models # Overwritten by models
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment