Skip to content
Snippets Groups Projects
Commit 50a326d6 authored by Julian's avatar Julian
Browse files

Implemented relationships

parent e319a8bd
No related branches found
No related tags found
1 merge request!18LDAP Object Mapper
......@@ -3,12 +3,18 @@ from copy import deepcopy
from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
from ldap3.utils.conv import escape_filter_chars
def encode_filter(params):
return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in params]))
def encode_filter(filter_params):
return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in filter_params]))
def match_dn(dn, base):
return dn.endswith(base) # Probably good enougth for all valid dns
def make_cache_key(search_base, filter_params):
res = [search_base]
for attr, value in sorted(filter_params):
res.append((attr, value))
return res
class LDAPCommitError(Exception):
pass
......@@ -125,6 +131,7 @@ class Session:
self.committed_state = SessionState()
self.state = SessionState()
self.changes = []
self.cached_searches = set()
def add(self, obj, dn, object_classes):
if self.state.objects.get(dn) == obj:
......@@ -190,7 +197,7 @@ class Session:
self.state.ref(obj, attr, values)
return obj
def filter_local(self, search_base, filter_params):
def filter(self, search_base, filter_params):
if not filter_params:
matches = self.state.objects.values()
else:
......@@ -198,12 +205,12 @@ class Session:
matches = submatches.pop(0)
while submatches:
matches = matches.intersection(submatches.pop(0))
return [obj for obj in matches if match_dn(obj.state.dn, search_base)]
def filter(self, search_base, filter_params):
res = [obj for obj in matches if match_dn(obj.state.dn, search_base)]
cache_key = make_cache_key(search_base, filter_params)
if cache_key in self.cached_searches:
return res
conn = self.get_connection()
conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
res = []
for response in conn.response:
dn = response['dn']
if dn in self.state.objects or dn in self.state.deleted_objects:
......@@ -214,7 +221,8 @@ class Session:
for attr, values in obj.state.attributes.items():
self.state.ref(obj, attr, values)
res.append(obj)
return res + self.filter_local(search_base, filter_params)
self.cached_searches.add(cache_key)
return res
class Object:
def __init__(self, session=None, response=None):
......
......@@ -21,7 +21,7 @@ class Session:
self.ldap_session = base.Session(get_connection)
def add(self, obj):
self.ldap_session.add(obj.ldap_object, obj.dn, obj.object_classes)
self.ldap_session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes)
def delete(self, obj):
self.ldap_session.delete(obj.ldap_object)
......@@ -80,6 +80,7 @@ class Model:
# Overwritten by models
ldap_search_base = None
ldap_filter_params = None
ldap_object_classes = None
ldap_dn_base = None
ldap_dn_attribute = None
......
from collections.abc import MutableSet
from .model import make_modelobj, make_modelobjs
class UnboundObjectError(Exception):
pass
class RelationshipSet(MutableSet):
def __init__(self, ldap_object, name, model, destmodel):
self.__ldap_object = ldap_object
self.__name = name
self.__model = model
self.__destmodel = destmodel
def __modify_check(self, value):
if self.__ldap_object.session is None:
raise UnboundObjectError()
if not isinstance(value, self.__destmodel):
raise TypeError()
def __repr__(self):
return repr(set(self))
def __contains__(self, value):
if value is None or not isinstance(value, self.__destmodel):
return False
return value.ldap_object.dn in self.__ldap_object.getattr(self.__name)
def __iter__(self):
def get(dn):
return make_modelobj(self.__ldap_object.session.get(dn, self.__model.ldap_filter_params), self.__destmodel)
dns = set(self.__ldap_object.getattr(self.__name))
return iter(filter(lambda obj: obj is not None, map(get, dns)))
def __len__(self):
return len(set(self))
def add(self, value):
self.__modify_check(value)
if value.ldap_object.session is None:
self.__ldap_object.session.add(value.ldap_object)
assert value.ldap_object.session == self.__ldap_object.session
self.__ldap_object.attradd(self.__name, value.dn)
def discard(self, value):
self.__modify_check(value)
self.__ldap_object.attrdel(self.__name, value.dn)
class Relationship:
def __init__(self, name, destmodel, backref=None):
self.name = name
self.destmodel = destmodel
self.backref = backref
def __set_name__(self, cls, name):
if self.backref is not None:
setattr(self.destmodel, self.backref, Backreference(self.name, cls))
def __get__(self, obj, objtype=None):
if obj is None:
return self
return RelationshipSet(obj, self.name, type(obj), self.destmodel)
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
class BackreferenceSet(MutableSet):
def __init__(self, ldap_object, name, model, srcmodel):
self.__ldap_object = ldap_object
self.__name = name
self.__model = model
self.__srcmodel = srcmodel
def __modify_check(self, value):
if self.__ldap_object.session is None:
raise UnboundObjectError()
if not isinstance(value, self.__srcmodel):
raise TypeError()
def __get(self):
if self.__ldap_object.session is None:
return set()
filter_params = self.__srcmodel.filter_params + [(self.__name, self.__ldap_object.dn)]
objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params)
return set(make_modelobjs(objs, self.__srcmodel))
def __repr__(self):
return repr(self.__get())
def __contains__(self, value):
return value in self.__get()
def __iter__(self):
return iter(self.__get())
def __len__(self):
return len(self.__get())
def add(self, value):
self.__modify_check(value)
if value.ldap_object.session is None:
self.__ldap_object.session.add(value.ldap_object)
assert value.ldap_object.session == self.__ldap_object.session
if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name):
value.ldap_object.attradd(self.__name, self.__ldap_object.dn)
def discard(self, value):
self.__modify_check(value)
value.ldap_object.attrdel(self.__name, self.__ldap_object.dn)
class Backreference:
def __init__(self, name, srcmodel):
self.name = name
self.srcmodel = srcmodel
def __get__(self, obj, objtype=None):
if obj is None:
return self
return BackreferenceSet(obj, self.name, type(obj), self.srcmodel)
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment