diff --git a/db.py b/db.py index 36d131f0699b6d0b110e44670e7d00f63cfc1c81..50e08a8a212bae6ffb9a67e70f4088ca5445a81d 100644 --- a/db.py +++ b/db.py @@ -1,11 +1,12 @@ from crypt import crypt +from ssl import SSLContext, SSLSocket from sqlalchemy import create_engine, or_, and_, Column, Integer, String from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent -from server import LDAPRequestHandler +from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired from socketserver import ForkingTCPServer from dn import parse_dn, build_dn @@ -83,6 +84,91 @@ class BaseSearchEvaluator: def query(self, filter_obj): return [] +class StaticLDAPObject: + def __init__(self, dn, attributes=None): + self.dn = dn + self.attributes = attributes or {} + +class StaticSearchEvaluator(BaseSearchEvaluator): + def __init__(self): + self.present_map = {} # name -> set of objs + self.value_map = {} # (name, value) -> set of objs + self.dn_map = {} # parsed dn -> obj + self.singlelevel_map = {} # parsed dn part -> set objs + self.subtree_map = {} # parsed dn part -> set objs + self.all_objects = set() + + def add(self, dn, attributes): + dn = parse_dn(dn) + obj = StaticLDAPObject(dn) + assert dn not in self.dn_map + self.dn_map[dn] = {obj} + self.all_objects.add(obj) + if dn: + key = tuple(dn[1:]) + self.singlelevel_map[key] = self.singlelevel_map.get(key, set()) + self.singlelevel_map[key].add(obj) + path = list(dn) + for _ in range(len(path) + 1): + key = tuple(path) + self.subtree_map[key] = self.subtree_map.get(key, set()) + self.subtree_map[key].add(obj) + if path: + path.pop(0) + for name, values in attributes.items(): + if not isinstance(values, list): + values = [values] + obj.attributes[name] = [] + for value in values: + if isinstance(value, int): + value = str(value) + if isinstance(value, str): + value = value.encode() + obj.attributes[name].append(value) + key = name.lower() + self.present_map[key] = self.present_map.get(key, set()) + self.present_map[key].add(obj) + key = (name.lower(), value) + self.value_map[key] = self.value_map.get(key, set()) + self.value_map[key].add(obj) + + def filter_present(self, name): + key = name.lower() + return self.present_map.get(key, set()) + + def filter_equal(self, name, value): + key = (name.lower(), value) + return self.value_map.get(key, set()) + + def _filter_and(self, *subresults): + objs = subresults[0] + for subres in subresults[1:]: + objs = objs.intersection(subres) + return objs + + def _filter_or(self, *subresults): + objs = subresults[0] + for subres in subresults[1:]: + objs = objs.union(subres) + return objs + + def _filter_not(self, subresult): + return self.all_objects.difference(subresults) + + def filter_dn(self, base, scope): + dn = parse_dn(base) + if scope == SearchScope.baseObject: + return self.dn_map.get(dn, set()) + elif scope == SearchScope.singleLevel: + return self.singlelevel_map.get(dn, set()) + elif scope == SearchScope.wholeSubtree: + return self.subtree_map.get(dn, set()) + else: + return set() + + def query(self, filter_obj): + return [(build_dn(obj.dn), obj.attributes) for obj in filter_obj] + class SQLSearchEvaluator(BaseSearchEvaluator): def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): self.model = model @@ -159,7 +245,6 @@ class SQLSearchEvaluator(BaseSearchEvaluator): objs = self.session.query(self.model) else: objs = self.session.query(self.model).filter(filter_obj) - results = [] for obj in objs: attrs = {} for ldap_name, attr_name in self.attributes.items(): @@ -173,8 +258,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator): attrs[ldap_name] = [value] attrs['objectClass'] = self.objectclasses dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path - results.append((build_dn(dn_parts), attrs)) - return results + yield (build_dn(dn_parts), attrs) engine = create_engine('sqlite:///db.sqlite', echo=True) Session = sessionmaker(bind=engine) @@ -213,57 +297,69 @@ class Group(Base): Base.metadata.create_all(engine) +staticobjs = StaticSearchEvaluator() +staticobjs.add(dn='', attributes={'objectClass': 'top', 'supportedSASLMechanisms': ['PLAIN', 'ANONYMOUS', 'EXTERNAL', 'SCRAM', 'DIGEST-MD5', 'CRAM-MD5', 'NTLM']}) + +usereval = SQLSearchEvaluator( + model=User, + session=session, + attributes={ + 'cn': 'displayname', + 'displayname': 'displayname', + 'gidnumber': 'ldap_gid', + 'givenname': 'displayname', + 'homedirectory': 'homedirectory', + 'mail': 'email', + 'sn': 'ldap_sn', + 'uid': 'loginname', + 'uidnumber': 'id', + }, + objectclasses=[b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'], + rdn_attr='uid', + dn_base='ou=users,dc=example,dc=com' +) + +groupeval = SQLSearchEvaluator( + model=Group, + session=session, + attributes={ + 'cn': 'name', + 'description': 'description', + 'gidnumber': 'id', + }, + objectclasses=[b'top', b'posixGroup', b'groupOfUniqueNames'], + rdn_attr='cn', + dn_base='ou=groups,dc=example,dc=com' +) + +ssl_context = SSLContext() +ssl_context.load_cert_chain('devcert.crt', 'devcert.key') + class RequestHandler(LDAPRequestHandler): + ssl_context = ssl_context + def do_bind(self, name, password): if not name and not password: return None + if not isinstance(self.request, SSLSocket): + raise LDAPConfidentialityRequired() try: password = password.decode() except UnicodeDecodeError: raise LDAPInvalidCredentials() try: - evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, - objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute, - dn_base=User.ldap_dn_base) + user = session.query(User).filter(usereval.filter_dn(name, SearchScope.baseObject)).one_or_none() except ValueError: raise LDAPInvalidCredentials() - user = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none() if user is None or not user.check_password(password): raise LDAPInvalidCredentials() return user def do_search(self, baseobj, scope, filter): - # User - ldap_attributes = { - 'cn': 'displayname', - 'displayname': 'displayname', - 'gidnumber': 'ldap_gid', - 'givenname': 'displayname', - 'homedirectory': 'homedirectory', - 'mail': 'email', - 'sn': 'ldap_sn', - 'uid': 'loginname', - 'uidnumber': 'id', - } - ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'] - ldap_rdn_attribute = 'uid' - ldap_dn_base = 'ou=users,dc=example,dc=com' - evaluator = SQLSearchEvaluator(User, session, attributes=ldap_attributes, - objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute, - dn_base=ldap_dn_base) - yield from evaluator(baseobj, scope, filter) - # Group - ldap_attributes = { - 'cn': 'name', - 'description': 'description', - 'gidnumber': 'id', - } - ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames'] - ldap_rdn_attribute = 'cn' - ldap_dn_base = 'ou=groups,dc=example,dc=com' - evaluator = SQLSearchEvaluator(Group, session, attributes=ldap_attributes, - objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute, - dn_base=ldap_dn_base) - yield from evaluator(baseobj, scope, filter) + #if self.bind_object is None: + # raise LDAPInsufficientAccessRights() + yield from staticobjs(baseobj, scope, filter) + yield from usereval(baseobj, scope, filter) + yield from groupeval(baseobj, scope, filter) ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever()