From 9a0699fb99421b9f08b33fda2334bf0c98191b44 Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Tue, 23 Feb 2021 01:22:17 +0100
Subject: [PATCH] Replace old ldap mapper code

---
 ldap3_mapper/__init__.py                      |  35 --
 ldap3_mapper/base.py                          | 320 ------------------
 ldap3_mapper/db_relation.py                   |  66 ----
 ldap3_mapper/types.py                         |  35 --
 {ldap3_mapper_new => ldap_mapper}/__init__.py |   0
 .../attribute.py                              |   0
 {ldap3_mapper_new => ldap_mapper}/core.py     |   0
 {ldap3_mapper_new => ldap_mapper}/dbutils.py  |   0
 {ldap3_mapper_new => ldap_mapper}/model.py    |   0
 .../relationship.py                           |   0
 uffd/ldap.py                                  |   2 +-
 uffd/role/models.py                           |   2 +-
 12 files changed, 2 insertions(+), 458 deletions(-)
 delete mode 100644 ldap3_mapper/__init__.py
 delete mode 100644 ldap3_mapper/base.py
 delete mode 100644 ldap3_mapper/db_relation.py
 delete mode 100644 ldap3_mapper/types.py
 rename {ldap3_mapper_new => ldap_mapper}/__init__.py (100%)
 rename {ldap3_mapper_new => ldap_mapper}/attribute.py (100%)
 rename {ldap3_mapper_new => ldap_mapper}/core.py (100%)
 rename {ldap3_mapper_new => ldap_mapper}/dbutils.py (100%)
 rename {ldap3_mapper_new => ldap_mapper}/model.py (100%)
 rename {ldap3_mapper_new => ldap_mapper}/relationship.py (100%)

diff --git a/ldap3_mapper/__init__.py b/ldap3_mapper/__init__.py
deleted file mode 100644
index 8c7114e7..00000000
--- a/ldap3_mapper/__init__.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import ldap3
-
-from .types import LDAPCommitError
-from . import base
-
-class LDAP3Mapper:
-	def __init__(self, server=None, bind_dn=None, bind_password=None):
-		if not hasattr(type(self), 'server'):
-			self.server = server
-		if not hasattr(type(self), 'bind_dn'):
-			self.bind_dn = bind_dn
-		if not hasattr(type(self), 'bind_password'):
-			self.bind_password = bind_password
-		if not hasattr(type(self), 'session'):
-			self.session = base.Session()
-
-		class Model(base.Model):
-			ldap_mapper = self
-
-		class Attribute(base.Attribute):
-			ldap_mapper = self
-
-		class Relation(base.Relation):
-			ldap_mapper = self
-
-		class Backref(base.Backref):
-			ldap_mapper = self
-
-		self.Model = Model # pylint: disable=invalid-name
-		self.Attribute = Attribute # pylint: disable=invalid-name
-		self.Relation = Relation # pylint: disable=invalid-name
-		self.Backref = Backref # pylint: disable=invalid-name
-
-	def connect(self):
-		return ldap3.Connection(self.server, self.bind_dn, self.bind_password, auto_bind=True)
diff --git a/ldap3_mapper/base.py b/ldap3_mapper/base.py
deleted file mode 100644
index 62bff918..00000000
--- a/ldap3_mapper/base.py
+++ /dev/null
@@ -1,320 +0,0 @@
-from copy import deepcopy
-
-from ldap3.utils.conv import escape_filter_chars
-from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
-
-from .types import LDAPSet, LDAPCommitError
-
-class Session:
-	def __init__(self):
-		self.__objects = {} # dn -> instance
-		self.__to_delete = []
-		self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...}
-
-	def lookup(self, dn):
-		return self.__objects.get(dn)
-
-	def register(self, obj):
-		if obj.dn in self.__objects and self.__objects[obj.dn] != obj:
-			raise Exception()
-		self.__objects[obj.dn] = obj
-		return obj
-
-	def lookup_relations(self, srccls, srcattr, dn):
-		key = (srccls, srcattr, dn)
-		return self.__relations.get(key, set())
-
-	def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None):
-		for dn in (delete_dns or []):
-			key = (type(srcobj), srcattr, dn)
-			self.__relations[key] = self.__relations.get(key, set())
-			self.__relations[key].discard(srcobj)
-		for dn in (add_dns or []):
-			key = (type(srcobj), srcattr, dn)
-			self.__relations[key] = self.__relations.get(key, set())
-			self.__relations[key].add(srcobj)
-
-	def add(self, obj):
-		self.register(obj)
-
-	def delete(self, obj):
-		if obj.dn in self.__objects:
-			del self.__objects[obj.dn]
-		self.__to_delete.append(obj)
-
-	def commit(self):
-		while self.__to_delete:
-			self.__to_delete.pop(0).ldap_delete()
-		for obj in list(self.__objects.values()):
-			if not obj.ldap_created:
-				obj.ldap_create()
-			elif obj.ldap_dirty:
-				obj.ldap_modify()
-
-	def rollback(self):
-		self.__to_delete.clear()
-		self.__objects = {dn: obj for dn, obj in self.__objects.items() if obj.ldap_created}
-		for obj in self.__objects.values():
-			if obj.ldap_dirty:
-				obj.ldap_reset()
-
-class Model:
-	ldap_mapper = None # Overwritten by LDAP3Mapper
-
-	ldap_dn_attribute = None
-	ldap_dn_base = None
-	ldap_base = None
-	ldap_object_classes = None
-	ldap_filter = None
-	# Caution: Never mutate ldap_pre_create_hooks and ldap_relations, always reassign!
-	ldap_pre_create_hooks = []
-	ldap_relations = []
-
-	def __init__(self, _ldap_response=None, **kwargs):
-		self.ldap_session = self.ldap_mapper.session
-		self.ldap_relation_data = set()
-		self.__ldap_dn = None if _ldap_response is None else _ldap_response['dn']
-		self.__ldap_attributes = {}
-		for key, values in (_ldap_response or {}).get('attributes', {}).items():
-			if isinstance(values, list):
-				self.__ldap_attributes[key] = values
-			else:
-				self.__ldap_attributes[key] = [values]
-		self.__attributes = deepcopy(self.__ldap_attributes)
-		self.__changes = {}
-		for key, value, in kwargs.items():
-			if not hasattr(self, key):
-				raise Exception()
-			setattr(self, key, value)
-		for name in self.ldap_relations:
-			self.__update_relations(name, add_dns=self.__attributes.get(name, []))
-
-	def __update_relations(self, name, delete_dns=None, add_dns=None):
-		if name in self.ldap_relations:
-			self.ldap_session.update_relations(self, name, delete_dns, add_dns)
-
-	def ldap_getattr(self, name):
-		return self.__attributes.get(name, [])
-
-	def ldap_setattr(self, name, values):
-		self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
-		self.__changes[name] = [(MODIFY_REPLACE, values)]
-		self.__attributes[name] = values
-		self.__update_relations(name, add_dns=values)
-
-	def ldap_attradd(self, name, value):
-		self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])]
-		self.__attributes[name].append(value)
-		self.__update_relations(name, add_dns=[value])
-
-	def ldap_attrdel(self, name, value):
-		self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])]
-		if value in self.__attributes.get(name, []):
-			self.__attributes[name].remove(value)
-		self.__update_relations(name, delete_dns=[value])
-
-	def __repr__(self):
-		name = '%s.%s'%(type(self).__module__, type(self).__name__)
-		if self.__ldap_dn is None:
-			return '<%s>'%name
-		return '<%s %s>'%(name, self.__ldap_dn)
-
-	def build_dn(self):
-		if self.ldap_dn_attribute is None:
-			return None
-		if self.ldap_dn_base is None:
-			return None
-		if self.__attributes.get(self.ldap_dn_attribute) is None:
-			return None
-		return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(self.__attributes[self.ldap_dn_attribute][0]), self.ldap_dn_base)
-
-	@property
-	def dn(self):
-		if self.__ldap_dn is not None:
-			return self.__ldap_dn
-		return self.build_dn()
-
-	@classmethod
-	def ldap_get(cls, dn):
-		obj = cls.ldap_mapper.session.lookup(dn)
-		if obj is None:
-			conn = cls.ldap_mapper.connect()
-			conn.search(dn, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
-			if not conn.response:
-				return None
-			if len(conn.response) != 1:
-				raise Exception()
-			obj = cls.ldap_mapper.session.register(cls(_ldap_response=conn.response[0]))
-		return obj
-
-	@classmethod
-	def ldap_all(cls):
-		conn = cls.ldap_mapper.connect()
-		conn.search(cls.ldap_base, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
-		res = []
-		for entry in conn.response:
-			obj = cls.ldap_mapper.session.lookup(entry['dn'])
-			if obj is None:
-				obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry))
-			res.append(obj)
-		return res
-
-	@classmethod
-	def ldap_filter_by_raw(cls, **kwargs):
-		filters = [cls.ldap_filter]
-		for key, value in kwargs.items():
-			filters.append('(%s=%s)'%(key, escape_filter_chars(value)))
-		conn = cls.ldap_mapper.connect()
-		conn.search(cls.ldap_base, '(&%s)'%(''.join(filters)), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
-		res = []
-		for entry in conn.response:
-			obj = cls.ldap_mapper.session.lookup(entry['dn'])
-			if obj is None:
-				obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry))
-			res.append(obj)
-		return res
-
-	@classmethod
-	def ldap_filter_by(cls, **kwargs):
-		_kwargs = {}
-		for key, value in kwargs.items():
-			attr = getattr(cls, key)
-			_kwargs[attr.name] = attr.encode(value)
-		return cls.ldap_filter_by_raw(**_kwargs)
-
-	def ldap_reset(self):
-		for name in self.ldap_relations:
-			self.__update_relations(name, delete_dns=self.__attributes.get(name, []))
-		self.__changes = {}
-		self.__attributes = deepcopy(self.__ldap_attributes)
-		for name in self.ldap_relations:
-			self.__update_relations(name, add_dns=self.__attributes.get(name, []))
-
-	@property
-	def ldap_dirty(self):
-		return bool(self.__changes)
-
-	@property
-	def ldap_created(self):
-		return bool(self.__ldap_attributes)
-
-	def ldap_modify(self):
-		if not self.ldap_created:
-			raise Exception()
-		if not self.ldap_dirty:
-			return
-		conn = self.ldap_mapper.connect()
-		success = conn.modify(self.dn, self.__changes)
-		if not success:
-			raise Exception()
-		self.__changes = {}
-		self.__ldap_attributes = deepcopy(self.__attributes)
-
-	def ldap_create(self):
-		if self.ldap_created:
-			raise Exception()
-		conn = self.ldap_mapper.connect()
-		for func in self.ldap_pre_create_hooks:
-			func(self)
-		success = conn.add(self.dn, self.ldap_object_classes, self.__attributes)
-		if not success:
-			raise LDAPCommitError()
-		self.__changes = {}
-		self.__ldap_attributes = deepcopy(self.__attributes)
-
-	def ldap_delete(self):
-		conn = self.ldap_mapper.connect()
-		success = conn.delete(self.dn)
-		if not success:
-			raise Exception()
-		self.__ldap_attributes = {}
-
-class Attribute:
-	ldap_mapper = None # Overwritten by LDAP3Mapper
-
-	def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None):
-		self.name = name
-		self.multi = multi
-		self.encode = encode or (lambda x: x)
-		self.decode = decode or (lambda x: x)
-		self.default_values = default
-		self.aliases = aliases or []
-
-	def default(self, obj):
-		if obj.ldap_getattr(self.name) == []:
-			values = self.default_values
-			if callable(values):
-				values = values()
-			self.__set__(obj, values)
-
-	def additem(self, obj, value):
-		obj.ldap_attradd(self.name, value)
-		for name in self.aliases:
-			obj.ldap_attradd(name, value)
-
-	def delitem(self, obj, value):
-		obj.ldap_attradd(self.name, value)
-		for name in self.aliases:
-			obj.ldap_attradd(name, value)
-
-	def __set_name__(self, cls, name):
-		if self.default_values is not None:
-			cls.ldap_pre_create_hooks = cls.ldap_pre_create_hooks + [self.default]
-
-	def __get__(self, obj, objtype=None):
-		if obj is None:
-			return self
-		if self.multi:
-			return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name),
-			               additem=lambda value: self.additem(obj, value),
-			               delitem=lambda value: self.delitem(obj, value),
-			               encode=self.encode, decode=self.decode)
-		return self.decode((obj.ldap_getattr(self.name) or [None])[0])
-
-	def __set__(self, obj, values):
-		if not self.multi:
-			values = [values]
-		obj.ldap_setattr(self.name, [self.encode(value) for value in values])
-		for name in self.aliases:
-			obj.ldap_setattr(name, [self.encode(value) for value in values])
-
-class Backref:
-	ldap_mapper = None # Overwritten by LDAP3Mapper
-
-	def __init__(self, srccls, srcattr):
-		self.srccls = srccls
-		self.srcattr = srcattr
-		srccls.ldap_relations = srccls.ldap_relations + [srcattr]
-
-	def init(self, obj):
-		if self.srcattr not in obj.ldap_relation_data and obj.ldap_created:
-			# The query instanciates all related objects that in turn add their relations to session
-			self.srccls.ldap_filter_by_raw(**{self.srcattr: obj.dn})
-		obj.ldap_relation_data.add(self.srcattr)
-
-	def __get__(self, obj, objtype=None):
-		if obj is None:
-			return self
-		self.init(obj)
-		return LDAPSet(getitems=lambda: obj.ldap_session.lookup_relations(self.srccls, self.srcattr, obj.dn),
-									 additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn),
-									 delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn))
-
-	def __set__(self, obj, values):
-		current = self.__get__(obj)
-		current.clear()
-		for value in values:
-			current.add(value)
-
-class Relation(Attribute):
-	ldap_mapper = None # Overwritten by LDAP3Mapper
-
-	def __init__(self, name, dest, backref=None):
-		super().__init__(name, multi=True, encode=lambda value: value.dn, decode=dest.ldap_get)
-		self.name = name
-		self.dest = dest
-		self.backref = backref
-
-	def __set_name__(self, cls, name):
-		if self.backref is not None:
-			setattr(self.dest, self.backref, Backref(cls, self.name))
diff --git a/ldap3_mapper/db_relation.py b/ldap3_mapper/db_relation.py
deleted file mode 100644
index 337e0869..00000000
--- a/ldap3_mapper/db_relation.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from .types import LDAPSet
-
-class DB2LDAPBackref:
-	def __init__(self, baseattr, mapcls, backattr):
-		self.baseattr = baseattr
-		self.mapcls = mapcls
-		self.backattr = backattr
-
-	def getitems(self, ldapobj):
-		return {getattr(mapobj, self.backattr) for mapobj in self.mapcls.query.filter_by(dn=ldapobj.dn)}
-
-	def additem(self, ldapobj, dbobj):
-		if dbobj not in self.getitems(ldapobj):
-			getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn))
-
-	def delitem(self, ldapobj, dbobj):
-		for mapobj in list(getattr(dbobj, self.baseattr)):
-			if mapobj.dn == ldapobj.dn:
-				getattr(dbobj, self.baseattr).remove(mapobj)
-
-	def __get__(self, ldapobj, objtype=None):
-		if ldapobj is None:
-			return self
-		return LDAPSet(getitems=lambda: self.getitems(ldapobj),
-		               additem=lambda dbobj: self.additem(ldapobj, dbobj),
-		               delitem=lambda dbobj: self.delitem(ldapobj, dbobj))
-
-	def __set__(self, ldapobj, dbobjs):
-		rel = self.__get__(ldapobj)
-		rel.clear()
-		for dbobj in dbobjs:
-			rel.add(dbobj)
-
-class DB2LDAPRelation:
-	def __init__(self, baseattr, mapcls, ldapcls, backattr=None, backref=None):
-		self.baseattr = baseattr
-		self.mapcls = mapcls
-		self.ldapcls = ldapcls
-		if backref is not None:
-			setattr(ldapcls, backref, DB2LDAPBackref(baseattr, mapcls, backattr))
-
-	def getitems(self, dbobj):
-		return {mapobj.dn for mapobj in getattr(dbobj, self.baseattr)}
-
-	def additem(self, dbobj, dn):
-		if dn not in self.getitems(dbobj):
-			getattr(dbobj, self.baseattr).append(self.mapcls(dn=dn))
-
-	def delitem(self, dbobj, dn):
-		for mapobj in list(getattr(dbobj, self.baseattr)):
-			if mapobj.dn == dn:
-				getattr(dbobj, self.baseattr).remove(mapobj)
-
-	def __get__(self, dbobj, objtype=None):
-		if dbobj is None:
-			return self
-		return LDAPSet(getitems=lambda: self.getitems(dbobj),
-		               additem=lambda dn: self.additem(dbobj, dn),
-		               delitem=lambda dn: self.delitem(dbobj, dn),
-		               encode=lambda ldapobj: ldapobj.dn,
-		               decode=self.ldapcls.ldap_get)
-
-	def __set__(self, dbobj, ldapobjs):
-		getattr(dbobj, self.baseattr).clear()
-		for ldapobj in ldapobjs:
-			getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn))
diff --git a/ldap3_mapper/types.py b/ldap3_mapper/types.py
deleted file mode 100644
index 59fb50fc..00000000
--- a/ldap3_mapper/types.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from collections.abc import MutableSet
-
-class LDAPCommitError(Exception):
-	pass
-
-class LDAPSet(MutableSet):
-	def __init__(self, getitems, additem, delitem, encode=None, decode=None):
-		self.__getitems = getitems
-		self.__additem = additem
-		self.__delitem = delitem
-		self.__encode = encode or (lambda x: x)
-		self.__decode = decode or (lambda x: x)
-
-	def __repr__(self):
-		return repr(set(self))
-
-	def __contains__(self, value):
-		return value is not None and self.__encode(value) in self.__getitems()
-
-	def __iter__(self):
-		return iter(filter(lambda obj: obj is not None, map(self.__decode, self.__getitems())))
-
-	def __len__(self):
-		return len(set(self))
-
-	def add(self, value):
-		if value not in self:
-			self.__additem(self.__encode(value))
-
-	def discard(self, value):
-		self.__delitem(self.__encode(value))
-
-	def update(self, values):
-		for value in values:
-			self.add(value)
diff --git a/ldap3_mapper_new/__init__.py b/ldap_mapper/__init__.py
similarity index 100%
rename from ldap3_mapper_new/__init__.py
rename to ldap_mapper/__init__.py
diff --git a/ldap3_mapper_new/attribute.py b/ldap_mapper/attribute.py
similarity index 100%
rename from ldap3_mapper_new/attribute.py
rename to ldap_mapper/attribute.py
diff --git a/ldap3_mapper_new/core.py b/ldap_mapper/core.py
similarity index 100%
rename from ldap3_mapper_new/core.py
rename to ldap_mapper/core.py
diff --git a/ldap3_mapper_new/dbutils.py b/ldap_mapper/dbutils.py
similarity index 100%
rename from ldap3_mapper_new/dbutils.py
rename to ldap_mapper/dbutils.py
diff --git a/ldap3_mapper_new/model.py b/ldap_mapper/model.py
similarity index 100%
rename from ldap3_mapper_new/model.py
rename to ldap_mapper/model.py
diff --git a/ldap3_mapper_new/relationship.py b/ldap_mapper/relationship.py
similarity index 100%
rename from ldap3_mapper_new/relationship.py
rename to ldap_mapper/relationship.py
diff --git a/uffd/ldap.py b/uffd/ldap.py
index 0dcd9fdf..2b9fe0a4 100644
--- a/uffd/ldap.py
+++ b/uffd/ldap.py
@@ -2,7 +2,7 @@ from flask import current_app, request
 
 import ldap3
 
-from ldap3_mapper_new import LDAPMapper, LDAPCommitError # pylint: disable=unused-import
+from ldap_mapper import LDAPMapper, LDAPCommitError # pylint: disable=unused-import
 
 class FlaskLDAPMapper(LDAPMapper):
 	def __init__(self):
diff --git a/uffd/role/models.py b/uffd/role/models.py
index 996b6b3f..d7e0daf9 100644
--- a/uffd/role/models.py
+++ b/uffd/role/models.py
@@ -2,7 +2,7 @@ from sqlalchemy import Column, String, Integer, Text, ForeignKey
 from sqlalchemy.orm import relationship
 from sqlalchemy.ext.declarative import declared_attr
 
-from ldap3_mapper_new.dbutils import DBRelationship
+from ldap_mapper.dbutils import DBRelationship
 
 from uffd.database import db
 from uffd.user.models import User, Group
-- 
GitLab