from collections.abc import MutableSet from .model import make_modelobj, make_modelobjs class UnboundObjectError(Exception): pass class RelationshipSet(MutableSet): def __init__(self, ldap_object, name, model, destmodel): self.__ldap_object = ldap_object self.__name = name self.__model = model self.__destmodel = destmodel def __modify_check(self, value): if self.__ldap_object.session is None: raise UnboundObjectError() if not isinstance(value, self.__destmodel): raise TypeError() def __repr__(self): return repr(set(self)) def __contains__(self, value): if value is None or not isinstance(value, self.__destmodel): return False return value.ldap_object.dn in self.__ldap_object.getattr(self.__name) def __iter__(self): def get(dn): return make_modelobj(self.__ldap_object.session.get(dn, self.__model.ldap_filter_params), self.__destmodel) dns = set(self.__ldap_object.getattr(self.__name)) return iter(filter(lambda obj: obj is not None, map(get, dns))) def __len__(self): return len(set(self)) def add(self, value): self.__modify_check(value) if value.ldap_object.session is None: self.__ldap_object.session.add(value.ldap_object) assert value.ldap_object.session == self.__ldap_object.session self.__ldap_object.attradd(self.__name, value.dn) def discard(self, value): self.__modify_check(value) self.__ldap_object.attrdel(self.__name, value.dn) class Relationship: def __init__(self, name, destmodel, backref=None): self.name = name self.destmodel = destmodel self.backref = backref def __set_name__(self, cls, name): if self.backref is not None: setattr(self.destmodel, self.backref, Backreference(self.name, cls)) def __get__(self, obj, objtype=None): if obj is None: return self return RelationshipSet(obj, self.name, type(obj), self.destmodel) def __set__(self, obj, values): tmp = self.__get__(obj) tmp.clear() for value in values: tmp.add(value) class BackreferenceSet(MutableSet): def __init__(self, ldap_object, name, model, srcmodel): self.__ldap_object = ldap_object self.__name = name self.__model = model self.__srcmodel = srcmodel def __modify_check(self, value): if self.__ldap_object.session is None: raise UnboundObjectError() if not isinstance(value, self.__srcmodel): raise TypeError() def __get(self): if self.__ldap_object.session is None: return set() filter_params = self.__srcmodel.filter_params + [(self.__name, self.__ldap_object.dn)] objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params) return set(make_modelobjs(objs, self.__srcmodel)) def __repr__(self): return repr(self.__get()) def __contains__(self, value): return value in self.__get() def __iter__(self): return iter(self.__get()) def __len__(self): return len(self.__get()) def add(self, value): self.__modify_check(value) if value.ldap_object.session is None: self.__ldap_object.session.add(value.ldap_object) assert value.ldap_object.session == self.__ldap_object.session if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name): value.ldap_object.attradd(self.__name, self.__ldap_object.dn) def discard(self, value): self.__modify_check(value) value.ldap_object.attrdel(self.__name, self.__ldap_object.dn) class Backreference: def __init__(self, name, srcmodel): self.name = name self.srcmodel = srcmodel def __get__(self, obj, objtype=None): if obj is None: return self return BackreferenceSet(obj, self.name, type(obj), self.srcmodel) def __set__(self, obj, values): tmp = self.__get__(obj) tmp.clear() for value in values: tmp.add(value)