From 5bebfb459e6676bfd233698bbe6725deb1c62442 Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Mon, 22 Feb 2021 20:11:33 +0100
Subject: [PATCH] Added default parameter

---
 ldap3_mapper_new/attribute.py    | 11 ++++++++++-
 ldap3_mapper_new/dbutils.py      |  4 +++-
 ldap3_mapper_new/model.py        | 18 +++++++++++-------
 ldap3_mapper_new/relationship.py |  6 +++---
 4 files changed, 27 insertions(+), 12 deletions(-)

diff --git a/ldap3_mapper_new/attribute.py b/ldap3_mapper_new/attribute.py
index 6b8376e4..9e984025 100644
--- a/ldap3_mapper_new/attribute.py
+++ b/ldap3_mapper_new/attribute.py
@@ -38,10 +38,19 @@ class AttributeList(MutableSequence):
 		self.__set(tmp)
 
 class Attribute:
-	def __init__(self, name, aliases=None, multi=False):
+	def __init__(self, name, aliases=None, multi=False, default=None):
 		self.name = name
 		self.aliases = aliases or []
 		self.multi = multi
+		self.default = default
+
+	def add_hook(self, obj):
+		if obj.ldap_object.getattr(self.name) == []:
+			self.__set__(self.name, self.default() if callable(self.default) else self.default)
+
+	def __set_name__(self, cls, name):
+		if self.default is not None:
+			cls.ldap_add_hooks = cls.ldap_add_hooks + (self.add_hook,)
 
 	def __get__(self, obj, objtype=None):
 		if obj is None:
diff --git a/ldap3_mapper_new/dbutils.py b/ldap3_mapper_new/dbutils.py
index 9dee2179..21a62321 100644
--- a/ldap3_mapper_new/dbutils.py
+++ b/ldap3_mapper_new/dbutils.py
@@ -1,5 +1,7 @@
 from collections.abc import MutableSet
 
+from .model import add_to_session
+
 class DBRelationshipSet(MutableSet):
 	def __init__(self, dbobj, relattr, ldapcls):
 		self.__dbobj = dbobj
@@ -27,7 +29,7 @@ class DBRelationshipSet(MutableSet):
 		if not isinstance(value, self.__ldapcls):
 			raise TypeError()
 		if value.ldap_object.session is not None:
-			self.__ldapcls.ldap_mapper.session.add(value)
+			add_to_session(value, self.__ldapcls.ldap_mapper.session)
 		if value.ldap_object.dn not in self.__get_dns():
 			getattr(self.__dbobj, self.__relattr).append(self.__ldapcls(dn=value.ldap_object.dn))
 
diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py
index 7f65466d..4a0c20d7 100644
--- a/ldap3_mapper_new/model.py
+++ b/ldap3_mapper_new/model.py
@@ -1,4 +1,3 @@
-
 try:
 	# Added in v2.5
 	from ldap3.utils.dn import escape_rdn
@@ -14,14 +13,19 @@ except ImportError:
 			rdn = ''.join((rdn[:-1], '\\ '))
 		return rdn
 
-from . import base
+from . import core
+
+def add_to_session(obj, session):
+	for func in obj.ldap_add_hooks:
+		func(obj)
+	session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes)
 
 class Session:
 	def __init__(self, get_connection):
-		self.ldap_session = base.Session(get_connection)
+		self.ldap_session = core.Session(get_connection)
 
 	def add(self, obj):
-		self.ldap_session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes)
+		add_to_session(obj, self.ldap_session)
 
 	def delete(self, obj):
 		self.ldap_session.delete(obj.ldap_object)
@@ -76,6 +80,8 @@ class ModelQueryWrapper:
 class Model:
 	# Overwritten by mapper
 	ldap_mapper = None
+	query = ModelQueryWrapper()
+	ldap_add_hooks = tuple()
 
 	# Overwritten by models
 	ldap_search_base = None
@@ -84,10 +90,8 @@ class Model:
 	ldap_dn_base = None
 	ldap_dn_attribute = None
 
-	query = ModelQueryWrapper()
-
 	def __init__(self, **kwargs):
-		self.ldap_object = base.Object()
+		self.ldap_object = core.Object()
 		for key, value, in kwargs.items():
 			if not hasattr(self, key):
 				raise Exception()
diff --git a/ldap3_mapper_new/relationship.py b/ldap3_mapper_new/relationship.py
index 486d7276..06986266 100644
--- a/ldap3_mapper_new/relationship.py
+++ b/ldap3_mapper_new/relationship.py
@@ -1,6 +1,6 @@
 from collections.abc import MutableSet
 
-from .model import make_modelobj, make_modelobjs
+from .model import make_modelobj, make_modelobjs, add_to_session
 
 class UnboundObjectError(Exception):
 	pass
@@ -38,7 +38,7 @@ class RelationshipSet(MutableSet):
 	def add(self, value):
 		self.__modify_check(value)
 		if value.ldap_object.session is None:
-			self.__ldap_object.session.add(value.ldap_object)
+			add_to_session(value, self.__ldap_object.session)
 		assert value.ldap_object.session == self.__ldap_object.session
 		self.__ldap_object.attradd(self.__name, value.dn)
 
@@ -102,7 +102,7 @@ class BackreferenceSet(MutableSet):
 	def add(self, value):
 		self.__modify_check(value)
 		if value.ldap_object.session is None:
-			self.__ldap_object.session.add(value.ldap_object)
+			add_to_session(value, self.__ldap_object.session)
 		assert value.ldap_object.session == self.__ldap_object.session
 		if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name):
 			value.ldap_object.attradd(self.__name, self.__ldap_object.dn)
-- 
GitLab