From 9aaf5e60bbd25d5014c0411270f813eba5e5300c Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Mon, 22 Feb 2021 11:39:05 +0100 Subject: [PATCH] Made low-level search functions fully honor local objects --- ldap3_mapper_new/base.py | 61 +++++++++++++++++++++++++++++---------- ldap3_mapper_new/model.py | 29 ++++++++++++------- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py index 0549f99d..ff474a55 100644 --- a/ldap3_mapper_new/base.py +++ b/ldap3_mapper_new/base.py @@ -1,6 +1,13 @@ from copy import deepcopy from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES +from ldap3.utils.conv import escape_filter_chars + +def encode_filter(params): + return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params])) + +def match_dn(dn, base): + return dn.endswith(base) # Probably good enougth for all valid dns class LDAPCommitError(Exception): pass @@ -164,13 +171,14 @@ class Session: self.state = self.committed_state.copy() self.changes.clear() - def get(self, dn, search_filter): + def get(self, dn, filter_params): if dn in self.state.objects: - return self.state.objects[dn] + obj = self.state.objects[dn] + return obj if obj.matches(filter_params) else None if dn in self.state.deleted_objects: return None conn = self.get_connection() - conn.search(dn, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + conn.search(dn, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) if not conn.response: return None assert len(conn.response) == 1 @@ -182,24 +190,31 @@ class Session: self.state.ref(obj, attr, values) return obj - def search(self, search_base, search_filter): + def filter_local(self, search_base, filter_params): + if not filter_params: + matches = self.state.objects.values() + else: + submatches = [self.state.references.get((attr, value), set()) for attr, value in filter_params] + matches = submatches.pop(0) + while submatches: + matches = matches.intersection(submatches.pop(0)) + return [obj for obj in matches if match_dn(obj.state.dn, search_base)] + + def filter(self, search_base, filter_params): conn = self.get_connection() - conn.search(search_base, search_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) res = [] for response in conn.response: dn = response['dn'] - if dn in self.state.objects: - res.append(self.state.objects[dn]) - elif dn in self.state.deleted_objects: + if dn in self.state.objects or dn in self.state.deleted_objects: continue - else: - obj = Object(self, response) - self.state.objects[dn] = obj - self.committed_state.objects[dn] = obj - for attr, values in obj.state.attributes.items(): - self.state.ref(obj, attr, values) - res.append(obj) - return res + obj = Object(self, response) + self.state.objects[dn] = obj + self.committed_state.objects[dn] = obj + for attr, values in obj.state.attributes.items(): + self.state.ref(obj, attr, values) + res.append(obj) + return res + self.filter_local(search_base, filter_params) class Object: def __init__(self, session=None, response=None): @@ -211,6 +226,14 @@ class Object: self.committed_state = ObjectState(session, attrs, response['dn']) self.state = self.committed_state.copy() + @property + def dn(self): + return self.state.dn + + @property + def session(self): + return self.state.session + def getattr(self, name): return self.state.attributes.get(name, []) @@ -234,3 +257,9 @@ class Object: if self.state.session: oper.apply_session(self.state.session.state) self.state.session.changes.append(oper) + + def match(self, filter_params): + for attr, value in filter_params: + if value not in self.getattr(attr): + return False + return True diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py index 29c33e96..cf681be8 100644 --- a/ldap3_mapper_new/model.py +++ b/ldap3_mapper_new/model.py @@ -30,23 +30,32 @@ def make_modelobj(obj, model): return None return obj.model +def make_modelobjs(objs, model): + modelobjs = [] + for obj in objs: + modelobj = make_modelobj(obj, model) + if modelobj is not None: + modelobjs.append(modelobj) + return modelobjs + class ModelQuery: def __init__(self, model): self.model = model def get(self, dn): session = self.model.ldap_mapper.session.ldap_session - return make_modelobj(session.get(dn, self.model.ldap_search_filter), self.model) + return make_modelobj(session.get(dn, self.model.ldap_filter_params), self.model) def all(self): session = self.model.ldap_mapper.session.ldap_session - objs = session.search(self.model.ldap_search_base, self.model.ldap_search_filter) - # TODO: check cached objects for non-committed objects - objs = [make_modelobj(obj, self.model) for obj in objs] - return [obj for obj in objs if obj is not None] + objs = session.filter(self.model.ldap_search_base, self.model.ldap_filter_params) + return make_modelobjs(objs, self.model) - def filter_by(self, dn): - pass # TODO + def filter_by(self, **kwargs): + filter_params = self.model.ldap_filter_params + list(kwargs.items()) + session = self.model.ldap_mapper.session.ldap_session + objs = session.filter(self.model.ldap_search_base, filter_params) + return make_modelobjs(objs, self.model) class ModelQueryWrapper: def __get__(self, obj, objtype=None): @@ -58,7 +67,7 @@ class Model: # Overwritten by models ldap_search_base = None - ldap_search_filter = None + ldap_filter_params = None ldap_dn_base = None ldap_dn_attribute = None @@ -73,8 +82,8 @@ class Model: @property def dn(self): - if self.ldap_object.state.dn is not None: - return self.ldap_object.state.dn + if self.ldap_object.dn is not None: + return self.ldap_object.dn if self.ldap_dn_base is None or self.ldap_dn_attribute is None: return None values = self.ldap_object.getattr(self.ldap_dn_attribute) -- GitLab