Skip to content
Snippets Groups Projects
Commit c57d3016 authored by Julian Rother's avatar Julian Rother
Browse files

Implemented dn utils and exemplary sqlalchemy adaptor

parent f062794f
No related branches found
No related tags found
No related merge requests found
db.py 0 → 100644
from crypt import crypt
from sqlalchemy import create_engine, or_, and_, Column, Integer, String
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from dn import parse_dn, build_dn, DNScope
Base = declarative_base()
class BaseSearchEvaluator:
def __call__(self, base, scope, filter_expr):
dn_res = self.filter_dn(base, scope)
filter_res = self.filter_expr(filter_expr)
return self.query(self.filter_and(dn_res, filter_res))
def filter_expr(self, expr):
operator, *args = expr
if operator == 'and':
return self.filter_and(*[self.filter_expr(subexpr) for subexpr in args])
elif operator == 'or':
return self.filter_or(*[self.filter_expr(subexpr) for subexpr in args])
elif operator == 'not':
return self.filter_not(self.filter_expr(args[0]))
elif operator == 'equal':
return self.filter_equal(args[0].lower(), args[1])
elif operator == 'present':
return self.filter_present(args[0].lower())
else:
return False
def filter_present(self, name):
return False
def filter_equal(self, name, value):
return False
def filter_and(self, *subresults):
filtered = []
for subres in subresults:
if subres is True:
continue
if subres is False:
return False
filtered.append(subres)
if not filtered:
return True
return self._filter_and(*filtered)
def _filter_and(self, *subresults):
return False
def filter_or(self, *subresults):
filtered = []
for subres in subresults:
if subres is True:
return True
if subres is False:
continue
filtered.append(subres)
if not filtered:
return False
return self._filter_or(*filtered)
def _filter_or(self, *subresults):
return False
def filter_not(self, subresult):
if subresult is True:
return False
if subresult is False:
return True
return self._filter_not(subresult)
def _filter_not(self, subresult):
return False
def filter_dn(self, base, scope):
return False
def query(self, filter_obj):
return []
class SQLSearchEvaluator(BaseSearchEvaluator):
def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
self.model = model
self.session = session
self.attributes = attributes or {}
self.objectclasses = objectclasses or []
self.rdn_attr = rdn_attr
self.dn_base_path = parse_dn(dn_base)
def filter_present(self, name):
if name == 'objectclass':
return True
if name not in self.attributes:
return False
return getattr(self.model, self.attributes[name]).is_not(None)
def filter_equal(self, name, value):
if name == 'objectclass':
return value in self.objectclasses
if name not in self.attributes:
return False
attr = getattr(self.model, self.attributes[name])
if isinstance(attr.type, String):
value = value.decode()
elif isinstance(attr.type, Integer):
value = int(value)
return attr == value
def _filter_and(self, *subresults):
return and_(*subresults)
def _filter_or(self, *subresults):
return or_(*subresults)
def _filter_not(self, subresult):
return ~subresult
def filter_dn(self, base, scope):
search_path = list(parse_dn(base))
base_path = list(self.dn_base_path)
while search_path and base_path:
if search_path.pop() != base_path.pop():
return False
if scope == DNScope.baseObject:
if base_path or len(search_path) != 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr:
return False
return self.filter_equal(self.rdn_attr, search_path[0][0][1])
elif scope == DNScope.singleLevel:
return not search_path and not base_path
elif scope == DNScope.wholeSubtree:
if not search_path:
return True
if len(search_path) > 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr:
return False
return self.filter_equal(self.rdn_attr, search_path[0][0][1])
else:
return False
def query(self, filter_obj):
if filter_obj is False:
return []
elif filter_obj is True:
objs = self.session.query(self.model)
else:
objs = self.session.query(self.model).filter(filter_obj)
results = []
for obj in objs:
attrs = {}
for ldap_name, attr_name in self.attributes.items():
attrs [ldap_name] = getattr(obj, attr_name)
attrs['objectClass'] = self.objectclasses
dn_parts = (((self.rdn_attr, attrs[self.rdn_attr]),),) + self.dn_base_path
results.append((build_dn(dn_parts), attrs))
return results
engine = create_engine('sqlite:///db.sqlite', echo=True)
Session = sessionmaker(bind=engine)
session = Session()
class LDAPViewMixin:
ldap_attributes = {}
ldap_objectclasses = [b'top']
ldap_rdn_attribute = 'uid'
ldap_dn_base = ''
@classmethod
def ldap_search(cls, base, scope, filter_expr, conn):
evaluator = SQLSearchEvaluator(cls, session, attributes=cls.ldap_attributes,
objectclasses=cls.ldap_objectclasses, rdn_attr=cls.ldap_rdn_attribute,
dn_base=cls.ldap_dn_base)
return evaluator(base, scope, filter_expr)
class User(Base, LDAPViewMixin):
__tablename__ = 'users'
ldap_attributes = {
'givenname': 'displayname',
'mail': 'email',
'uid': 'loginname',
'uidnumeric': 'id',
}
ldap_objectclasses = [b'top', b'person']
ldap_rdn_attribute = 'uid'
ldap_dn_base = 'ou=users,dc=example,dc=com'
id = Column(Integer, primary_key=True)
loginname = Column(String, unique=True, nullable=False)
displayname = Column(String, nullable=False, default='')
email = Column(String)
pwhash = Column(String)
# Write-only property
def password(self, value):
self.pwhash = crypt(value)
password = property(fset=password)
def check_password(self, password):
return self.pwhash is not None and crypt(value, self.pwhash) == self.pwhash
Base.metadata.create_all(engine)
dn.py 0 → 100644
from string import hexdigits as HEXDIGITS
DN_ESCAPED = ('"', '+', ',', ';', '<', '>')
DN_SPECIAL = DN_ESCAPED + (' ', '#', '=')
def parse_assertion(expr, case_ignore_attrs=None):
case_ignore_attrs = case_ignore_attrs or []
hexdigit = None
escaped = False
tokens = []
token = b''
for c in expr:
if hexdigit is not None:
if c not in HEXDIGITS:
raise ValueError('Invalid hexpair: \\%c%c'%(hexdigit, c))
token += bytes.fromhex('%c%c'%(hexdigit, c))
hexdigit = None
elif escaped:
escaped = False
if c in DN_SPECIAL or c == '\\':
token += c.encode()
elif c in HEXDIGITS:
hexdigit = c
else:
raise ValueError('Invalid escape: \\%c'%c)
elif c == '\\':
escaped = True
elif c == '=':
tokens.append(token)
token = b''
else:
token += c.encode()
tokens.append(token)
if len(tokens) != 2:
raise ValueError('Invalid assertion in RDN: "%s"'%expr)
name = tokens[0].decode().lower()
value = tokens[1]
if not name or not value:
raise ValueError('Invalid assertion in RDN: "%s"'%expr)
# TODO: handle hex strings
if name in case_ignore_attrs:
value = value.lower()
return (name, value)
def parse_rdn(rdn, case_ignore_attrs=None):
escaped = False
assertions = []
token = ''
for c in rdn:
if escaped:
escaped = False
token += c
elif c == '+':
assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs))
token = ''
else:
if c == '\\':
escaped = True
token += c
assertions.append(parse_assertion(token, case_ignore_attrs=case_ignore_attrs))
if not assertions:
raise ValueError('Invalid RDN "%s"'%rdn)
return tuple(sorted(assertions))
def parse_dn(dn, case_ignore_attrs=None):
if not dn:
return tuple()
escaped = False
rdns = []
rdn = ''
for c in dn:
if escaped:
escaped = False
rdn += c
elif c == ',':
rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs))
rdn = ''
else:
if c == '\\':
escaped = True
rdn += c
rdns.append(parse_rdn(rdn, case_ignore_attrs=case_ignore_attrs))
return tuple(rdns)
# >>> parse_dn('OU=Sales+CN=J. Smith,DC=example,DC=net', case_ignore_attrs=['cn', 'ou', 'dc'])
# ((('cn', b'j. smith'), ('ou', b'sales')), (('dc', b'example'),), (('dc', b'net'),))
def escape_dn_value(value):
if isinstance(value, int):
value = str(value)
if isinstance(value, str):
value = value.encode()
res = ''
for c in value:
c = bytes((c,))
try:
s = c.decode()
except UnicodeDecodeError:
s = '\\'+c.hex()
if s in DN_SPECIAL:
s = '\\'+s
res += s
return res
def build_assertion(assertion):
name, value = assertion
return '%s=%s'%(escape_dn_value(name.encode()), escape_dn_value(value))
def build_rdn(assertions):
return '+'.join(map(build_assertion, assertions))
def build_dn(rdns):
return ','.join(map(build_rdn, rdns))
from enum import Enum
class DNScope(Enum):
baseObject = 0 # The scope is constrained to the entry named by baseObject.
singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject.
wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment