Skip to content
Snippets Groups Projects
Commit 9aaf5e60 authored by Julian's avatar Julian
Browse files

Made low-level search functions fully honor local objects

parent 3a3594db
No related branches found
No related tags found
No related merge requests found
from copy import deepcopy from copy import deepcopy
from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
from ldap3.utils.conv import escape_filter_chars
def encode_filter(params):
return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params]))
def match_dn(dn, base):
return dn.endswith(base) # Probably good enougth for all valid dns
class LDAPCommitError(Exception): class LDAPCommitError(Exception):
pass pass
...@@ -164,13 +171,14 @@ class Session: ...@@ -164,13 +171,14 @@ class Session:
self.state = self.committed_state.copy() self.state = self.committed_state.copy()
self.changes.clear() self.changes.clear()
def get(self, dn, search_filter): def get(self, dn, filter_params):
if dn in self.state.objects: if dn in self.state.objects:
return self.state.objects[dn] obj = self.state.objects[dn]
return obj if obj.matches(filter_params) else None
if dn in self.state.deleted_objects: if dn in self.state.deleted_objects:
return None return None
conn = self.get_connection() conn = self.get_connection()
conn.search(dn, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) conn.search(dn, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
if not conn.response: if not conn.response:
return None return None
assert len(conn.response) == 1 assert len(conn.response) == 1
...@@ -182,24 +190,31 @@ class Session: ...@@ -182,24 +190,31 @@ class Session:
self.state.ref(obj, attr, values) self.state.ref(obj, attr, values)
return obj return obj
def search(self, search_base, search_filter): def filter_local(self, search_base, filter_params):
if not filter_params:
matches = self.state.objects.values()
else:
submatches = [self.state.references.get((attr, value), set()) for attr, value in filter_params]
matches = submatches.pop(0)
while submatches:
matches = matches.intersection(submatches.pop(0))
return [obj for obj in matches if match_dn(obj.state.dn, search_base)]
def filter(self, search_base, filter_params):
conn = self.get_connection() conn = self.get_connection()
conn.search(search_base, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
res = [] res = []
for response in conn.response: for response in conn.response:
dn = response['dn'] dn = response['dn']
if dn in self.state.objects: if dn in self.state.objects or dn in self.state.deleted_objects:
res.append(self.state.objects[dn])
elif dn in self.state.deleted_objects:
continue continue
else: obj = Object(self, response)
obj = Object(self, response) self.state.objects[dn] = obj
self.state.objects[dn] = obj self.committed_state.objects[dn] = obj
self.committed_state.objects[dn] = obj for attr, values in obj.state.attributes.items():
for attr, values in obj.state.attributes.items(): self.state.ref(obj, attr, values)
self.state.ref(obj, attr, values) res.append(obj)
res.append(obj) return res + self.filter_local(search_base, filter_params)
return res
class Object: class Object:
def __init__(self, session=None, response=None): def __init__(self, session=None, response=None):
...@@ -211,6 +226,14 @@ class Object: ...@@ -211,6 +226,14 @@ class Object:
self.committed_state = ObjectState(session, attrs, response['dn']) self.committed_state = ObjectState(session, attrs, response['dn'])
self.state = self.committed_state.copy() self.state = self.committed_state.copy()
@property
def dn(self):
return self.state.dn
@property
def session(self):
return self.state.session
def getattr(self, name): def getattr(self, name):
return self.state.attributes.get(name, []) return self.state.attributes.get(name, [])
...@@ -234,3 +257,9 @@ class Object: ...@@ -234,3 +257,9 @@ class Object:
if self.state.session: if self.state.session:
oper.apply_session(self.state.session.state) oper.apply_session(self.state.session.state)
self.state.session.changes.append(oper) self.state.session.changes.append(oper)
def match(self, filter_params):
for attr, value in filter_params:
if value not in self.getattr(attr):
return False
return True
...@@ -30,23 +30,32 @@ def make_modelobj(obj, model): ...@@ -30,23 +30,32 @@ def make_modelobj(obj, model):
return None return None
return obj.model return obj.model
def make_modelobjs(objs, model):
modelobjs = []
for obj in objs:
modelobj = make_modelobj(obj, model)
if modelobj is not None:
modelobjs.append(modelobj)
return modelobjs
class ModelQuery: class ModelQuery:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
def get(self, dn): def get(self, dn):
session = self.model.ldap_mapper.session.ldap_session session = self.model.ldap_mapper.session.ldap_session
return make_modelobj(session.get(dn, self.model.ldap_search_filter), self.model) return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model)
def all(self): def all(self):
session = self.model.ldap_mapper.session.ldap_session session = self.model.ldap_mapper.session.ldap_session
objs = session.search(self.model.ldap_search_base, self.model.ldap_search_filter) objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params)
# TODO: check cached objects for non-committed objects return make_modelobjs(objs, self.model)
objs = [make_modelobj(obj, self.model) for obj in objs]
return [obj for obj in objs if obj is not None]
def filter_by(self, dn): def filter_by(self, **kwargs):
pass # TODO filter_params = self.model.ldap_filter_params + list(kwargs.items())
session = self.model.ldap_mapper.session.ldap_session
objs = session.filter(self.model.ldap_search_base, filter_params)
return make_modelobjs(objs, self.model)
class ModelQueryWrapper: class ModelQueryWrapper:
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
...@@ -58,7 +67,7 @@ class Model: ...@@ -58,7 +67,7 @@ class Model:
# Overwritten by models # Overwritten by models
ldap_search_base = None ldap_search_base = None
ldap_search_filter = None ldap_filter_params = None
ldap_dn_base = None ldap_dn_base = None
ldap_dn_attribute = None ldap_dn_attribute = None
...@@ -73,8 +82,8 @@ class Model: ...@@ -73,8 +82,8 @@ class Model:
@property @property
def dn(self): def dn(self):
if self.ldap_object.state.dn is not None: if self.ldap_object.dn is not None:
return self.ldap_object.state.dn return self.ldap_object.dn
if self.ldap_dn_base is None or self.ldap_dn_attribute is None: if self.ldap_dn_base is None or self.ldap_dn_attribute is None:
return None return None
values = self.ldap_object.getattr(self.ldap_dn_attribute) values = self.ldap_object.getattr(self.ldap_dn_attribute)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment