diff --git a/ldap3_mapper_new/base.py b/ldap3_mapper_new/base.py index ff474a556805950f625d30aef9daa6e1feb87eba..11399220cdf53b5a9b453259398d24ed8750556a 100644 --- a/ldap3_mapper_new/base.py +++ b/ldap3_mapper_new/base.py @@ -3,12 +3,18 @@ from copy import deepcopy from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES from ldap3.utils.conv import escape_filter_chars -def encode_filter(params): - return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params])) +def encode_filter(filter_params): + return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in filter_params])) def match_dn(dn, base): return dn.endswith(base) # Probably good enougth for all valid dns +def make_cache_key(search_base, filter_params): + res = [search_base] + for attr, value in sorted(filter_params): + res.append((attr, value)) + return res + class LDAPCommitError(Exception): pass @@ -125,6 +131,7 @@ class Session: self.committed_state = SessionState() self.state = SessionState() self.changes = [] + self.cached_searches = set() def add(self, obj, dn, object_classes): if self.state.objects.get(dn) == obj: @@ -190,7 +197,7 @@ class Session: self.state.ref(obj, attr, values) return obj - def filter_local(self, search_base, filter_params): + def filter(self, search_base, filter_params): if not filter_params: matches = self.state.objects.values() else: @@ -198,12 +205,12 @@ class Session: matches = submatches.pop(0) while submatches: matches = matches.intersection(submatches.pop(0)) - return [obj for obj in matches if match_dn(obj.state.dn, search_base)] - - def filter(self, search_base, filter_params): + res = [obj for obj in matches if match_dn(obj.state.dn, search_base)] + cache_key = make_cache_key(search_base, filter_params) + if cache_key in self.cached_searches: + return res conn = self.get_connection() conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) - res = [] for response in conn.response: dn = response['dn'] if dn in self.state.objects or dn in self.state.deleted_objects: @@ -214,7 +221,8 @@ class Session: for attr, values in obj.state.attributes.items(): self.state.ref(obj, attr, values) res.append(obj) - return res + self.filter_local(search_base, filter_params) + self.cached_searches.add(cache_key) + return res class Object: def __init__(self, session=None, response=None): diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py index 8981899ff41ef726a5bb981f3cc217a3b9aa7771..7f65466d15ac4c3a91196456d0569b845d35af46 100644 --- a/ldap3_mapper_new/model.py +++ b/ldap3_mapper_new/model.py @@ -21,7 +21,7 @@ class Session: self.ldap_session = base.Session(get_connection) def add(self, obj): - self.ldap_session.add(obj.ldap_object, obj.dn, obj.object_classes) + self.ldap_session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes) def delete(self, obj): self.ldap_session.delete(obj.ldap_object) @@ -80,6 +80,7 @@ class Model: # Overwritten by models ldap_search_base = None ldap_filter_params = None + ldap_object_classes = None ldap_dn_base = None ldap_dn_attribute = None diff --git a/ldap3_mapper_new/relationship.py b/ldap3_mapper_new/relationship.py new file mode 100644 index 0000000000000000000000000000000000000000..486d7276a3ea4307200097a8cbfc1a65343e4362 --- /dev/null +++ b/ldap3_mapper_new/relationship.py @@ -0,0 +1,128 @@ +from collections.abc import MutableSet + +from .model import make_modelobj, make_modelobjs + +class UnboundObjectError(Exception): + pass + +class RelationshipSet(MutableSet): + def __init__(self, ldap_object, name, model, destmodel): + self.__ldap_object = ldap_object + self.__name = name + self.__model = model + self.__destmodel = destmodel + + def __modify_check(self, value): + if self.__ldap_object.session is None: + raise UnboundObjectError() + if not isinstance(value, self.__destmodel): + raise TypeError() + + def __repr__(self): + return repr(set(self)) + + def __contains__(self, value): + if value is None or not isinstance(value, self.__destmodel): + return False + return value.ldap_object.dn in self.__ldap_object.getattr(self.__name) + + def __iter__(self): + def get(dn): + return make_modelobj(self.__ldap_object.session.get(dn, self.__model.ldap_filter_params), self.__destmodel) + dns = set(self.__ldap_object.getattr(self.__name)) + return iter(filter(lambda obj: obj is not None, map(get, dns))) + + def __len__(self): + return len(set(self)) + + def add(self, value): + self.__modify_check(value) + if value.ldap_object.session is None: + self.__ldap_object.session.add(value.ldap_object) + assert value.ldap_object.session == self.__ldap_object.session + self.__ldap_object.attradd(self.__name, value.dn) + + def discard(self, value): + self.__modify_check(value) + self.__ldap_object.attrdel(self.__name, value.dn) + +class Relationship: + def __init__(self, name, destmodel, backref=None): + self.name = name + self.destmodel = destmodel + self.backref = backref + + def __set_name__(self, cls, name): + if self.backref is not None: + setattr(self.destmodel, self.backref, Backreference(self.name, cls)) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return RelationshipSet(obj, self.name, type(obj), self.destmodel) + + def __set__(self, obj, values): + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) + +class BackreferenceSet(MutableSet): + def __init__(self, ldap_object, name, model, srcmodel): + self.__ldap_object = ldap_object + self.__name = name + self.__model = model + self.__srcmodel = srcmodel + + def __modify_check(self, value): + if self.__ldap_object.session is None: + raise UnboundObjectError() + if not isinstance(value, self.__srcmodel): + raise TypeError() + + def __get(self): + if self.__ldap_object.session is None: + return set() + filter_params = self.__srcmodel.filter_params + [(self.__name, self.__ldap_object.dn)] + objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params) + return set(make_modelobjs(objs, self.__srcmodel)) + + def __repr__(self): + return repr(self.__get()) + + def __contains__(self, value): + return value in self.__get() + + def __iter__(self): + return iter(self.__get()) + + def __len__(self): + return len(self.__get()) + + def add(self, value): + self.__modify_check(value) + if value.ldap_object.session is None: + self.__ldap_object.session.add(value.ldap_object) + assert value.ldap_object.session == self.__ldap_object.session + if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name): + value.ldap_object.attradd(self.__name, self.__ldap_object.dn) + + def discard(self, value): + self.__modify_check(value) + value.ldap_object.attrdel(self.__name, self.__ldap_object.dn) + +class Backreference: + def __init__(self, name, srcmodel): + self.name = name + self.srcmodel = srcmodel + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return BackreferenceSet(obj, self.name, type(obj), self.srcmodel) + + def __set__(self, obj, values): + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value)