diff --git a/db.py b/db.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5bf0756a0bbae4bca03eeb7ecd7ea672d3704a --- /dev/null +++ b/db.py @@ -0,0 +1,201 @@ +from crypt import crypt + +from sqlalchemy import create_engine, or_, and_, Column, Integer, String +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base +from dn import parse_dn, build_dn, DNScope + +Base = declarative_base() + +class BaseSearchEvaluator: + def __call__(self, base, scope, filter_expr): + dn_res = self.filter_dn(base, scope) + filter_res = self.filter_expr(filter_expr) + return self.query(self.filter_and(dn_res, filter_res)) + + def filter_expr(self, expr): + operator, *args = expr + if operator == 'and': + return self.filter_and(*[self.filter_expr(subexpr) for subexpr in args]) + elif operator == 'or': + return self.filter_or(*[self.filter_expr(subexpr) for subexpr in args]) + elif operator == 'not': + return self.filter_not(self.filter_expr(args[0])) + elif operator == 'equal': + return self.filter_equal(args[0].lower(), args[1]) + elif operator == 'present': + return self.filter_present(args[0].lower()) + else: + return False + + def filter_present(self, name): + return False + + def filter_equal(self, name, value): + return False + + def filter_and(self, *subresults): + filtered = [] + for subres in subresults: + if subres is True: + continue + if subres is False: + return False + filtered.append(subres) + if not filtered: + return True + return self._filter_and(*filtered) + + def _filter_and(self, *subresults): + return False + + def filter_or(self, *subresults): + filtered = [] + for subres in subresults: + if subres is True: + return True + if subres is False: + continue + filtered.append(subres) + if not filtered: + return False + return self._filter_or(*filtered) + + def _filter_or(self, *subresults): + return False + + def filter_not(self, subresult): + if subresult is True: + return False + if subresult is False: + return True + return self._filter_not(subresult) + + def _filter_not(self, subresult): + return False + + def filter_dn(self, base, scope): + return False + + def query(self, filter_obj): + return [] + +class SQLSearchEvaluator(BaseSearchEvaluator): + def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): + self.model = model + self.session = session + self.attributes = attributes or {} + self.objectclasses = objectclasses or [] + self.rdn_attr = rdn_attr + self.dn_base_path = parse_dn(dn_base) + + def filter_present(self, name): + if name == 'objectclass': + return True + if name not in self.attributes: + return False + return getattr(self.model, self.attributes[name]).is_not(None) + + def filter_equal(self, name, value): + if name == 'objectclass': + return value in self.objectclasses + if name not in self.attributes: + return False + attr = getattr(self.model, self.attributes[name]) + if isinstance(attr.type, String): + value = value.decode() + elif isinstance(attr.type, Integer): + value = int(value) + return attr == value + + def _filter_and(self, *subresults): + return and_(*subresults) + + def _filter_or(self, *subresults): + return or_(*subresults) + + def _filter_not(self, subresult): + return ~subresult + + def filter_dn(self, base, scope): + search_path = list(parse_dn(base)) + base_path = list(self.dn_base_path) + while search_path and base_path: + if search_path.pop() != base_path.pop(): + return False + if scope == DNScope.baseObject: + if base_path or len(search_path) != 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr: + return False + return self.filter_equal(self.rdn_attr, search_path[0][0][1]) + elif scope == DNScope.singleLevel: + return not search_path and not base_path + elif scope == DNScope.wholeSubtree: + if not search_path: + return True + if len(search_path) > 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr: + return False + return self.filter_equal(self.rdn_attr, search_path[0][0][1]) + else: + return False + + def query(self, filter_obj): + if filter_obj is False: + return [] + elif filter_obj is True: + objs = self.session.query(self.model) + else: + objs = self.session.query(self.model).filter(filter_obj) + results = [] + for obj in objs: + attrs = {} + for ldap_name, attr_name in self.attributes.items(): + attrs [ldap_name] = getattr(obj, attr_name) + attrs['objectClass'] = self.objectclasses + dn_parts = (((self.rdn_attr, attrs[self.rdn_attr]),),) + self.dn_base_path + results.append((build_dn(dn_parts), attrs)) + return results + +engine = create_engine('sqlite:///db.sqlite', echo=True) +Session = sessionmaker(bind=engine) +session = Session() + +class LDAPViewMixin: + ldap_attributes = {} + ldap_objectclasses = [b'top'] + ldap_rdn_attribute = 'uid' + ldap_dn_base = '' + + @classmethod + def ldap_search(cls, base, scope, filter_expr, conn): + evaluator = SQLSearchEvaluator(cls, session, attributes=cls.ldap_attributes, + objectclasses=cls.ldap_objectclasses, rdn_attr=cls.ldap_rdn_attribute, + dn_base=cls.ldap_dn_base) + return evaluator(base, scope, filter_expr) + +class User(Base, LDAPViewMixin): + __tablename__ = 'users' + ldap_attributes = { + 'givenname': 'displayname', + 'mail': 'email', + 'uid': 'loginname', + 'uidnumeric': 'id', + } + ldap_objectclasses = [b'top', b'person'] + ldap_rdn_attribute = 'uid' + ldap_dn_base = 'ou=users,dc=example,dc=com' + + id = Column(Integer, primary_key=True) + loginname = Column(String, unique=True, nullable=False) + displayname = Column(String, nullable=False, default='') + email = Column(String) + pwhash = Column(String) + + # Write-only property + def password(self, value): + self.pwhash = crypt(value) + password = property(fset=password) + + def check_password(self, password): + return self.pwhash is not None and crypt(value, self.pwhash) == self.pwhash + +Base.metadata.create_all(engine) diff --git a/dn.py b/dn.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce21a366f435b6fcbc505920d0f2d5fd5b0a095 --- /dev/null +++ b/dn.py @@ -0,0 +1,121 @@ +from string import hexdigits as HEXDIGITS + +DN_ESCAPED = ('"', '+', ',', ';', '<', '>') +DN_SPECIAL = DN_ESCAPED + (' ', '#', '=') + +def parse_assertion(expr, case_ignore_attrs=None): + case_ignore_attrs = case_ignore_attrs or [] + hexdigit = None + escaped = False + tokens = [] + token = b'' + for c in expr: + if hexdigit is not None: + if c not in HEXDIGITS: + raise ValueError('Invalid hexpair: \\%c%c'%(hexdigit, c)) + token += bytes.fromhex('%c%c'%(hexdigit, c)) + hexdigit = None + elif escaped: + escaped = False + if c in DN_SPECIAL or c == '\\': + token += c.encode() + elif c in HEXDIGITS: + hexdigit = c + else: + raise ValueError('Invalid escape: \\%c'%c) + elif c == '\\': + escaped = True + elif c == '=': + tokens.append(token) + token = b'' + else: + token += c.encode() + tokens.append(token) + if len(tokens) != 2: + raise ValueError('Invalid assertion in RDN: "%s"'%expr) + name = tokens[0].decode().lower() + value = tokens[1] + if not name or not value: + raise ValueError('Invalid assertion in RDN: "%s"'%expr) + # TODO: handle hex strings + if name in case_ignore_attrs: + value = value.lower() + return (name, value) + +def parse_rdn(rdn, case_ignore_attrs=None): + escaped = False + assertions = [] + token = '' + for c in rdn: + if escaped: + escaped = False + token += c + elif c == '+': + assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs)) + token = '' + else: + if c == '\\': + escaped = True + token += c + assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs)) + if not assertions: + raise ValueError('Invalid RDN "%s"'%rdn) + return tuple(sorted(assertions)) + +def parse_dn(dn, case_ignore_attrs=None): + if not dn: + return tuple() + escaped = False + rdns = [] + rdn = '' + for c in dn: + if escaped: + escaped = False + rdn += c + elif c == ',': + rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs)) + rdn = '' + else: + if c == '\\': + escaped = True + rdn += c + rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs)) + return tuple(rdns) + +# >>> parse_dn('OU=Sales+CN=J. Smith,DC=example,DC=net', case_ignore_attrs=['cn', 'ou', 'dc']) +# ((('cn', b'j. smith'), ('ou', b'sales')), (('dc', b'example'),), (('dc', b'net'),)) + +def escape_dn_value(value): + if isinstance(value, int): + value = str(value) + if isinstance(value, str): + value = value.encode() + res = '' + for c in value: + c = bytes((c,)) + try: + s = c.decode() + except UnicodeDecodeError: + s = '\\'+c.hex() + if s in DN_SPECIAL: + s = '\\'+s + res += s + return res + +def build_assertion(assertion): + name, value = assertion + return '%s=%s'%(escape_dn_value(name.encode()), escape_dn_value(value)) + +def build_rdn(assertions): + return '+'.join(map(build_assertion, assertions)) + +def build_dn(rdns): + return ','.join(map(build_rdn, rdns)) + +from enum import Enum + +class DNScope(Enum): + baseObject = 0 # The scope is constrained to the entry named by baseObject. + singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject. + wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates. +