diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 036f0f994f3ac980574ebda8319d0e2ae5b93074..e3385bbc5cb858f46142c1640f89c353e498ed47 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -71,3 +71,13 @@ trans_de: script: - ./update_translations.sh de coverage: '/^TOTAL.*\s+(\d+\%)$/' + +publish-pip: + stage: deploy + script: + - pip3 install build twine + - PACKAGE_VERSION="${CI_COMMIT_TAG#v}" python3 -m build + - TWINE_USERNAME="${GITLABPKGS_USERNAME}" TWINE_PASSWORD="${GITLABPKGS_PASSWORD}" python3 -m twine upload --repository-url ${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/pypi dist/* + - TWINE_USERNAME="${PYPI_USERNAME}" TWINE_PASSWORD="${PYPI_PASSWORD}" python3 -m twine upload dist/* + rules: + - if: '$CI_COMMIT_TAG =~ /v[0-9]+[.][0-9]+[.][0-9]+.*/' diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index f71540171466594ed28073c553f91825530fbf00..0000000000000000000000000000000000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "deps/ldapalchemy"] - path = deps/ldapalchemy - url = https://git.cccv.de/infra/uffd/ldapalchemy.git diff --git a/deps/ldapalchemy b/deps/ldapalchemy deleted file mode 160000 index 7a232d305fda3e261b6f8d3c0958a16f4c2e8d8b..0000000000000000000000000000000000000000 --- a/deps/ldapalchemy +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7a232d305fda3e261b6f8d3c0958a16f4c2e8d8b diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..ead7fb04ff725f357d14ebfa1ee564458f91187f --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup, find_packages +import os + +with open('README.md', 'r', encoding='utf-8') as f: + long_description = f.read() + +setup( + name='uffd', + version=os.environ.get('PACKAGE_VERSION', 'local'), + description='UserFerwaltungsFrontend: Ldap based single sign on and user management web software', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://git.cccv.de/uffd/uffd', + classifiers=[ + 'Programming Language :: Python :: 3', + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)', + 'Operating System :: OS Independent', + 'Topic :: System :: Systems Administration :: Authentication/Directory :: LDAP', + 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', + 'Environment :: Web Environment', + 'Framework :: Flask', + ], + author='CCCV', + author_email='it@cccv.de', + license='AGPL3', + packages=find_packages(), + zip_safe=False, + python_requires='>=3.7', +) diff --git a/uffd/__init__.py b/uffd/__init__.py index 293b47a1eccb8f929657e8e2a886a792bb7ad9f5..c099ed3eb6253b7acdb53f969e258aa5deb87977 100644 --- a/uffd/__init__.py +++ b/uffd/__init__.py @@ -10,15 +10,11 @@ from werkzeug.contrib.profiler import ProfilerMiddleware from werkzeug.exceptions import InternalServerError from flask_migrate import Migrate -sys.path.append('deps/ldapalchemy') - -# pylint: disable=wrong-import-position -from uffd.database import db, SQLAlchemyJSON import uffd.ldap +from uffd.database import db, SQLAlchemyJSON from uffd.template_helper import register_template_helper from uffd.navbar import setup_navbar from uffd.secure_redirect import secure_local_redirect -# pylint: enable=wrong-import-position def load_config_file(app, cfg_name, silent=False): cfg_path = os.path.join(app.instance_path, cfg_name) diff --git a/uffd/invite/models.py b/uffd/invite/models.py index 876984556183f56d6b8369204ab9411ab808f1d0..fb6f315fd38fbb417c3fc261b01ada9935af313c 100644 --- a/uffd/invite/models.py +++ b/uffd/invite/models.py @@ -4,8 +4,8 @@ import datetime from flask import current_app from sqlalchemy import Column, String, Integer, ForeignKey, DateTime, Boolean from sqlalchemy.orm import relationship -from ldapalchemy.dbutils import DBRelationship +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.user.models import User from uffd.signup.models import Signup diff --git a/uffd/ldap.py b/uffd/ldap.py index be03350ee56f9ea0086f7b9f4251fe53b5a890d4..015713d5d898d89358c360c07171590bff6adc68 100644 --- a/uffd/ldap.py +++ b/uffd/ldap.py @@ -6,9 +6,9 @@ from flask import current_app, request, abort, session import ldap3 from ldap3.core.exceptions import LDAPBindError, LDAPPasswordIsMandatoryError, LDAPInvalidDnError -from ldapalchemy import LDAPMapper, LDAPCommitError # pylint: disable=unused-import -from ldapalchemy.model import Query -from ldapalchemy.core import encode_filter +from uffd.ldapalchemy import LDAPMapper, LDAPCommitError # pylint: disable=unused-import +from uffd.ldapalchemy.model import Query +from uffd.ldapalchemy.core import encode_filter def check_hashed(password_hash, password): diff --git a/uffd/ldapalchemy/__init__.py b/uffd/ldapalchemy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0730e8c7c033bb53637547e58e81d62b77121a --- /dev/null +++ b/uffd/ldapalchemy/__init__.py @@ -0,0 +1,30 @@ +import ldap3 + +from .core import LDAPCommitError +from . import model, attribute, relationship + +__all__ = ['LDAPMapper', 'LDAPCommitError'] + +class LDAPMapper: + def __init__(self, server=None, bind_dn=None, bind_password=None): + + class Model(model.Model): + ldap_mapper = self + + self.Model = Model # pylint: disable=invalid-name + self.Session = model.Session # pylint: disable=invalid-name + self.Attribute = attribute.Attribute # pylint: disable=invalid-name + self.Relationship = relationship.Relationship # pylint: disable=invalid-name + self.Backreference = relationship.Backreference # pylint: disable=invalid-name + + if not hasattr(type(self), 'server'): + self.server = server + if not hasattr(type(self), 'bind_dn'): + self.bind_dn = bind_dn + if not hasattr(type(self), 'bind_password'): + self.bind_password = bind_password + if not hasattr(type(self), 'session'): + self.session = self.Session(self.get_connection) + + def get_connection(self): + return ldap3.Connection(self.server, self.bind_dn, self.bind_password, auto_bind=True) diff --git a/uffd/ldapalchemy/attribute.py b/uffd/ldapalchemy/attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..409cb58a9cf60e12569729f150d3a0ee52d05930 --- /dev/null +++ b/uffd/ldapalchemy/attribute.py @@ -0,0 +1,66 @@ +from collections.abc import MutableSequence + +class AttributeList(MutableSequence): + def __init__(self, ldap_object, name, aliases): + self.__ldap_object = ldap_object + self.__name = name + self.__aliases = [name] + aliases + + def __get(self): + return list(self.__ldap_object.getattr(self.__name)) + + def __set(self, values): + for name in self.__aliases: + self.__ldap_object.setattr(name, values) + + def __repr__(self): + return repr(self.__get()) + + def __setitem__(self, key, value): + tmp = self.__get() + tmp[key] = value + self.__set(tmp) + + def __delitem__(self, key): + tmp = self.__get() + del tmp[key] + self.__set(tmp) + + def __len__(self): + return len(self.__get()) + + def __getitem__(self, key): + return self.__get()[key] + + def insert(self, index, value): + tmp = self.__get() + tmp.insert(index, value) + self.__set(tmp) + +class Attribute: + def __init__(self, name, aliases=None, multi=False, default=None): + self.name = name + self.aliases = aliases if aliases is not None else [] + self.multi = multi + self.default = default + + def add_hook(self, obj): + if obj.ldap_object.getattr(self.name) == []: + self.__set__(obj, self.default() if callable(self.default) else self.default) + + def __set_name__(self, cls, name): + if self.default is not None: + cls.ldap_add_hooks = cls.ldap_add_hooks + (self.add_hook,) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + if self.multi: + return AttributeList(obj.ldap_object, self.name, self.aliases) + return (obj.ldap_object.getattr(self.name) or [None])[0] + + def __set__(self, obj, values): + if not self.multi: + values = [values] + for name in [self.name] + self.aliases: + obj.ldap_object.setattr(name, values) diff --git a/uffd/ldapalchemy/core.py b/uffd/ldapalchemy/core.py new file mode 100644 index 0000000000000000000000000000000000000000..c20a93ef789a057ff52470a2e310665b673fe63c --- /dev/null +++ b/uffd/ldapalchemy/core.py @@ -0,0 +1,284 @@ +from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES +from ldap3.utils.conv import escape_filter_chars + +def encode_filter(filter_params): + return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in filter_params])) + +def match_dn(dn, base): + return dn.endswith(base) # Probably good enougth for all valid dns + +def make_cache_key(search_base, filter_params): + res = (search_base,) + for attr, value in sorted(filter_params): + res += ((attr, value),) + return res + +class LDAPCommitError(Exception): + pass + +class SessionState: + def __init__(self, objects=None, deleted_objects=None, references=None): + self.objects = objects or {} + self.deleted_objects = deleted_objects or {} + self.references = references or {} # {(attr_name, value): {srcobj, ...}, ...} + + def copy(self): + objects = self.objects.copy() + deleted_objects = self.deleted_objects.copy() + references = {key: objs.copy() for key, objs in self.references.items()} + return SessionState(objects, deleted_objects, references) + + def ref(self, obj, attr, values): + for value in values: + key = (attr, value) + if key not in self.references: + self.references[key] = {obj} + else: + self.references[key].add(obj) + + def unref(self, obj, attr, values): + for value in values: + self.references.get((attr, value), set()).discard(obj) + +class ObjectState: + def __init__(self, session=None, attributes=None, dn=None): + self.session = session + self.attributes = attributes or {} + self.dn = dn + + def copy(self): + attributes = {name: values.copy() for name, values in self.attributes.items()} + return ObjectState(attributes=attributes, dn=self.dn, session=self.session) + +class AddOperation: + def __init__(self, obj, dn, object_classes): + self.obj = obj + self.dn = dn + self.object_classes = object_classes + self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()} + + def apply_object(self, obj_state): + obj_state.dn = self.dn + obj_state.attributes = {name: values.copy() for name, values in self.attributes.items()} + obj_state.attributes['objectClass'] = obj_state.attributes.get('objectClass', []) + list(self.object_classes) + + def apply_session(self, session_state): + assert self.dn not in session_state.objects + session_state.objects[self.dn] = self.obj + for name, values in self.attributes.items(): + session_state.ref(self.obj, name, values) + session_state.ref(self.obj, 'objectClass', self.object_classes) + + def apply_ldap(self, conn): + success = conn.add(self.dn, self.object_classes, self.attributes) + if not success: + raise LDAPCommitError() + +class DeleteOperation: + def __init__(self, obj): + self.dn = obj.state.dn + self.obj = obj + self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()} + + def apply_object(self, obj_state): + obj_state.dn = None + + def apply_session(self, session_state): + assert self.dn in session_state.objects + del session_state.objects[self.dn] + session_state.deleted_objects[self.dn] = self.obj + for name, values in self.attributes.items(): + session_state.unref(self.obj, name, values) + + def apply_ldap(self, conn): + success = conn.delete(self.dn) + if not success: + raise LDAPCommitError() + +class ModifyOperation: + def __init__(self, obj, changes): + self.obj = obj + self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()} + self.changes = changes + + def apply_object(self, obj_state): + for attr, changes in self.changes.items(): + for action, values in changes: + if action == MODIFY_REPLACE: + obj_state.attributes[attr] = values + elif action == MODIFY_ADD: + obj_state.attributes[attr] += values + elif action == MODIFY_DELETE: + for value in values: + if value in obj_state.attributes[attr]: + obj_state.attributes[attr].remove(value) + + def apply_session(self, session_state): + for attr, changes in self.changes.items(): + for action, values in changes: + if action == MODIFY_REPLACE: + session_state.unref(self.obj, attr, self.attributes.get(attr, [])) + session_state.ref(self.obj, attr, values) + elif action == MODIFY_ADD: + session_state.ref(self.obj, attr, values) + elif action == MODIFY_DELETE: + session_state.unref(self.obj, attr, values) + + def apply_ldap(self, conn): + success = conn.modify(self.obj.state.dn, self.changes) + if not success: + raise LDAPCommitError() + +class Session: + def __init__(self, get_connection): + self.get_connection = get_connection + self.committed_state = SessionState() + self.state = SessionState() + self.changes = [] + self.cached_searches = set() + + def add(self, obj, dn, object_classes): + if self.state.objects.get(dn) == obj: + return + assert obj.state.session is None + oper = AddOperation(obj, dn, object_classes) + oper.apply_object(obj.state) + obj.state.session = self + oper.apply_session(self.state) + self.changes.append(oper) + + def delete(self, obj): + if obj.state.dn not in self.state.objects: + return + assert obj.state.session == self + oper = DeleteOperation(obj) + oper.apply_object(obj.state) + obj.state.session = None + oper.apply_session(self.state) + self.changes.append(oper) + + def record(self, oper): + assert oper.obj.state.session == self + self.changes.append(oper) + + def commit(self): + conn = self.get_connection() + while self.changes: + oper = self.changes.pop(0) + try: + oper.apply_ldap(conn) + except Exception as err: + self.changes.insert(0, oper) + raise err + oper.apply_object(oper.obj.committed_state) + oper.apply_session(self.committed_state) + self.committed_state = self.state.copy() + + def rollback(self): + for obj in self.state.objects.values(): + obj.state = obj.committed_state.copy() + for obj in self.state.deleted_objects.values(): + obj.state = obj.committed_state.copy() + self.state = self.committed_state.copy() + self.changes.clear() + + def get(self, dn, filter_params): + if dn in self.state.objects: + obj = self.state.objects[dn] + return obj if obj.match(filter_params) else None + if dn in self.state.deleted_objects: + return None + conn = self.get_connection() + conn.search(dn, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + if not conn.response: + return None + assert len(conn.response) == 1 + if conn.response[0]['dn'] != dn: + # To use DNs as cache keys, we assume each DN has a single unique string + # representation. This is not generally true: RDN attributes may be + # case insensitive or values may contain escape sequences. + # In this case, the provided DN differs from the canonical form the + # server returned. We cannot handle this consistently, so we report no + # match. + return None + obj = Object(self, conn.response[0]) + self.state.objects[dn] = obj + self.committed_state.objects[dn] = obj + for attr, values in obj.state.attributes.items(): + self.state.ref(obj, attr, values) + return obj + + def filter(self, search_base, filter_params): + if not filter_params: + matches = self.state.objects.values() + else: + submatches = [self.state.references.get((attr, value), set()) for attr, value in filter_params] + matches = submatches.pop(0) + while submatches: + matches = matches.intersection(submatches.pop(0)) + res = [obj for obj in matches if match_dn(obj.state.dn, search_base)] + cache_key = make_cache_key(search_base, filter_params) + if cache_key in self.cached_searches: + return res + conn = self.get_connection() + conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + for response in conn.response: + dn = response['dn'] + if dn in self.state.objects or dn in self.state.deleted_objects: + continue + obj = Object(self, response) + self.state.objects[dn] = obj + self.committed_state.objects[dn] = obj + for attr, values in obj.state.attributes.items(): + self.state.ref(obj, attr, values) + res.append(obj) + self.cached_searches.add(cache_key) + return res + +class Object: + def __init__(self, session=None, response=None): + if response is None: + self.committed_state = ObjectState() + else: + assert session is not None + attrs = {attr: value if isinstance(value, list) else [value] for attr, value in response['attributes'].items()} + self.committed_state = ObjectState(session, attrs, response['dn']) + self.state = self.committed_state.copy() + + @property + def dn(self): + return self.state.dn + + @property + def session(self): + return self.state.session + + def getattr(self, name): + return self.state.attributes.get(name, []) + + def setattr(self, name, values): + oper = ModifyOperation(self, {name: [(MODIFY_REPLACE, values)]}) + oper.apply_object(self.state) + if self.state.session: + oper.apply_session(self.state.session.state) + self.state.session.changes.append(oper) + + def attr_append(self, name, value): + oper = ModifyOperation(self, {name: [(MODIFY_ADD, [value])]}) + oper.apply_object(self.state) + if self.state.session: + oper.apply_session(self.state.session.state) + self.state.session.changes.append(oper) + + def attr_remove(self, name, value): + oper = ModifyOperation(self, {name: [(MODIFY_DELETE, [value])]}) + oper.apply_object(self.state) + if self.state.session: + oper.apply_session(self.state.session.state) + self.state.session.changes.append(oper) + + def match(self, filter_params): + for attr, value in filter_params: + if value not in self.getattr(attr): + return False + return True diff --git a/uffd/ldapalchemy/dbutils.py b/uffd/ldapalchemy/dbutils.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc93f9a721c7151aadcb89d443b811a871d9681 --- /dev/null +++ b/uffd/ldapalchemy/dbutils.py @@ -0,0 +1,145 @@ +from collections.abc import MutableSet + +from .model import add_to_session + +class DBRelationshipSet(MutableSet): + def __init__(self, dbobj, relattr, ldapcls, mapcls): + self.__dbobj = dbobj + self.__relattr = relattr + self.__ldapcls = ldapcls + self.__mapcls = mapcls + + def __get_dns(self): + return [mapobj.dn for mapobj in getattr(self.__dbobj, self.__relattr)] + + def __repr__(self): + return repr(set(self)) + + def __contains__(self, value): + if value is None or not isinstance(value, self.__ldapcls): + return False + return value.ldap_object.dn in self.__get_dns() + + def __iter__(self): + return iter(filter(lambda obj: obj is not None, [self.__ldapcls.query.get(dn) for dn in self.__get_dns()])) + + def __len__(self): + return len(set(self)) + + def add(self, value): + if not isinstance(value, self.__ldapcls): + raise TypeError() + if value.ldap_object.session is None: + add_to_session(value, self.__ldapcls.ldap_mapper.session.ldap_session) + if value.ldap_object.dn not in self.__get_dns(): + getattr(self.__dbobj, self.__relattr).append(self.__mapcls(dn=value.ldap_object.dn)) + + def discard(self, value): + if not isinstance(value, self.__ldapcls): + raise TypeError() + rel = getattr(self.__dbobj, self.__relattr) + for mapobj in list(rel): + if mapobj.dn == value.ldap_object.dn: + rel.remove(mapobj) + +class DBRelationship: + def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None): + self.relattr = relattr + self.ldapcls = ldapcls + self.mapcls = mapcls + self.backref = backref + self.backattr = backattr + + def __set_name__(self, cls, name): + if self.backref: + setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr)) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + if self.mapcls is not None: + return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls) + dn = getattr(obj, self.relattr) + if dn is not None: + return self.ldapcls.query.get(dn) + return None + + def __set__(self, obj, values): + if self.mapcls is not None: + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) + else: + if not isinstance(values, self.ldapcls): + raise TypeError() + setattr(obj, self.relattr, values.ldap_object.dn) + +class DBBackreferenceSet(MutableSet): + def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr): + self.__ldapobj = ldapobj + self.__dbcls = dbcls + self.__relattr = relattr + self.__mapcls = mapcls + self.__backattr = backattr + + @property + def __dn(self): + return self.__ldapobj.ldap_object.dn + + def __get(self): + if self.__mapcls is None: + return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all() + return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)} + + def __repr__(self): + return repr(self.__get()) + + def __contains__(self, value): + return value in self.__get() + + def __iter__(self): + return iter(self.__get()) + + def __len__(self): + return len(self.__get()) + + def add(self, value): + assert self.__ldapobj.ldap_object.session is not None + if not isinstance(value, self.__dbcls): + raise TypeError() + if self.__mapcls is None: + setattr(value, self.__relattr, self.__dn) + else: + rel = getattr(value, self.__relattr) + if self.__dn not in {mapobj.dn for mapobj in rel}: + rel.append(self.__mapcls(dn=self.__dn)) + + def discard(self, value): + if not isinstance(value, self.__dbcls): + raise TypeError() + if self.__mapcls is None: + setattr(value, self.__relattr, None) + else: + rel = getattr(value, self.__relattr) + for mapobj in list(rel): + if mapobj.dn == self.__dn: + rel.remove(mapobj) + +class DBBackreference: + def __init__(self, dbcls, relattr, mapcls=None, backattr=None): + self.dbcls = dbcls + self.relattr = relattr + self.mapcls = mapcls + self.backattr = backattr + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return DBBackreferenceSet(obj, self.dbcls, self.relattr, self.mapcls, self.backattr) + + def __set__(self, obj, values): + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) diff --git a/uffd/ldapalchemy/model.py b/uffd/ldapalchemy/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e2fc33964ba1aba50d04b451e49f00f1a19380 --- /dev/null +++ b/uffd/ldapalchemy/model.py @@ -0,0 +1,148 @@ +from collections.abc import Sequence + +try: + # Added in v2.5 + from ldap3.utils.dn import escape_rdn +except ImportError: + # From ldap3 source code, Copyright Giovanni Cannata, LGPL v3 license + def escape_rdn(rdn): + # '/' must be handled first or the escape slashes will be escaped! + for char in ['\\', ',', '+', '"', '<', '>', ';', '=', '\x00']: + rdn = rdn.replace(char, '\\' + char) + if rdn[0] == '#' or rdn[0] == ' ': + rdn = ''.join(('\\', rdn)) + if rdn[-1] == ' ': + rdn = ''.join((rdn[:-1], '\\ ')) + return rdn + +from . import core + +def add_to_session(obj, session): + if obj.ldap_object.session is None: + for func in obj.ldap_add_hooks: + func(obj) + session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes) + +class Session: + def __init__(self, get_connection): + self.ldap_session = core.Session(get_connection) + + def add(self, obj): + add_to_session(obj, self.ldap_session) + + def delete(self, obj): + self.ldap_session.delete(obj.ldap_object) + + def commit(self): + self.ldap_session.commit() + + def rollback(self): + self.ldap_session.rollback() + +def make_modelobj(obj, model): + if obj is None: + return None + if not hasattr(obj, 'model'): + obj.model = model() + obj.model.ldap_object = obj + if not isinstance(obj.model, model): + return None + return obj.model + +def make_modelobjs(objs, model): + modelobjs = [] + for obj in objs: + modelobj = make_modelobj(obj, model) + if modelobj is not None: + modelobjs.append(modelobj) + return modelobjs + +class Query(Sequence): + def __init__(self, model, filter_params=None): + self.__model = model + self.__filter_params = list(model.ldap_filter_params) + (filter_params or []) + + @property + def __session(self): + return self.__model.ldap_mapper.session.ldap_session + + def get(self, dn): + return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model) + + def all(self): + objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params) + objs = sorted(objs, key=lambda obj: obj.dn) + return make_modelobjs(objs, self.__model) + + def first(self): + return (self.all() or [None])[0] + + def one(self): + modelobjs = self.all() + if len(modelobjs) != 1: + raise Exception() + return modelobjs[0] + + def one_or_none(self): + modelobjs = self.all() + if len(modelobjs) > 1: + raise Exception() + return (modelobjs or [None])[0] + + def __contains__(self, value): + return value in self.all() + + def __iter__(self): + return iter(self.all()) + + def __len__(self): + return len(self.all()) + + def __getitem__(self, index): + return self.all()[index] + + def filter_by(self, **kwargs): + filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()] + return type(self)(self.__model, self.__filter_params + filter_params) + +class QueryWrapper: + def __get__(self, obj, objtype=None): + return objtype.query_class(objtype) + +class Model: + # Overwritten by mapper + ldap_mapper = None + query_class = Query + query = QueryWrapper() + ldap_add_hooks = () + + # Overwritten by models + ldap_search_base = None + ldap_filter_params = () + ldap_object_classes = () + ldap_dn_base = None + ldap_dn_attribute = None + + def __init__(self, **kwargs): + self.ldap_object = core.Object() + for key, value, in kwargs.items(): + setattr(self, key, value) + + @property + def dn(self): + if self.ldap_object.dn is not None: + return self.ldap_object.dn + if self.ldap_dn_base is None or self.ldap_dn_attribute is None: + return None + values = self.ldap_object.getattr(self.ldap_dn_attribute) + if not values: + return None + # escape_rdn can't handle empty strings + rdn = escape_rdn(values[0]) if values[0] else '' + return '%s=%s,%s'%(self.ldap_dn_attribute, rdn, self.ldap_dn_base) + + def __repr__(self): + cls_name = '%s.%s'%(type(self).__module__, type(self).__name__) + if self.dn is not None: + return '<%s %s>'%(cls_name, self.dn) + return '<%s>'%cls_name diff --git a/uffd/ldapalchemy/relationship.py b/uffd/ldapalchemy/relationship.py new file mode 100644 index 0000000000000000000000000000000000000000..5666a5d3583910f35e10b375d25fda5ca49df63c --- /dev/null +++ b/uffd/ldapalchemy/relationship.py @@ -0,0 +1,136 @@ +from collections.abc import MutableSet + +from .model import make_modelobj, make_modelobjs, add_to_session + +class UnboundObjectError(Exception): + pass + +class RelationshipSet(MutableSet): + def __init__(self, ldap_object, name, model, destmodel): + self.__ldap_object = ldap_object + self.__name = name + self.__model = model # pylint: disable=unused-private-member + self.__destmodel = destmodel + + def __modify_check(self, value): + if self.__ldap_object.session is None: + raise UnboundObjectError() + if not isinstance(value, self.__destmodel): + raise TypeError() + + def __repr__(self): + return repr(set(self)) + + def __contains__(self, value): + if value is None or not isinstance(value, self.__destmodel): + return False + return value.ldap_object.dn in self.__ldap_object.getattr(self.__name) + + def __iter__(self): + def get(dn): + return make_modelobj(self.__ldap_object.session.get(dn, self.__destmodel.ldap_filter_params), self.__destmodel) + dns = set(self.__ldap_object.getattr(self.__name)) + return iter(filter(lambda obj: obj is not None, map(get, dns))) + + def __len__(self): + return len(set(self)) + + def add(self, value): + self.__modify_check(value) + if value.ldap_object.session is None: + add_to_session(value, self.__ldap_object.session) + assert value.ldap_object.session == self.__ldap_object.session + self.__ldap_object.attr_append(self.__name, value.dn) + + def discard(self, value): + self.__modify_check(value) + self.__ldap_object.attr_remove(self.__name, value.dn) + + def update(self, values): + for value in values: + self.add(value) + +class Relationship: + def __init__(self, name, destmodel, backref=None): + self.name = name + self.destmodel = destmodel + self.backref = backref + + def __set_name__(self, cls, name): + if self.backref is not None: + setattr(self.destmodel, self.backref, Backreference(self.name, cls)) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return RelationshipSet(obj.ldap_object, self.name, type(obj), self.destmodel) + + def __set__(self, obj, values): + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) + +class BackreferenceSet(MutableSet): + def __init__(self, ldap_object, name, model, srcmodel): + self.__ldap_object = ldap_object + self.__name = name + self.__model = model # pylint: disable=unused-private-member + self.__srcmodel = srcmodel + + def __modify_check(self, value): + if self.__ldap_object.session is None: + raise UnboundObjectError() + if not isinstance(value, self.__srcmodel): + raise TypeError() + + def __get(self): + if self.__ldap_object.session is None: + return set() + filter_params = list(self.__srcmodel.ldap_filter_params) + [(self.__name, self.__ldap_object.dn)] + objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params) + return set(make_modelobjs(objs, self.__srcmodel)) + + def __repr__(self): + return repr(self.__get()) + + def __contains__(self, value): + return value in self.__get() + + def __iter__(self): + return iter(self.__get()) + + def __len__(self): + return len(self.__get()) + + def add(self, value): + self.__modify_check(value) + if value.ldap_object.session is None: + add_to_session(value, self.__ldap_object.session) + assert value.ldap_object.session == self.__ldap_object.session + if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name): + value.ldap_object.attr_append(self.__name, self.__ldap_object.dn) + + def discard(self, value): + self.__modify_check(value) + value.ldap_object.attr_remove(self.__name, self.__ldap_object.dn) + + def update(self, values): + for value in values: + self.add(value) + +class Backreference: + def __init__(self, name, srcmodel): + self.name = name + self.srcmodel = srcmodel + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return BackreferenceSet(obj.ldap_object, self.name, type(obj), self.srcmodel) + + def __set__(self, obj, values): + tmp = self.__get__(obj) + tmp.clear() + for value in values: + tmp.add(value) diff --git a/uffd/mfa/models.py b/uffd/mfa/models.py index 541f231e127541bfafa3ebc0f145af322dc7f9f2..1e3eda273c78e20d599c0f8a4dfeba459d93e64d 100644 --- a/uffd/mfa/models.py +++ b/uffd/mfa/models.py @@ -13,8 +13,8 @@ import crypt from flask import request, current_app from sqlalchemy import Column, Integer, Enum, String, DateTime, Text -from ldapalchemy.dbutils import DBRelationship +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.user.models import User diff --git a/uffd/oauth2/models.py b/uffd/oauth2/models.py index c56ba0bd9c430bc7cd26e2caf4f4d9d2d38466b1..450edd9aadf0a183f17b97a4ce2e4e581cd0f762 100644 --- a/uffd/oauth2/models.py +++ b/uffd/oauth2/models.py @@ -1,7 +1,7 @@ from flask import current_app from sqlalchemy import Column, Integer, String, DateTime, Text -from ldapalchemy.dbutils import DBRelationship +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.user.models import User from uffd.session.models import DeviceLoginInitiation, DeviceLoginType diff --git a/uffd/role/models.py b/uffd/role/models.py index 580a8a62d4322864699ddb7d32b200ed0d394a22..e5b40190cf2b3ded9d3a732a4317fd756ab4e28c 100644 --- a/uffd/role/models.py +++ b/uffd/role/models.py @@ -3,8 +3,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm.collections import MappedCollection, collection from sqlalchemy.ext.declarative import declared_attr -from ldapalchemy.dbutils import DBRelationship - +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.user.models import User, Group diff --git a/uffd/session/models.py b/uffd/session/models.py index d515395f66f815b94e216c39893ea00f5b24e433..c91619e6d830a27a238e975d6b014e010a705ef1 100644 --- a/uffd/session/models.py +++ b/uffd/session/models.py @@ -6,8 +6,8 @@ import enum from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Enum from sqlalchemy.orm import relationship from sqlalchemy.ext.hybrid import hybrid_property -from ldapalchemy.dbutils import DBRelationship +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.user.models import User diff --git a/uffd/signup/models.py b/uffd/signup/models.py index 7c8dd6658fd8286b80de2c394e7dbc531c490ffd..7c7f4cc9b476b5632100b4038bf26e341117264b 100644 --- a/uffd/signup/models.py +++ b/uffd/signup/models.py @@ -3,8 +3,8 @@ import datetime from crypt import crypt from sqlalchemy import Column, String, Text, DateTime -from ldapalchemy.dbutils import DBRelationship +from uffd.ldapalchemy.dbutils import DBRelationship from uffd.database import db from uffd.ldap import ldap from uffd.user.models import User