Skip to content
Snippets Groups Projects
Commit db135ec7 authored by Julian Rother's avatar Julian Rother
Browse files

Extended DBRelationship/DBBackreference to support 1-n-relationships

parent 6a0b85f2
No related branches found
No related tags found
No related merge requests found
...@@ -43,7 +43,7 @@ class DBRelationshipSet(MutableSet): ...@@ -43,7 +43,7 @@ class DBRelationshipSet(MutableSet):
rel.remove(mapobj) rel.remove(mapobj)
class DBRelationship: class DBRelationship:
def __init__(self, relattr, ldapcls, mapcls, backref=None, backattr=None): def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None):
self.relattr = relattr self.relattr = relattr
self.ldapcls = ldapcls self.ldapcls = ldapcls
self.mapcls = mapcls self.mapcls = mapcls
...@@ -52,19 +52,25 @@ class DBRelationship: ...@@ -52,19 +52,25 @@ class DBRelationship:
def __set_name__(self, cls, name): def __set_name__(self, cls, name):
if self.backref: if self.backref:
assert self.backattr
setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr)) setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr))
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
if obj is None: if obj is None:
return self return self
if self.mapcls is not None:
return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls) return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
return self.ldapcls.query.get(getattr(obj, self.relattr))
def __set__(self, obj, values): def __set__(self, obj, values):
if self.mapcls is not None:
tmp = self.__get__(obj) tmp = self.__get__(obj)
tmp.clear() tmp.clear()
for value in values: for value in values:
tmp.add(value) tmp.add(value)
else:
if not isinstance(values, self.ldapcls):
raise TypeError()
setattr(obj, self.relattr, values.ldap_object.dn)
class DBBackreferenceSet(MutableSet): class DBBackreferenceSet(MutableSet):
def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr): def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr):
...@@ -79,6 +85,8 @@ class DBBackreferenceSet(MutableSet): ...@@ -79,6 +85,8 @@ class DBBackreferenceSet(MutableSet):
return self.__ldapobj.ldap_object.dn return self.__ldapobj.ldap_object.dn
def __get(self): def __get(self):
if self.__mapcls is None:
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn})
return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)} return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
def __repr__(self): def __repr__(self):
...@@ -98,6 +106,9 @@ class DBBackreferenceSet(MutableSet): ...@@ -98,6 +106,9 @@ class DBBackreferenceSet(MutableSet):
assert self.__ldapobj.ldap_object.session is not None assert self.__ldapobj.ldap_object.session is not None
if not isinstance(value, self.__dbcls): if not isinstance(value, self.__dbcls):
raise TypeError() raise TypeError()
if self.__mapcls is None:
setattr(value, self.__relattr, self.__dn)
else:
rel = getattr(value, self.__relattr) rel = getattr(value, self.__relattr)
if self.__dn not in {mapobj.dn for mapobj in rel}: if self.__dn not in {mapobj.dn for mapobj in rel}:
rel.append(self.__mapcls(dn=self.__dn)) rel.append(self.__mapcls(dn=self.__dn))
...@@ -105,13 +116,16 @@ class DBBackreferenceSet(MutableSet): ...@@ -105,13 +116,16 @@ class DBBackreferenceSet(MutableSet):
def discard(self, value): def discard(self, value):
if not isinstance(value, self.__dbcls): if not isinstance(value, self.__dbcls):
raise TypeError() raise TypeError()
if self.__mapcls is None:
setattr(value, self.__relattr, None)
else:
rel = getattr(value, self.__relattr) rel = getattr(value, self.__relattr)
for mapobj in list(rel): for mapobj in list(rel):
if mapobj.dn == self.__dn: if mapobj.dn == self.__dn:
rel.remove(mapobj) rel.remove(mapobj)
class DBBackreference: class DBBackreference:
def __init__(self, dbcls, relattr, mapcls, backattr): def __init__(self, dbcls, relattr, mapcls=None, backattr=None):
self.dbcls = dbcls self.dbcls = dbcls
self.relattr = relattr self.relattr = relattr
self.mapcls = mapcls self.mapcls = mapcls
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment