From db135ec7a4940a384e82a4500cfb0e73ecd2f557 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Tue, 23 Feb 2021 19:35:36 +0100
Subject: [PATCH] Extended DBRelationship/DBBackreference to support
 1-n-relationships

---
 ldapalchemy/dbutils.py | 44 ++++++++++++++++++++++++++++--------------
 1 file changed, 29 insertions(+), 15 deletions(-)

diff --git a/ldapalchemy/dbutils.py b/ldapalchemy/dbutils.py
index 42fccb9..a9f399a 100644
--- a/ldapalchemy/dbutils.py
+++ b/ldapalchemy/dbutils.py
@@ -43,7 +43,7 @@ class DBRelationshipSet(MutableSet):
 				rel.remove(mapobj)
 
 class DBRelationship:
-	def __init__(self, relattr, ldapcls, mapcls, backref=None, backattr=None):
+	def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None):
 		self.relattr = relattr
 		self.ldapcls = ldapcls
 		self.mapcls = mapcls
@@ -52,19 +52,25 @@ class DBRelationship:
 
 	def __set_name__(self, cls, name):
 		if self.backref:
-			assert self.backattr
 			setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr))
 
 	def __get__(self, obj, objtype=None):
 		if obj is None:
 			return self
-		return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
+		if self.mapcls is not None:
+			return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
+		return self.ldapcls.query.get(getattr(obj, self.relattr))
 
 	def __set__(self, obj, values):
-		tmp = self.__get__(obj)
-		tmp.clear()
-		for value in values:
-			tmp.add(value)
+		if self.mapcls is not None:
+			tmp = self.__get__(obj)
+			tmp.clear()
+			for value in values:
+				tmp.add(value)
+		else:
+			if not isinstance(values, self.ldapcls):
+				raise TypeError()
+			setattr(obj, self.relattr, values.ldap_object.dn)
 
 class DBBackreferenceSet(MutableSet):
 	def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr):
@@ -79,6 +85,8 @@ class DBBackreferenceSet(MutableSet):
 		return self.__ldapobj.ldap_object.dn
 
 	def __get(self):
+		if self.__mapcls is None:
+			return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn})
 		return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
 
 	def __repr__(self):
@@ -98,20 +106,26 @@ class DBBackreferenceSet(MutableSet):
 		assert self.__ldapobj.ldap_object.session is not None
 		if not isinstance(value, self.__dbcls):
 			raise TypeError()
-		rel = getattr(value, self.__relattr)
-		if self.__dn not in {mapobj.dn for mapobj in rel}:
-			rel.append(self.__mapcls(dn=self.__dn))
+		if self.__mapcls is None:
+			setattr(value, self.__relattr, self.__dn)
+		else:
+			rel = getattr(value, self.__relattr)
+			if self.__dn not in {mapobj.dn for mapobj in rel}:
+				rel.append(self.__mapcls(dn=self.__dn))
 
 	def discard(self, value):
 		if not isinstance(value, self.__dbcls):
 			raise TypeError()
-		rel = getattr(value, self.__relattr)
-		for mapobj in list(rel):
-			if mapobj.dn == self.__dn:
-				rel.remove(mapobj)
+		if self.__mapcls is None:
+			setattr(value, self.__relattr, None)
+		else:
+			rel = getattr(value, self.__relattr)
+			for mapobj in list(rel):
+				if mapobj.dn == self.__dn:
+					rel.remove(mapobj)
 
 class DBBackreference:
-	def __init__(self, dbcls, relattr, mapcls, backattr):
+	def __init__(self, dbcls, relattr, mapcls=None, backattr=None):
 		self.dbcls = dbcls
 		self.relattr = relattr
 		self.mapcls = mapcls
-- 
GitLab