diff --git a/ldap3_mapper/__init__.py b/ldap3_mapper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c7114e736a66a44f36273c403a9440f95cedf38 --- /dev/null +++ b/ldap3_mapper/__init__.py @@ -0,0 +1,35 @@ +import ldap3 + +from .types import LDAPCommitError +from . import base + +class LDAP3Mapper: + def __init__(self, server=None, bind_dn=None, bind_password=None): + if not hasattr(type(self), 'server'): + self.server = server + if not hasattr(type(self), 'bind_dn'): + self.bind_dn = bind_dn + if not hasattr(type(self), 'bind_password'): + self.bind_password = bind_password + if not hasattr(type(self), 'session'): + self.session = base.Session() + + class Model(base.Model): + ldap_mapper = self + + class Attribute(base.Attribute): + ldap_mapper = self + + class Relation(base.Relation): + ldap_mapper = self + + class Backref(base.Backref): + ldap_mapper = self + + self.Model = Model # pylint: disable=invalid-name + self.Attribute = Attribute # pylint: disable=invalid-name + self.Relation = Relation # pylint: disable=invalid-name + self.Backref = Backref # pylint: disable=invalid-name + + def connect(self): + return ldap3.Connection(self.server, self.bind_dn, self.bind_password, auto_bind=True) diff --git a/ldap3_mapper/base.py b/ldap3_mapper/base.py new file mode 100644 index 0000000000000000000000000000000000000000..62bff9189403482d209ff30885d2299df3e3ff95 --- /dev/null +++ b/ldap3_mapper/base.py @@ -0,0 +1,320 @@ +from copy import deepcopy + +from ldap3.utils.conv import escape_filter_chars +from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES + +from .types import LDAPSet, LDAPCommitError + +class Session: + def __init__(self): + self.__objects = {} # dn -> instance + self.__to_delete = [] + self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...} + + def lookup(self, dn): + return self.__objects.get(dn) + + def register(self, obj): + if obj.dn in self.__objects and self.__objects[obj.dn] != obj: + raise Exception() + self.__objects[obj.dn] = obj + return obj + + def lookup_relations(self, srccls, srcattr, dn): + key = (srccls, srcattr, dn) + return self.__relations.get(key, set()) + + def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None): + for dn in (delete_dns or []): + key = (type(srcobj), srcattr, dn) + self.__relations[key] = self.__relations.get(key, set()) + self.__relations[key].discard(srcobj) + for dn in (add_dns or []): + key = (type(srcobj), srcattr, dn) + self.__relations[key] = self.__relations.get(key, set()) + self.__relations[key].add(srcobj) + + def add(self, obj): + self.register(obj) + + def delete(self, obj): + if obj.dn in self.__objects: + del self.__objects[obj.dn] + self.__to_delete.append(obj) + + def commit(self): + while self.__to_delete: + self.__to_delete.pop(0).ldap_delete() + for obj in list(self.__objects.values()): + if not obj.ldap_created: + obj.ldap_create() + elif obj.ldap_dirty: + obj.ldap_modify() + + def rollback(self): + self.__to_delete.clear() + self.__objects = {dn: obj for dn, obj in self.__objects.items() if obj.ldap_created} + for obj in self.__objects.values(): + if obj.ldap_dirty: + obj.ldap_reset() + +class Model: + ldap_mapper = None # Overwritten by LDAP3Mapper + + ldap_dn_attribute = None + ldap_dn_base = None + ldap_base = None + ldap_object_classes = None + ldap_filter = None + # Caution: Never mutate ldap_pre_create_hooks and ldap_relations, always reassign! + ldap_pre_create_hooks = [] + ldap_relations = [] + + def __init__(self, _ldap_response=None, **kwargs): + self.ldap_session = self.ldap_mapper.session + self.ldap_relation_data = set() + self.__ldap_dn = None if _ldap_response is None else _ldap_response['dn'] + self.__ldap_attributes = {} + for key, values in (_ldap_response or {}).get('attributes', {}).items(): + if isinstance(values, list): + self.__ldap_attributes[key] = values + else: + self.__ldap_attributes[key] = [values] + self.__attributes = deepcopy(self.__ldap_attributes) + self.__changes = {} + for key, value, in kwargs.items(): + if not hasattr(self, key): + raise Exception() + setattr(self, key, value) + for name in self.ldap_relations: + self.__update_relations(name, add_dns=self.__attributes.get(name, [])) + + def __update_relations(self, name, delete_dns=None, add_dns=None): + if name in self.ldap_relations: + self.ldap_session.update_relations(self, name, delete_dns, add_dns) + + def ldap_getattr(self, name): + return self.__attributes.get(name, []) + + def ldap_setattr(self, name, values): + self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) + self.__changes[name] = [(MODIFY_REPLACE, values)] + self.__attributes[name] = values + self.__update_relations(name, add_dns=values) + + def ldap_attradd(self, name, value): + self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])] + self.__attributes[name].append(value) + self.__update_relations(name, add_dns=[value]) + + def ldap_attrdel(self, name, value): + self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])] + if value in self.__attributes.get(name, []): + self.__attributes[name].remove(value) + self.__update_relations(name, delete_dns=[value]) + + def __repr__(self): + name = '%s.%s'%(type(self).__module__, type(self).__name__) + if self.__ldap_dn is None: + return '<%s>'%name + return '<%s %s>'%(name, self.__ldap_dn) + + def build_dn(self): + if self.ldap_dn_attribute is None: + return None + if self.ldap_dn_base is None: + return None + if self.__attributes.get(self.ldap_dn_attribute) is None: + return None + return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(self.__attributes[self.ldap_dn_attribute][0]), self.ldap_dn_base) + + @property + def dn(self): + if self.__ldap_dn is not None: + return self.__ldap_dn + return self.build_dn() + + @classmethod + def ldap_get(cls, dn): + obj = cls.ldap_mapper.session.lookup(dn) + if obj is None: + conn = cls.ldap_mapper.connect() + conn.search(dn, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + if not conn.response: + return None + if len(conn.response) != 1: + raise Exception() + obj = cls.ldap_mapper.session.register(cls(_ldap_response=conn.response[0])) + return obj + + @classmethod + def ldap_all(cls): + conn = cls.ldap_mapper.connect() + conn.search(cls.ldap_base, cls.ldap_filter, attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + res = [] + for entry in conn.response: + obj = cls.ldap_mapper.session.lookup(entry['dn']) + if obj is None: + obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry)) + res.append(obj) + return res + + @classmethod + def ldap_filter_by_raw(cls, **kwargs): + filters = [cls.ldap_filter] + for key, value in kwargs.items(): + filters.append('(%s=%s)'%(key, escape_filter_chars(value))) + conn = cls.ldap_mapper.connect() + conn.search(cls.ldap_base, '(&%s)'%(''.join(filters)), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]) + res = [] + for entry in conn.response: + obj = cls.ldap_mapper.session.lookup(entry['dn']) + if obj is None: + obj = cls.ldap_mapper.session.register(cls(_ldap_response=entry)) + res.append(obj) + return res + + @classmethod + def ldap_filter_by(cls, **kwargs): + _kwargs = {} + for key, value in kwargs.items(): + attr = getattr(cls, key) + _kwargs[attr.name] = attr.encode(value) + return cls.ldap_filter_by_raw(**_kwargs) + + def ldap_reset(self): + for name in self.ldap_relations: + self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) + self.__changes = {} + self.__attributes = deepcopy(self.__ldap_attributes) + for name in self.ldap_relations: + self.__update_relations(name, add_dns=self.__attributes.get(name, [])) + + @property + def ldap_dirty(self): + return bool(self.__changes) + + @property + def ldap_created(self): + return bool(self.__ldap_attributes) + + def ldap_modify(self): + if not self.ldap_created: + raise Exception() + if not self.ldap_dirty: + return + conn = self.ldap_mapper.connect() + success = conn.modify(self.dn, self.__changes) + if not success: + raise Exception() + self.__changes = {} + self.__ldap_attributes = deepcopy(self.__attributes) + + def ldap_create(self): + if self.ldap_created: + raise Exception() + conn = self.ldap_mapper.connect() + for func in self.ldap_pre_create_hooks: + func(self) + success = conn.add(self.dn, self.ldap_object_classes, self.__attributes) + if not success: + raise LDAPCommitError() + self.__changes = {} + self.__ldap_attributes = deepcopy(self.__attributes) + + def ldap_delete(self): + conn = self.ldap_mapper.connect() + success = conn.delete(self.dn) + if not success: + raise Exception() + self.__ldap_attributes = {} + +class Attribute: + ldap_mapper = None # Overwritten by LDAP3Mapper + + def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None): + self.name = name + self.multi = multi + self.encode = encode or (lambda x: x) + self.decode = decode or (lambda x: x) + self.default_values = default + self.aliases = aliases or [] + + def default(self, obj): + if obj.ldap_getattr(self.name) == []: + values = self.default_values + if callable(values): + values = values() + self.__set__(obj, values) + + def additem(self, obj, value): + obj.ldap_attradd(self.name, value) + for name in self.aliases: + obj.ldap_attradd(name, value) + + def delitem(self, obj, value): + obj.ldap_attradd(self.name, value) + for name in self.aliases: + obj.ldap_attradd(name, value) + + def __set_name__(self, cls, name): + if self.default_values is not None: + cls.ldap_pre_create_hooks = cls.ldap_pre_create_hooks + [self.default] + + def __get__(self, obj, objtype=None): + if obj is None: + return self + if self.multi: + return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name), + additem=lambda value: self.additem(obj, value), + delitem=lambda value: self.delitem(obj, value), + encode=self.encode, decode=self.decode) + return self.decode((obj.ldap_getattr(self.name) or [None])[0]) + + def __set__(self, obj, values): + if not self.multi: + values = [values] + obj.ldap_setattr(self.name, [self.encode(value) for value in values]) + for name in self.aliases: + obj.ldap_setattr(name, [self.encode(value) for value in values]) + +class Backref: + ldap_mapper = None # Overwritten by LDAP3Mapper + + def __init__(self, srccls, srcattr): + self.srccls = srccls + self.srcattr = srcattr + srccls.ldap_relations = srccls.ldap_relations + [srcattr] + + def init(self, obj): + if self.srcattr not in obj.ldap_relation_data and obj.ldap_created: + # The query instanciates all related objects that in turn add their relations to session + self.srccls.ldap_filter_by_raw(**{self.srcattr: obj.dn}) + obj.ldap_relation_data.add(self.srcattr) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + self.init(obj) + return LDAPSet(getitems=lambda: obj.ldap_session.lookup_relations(self.srccls, self.srcattr, obj.dn), + additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn), + delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn)) + + def __set__(self, obj, values): + current = self.__get__(obj) + current.clear() + for value in values: + current.add(value) + +class Relation(Attribute): + ldap_mapper = None # Overwritten by LDAP3Mapper + + def __init__(self, name, dest, backref=None): + super().__init__(name, multi=True, encode=lambda value: value.dn, decode=dest.ldap_get) + self.name = name + self.dest = dest + self.backref = backref + + def __set_name__(self, cls, name): + if self.backref is not None: + setattr(self.dest, self.backref, Backref(cls, self.name)) diff --git a/ldap3_mapper/db_relation.py b/ldap3_mapper/db_relation.py new file mode 100644 index 0000000000000000000000000000000000000000..337e086960184d086605f10729ea9aceb73acd97 --- /dev/null +++ b/ldap3_mapper/db_relation.py @@ -0,0 +1,66 @@ +from .types import LDAPSet + +class DB2LDAPBackref: + def __init__(self, baseattr, mapcls, backattr): + self.baseattr = baseattr + self.mapcls = mapcls + self.backattr = backattr + + def getitems(self, ldapobj): + return {getattr(mapobj, self.backattr) for mapobj in self.mapcls.query.filter_by(dn=ldapobj.dn)} + + def additem(self, ldapobj, dbobj): + if dbobj not in self.getitems(ldapobj): + getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) + + def delitem(self, ldapobj, dbobj): + for mapobj in list(getattr(dbobj, self.baseattr)): + if mapobj.dn == ldapobj.dn: + getattr(dbobj, self.baseattr).remove(mapobj) + + def __get__(self, ldapobj, objtype=None): + if ldapobj is None: + return self + return LDAPSet(getitems=lambda: self.getitems(ldapobj), + additem=lambda dbobj: self.additem(ldapobj, dbobj), + delitem=lambda dbobj: self.delitem(ldapobj, dbobj)) + + def __set__(self, ldapobj, dbobjs): + rel = self.__get__(ldapobj) + rel.clear() + for dbobj in dbobjs: + rel.add(dbobj) + +class DB2LDAPRelation: + def __init__(self, baseattr, mapcls, ldapcls, backattr=None, backref=None): + self.baseattr = baseattr + self.mapcls = mapcls + self.ldapcls = ldapcls + if backref is not None: + setattr(ldapcls, backref, DB2LDAPBackref(baseattr, mapcls, backattr)) + + def getitems(self, dbobj): + return {mapobj.dn for mapobj in getattr(dbobj, self.baseattr)} + + def additem(self, dbobj, dn): + if dn not in self.getitems(dbobj): + getattr(dbobj, self.baseattr).append(self.mapcls(dn=dn)) + + def delitem(self, dbobj, dn): + for mapobj in list(getattr(dbobj, self.baseattr)): + if mapobj.dn == dn: + getattr(dbobj, self.baseattr).remove(mapobj) + + def __get__(self, dbobj, objtype=None): + if dbobj is None: + return self + return LDAPSet(getitems=lambda: self.getitems(dbobj), + additem=lambda dn: self.additem(dbobj, dn), + delitem=lambda dn: self.delitem(dbobj, dn), + encode=lambda ldapobj: ldapobj.dn, + decode=self.ldapcls.ldap_get) + + def __set__(self, dbobj, ldapobjs): + getattr(dbobj, self.baseattr).clear() + for ldapobj in ldapobjs: + getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) diff --git a/ldap3_mapper/types.py b/ldap3_mapper/types.py new file mode 100644 index 0000000000000000000000000000000000000000..59fb50fc1358b86635eb2bc2714b97edb29666a7 --- /dev/null +++ b/ldap3_mapper/types.py @@ -0,0 +1,35 @@ +from collections.abc import MutableSet + +class LDAPCommitError(Exception): + pass + +class LDAPSet(MutableSet): + def __init__(self, getitems, additem, delitem, encode=None, decode=None): + self.__getitems = getitems + self.__additem = additem + self.__delitem = delitem + self.__encode = encode or (lambda x: x) + self.__decode = decode or (lambda x: x) + + def __repr__(self): + return repr(set(self)) + + def __contains__(self, value): + return value is not None and self.__encode(value) in self.__getitems() + + def __iter__(self): + return iter(filter(lambda obj: obj is not None, map(self.__decode, self.__getitems()))) + + def __len__(self): + return len(set(self)) + + def add(self, value): + if value not in self: + self.__additem(self.__encode(value)) + + def discard(self, value): + self.__delitem(self.__encode(value)) + + def update(self, values): + for value in values: + self.add(value) diff --git a/uffd/__init__.py b/uffd/__init__.py index be953535077d3ac966347348ef037dffc002789b..3d5ae6262855692f379638f88a0f89b4a6adf067 100644 --- a/uffd/__init__.py +++ b/uffd/__init__.py @@ -5,7 +5,6 @@ from flask import Flask, redirect, url_for from werkzeug.routing import IntegerConverter from uffd.database import db, SQLAlchemyJSON -from uffd.ldap import ldap from uffd.template_helper import register_template_helper from uffd.navbar import setup_navbar diff --git a/uffd/ldap.py b/uffd/ldap.py index dc2333f1e6369241af75a2e068d8e43b02ce845d..2365a81ee0385026ffc3f2152eec3a5ab1ae7bc6 100644 --- a/uffd/ldap.py +++ b/uffd/ldap.py @@ -1,471 +1,35 @@ -from copy import deepcopy -from collections.abc import MutableSet - from flask import current_app, request -from ldap3.utils.conv import escape_filter_chars -from ldap3.core.exceptions import LDAPBindError, LDAPPasswordIsMandatoryError -from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD - -from ldap3 import Server, Connection, ALL, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, MOCK_SYNC - -def fix_connection(conn): - old_search = conn.search - def search(*args, **kwargs): - kwargs.update({'attributes': [ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES]}) - return old_search(*args, **kwargs) - conn.search = search - return conn - -def get_mock_conn(): - if not current_app.debug: - raise Exception('LDAP_SERVICE_MOCK cannot be enabled on production instances') - # Entries are stored in-memory in the mocked `Connection` object. To make - # changes persistent across requests we reuse the same `Connection` object - # for all calls to `service_conn()` and `user_conn()`. - if not hasattr(current_app, 'ldap_mock'): - server = Server.from_definition('ldap_mock', 'ldap_server_info.json', 'ldap_server_schema.json') - current_app.ldap_mock = fix_connection(Connection(server, client_strategy=MOCK_SYNC)) - current_app.ldap_mock.strategy.entries_from_json('ldap_server_entries.json') - current_app.ldap_mock.bind() - return current_app.ldap_mock - -def get_conn(): - if current_app.config.get('LDAP_SERVICE_MOCK', False): - return get_mock_conn() - server = Server(current_app.config["LDAP_SERVICE_URL"], get_info=ALL) - return fix_connection(Connection(server, current_app.config["LDAP_SERVICE_BIND_DN"], current_app.config["LDAP_SERVICE_BIND_PASSWORD"], auto_bind=True)) +import ldap3 -def user_conn(dn, password): - if current_app.config.get('LDAP_SERVICE_MOCK', False): - conn = get_mock_conn() - # Since we reuse the same conn for all calls to `user_conn()` we - # simulate the password check by rebinding. Note that ldap3's mocking - # implementation just compares the string in the objects's userPassword - # field with the password, no support for hashing or OpenLDAP-style - # password-prefixes ("{PLAIN}..." or "{ssha512}..."). - try: - if not conn.rebind(dn, password): - return False - except (LDAPBindError, LDAPPasswordIsMandatoryError): - return False - return get_mock_conn() - server = Server(current_app.config["LDAP_SERVICE_URL"], get_info=ALL) - try: - return fix_connection(Connection(server, dn, password, auto_bind=True)) - except (LDAPBindError, LDAPPasswordIsMandatoryError): - return False +from ldap3_mapper import LDAP3Mapper, LDAPCommitError # pylint: disable=unused-import +from ldap3_mapper.base import Session -class LDAPCommitError(Exception): - pass - -class LDAPSession: +class FlaskLDAP3Mapper(LDAP3Mapper): def __init__(self): - self.__objects = {} # dn -> instance - self.__to_delete = [] - self.__relations = {} # (srccls, srcattr, dn) -> {srcobj, ...} - - def lookup(self, dn): - return self.__objects.get(dn) - - def register(self, obj): - if obj.dn in self.__objects and self.__objects[obj.dn] != obj: - raise Exception() - self.__objects[obj.dn] = obj - return obj - - def lookup_relations(self, srccls, srcattr, dn): - key = (srccls, srcattr, dn) - return self.__relations.get(key, set()) + super().__init__() - def update_relations(self, srcobj, srcattr, delete_dns=None, add_dns=None): - for dn in (delete_dns or []): - key = (type(srcobj), srcattr, dn) - self.__relations[key] = self.__relations.get(key, set()) - self.__relations[key].discard(srcobj) - for dn in (add_dns or []): - key = (type(srcobj), srcattr, dn) - self.__relations[key] = self.__relations.get(key, set()) - self.__relations[key].add(srcobj) - - def add(self, obj): - self.register(obj) - - def delete(self, obj): - if obj.dn in self.__objects: - del self.__objects[obj.dn] - self.__to_delete.append(obj) - - def commit(self): - while self.__to_delete: - self.__to_delete.pop(0).ldap_delete() - for obj in list(self.__objects.values()): - if not obj.ldap_created: - obj.ldap_create() - elif obj.ldap_dirty: - obj.ldap_modify() - - def rollback(self): - self.__to_delete.clear() - self.__objects = {dn: obj for dn, obj in self.__objects.items() if obj.ldap_created} - for obj in self.__objects.values(): - if obj.ldap_dirty: - obj.ldap_reset() - -class FlaskLDAPMapper: @property def session(self): if not hasattr(request, 'ldap_session'): - request.ldap_session = LDAPSession() + request.ldap_session = Session() return request.ldap_session -ldap = FlaskLDAPMapper() - -class LDAPSet(MutableSet): - def __init__(self, getitems, additem, delitem, encode=None, decode=None): - self.__getitems = getitems - self.__additem = additem - self.__delitem = delitem - self.__encode = encode or (lambda x: x) - self.__decode = decode or (lambda x: x) - - def __repr__(self): - return repr(set(self)) - - def __contains__(self, value): - return value is not None and self.__encode(value) in self.__getitems() - - def __iter__(self): - return iter(filter(lambda obj: obj is not None, map(self.__decode, self.__getitems()))) - - def __len__(self): - return len(set(self)) - - def add(self, value): - if value not in self: - self.__additem(self.__encode(value)) - - def discard(self, value): - self.__delitem(self.__encode(value)) - - def update(self, values): - for value in values: - self.add(value) - -class LDAPAttribute: - def __init__(self, name, multi=False, default=None, encode=None, decode=None, aliases=None): - self.name = name - self.multi = multi - self.encode = encode or (lambda x: x) - self.decode = decode or (lambda x: x) - self.default_values = default - self.aliases = aliases or [] - - def default(self, obj): - if obj.ldap_getattr(self.name) == []: - values = self.default_values - if callable(values): - values = values() - self.__set__(obj, values) - - def additem(self, obj, value): - obj.ldap_attradd(self.name, value) - for name in self.aliases: - obj.ldap_attradd(name, value) - - def delitem(self, obj, value): - obj.ldap_attradd(self.name, value) - for name in self.aliases: - obj.ldap_attradd(name, value) - - def __set_name__(self, cls, name): - if self.default_values is not None: - cls.ldap_pre_create_hooks = cls.ldap_pre_create_hooks + [self.default] - - def __get__(self, obj, objtype=None): - if obj is None: - return self - if self.multi: - return LDAPSet(getitems=lambda: obj.ldap_getattr(self.name), - additem=lambda value: self.additem(obj, value), - delitem=lambda value: self.delitem(obj, value), - encode=self.encode, decode=self.decode) - return self.decode((obj.ldap_getattr(self.name) or [None])[0]) - - def __set__(self, obj, values): - if not self.multi: - values = [values] - obj.ldap_setattr(self.name, [self.encode(value) for value in values]) - for name in self.aliases: - obj.ldap_setattr(name, [self.encode(value) for value in values]) - -class LDAPBackref: - def __init__(self, srccls, srcattr): - self.srccls = srccls - self.srcattr = srcattr - srccls.ldap_relations = srccls.ldap_relations + [srcattr] - - def init(self, obj): - if self.srcattr not in obj.ldap_relation_data and obj.ldap_created: - # The query instanciates all related objects that in turn add their relations to session - self.srccls.ldap_filter_by_raw(**{self.srcattr: obj.dn}) - obj.ldap_relation_data.add(self.srcattr) - - def __get__(self, obj, objtype=None): - if obj is None: - return self - self.init(obj) - return LDAPSet(getitems=lambda: ldap.session.lookup_relations(self.srccls, self.srcattr, obj.dn), - additem=lambda value: value.ldap_attradd(self.srcattr, obj.dn), - delitem=lambda value: value.ldap_attrdel(self.srcattr, obj.dn)) - - def __set__(self, obj, values): - current = self.__get__(obj) - current.clear() - for value in values: - current.add(value) - -class LDAPRelation(LDAPAttribute): - def __init__(self, name, dest, backref=None): - super().__init__(name, multi=True, encode=lambda value: value.dn, decode=dest.ldap_get) - self.name = name - self.dest = dest - self.backref = backref - - def __set_name__(self, cls, name): - if self.backref is not None: - setattr(self.dest, self.backref, LDAPBackref(cls, self.name)) - -class LDAPModel: - ldap_dn_attribute = None - ldap_dn_base = None - ldap_base = None - ldap_object_classes = None - ldap_filter = None - # Caution: Never mutate ldap_pre_create_hooks and ldap_relations, always reassign! - ldap_pre_create_hooks = [] - ldap_relations = [] - - def __init__(self, _ldap_response=None, **kwargs): - self.ldap_relation_data = set() - self.__ldap_dn = None if _ldap_response is None else _ldap_response['dn'] - self.__ldap_attributes = {} - for key, values in (_ldap_response or {}).get('attributes', {}).items(): - if isinstance(values, list): - self.__ldap_attributes[key] = values - else: - self.__ldap_attributes[key] = [values] - self.__attributes = deepcopy(self.__ldap_attributes) - self.__changes = {} - for key, value, in kwargs.items(): - if not hasattr(self, key): - raise Exception() - setattr(self, key, value) - for name in self.ldap_relations: - self.__update_relations(name, add_dns=self.__attributes.get(name, [])) - - def __update_relations(self, name, delete_dns=None, add_dns=None): - if name in self.ldap_relations: - ldap.session.update_relations(self, name, delete_dns, add_dns) - - def ldap_getattr(self, name): - return self.__attributes.get(name, []) - - def ldap_setattr(self, name, values): - self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) - self.__changes[name] = [(MODIFY_REPLACE, values)] - self.__attributes[name] = values - self.__update_relations(name, add_dns=values) - - def ldap_attradd(self, name, value): - self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_ADD, [value])] - self.__attributes[name].append(value) - self.__update_relations(name, add_dns=[value]) - - def ldap_attrdel(self, name, value): - self.__changes[name] = self.__changes.get(name, []) + [(MODIFY_DELETE, [value])] - if value in self.__attributes.get(name, []): - self.__attributes[name].remove(value) - self.__update_relations(name, delete_dns=[value]) - - def __repr__(self): - name = '%s.%s'%(type(self).__module__, type(self).__name__) - if self.__ldap_dn is None: - return '<%s>'%name - return '<%s %s>'%(name, self.__ldap_dn) - - def build_dn(self): - if self.ldap_dn_attribute is None: - return None - if self.ldap_dn_base is None: - return None - if self.__attributes.get(self.ldap_dn_attribute) is None: - return None - return '%s=%s,%s'%(self.ldap_dn_attribute, escape_filter_chars(self.__attributes[self.ldap_dn_attribute][0]), self.ldap_dn_base) - - @property - def dn(self): - if self.__ldap_dn is not None: - return self.__ldap_dn - return self.build_dn() - - @classmethod - def ldap_get(cls, dn): - obj = ldap.session.lookup(dn) - if obj is None: - conn = get_conn() - conn.search(dn, cls.ldap_filter) - if not conn.response: - return None - if len(conn.response) != 1: - raise Exception() - obj = ldap.session.register(cls(_ldap_response=conn.response[0])) - return obj - - @classmethod - def ldap_all(cls): - conn = get_conn() - conn.search(cls.ldap_base, cls.ldap_filter) - res = [] - for entry in conn.response: - obj = ldap.session.lookup(entry['dn']) - if obj is None: - obj = ldap.session.register(cls(_ldap_response=entry)) - res.append(obj) - return res - - @classmethod - def ldap_filter_by_raw(cls, **kwargs): - filters = [cls.ldap_filter] - for key, value in kwargs.items(): - filters.append('(%s=%s)'%(key, escape_filter_chars(value))) - conn = get_conn() - conn.search(cls.ldap_base, '(&%s)'%(''.join(filters))) - res = [] - for entry in conn.response: - obj = ldap.session.lookup(entry['dn']) - if obj is None: - obj = ldap.session.register(cls(_ldap_response=entry)) - res.append(obj) - return res - - @classmethod - def ldap_filter_by(cls, **kwargs): - _kwargs = {} - for key, value in kwargs.items(): - attr = getattr(cls, key) - _kwargs[attr.name] = attr.encode(value) - return cls.ldap_filter_by_raw(**_kwargs) - - def ldap_reset(self): - for name in self.ldap_relations: - self.__update_relations(name, delete_dns=self.__attributes.get(name, [])) - self.__changes = {} - self.__attributes = deepcopy(self.__ldap_attributes) - for name in self.ldap_relations: - self.__update_relations(name, add_dns=self.__attributes.get(name, [])) - - @property - def ldap_dirty(self): - return bool(self.__changes) - - @property - def ldap_created(self): - return bool(self.__ldap_attributes) - - def ldap_modify(self): - if not self.ldap_created: - raise Exception() - if not self.ldap_dirty: - return - conn = get_conn() - success = conn.modify(self.dn, self.__changes) - if not success: - raise Exception() - self.__changes = {} - self.__ldap_attributes = deepcopy(self.__attributes) - - def ldap_create(self): - if self.ldap_created: - raise Exception() - conn = get_conn() - for func in self.ldap_pre_create_hooks: - func(self) - success = conn.add(self.dn, self.ldap_object_classes, self.__attributes) - if not success: - print('commit error', success, conn.result) - raise LDAPCommitError() - self.__changes = {} - self.__ldap_attributes = deepcopy(self.__attributes) - - def ldap_delete(self): - conn = get_conn() - success = conn.delete(self.dn) - if not success: - raise Exception() - self.__ldap_attributes = {} - -class DB2LDAPBackref: - def __init__(self, baseattr, mapcls, backattr): - self.baseattr = baseattr - self.mapcls = mapcls - self.backattr = backattr - - def getitems(self, ldapobj): - return {getattr(mapobj, self.backattr) for mapobj in self.mapcls.query.filter_by(dn=ldapobj.dn)} - - def additem(self, ldapobj, dbobj): - if dbobj not in self.getitems(ldapobj): - getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) - - def delitem(self, ldapobj, dbobj): - for mapobj in list(getattr(dbobj, self.baseattr)): - if mapobj.dn == ldapobj.dn: - getattr(dbobj, self.baseattr).remove(mapobj) - - def __get__(self, ldapobj, objtype=None): - if ldapobj is None: - return self - return LDAPSet(getitems=lambda: self.getitems(ldapobj), - additem=lambda dbobj: self.additem(ldapobj, dbobj), - delitem=lambda dbobj: self.delitem(ldapobj, dbobj)) - - def __set__(self, ldapobj, dbobjs): - rel = self.__get__(ldapobj) - rel.clear() - for dbobj in dbobjs: - rel.add(dbobj) - -class DB2LDAPRelation: - def __init__(self, baseattr, mapcls, ldapcls, backattr=None, backref=None): - self.baseattr = baseattr - self.mapcls = mapcls - self.ldapcls = ldapcls - if backref is not None: - setattr(ldapcls, backref, DB2LDAPBackref(baseattr, mapcls, backattr)) - - def getitems(self, dbobj): - return {mapobj.dn for mapobj in getattr(dbobj, self.baseattr)} - - def additem(self, dbobj, dn): - if dn not in self.getitems(dbobj): - getattr(dbobj, self.baseattr).append(self.mapcls(dn=dn)) - - def delitem(self, dbobj, dn): - for mapobj in list(getattr(dbobj, self.baseattr)): - if mapobj.dn == dn: - getattr(dbobj, self.baseattr).remove(mapobj) - - def __get__(self, dbobj, objtype=None): - if dbobj is None: - return self - return LDAPSet(getitems=lambda: self.getitems(dbobj), - additem=lambda dn: self.additem(dbobj, dn), - delitem=lambda dn: self.delitem(dbobj, dn), - encode=lambda ldapobj: ldapobj.dn, - decode=self.ldapcls.ldap_get) - - def __set__(self, dbobj, ldapobjs): - getattr(dbobj, self.baseattr).clear() - for ldapobj in ldapobjs: - getattr(dbobj, self.baseattr).append(self.mapcls(dn=ldapobj.dn)) + def connect(self): + if current_app.config.get('LDAP_SERVICE_MOCK', False): + if not current_app.debug: + raise Exception('LDAP_SERVICE_MOCK cannot be enabled on production instances') + # Entries are stored in-memory in the mocked `Connection` object. To make + # changes persistent across requests we reuse the same `Connection` object + # for all calls to `service_conn()` and `user_conn()`. + if not hasattr(current_app, 'ldap_mock'): + server = ldap3.Server.from_definition('ldap_mock', 'ldap_server_info.json', 'ldap_server_schema.json') + current_app.ldap_mock = ldap3.Connection(server, client_strategy=ldap3.MOCK_SYNC) + current_app.ldap_mock.strategy.entries_from_json('ldap_server_entries.json') + current_app.ldap_mock.bind() + return current_app.ldap_mock + server = ldap3.Server(current_app.config["LDAP_SERVICE_URL"], get_info=ldap3.ALL) + return ldap3.Connection(server, current_app.config["LDAP_SERVICE_BIND_DN"], + current_app.config["LDAP_SERVICE_BIND_PASSWORD"], auto_bind=True) + +ldap = FlaskLDAP3Mapper() diff --git a/uffd/mail/models.py b/uffd/mail/models.py index 26b7c82afd4418ffff20203ad69a82d7006a37dd..22a7e0b2aefb7e0183ad508a3456bc3e3e9dbe66 100644 --- a/uffd/mail/models.py +++ b/uffd/mail/models.py @@ -1,13 +1,13 @@ -from uffd.ldap import LDAPModel, LDAPAttribute +from uffd.ldap import ldap from uffd.lazyconfig import lazyconfig_str, lazyconfig_list -class Mail(LDAPModel): +class Mail(ldap.Model): ldap_base = lazyconfig_str('LDAP_BASE_MAIL') ldap_dn_attribute = 'uid' ldap_dn_base = lazyconfig_str('LDAP_BASE_MAIL') ldap_filter = '(objectClass=postfixVirtual)' ldap_object_classes = lazyconfig_list('MAIL_LDAP_OBJECTCLASSES') - uid = LDAPAttribute('uid') - receivers = LDAPAttribute('mailacceptinggeneralid', multi=True) - destinations = LDAPAttribute('maildrop', multi=True) + uid = ldap.Attribute('uid') + receivers = ldap.Attribute('mailacceptinggeneralid', multi=True) + destinations = ldap.Attribute('maildrop', multi=True) diff --git a/uffd/role/models.py b/uffd/role/models.py index 79445573e77c7448f236cfc6dcb2539cb7ec6c92..44512b55239c89dca057241e67d8a310848f134a 100644 --- a/uffd/role/models.py +++ b/uffd/role/models.py @@ -2,8 +2,9 @@ from sqlalchemy import Column, String, Integer, Text, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declared_attr +from ldap3_mapper.db_relation import DB2LDAPRelation + from uffd.database import db -from uffd.ldap import DB2LDAPRelation from uffd.user.models import User, Group class LdapMapping: diff --git a/uffd/session/views.py b/uffd/session/views.py index e9b28992a85dea8a85fa8c5e0f5154e212981d8e..cb28ee5780386213135b02afb76896d59e9a1da9 100644 --- a/uffd/session/views.py +++ b/uffd/session/views.py @@ -4,8 +4,11 @@ import functools from flask import Blueprint, render_template, request, url_for, redirect, flash, current_app, session, abort +import ldap3 +from ldap3.core.exceptions import LDAPBindError, LDAPPasswordIsMandatoryError + from uffd.user.models import User -from uffd.ldap import user_conn +from uffd.ldap import ldap from uffd.ratelimit import Ratelimit, host_ratelimit, format_delay bp = Blueprint("session", __name__, template_folder='templates', url_prefix='/') @@ -13,15 +16,27 @@ bp = Blueprint("session", __name__, template_folder='templates', url_prefix='/') login_ratelimit = Ratelimit('login', 1*60, 3) def login_get_user(loginname, password): - print('login with', loginname, password) dn = User(loginname=loginname).dn - conn = user_conn(dn, password) - if not conn: - print('conn is None') - return None + if current_app.config.get('LDAP_SERVICE_MOCK', False): + conn = ldap.connect() + # Since we reuse the same conn for all calls to `user_conn()` we + # simulate the password check by rebinding. Note that ldap3's mocking + # implementation just compares the string in the objects's userPassword + # field with the password, no support for hashing or OpenLDAP-style + # password-prefixes ("{PLAIN}..." or "{ssha512}..."). + try: + if not conn.rebind(dn, password): + return None + except (LDAPBindError, LDAPPasswordIsMandatoryError): + return None + else: + server = ldap3.Server(current_app.config["LDAP_SERVICE_URL"], get_info=ldap3.ALL) + try: + conn = ldap3.Connection(server, dn, password, auto_bind=True) + except (LDAPBindError, LDAPPasswordIsMandatoryError): + return None conn.search(conn.user, '(objectClass=person)') if len(conn.entries) != 1: - print('wrong number of entries', conn.entries) return None return User.ldap_get(dn) diff --git a/uffd/user/models.py b/uffd/user/models.py index 88308dc14860670dc5ef4123f0ee5cad76f59e75..e4507027e7f2a1706be5a6ee3e5f1c41636d55db 100644 --- a/uffd/user/models.py +++ b/uffd/user/models.py @@ -4,7 +4,7 @@ import string from flask import current_app from ldap3.utils.hashed import hashed, HASHED_SALTED_SHA512 -from uffd.ldap import LDAPModel, LDAPAttribute, LDAPRelation +from uffd.ldap import ldap from uffd.lazyconfig import lazyconfig_str, lazyconfig_list def get_next_uid(): @@ -17,18 +17,18 @@ def get_next_uid(): raise Exception('No free uid found') return next_uid -class User(LDAPModel): +class User(ldap.Model): ldap_base = lazyconfig_str('LDAP_BASE_USER') ldap_dn_attribute = 'uid' ldap_dn_base = lazyconfig_str('LDAP_BASE_USER') ldap_filter = '(objectClass=person)' ldap_object_classes = lazyconfig_list('LDAP_USER_OBJECTCLASSES') - uid = LDAPAttribute('uidNumber', default=get_next_uid) - loginname = LDAPAttribute('uid') - displayname = LDAPAttribute('cn', aliases=['givenName', 'displayName']) - mail = LDAPAttribute('mail') - pwhash = LDAPAttribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) + uid = ldap.Attribute('uidNumber', default=get_next_uid) + loginname = ldap.Attribute('uid') + displayname = ldap.Attribute('cn', aliases=['givenName', 'displayName']) + mail = ldap.Attribute('mail') + pwhash = ldap.Attribute('userPassword', default=lambda: hashed(HASHED_SALTED_SHA512, secrets.token_hex(128))) groups = [] # Shuts up pylint, overwritten by back-reference roles = [] # Shuts up pylint, overwritten by back-reference @@ -41,7 +41,7 @@ class User(LDAPModel): if self.ldap_getattr('gidNumber') == []: self.ldap_setattr('gidNumber', [current_app.config['LDAP_USER_GID']]) - ldap_pre_create_hooks = LDAPModel.ldap_pre_create_hooks + [dummy_attribute_defaults] + ldap_pre_create_hooks = ldap.Model.ldap_pre_create_hooks + [dummy_attribute_defaults] # Write-only property def password(self, value): @@ -101,13 +101,13 @@ class User(LDAPModel): self.mail = value return True -class Group(LDAPModel): +class Group(ldap.Model): ldap_base = lazyconfig_str('LDAP_BASE_GROUPS') ldap_filter = '(objectClass=groupOfUniqueNames)' - gid = LDAPAttribute('gidNumber') - name = LDAPAttribute('cn') - description = LDAPAttribute('description', default='') - members = LDAPRelation('uniqueMember', User, backref='groups') + gid = ldap.Attribute('gidNumber') + name = ldap.Attribute('cn') + description = ldap.Attribute('description', default='') + members = ldap.Relation('uniqueMember', User, backref='groups') roles = [] # Shuts up pylint, overwritten by back-reference