From aa64b5c1577ed9da2638b295500e8759be9fef29 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Wed, 10 Mar 2021 22:26:51 +0100 Subject: [PATCH] Implemented user-group relationships with LDAP mapping --- db.py | 260 ++++++++++++++++++++-------------------------------------- 1 file changed, 90 insertions(+), 170 deletions(-) diff --git a/db.py b/db.py index fb84192..ec49d48 100644 --- a/db.py +++ b/db.py @@ -3,12 +3,12 @@ import struct from crypt import crypt from ssl import SSLContext, SSLSocket -from sqlalchemy import create_engine, or_, and_, Column, Integer, String -from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, or_, and_, Column, Integer, String, Table, ForeignKey +from sqlalchemy.orm import sessionmaker, relationship, RelationshipProperty, aliased from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent -from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired, LDAPNoSuchObject +from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired, LDAPNoSuchObject, encode_attribute import socketserver from dn import parse_dn, build_dn @@ -86,95 +86,54 @@ class BaseSearchEvaluator: def query(self, filter_obj): return [] -class StaticLDAPObject: - def __init__(self, dn, attributes=None): - self.dn = dn - self.attributes = attributes or {} - -class StaticSearchEvaluator(BaseSearchEvaluator): - def __init__(self): - self.present_map = {} # name -> set of objs - self.value_map = {} # (name, value) -> set of objs - self.dn_map = {} # parsed dn -> obj - self.singlelevel_map = {} # parsed dn part -> set objs - self.subtree_map = {} # parsed dn part -> set objs - self.all_objects = set() - - def add(self, dn, attributes): - dn = parse_dn(dn) - obj = StaticLDAPObject(dn) - assert dn not in self.dn_map - self.dn_map[dn] = {obj} - self.all_objects.add(obj) - if dn: - key = tuple(dn[1:]) - self.singlelevel_map[key] = self.singlelevel_map.get(key, set()) - self.singlelevel_map[key].add(obj) - path = list(dn) - for _ in range(len(path) + 1): - key = tuple(path) - self.subtree_map[key] = self.subtree_map.get(key, set()) - self.subtree_map[key].add(obj) - if path: - path.pop(0) - for name, values in attributes.items(): - if not isinstance(values, list): - values = [values] - obj.attributes[name] = [] - for value in values: - if isinstance(value, int): - value = str(value) - if isinstance(value, str): - value = value.encode() - obj.attributes[name].append(value) - key = name.lower() - self.present_map[key] = self.present_map.get(key, set()) - self.present_map[key].add(obj) - key = (name.lower(), value) - self.value_map[key] = self.value_map.get(key, set()) - self.value_map[key].add(obj) +engine = create_engine('sqlite:///db.sqlite', echo=True) +Session = sessionmaker(bind=engine) +session = Session() - def filter_present(self, name): - key = name.lower() - return self.present_map.get(key, set()) +user_groups = Table('user_groups', Base.metadata, + Column('user_id', ForeignKey('users.id'), primary_key=True), + Column('group_id', ForeignKey('groups.id'), primary_key=True) +) - def filter_equal(self, name, value): - key = (name.lower(), value) - return self.value_map.get(key, set()) +class User(Base): + __tablename__ = 'users' - def _filter_and(self, *subresults): - objs = subresults[0] - for subres in subresults[1:]: - objs = objs.intersection(subres) - return objs + id = Column(Integer, primary_key=True) + loginname = Column(String, unique=True, nullable=False) + displayname = Column(String, nullable=False, default='') + email = Column(String) + pwhash = Column(String) + groups = relationship('Group', secondary=user_groups, back_populates='users') - def _filter_or(self, *subresults): - objs = subresults[0] - for subres in subresults[1:]: - objs = objs.union(subres) - return objs + ldap_gid = 1 + ldap_sn = ' ' - def _filter_not(self, subresult): - return self.all_objects.difference(subresults) + @hybrid_property + def homedirectory(self): + return '/home/' + self.loginname - def filter_dn(self, base, scope): - dn = parse_dn(base) - if scope == SearchScope.baseObject: - return self.dn_map.get(dn, set()) - elif scope == SearchScope.singleLevel: - return self.singlelevel_map.get(dn, set()) - elif scope == SearchScope.wholeSubtree: - return self.subtree_map.get(dn, set()) - else: - return set() + # Write-only property + def password(self, value): + self.pwhash = crypt(value) + password = property(fset=password) - def query(self, filter_obj): - return [(build_dn(obj.dn), obj.attributes) for obj in filter_obj] + def check_password(self, password): + return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash + +class Group(Base): + __tablename__ = 'groups' + + id = Column(Integer, primary_key=True) + name = Column(String, unique=True, nullable=False) + description = Column(String, nullable=False, default='') + users = relationship('User', secondary=user_groups, back_populates='groups') + +Base.metadata.create_all(engine) -class SQLSearchEvaluator(BaseSearchEvaluator): - def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): +class SQLModelWrapper(BaseSearchEvaluator): + def __init__(self, store, model, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): + self.store = store self.model = model - self.session = session self.attributes = {} for ldap_name, attr_name in (attributes or {}).items(): self.attributes[ldap_name.lower()] = attr_name @@ -185,6 +144,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator): value = value.encode() self.objectclasses.append(value) self.rdn_attr = rdn_attr + self.dn_base = dn_base self.dn_base_path = parse_dn(dn_base) def filter_present(self, name): @@ -192,7 +152,10 @@ class SQLSearchEvaluator(BaseSearchEvaluator): return True if name not in self.attributes: return False - return getattr(self.model, self.attributes[name]).isnot(None) + attr = getattr(self.model, self.attributes[name]) + if hasattr(attr, 'prop') and isinstance(attr.prop, RelationshipProperty): + return attr.any() + return attr.isnot(None) def filter_equal(self, name, value): if name == 'objectclass': @@ -200,13 +163,17 @@ class SQLSearchEvaluator(BaseSearchEvaluator): if name not in self.attributes: return False attr = getattr(self.model, self.attributes[name]) - if hasattr(attr, 'type') and isinstance(attr.type, String): + if isinstance(attr, str): value = value.decode() - elif hasattr(attr, 'type') and isinstance(attr.type, Integer): + elif isinstance(attr, int): value = int(value) - elif isinstance(attr, str): + elif isinstance(attr, bytes): + pass + elif hasattr(attr, 'prop') and isinstance(attr.prop, RelationshipProperty): + return attr.any(self.store.models[attr.prop.argument()].filter_dn(value.decode(), SearchScope.baseObject)) + elif hasattr(attr, 'type') and isinstance(attr.type, String): value = value.decode() - elif isinstance(attr, int): + elif hasattr(attr, 'type') and isinstance(attr.type, Integer): value = int(value) return attr == value @@ -240,71 +207,51 @@ class SQLSearchEvaluator(BaseSearchEvaluator): else: return False + def get_dn(self, obj): + attr_name = self.attributes[self.rdn_attr] + rdn_value = encode_attribute(getattr(obj, attr_name)) + dn_parts = (((self.rdn_attr, rdn_value),),) + self.dn_base_path + return build_dn(dn_parts) + def query(self, filter_obj): if filter_obj is False: return [] elif filter_obj is True: - objs = self.session.query(self.model) + objs = self.store.session.query(self.model) else: - objs = self.session.query(self.model).filter(filter_obj) + objs = self.store.session.query(self.model).filter(filter_obj) for obj in objs: attrs = {} for ldap_name, attr_name in self.attributes.items(): - value = getattr(obj, attr_name) - if value is None: + values = getattr(obj, attr_name) + if values is None: continue - if isinstance(value, int): - value = str(value) - if isinstance(value, str): - value = value.encode() - attrs[ldap_name] = [value] + if not isinstance(values, list): + values = [values] + attrs[ldap_name] = [] + for value in values: + if isinstance(value, Base): + value = self.store.models[type(value)].get_dn(value) + value = encode_attribute(value) + attrs[ldap_name].append(value) attrs['objectClass'] = self.objectclasses - dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path - yield (build_dn(dn_parts), attrs) + yield (self.get_dn(obj), attrs) -engine = create_engine('sqlite:///db.sqlite', echo=True) -Session = sessionmaker(bind=engine) -session = Session() - -class User(Base): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - loginname = Column(String, unique=True, nullable=False) - displayname = Column(String, nullable=False, default='') - email = Column(String) - pwhash = Column(String) - - ldap_gid = 1 - ldap_sn = ' ' - - @hybrid_property - def homedirectory(self): - return '/home/' + self.loginname - - # Write-only property - def password(self, value): - self.pwhash = crypt(value) - password = property(fset=password) - - def check_password(self, password): - return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash - -class Group(Base): - __tablename__ = 'groups' - - id = Column(Integer, primary_key=True) - name = Column(String, unique=True, nullable=False) - description = Column(String, nullable=False, default='') +class SQLObjectStore: + def __init__(self, session): + self.session = session + self.models = {} -Base.metadata.create_all(engine) + def register_model(self, model, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): + self.models[model] = SQLModelWrapper(self, model, attributes, objectclasses, rdn_attr, dn_base) -staticobjs = StaticSearchEvaluator() -staticobjs.add(dn='', attributes={'objectClass': 'top', 'supportedSASLMechanisms': ['EXTERNAL']}) #'PLAIN', 'ANONYMOUS', 'EXTERNAL', 'SCRAM', 'DIGEST-MD5', 'CRAM-MD5', 'NTLM']}) + def search(self, baseobj, scope, filter): + for model, wrapper in self.models.items(): + yield from wrapper(baseobj, scope, filter) -usereval = SQLSearchEvaluator( +sqlstore = SQLObjectStore(session) +sqlstore.register_model( model=User, - session=session, attributes={ 'cn': 'displayname', 'displayname': 'displayname', @@ -315,19 +262,20 @@ usereval = SQLSearchEvaluator( 'sn': 'ldap_sn', 'uid': 'loginname', 'uidnumber': 'id', + 'memberof': 'groups', }, objectclasses=[b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'], rdn_attr='uid', dn_base='ou=users,dc=example,dc=com' ) -groupeval = SQLSearchEvaluator( +sqlstore.register_model( model=Group, - session=session, attributes={ 'cn': 'name', 'description': 'description', 'gidnumber': 'id', + 'uniqueMember': 'users', }, objectclasses=[b'top', b'posixGroup', b'groupOfUniqueNames'], rdn_attr='cn', @@ -337,33 +285,6 @@ groupeval = SQLSearchEvaluator( ssl_context = SSLContext() ssl_context.load_cert_chain('devcert.crt', 'devcert.key') -class OldRequestHandler(LDAPRequestHandler): - ssl_context = ssl_context - - def do_bind(self, name, password): - if not name and not password: - return None - if not isinstance(self.request, SSLSocket): - raise LDAPConfidentialityRequired() - try: - password = password.decode() - except UnicodeDecodeError: - raise LDAPInvalidCredentials() - try: - user = session.query(User).filter(usereval.filter_dn(name, SearchScope.baseObject)).one_or_none() - except ValueError: - raise LDAPInvalidCredentials() - if user is None or not user.check_password(password): - raise LDAPInvalidCredentials() - return user - - def do_search(self, baseobj, scope, filter): - #if self.bind_object is None: - # raise LDAPInsufficientAccessRights() - yield from staticobjs(baseobj, scope, filter) - yield from usereval(baseobj, scope, filter) - yield from groupeval(baseobj, scope, filter) - class RequestHandler(LDAPRequestHandler): ssl_context = ssl_context @@ -389,7 +310,7 @@ class RequestHandler(LDAPRequestHandler): except UnicodeDecodeError: raise LDAPInvalidCredentials() try: - user = session.query(User).filter(usereval.filter_dn(dn, SearchScope.baseObject)).one_or_none() + user = session.query(User).filter(sqlstore.models[User].filter_dn(dn, SearchScope.baseObject)).one_or_none() except ValueError: raise LDAPInvalidCredentials() if user is None or not user.check_password(password): @@ -421,8 +342,7 @@ class RequestHandler(LDAPRequestHandler): def do_search(self, baseobj, scope, filter): yield from super().do_search(baseobj, scope, filter) if self.bind_object is not None: - yield from usereval(baseobj, scope, filter) - yield from groupeval(baseobj, scope, filter) + yield from sqlstore.search(baseobj, scope, filter) socketserver.ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever() #socketserver.UnixStreamServer('/tmp/ldapd.sock', RequestHandler).serve_forever() -- GitLab