Skip to content
Snippets Groups Projects
Commit 46345640 authored by Julian Rother's avatar Julian Rother
Browse files

Implemented StaticSearchEvaluator and restructured model integration code

parent 58540c20
Branches
Tags
No related merge requests found
from crypt import crypt from crypt import crypt
from ssl import SSLContext, SSLSocket
from sqlalchemy import create_engine, or_, and_, Column, Integer, String from sqlalchemy import create_engine, or_, and_, Column, Integer, String
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent
from server import LDAPRequestHandler from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired
from socketserver import ForkingTCPServer from socketserver import ForkingTCPServer
from dn import parse_dn, build_dn from dn import parse_dn, build_dn
...@@ -83,6 +84,91 @@ class BaseSearchEvaluator: ...@@ -83,6 +84,91 @@ class BaseSearchEvaluator:
def query(self, filter_obj): def query(self, filter_obj):
return [] return []
class StaticLDAPObject:
def __init__(self, dn, attributes=None):
self.dn = dn
self.attributes = attributes or {}
class StaticSearchEvaluator(BaseSearchEvaluator):
def __init__(self):
self.present_map = {} # name -> set of objs
self.value_map = {} # (name, value) -> set of objs
self.dn_map = {} # parsed dn -> obj
self.singlelevel_map = {} # parsed dn part -> set objs
self.subtree_map = {} # parsed dn part -> set objs
self.all_objects = set()
def add(self, dn, attributes):
dn = parse_dn(dn)
obj = StaticLDAPObject(dn)
assert dn not in self.dn_map
self.dn_map[dn] = {obj}
self.all_objects.add(obj)
if dn:
key = tuple(dn[1:])
self.singlelevel_map[key] = self.singlelevel_map.get(key, set())
self.singlelevel_map[key].add(obj)
path = list(dn)
for _ in range(len(path) + 1):
key = tuple(path)
self.subtree_map[key] = self.subtree_map.get(key, set())
self.subtree_map[key].add(obj)
if path:
path.pop(0)
for name, values in attributes.items():
if not isinstance(values, list):
values = [values]
obj.attributes[name] = []
for value in values:
if isinstance(value, int):
value = str(value)
if isinstance(value, str):
value = value.encode()
obj.attributes[name].append(value)
key = name.lower()
self.present_map[key] = self.present_map.get(key, set())
self.present_map[key].add(obj)
key = (name.lower(), value)
self.value_map[key] = self.value_map.get(key, set())
self.value_map[key].add(obj)
def filter_present(self, name):
key = name.lower()
return self.present_map.get(key, set())
def filter_equal(self, name, value):
key = (name.lower(), value)
return self.value_map.get(key, set())
def _filter_and(self, *subresults):
objs = subresults[0]
for subres in subresults[1:]:
objs = objs.intersection(subres)
return objs
def _filter_or(self, *subresults):
objs = subresults[0]
for subres in subresults[1:]:
objs = objs.union(subres)
return objs
def _filter_not(self, subresult):
return self.all_objects.difference(subresults)
def filter_dn(self, base, scope):
dn = parse_dn(base)
if scope == SearchScope.baseObject:
return self.dn_map.get(dn, set())
elif scope == SearchScope.singleLevel:
return self.singlelevel_map.get(dn, set())
elif scope == SearchScope.wholeSubtree:
return self.subtree_map.get(dn, set())
else:
return set()
def query(self, filter_obj):
return [(build_dn(obj.dn), obj.attributes) for obj in filter_obj]
class SQLSearchEvaluator(BaseSearchEvaluator): class SQLSearchEvaluator(BaseSearchEvaluator):
def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
self.model = model self.model = model
...@@ -159,7 +245,6 @@ class SQLSearchEvaluator(BaseSearchEvaluator): ...@@ -159,7 +245,6 @@ class SQLSearchEvaluator(BaseSearchEvaluator):
objs = self.session.query(self.model) objs = self.session.query(self.model)
else: else:
objs = self.session.query(self.model).filter(filter_obj) objs = self.session.query(self.model).filter(filter_obj)
results = []
for obj in objs: for obj in objs:
attrs = {} attrs = {}
for ldap_name, attr_name in self.attributes.items(): for ldap_name, attr_name in self.attributes.items():
...@@ -173,8 +258,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator): ...@@ -173,8 +258,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator):
attrs[ldap_name] = [value] attrs[ldap_name] = [value]
attrs['objectClass'] = self.objectclasses attrs['objectClass'] = self.objectclasses
dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path
results.append((build_dn(dn_parts), attrs)) yield (build_dn(dn_parts), attrs)
return results
engine = create_engine('sqlite:///db.sqlite', echo=True) engine = create_engine('sqlite:///db.sqlite', echo=True)
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)
...@@ -213,57 +297,69 @@ class Group(Base): ...@@ -213,57 +297,69 @@ class Group(Base):
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
staticobjs = StaticSearchEvaluator()
staticobjs.add(dn='', attributes={'objectClass': 'top', 'supportedSASLMechanisms': ['PLAIN', 'ANONYMOUS', 'EXTERNAL', 'SCRAM', 'DIGEST-MD5', 'CRAM-MD5', 'NTLM']})
usereval = SQLSearchEvaluator(
model=User,
session=session,
attributes={
'cn': 'displayname',
'displayname': 'displayname',
'gidnumber': 'ldap_gid',
'givenname': 'displayname',
'homedirectory': 'homedirectory',
'mail': 'email',
'sn': 'ldap_sn',
'uid': 'loginname',
'uidnumber': 'id',
},
objectclasses=[b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'],
rdn_attr='uid',
dn_base='ou=users,dc=example,dc=com'
)
groupeval = SQLSearchEvaluator(
model=Group,
session=session,
attributes={
'cn': 'name',
'description': 'description',
'gidnumber': 'id',
},
objectclasses=[b'top', b'posixGroup', b'groupOfUniqueNames'],
rdn_attr='cn',
dn_base='ou=groups,dc=example,dc=com'
)
ssl_context = SSLContext()
ssl_context.load_cert_chain('devcert.crt', 'devcert.key')
class RequestHandler(LDAPRequestHandler): class RequestHandler(LDAPRequestHandler):
ssl_context = ssl_context
def do_bind(self, name, password): def do_bind(self, name, password):
if not name and not password: if not name and not password:
return None return None
if not isinstance(self.request, SSLSocket):
raise LDAPConfidentialityRequired()
try: try:
password = password.decode() password = password.decode()
except UnicodeDecodeError: except UnicodeDecodeError:
raise LDAPInvalidCredentials() raise LDAPInvalidCredentials()
try: try:
evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, user = session.query(User).filter(usereval.filter_dn(name, SearchScope.baseObject)).one_or_none()
objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
dn_base=User.ldap_dn_base)
except ValueError: except ValueError:
raise LDAPInvalidCredentials() raise LDAPInvalidCredentials()
user = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
if user is None or not user.check_password(password): if user is None or not user.check_password(password):
raise LDAPInvalidCredentials() raise LDAPInvalidCredentials()
return user return user
def do_search(self, baseobj, scope, filter): def do_search(self, baseobj, scope, filter):
# User #if self.bind_object is None:
ldap_attributes = { # raise LDAPInsufficientAccessRights()
'cn': 'displayname', yield from staticobjs(baseobj, scope, filter)
'displayname': 'displayname', yield from usereval(baseobj, scope, filter)
'gidnumber': 'ldap_gid', yield from groupeval(baseobj, scope, filter)
'givenname': 'displayname',
'homedirectory': 'homedirectory',
'mail': 'email',
'sn': 'ldap_sn',
'uid': 'loginname',
'uidnumber': 'id',
}
ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount']
ldap_rdn_attribute = 'uid'
ldap_dn_base = 'ou=users,dc=example,dc=com'
evaluator = SQLSearchEvaluator(User, session, attributes=ldap_attributes,
objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
dn_base=ldap_dn_base)
yield from evaluator(baseobj, scope, filter)
# Group
ldap_attributes = {
'cn': 'name',
'description': 'description',
'gidnumber': 'id',
}
ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames']
ldap_rdn_attribute = 'cn'
ldap_dn_base = 'ou=groups,dc=example,dc=com'
evaluator = SQLSearchEvaluator(Group, session, attributes=ldap_attributes,
objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
dn_base=ldap_dn_base)
yield from evaluator(baseobj, scope, filter)
ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever() ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment