Skip to content
Snippets Groups Projects
Commit 9aaf5e60 authored by Julian's avatar Julian
Browse files

Made low-level search functions fully honor local objects

parent 3a3594db
Branches
No related tags found
No related merge requests found
from copy import deepcopy
from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
from ldap3.utils.conv import escape_filter_chars
def encode_filter(params):
return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params]))
def match_dn(dn, base):
return dn.endswith(base) # Probably good enougth for all valid dns
class LDAPCommitError(Exception):
pass
......@@ -164,13 +171,14 @@ class Session:
self.state = self.committed_state.copy()
self.changes.clear()
def get(self, dn, search_filter):
def get(self, dn, filter_params):
if dn in self.state.objects:
return self.state.objects[dn]
obj = self.state.objects[dn]
return obj if obj.matches(filter_params) else None
if dn in self.state.deleted_objects:
return None
conn = self.get_connection()
conn.search(dn, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
conn.search(dn, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
if not conn.response:
return None
assert len(conn.response) == 1
......@@ -182,24 +190,31 @@ class Session:
self.state.ref(obj, attr, values)
return obj
def search(self, search_base, search_filter):
def filter_local(self, search_base, filter_params):
if not filter_params:
matches = self.state.objects.values()
else:
submatches = [self.state.references.get((attr, value), set()) for attr, value in filter_params]
matches = submatches.pop(0)
while submatches:
matches = matches.intersection(submatches.pop(0))
return [obj for obj in matches if match_dn(obj.state.dn, search_base)]
def filter(self, search_base, filter_params):
conn = self.get_connection()
conn.search(search_base, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
res = []
for response in conn.response:
dn = response['dn']
if dn in self.state.objects:
res.append(self.state.objects[dn])
elif dn in self.state.deleted_objects:
if dn in self.state.objects or dn in self.state.deleted_objects:
continue
else:
obj = Object(self, response)
self.state.objects[dn] = obj
self.committed_state.objects[dn] = obj
for attr, values in obj.state.attributes.items():
self.state.ref(obj, attr, values)
res.append(obj)
return res
obj = Object(self, response)
self.state.objects[dn] = obj
self.committed_state.objects[dn] = obj
for attr, values in obj.state.attributes.items():
self.state.ref(obj, attr, values)
res.append(obj)
return res + self.filter_local(search_base, filter_params)
class Object:
def __init__(self, session=None, response=None):
......@@ -211,6 +226,14 @@ class Object:
self.committed_state = ObjectState(session, attrs, response['dn'])
self.state = self.committed_state.copy()
@property
def dn(self):
return self.state.dn
@property
def session(self):
return self.state.session
def getattr(self, name):
return self.state.attributes.get(name, [])
......@@ -234,3 +257,9 @@ class Object:
if self.state.session:
oper.apply_session(self.state.session.state)
self.state.session.changes.append(oper)
def match(self, filter_params):
for attr, value in filter_params:
if value not in self.getattr(attr):
return False
return True
......@@ -30,23 +30,32 @@ def make_modelobj(obj, model):
return None
return obj.model
def make_modelobjs(objs, model):
modelobjs = []
for obj in objs:
modelobj = make_modelobj(obj, model)
if modelobj is not None:
modelobjs.append(modelobj)
return modelobjs
class ModelQuery:
def __init__(self, model):
self.model = model
def get(self, dn):
session = self.model.ldap_mapper.session.ldap_session
return make_modelobj(session.get(dn, self.model.ldap_search_filter), self.model)
return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)
def all(self):
session = self.model.ldap_mapper.session.ldap_session
objs = session.search(self.model.ldap_search_base, self.model.ldap_search_filter)
# TODO: check cached objects for non-committed objects
objs = [make_modelobj(obj, self.model) for obj in objs]
return [obj for obj in objs if obj is not None]
objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params)
return make_modelobjs(objs, self.model)
def filter_by(self, dn):
pass # TODO
def filter_by(self, **kwargs):
filter_params = self.model.ldap_filter_params + list(kwargs.items())
session = self.model.ldap_mapper.session.ldap_session
objs = session.filter(self.model.ldap_search_base, filter_params)
return make_modelobjs(objs, self.model)
class ModelQueryWrapper:
def __get__(self, obj, objtype=None):
......@@ -58,7 +67,7 @@ class Model:
# Overwritten by models
ldap_search_base = None
ldap_search_filter = None
ldap_filter_params = None
ldap_dn_base = None
ldap_dn_attribute = None
......@@ -73,8 +82,8 @@ class Model:
@property
def dn(self):
if self.ldap_object.state.dn is not None:
return self.ldap_object.state.dn
if self.ldap_object.dn is not None:
return self.ldap_object.dn
if self.ldap_dn_base is None or self.ldap_dn_attribute is None:
return None
values = self.ldap_object.getattr(self.ldap_dn_attribute)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment