Skip to content
Snippets Groups Projects
Commit 72e1d422 authored by Julian Rother's avatar Julian Rother
Browse files

Implemented sqlalchemy-like query class

parent db135ec7
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ class DBBackreferenceSet(MutableSet):
def __get(self):
if self.__mapcls is None:
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn})
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all()
return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
def __repr__(self):
......
from collections.abc import Sequence
try:
# Added in v2.5
from ldap3.utils.dn import escape_rdn
......@@ -54,34 +56,63 @@ def make_modelobjs(objs, model):
modelobjs.append(modelobj)
return modelobjs
class ModelQuery:
def __init__(self, model):
self.model = model
class Query(Sequence):
def __init__(self, model, filter_params=None):
self.__model = model
self.__filter_params = list(model.ldap_filter_params) + (filter_params or [])
@property
def __session(self):
return self.__model.ldap_mapper.session.ldap_session
def get(self, dn):
session = self.model.ldap_mapper.session.ldap_session
return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)
return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model)
def all(self):
session = self.model.ldap_mapper.session.ldap_session
objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params)
return make_modelobjs(objs, self.model)
objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params)
objs = sorted(objs, key=lambda obj: obj.dn)
return make_modelobjs(objs, self.__model)
def first(self):
return (self.all() or [None])[0]
def one(self):
modelobjs = self.all()
if len(modelobjs) != 1:
raise Exception()
return modelobjs[0]
def one_or_none(self):
modelobjs = self.all()
if len(modelobjs) > 1:
raise Exception()
return (modelobjs or [None])[0]
def __contains__(self, value):
return value in self.all()
def __iter__(self):
return iter(self.all())
def __len__(self):
return len(self.all())
def __getitem__(self, index):
return self.all()[index]
def filter_by(self, **kwargs):
filter_params = list(self.model.ldap_filter_params)
filter_params += [(getattr(self.model, attr).name, value) for attr, value in kwargs.items()]
session = self.model.ldap_mapper.session.ldap_session
objs = session.filter(self.model.ldap_search_base, filter_params)
return make_modelobjs(objs, self.model)
filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()]
return Query(self.__model, self.__filter_params + filter_params)
class ModelQueryWrapper:
class QueryWrapper:
def __get__(self, obj, objtype=None):
return ModelQuery(objtype)
return objtype.query_class(objtype)
class Model:
# Overwritten by mapper
ldap_mapper = None
query = ModelQueryWrapper()
query_class = Query
query = QueryWrapper()
ldap_add_hooks = ()
# Overwritten by models
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment