Skip to content
Snippets Groups Projects
Commit db135ec7 authored by Julian Rother's avatar Julian Rother
Browse files

Extended DBRelationship/DBBackreference to support 1-n-relationships

parent 6a0b85f2
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,7 @@ class DBRelationshipSet(MutableSet):
rel.remove(mapobj)
class DBRelationship:
def __init__(self, relattr, ldapcls, mapcls, backref=None, backattr=None):
def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None):
self.relattr = relattr
self.ldapcls = ldapcls
self.mapcls = mapcls
......@@ -52,19 +52,25 @@ class DBRelationship:
def __set_name__(self, cls, name):
if self.backref:
assert self.backattr
setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr))
def __get__(self, obj, objtype=None):
if obj is None:
return self
return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
if self.mapcls is not None:
return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
return self.ldapcls.query.get(getattr(obj, self.relattr))
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
if self.mapcls is not None:
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
else:
if not isinstance(values, self.ldapcls):
raise TypeError()
setattr(obj, self.relattr, values.ldap_object.dn)
class DBBackreferenceSet(MutableSet):
def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr):
......@@ -79,6 +85,8 @@ class DBBackreferenceSet(MutableSet):
return self.__ldapobj.ldap_object.dn
def __get(self):
if self.__mapcls is None:
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn})
return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
def __repr__(self):
......@@ -98,20 +106,26 @@ class DBBackreferenceSet(MutableSet):
assert self.__ldapobj.ldap_object.session is not None
if not isinstance(value, self.__dbcls):
raise TypeError()
rel = getattr(value, self.__relattr)
if self.__dn not in {mapobj.dn for mapobj in rel}:
rel.append(self.__mapcls(dn=self.__dn))
if self.__mapcls is None:
setattr(value, self.__relattr, self.__dn)
else:
rel = getattr(value, self.__relattr)
if self.__dn not in {mapobj.dn for mapobj in rel}:
rel.append(self.__mapcls(dn=self.__dn))
def discard(self, value):
if not isinstance(value, self.__dbcls):
raise TypeError()
rel = getattr(value, self.__relattr)
for mapobj in list(rel):
if mapobj.dn == self.__dn:
rel.remove(mapobj)
if self.__mapcls is None:
setattr(value, self.__relattr, None)
else:
rel = getattr(value, self.__relattr)
for mapobj in list(rel):
if mapobj.dn == self.__dn:
rel.remove(mapobj)
class DBBackreference:
def __init__(self, dbcls, relattr, mapcls, backattr):
def __init__(self, dbcls, relattr, mapcls=None, backattr=None):
self.dbcls = dbcls
self.relattr = relattr
self.mapcls = mapcls
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment