From 5bebfb459e6676bfd233698bbe6725deb1c62442 Mon Sep 17 00:00:00 2001 From: Julian Rother <julianr@fsmpi.rwth-aachen.de> Date: Mon, 22 Feb 2021 20:11:33 +0100 Subject: [PATCH] Added default parameter --- ldap3_mapper_new/attribute.py | 11 ++++++++++- ldap3_mapper_new/dbutils.py | 4 +++- ldap3_mapper_new/model.py | 18 +++++++++++------- ldap3_mapper_new/relationship.py | 6 +++--- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/ldap3_mapper_new/attribute.py b/ldap3_mapper_new/attribute.py index 6b8376e4..9e984025 100644 --- a/ldap3_mapper_new/attribute.py +++ b/ldap3_mapper_new/attribute.py @@ -38,10 +38,19 @@ class AttributeList(MutableSequence): self.__set(tmp) class Attribute: - def __init__(self, name, aliases=None, multi=False): + def __init__(self, name, aliases=None, multi=False, default=None): self.name = name self.aliases = aliases or [] self.multi = multi + self.default = default + + def add_hook(self, obj): + if obj.ldap_object.getattr(self.name) == []: + self.__set__(self.name, self.default() if callable(self.default) else self.default) + + def __set_name__(self, cls, name): + if self.default is not None: + cls.ldap_add_hooks = cls.ldap_add_hooks + (self.add_hook,) def __get__(self, obj, objtype=None): if obj is None: diff --git a/ldap3_mapper_new/dbutils.py b/ldap3_mapper_new/dbutils.py index 9dee2179..21a62321 100644 --- a/ldap3_mapper_new/dbutils.py +++ b/ldap3_mapper_new/dbutils.py @@ -1,5 +1,7 @@ from collections.abc import MutableSet +from .model import add_to_session + class DBRelationshipSet(MutableSet): def __init__(self, dbobj, relattr, ldapcls): self.__dbobj = dbobj @@ -27,7 +29,7 @@ class DBRelationshipSet(MutableSet): if not isinstance(value, self.__ldapcls): raise TypeError() if value.ldap_object.session is not None: - self.__ldapcls.ldap_mapper.session.add(value) + add_to_session(value, self.__ldapcls.ldap_mapper.session) if value.ldap_object.dn not in self.__get_dns(): getattr(self.__dbobj, self.__relattr).append(self.__ldapcls(dn=value.ldap_object.dn)) diff --git a/ldap3_mapper_new/model.py b/ldap3_mapper_new/model.py index 7f65466d..4a0c20d7 100644 --- a/ldap3_mapper_new/model.py +++ b/ldap3_mapper_new/model.py @@ -1,4 +1,3 @@ - try: # Added in v2.5 from ldap3.utils.dn import escape_rdn @@ -14,14 +13,19 @@ except ImportError: rdn = ''.join((rdn[:-1], '\\ ')) return rdn -from . import base +from . import core + +def add_to_session(obj, session): + for func in obj.ldap_add_hooks: + func(obj) + session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes) class Session: def __init__(self, get_connection): - self.ldap_session = base.Session(get_connection) + self.ldap_session = core.Session(get_connection) def add(self, obj): - self.ldap_session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes) + add_to_session(obj, self.ldap_session) def delete(self, obj): self.ldap_session.delete(obj.ldap_object) @@ -76,6 +80,8 @@ class ModelQueryWrapper: class Model: # Overwritten by mapper ldap_mapper = None + query = ModelQueryWrapper() + ldap_add_hooks = tuple() # Overwritten by models ldap_search_base = None @@ -84,10 +90,8 @@ class Model: ldap_dn_base = None ldap_dn_attribute = None - query = ModelQueryWrapper() - def __init__(self, **kwargs): - self.ldap_object = base.Object() + self.ldap_object = core.Object() for key, value, in kwargs.items(): if not hasattr(self, key): raise Exception() diff --git a/ldap3_mapper_new/relationship.py b/ldap3_mapper_new/relationship.py index 486d7276..06986266 100644 --- a/ldap3_mapper_new/relationship.py +++ b/ldap3_mapper_new/relationship.py @@ -1,6 +1,6 @@ from collections.abc import MutableSet -from .model import make_modelobj, make_modelobjs +from .model import make_modelobj, make_modelobjs, add_to_session class UnboundObjectError(Exception): pass @@ -38,7 +38,7 @@ class RelationshipSet(MutableSet): def add(self, value): self.__modify_check(value) if value.ldap_object.session is None: - self.__ldap_object.session.add(value.ldap_object) + add_to_session(value, self.__ldap_object.session) assert value.ldap_object.session == self.__ldap_object.session self.__ldap_object.attradd(self.__name, value.dn) @@ -102,7 +102,7 @@ class BackreferenceSet(MutableSet): def add(self, value): self.__modify_check(value) if value.ldap_object.session is None: - self.__ldap_object.session.add(value.ldap_object) + add_to_session(value, self.__ldap_object.session) assert value.ldap_object.session == self.__ldap_object.session if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name): value.ldap_object.attradd(self.__name, self.__ldap_object.dn) -- GitLab