From db135ec7a4940a384e82a4500cfb0e73ecd2f557 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Tue, 23 Feb 2021 19:35:36 +0100 Subject: [PATCH] Extended DBRelationship/DBBackreference to support 1-n-relationships --- ldapalchemy/dbutils.py | 44 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/ldapalchemy/dbutils.py b/ldapalchemy/dbutils.py index 42fccb9..a9f399a 100644 --- a/ldapalchemy/dbutils.py +++ b/ldapalchemy/dbutils.py @@ -43,7 +43,7 @@ class DBRelationshipSet(MutableSet): rel.remove(mapobj) class DBRelationship: - def __init__(self, relattr, ldapcls, mapcls, backref=None, backattr=None): + def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None): self.relattr = relattr self.ldapcls = ldapcls self.mapcls = mapcls @@ -52,19 +52,25 @@ class DBRelationship: def __set_name__(self, cls, name): if self.backref: - assert self.backattr setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr)) def __get__(self, obj, objtype=None): if obj is None: return self - return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls) + if self.mapcls is not None: + return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls) + return self.ldapcls.query.get(getattr(obj, self.relattr)) def __set__(self, obj, values): - tmp = self.__get__(obj) - tmp.clear() - for value in values: - tmp.add(value) + if self.mapcls is not None: + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) + else: + if not isinstance(values, self.ldapcls): + raise TypeError() + setattr(obj, self.relattr, values.ldap_object.dn) class DBBackreferenceSet(MutableSet): def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr): @@ -79,6 +85,8 @@ class DBBackreferenceSet(MutableSet): return self.__ldapobj.ldap_object.dn def __get(self): + if self.__mapcls is None: + return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}) return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)} def __repr__(self): @@ -98,20 +106,26 @@ class DBBackreferenceSet(MutableSet): assert self.__ldapobj.ldap_object.session is not None if not isinstance(value, self.__dbcls): raise TypeError() - rel = getattr(value, self.__relattr) - if self.__dn not in {mapobj.dn for mapobj in rel}: - rel.append(self.__mapcls(dn=self.__dn)) + if self.__mapcls is None: + setattr(value, self.__relattr, self.__dn) + else: + rel = getattr(value, self.__relattr) + if self.__dn not in {mapobj.dn for mapobj in rel}: + rel.append(self.__mapcls(dn=self.__dn)) def discard(self, value): if not isinstance(value, self.__dbcls): raise TypeError() - rel = getattr(value, self.__relattr) - for mapobj in list(rel): - if mapobj.dn == self.__dn: - rel.remove(mapobj) + if self.__mapcls is None: + setattr(value, self.__relattr, None) + else: + rel = getattr(value, self.__relattr) + for mapobj in list(rel): + if mapobj.dn == self.__dn: + rel.remove(mapobj) class DBBackreference: - def __init__(self, dbcls, relattr, mapcls, backattr): + def __init__(self, dbcls, relattr, mapcls=None, backattr=None): self.dbcls = dbcls self.relattr = relattr self.mapcls = mapcls -- GitLab