Skip to content
Snippets Groups Projects
Commit 46345640 authored by Julian Rother's avatar Julian Rother
Browse files

Implemented StaticSearchEvaluator and restructured model integration code

parent 58540c20
No related branches found
No related tags found
No related merge requests found
from crypt import crypt
from ssl import SSLContext, SSLSocket
from sqlalchemy import create_engine, or_, and_, Column, Integer, String
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent
from server import LDAPRequestHandler
from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired
from socketserver import ForkingTCPServer
from dn import parse_dn, build_dn
......@@ -83,6 +84,91 @@ class BaseSearchEvaluator:
def query(self, filter_obj):
return []
class StaticLDAPObject:
def __init__(self, dn, attributes=None):
self.dn = dn
self.attributes = attributes or {}
class StaticSearchEvaluator(BaseSearchEvaluator):
def __init__(self):
self.present_map = {} # name -> set of objs
self.value_map = {} # (name, value) -> set of objs
self.dn_map = {} # parsed dn -> obj
self.singlelevel_map = {} # parsed dn part -> set objs
self.subtree_map = {} # parsed dn part -> set objs
self.all_objects = set()
def add(self, dn, attributes):
dn = parse_dn(dn)
obj = StaticLDAPObject(dn)
assert dn not in self.dn_map
self.dn_map[dn] = {obj}
self.all_objects.add(obj)
if dn:
key = tuple(dn[1:])
self.singlelevel_map[key] = self.singlelevel_map.get(key, set())
self.singlelevel_map[key].add(obj)
path = list(dn)
for _ in range(len(path) + 1):
key = tuple(path)
self.subtree_map[key] = self.subtree_map.get(key, set())
self.subtree_map[key].add(obj)
if path:
path.pop(0)
for name, values in attributes.items():
if not isinstance(values, list):
values = [values]
obj.attributes[name] = []
for value in values:
if isinstance(value, int):
value = str(value)
if isinstance(value, str):
value = value.encode()
obj.attributes[name].append(value)
key = name.lower()
self.present_map[key] = self.present_map.get(key, set())
self.present_map[key].add(obj)
key = (name.lower(), value)
self.value_map[key] = self.value_map.get(key, set())
self.value_map[key].add(obj)
def filter_present(self, name):
key = name.lower()
return self.present_map.get(key, set())
def filter_equal(self, name, value):
key = (name.lower(), value)
return self.value_map.get(key, set())
def _filter_and(self, *subresults):
objs = subresults[0]
for subres in subresults[1:]:
objs = objs.intersection(subres)
return objs
def _filter_or(self, *subresults):
objs = subresults[0]
for subres in subresults[1:]:
objs = objs.union(subres)
return objs
def _filter_not(self, subresult):
return self.all_objects.difference(subresults)
def filter_dn(self, base, scope):
dn = parse_dn(base)
if scope == SearchScope.baseObject:
return self.dn_map.get(dn, set())
elif scope == SearchScope.singleLevel:
return self.singlelevel_map.get(dn, set())
elif scope == SearchScope.wholeSubtree:
return self.subtree_map.get(dn, set())
else:
return set()
def query(self, filter_obj):
return [(build_dn(obj.dn), obj.attributes) for obj in filter_obj]
class SQLSearchEvaluator(BaseSearchEvaluator):
def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
self.model = model
......@@ -159,7 +245,6 @@ class SQLSearchEvaluator(BaseSearchEvaluator):
objs = self.session.query(self.model)
else:
objs = self.session.query(self.model).filter(filter_obj)
results = []
for obj in objs:
attrs = {}
for ldap_name, attr_name in self.attributes.items():
......@@ -173,8 +258,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator):
attrs[ldap_name] = [value]
attrs['objectClass'] = self.objectclasses
dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path
results.append((build_dn(dn_parts), attrs))
return results
yield (build_dn(dn_parts), attrs)
engine = create_engine('sqlite:///db.sqlite', echo=True)
Session = sessionmaker(bind=engine)
......@@ -213,57 +297,69 @@ class Group(Base):
Base.metadata.create_all(engine)
staticobjs = StaticSearchEvaluator()
staticobjs.add(dn='', attributes={'objectClass': 'top', 'supportedSASLMechanisms': ['PLAIN', 'ANONYMOUS', 'EXTERNAL', 'SCRAM', 'DIGEST-MD5', 'CRAM-MD5', 'NTLM']})
usereval = SQLSearchEvaluator(
model=User,
session=session,
attributes={
'cn': 'displayname',
'displayname': 'displayname',
'gidnumber': 'ldap_gid',
'givenname': 'displayname',
'homedirectory': 'homedirectory',
'mail': 'email',
'sn': 'ldap_sn',
'uid': 'loginname',
'uidnumber': 'id',
},
objectclasses=[b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'],
rdn_attr='uid',
dn_base='ou=users,dc=example,dc=com'
)
groupeval = SQLSearchEvaluator(
model=Group,
session=session,
attributes={
'cn': 'name',
'description': 'description',
'gidnumber': 'id',
},
objectclasses=[b'top', b'posixGroup', b'groupOfUniqueNames'],
rdn_attr='cn',
dn_base='ou=groups,dc=example,dc=com'
)
ssl_context = SSLContext()
ssl_context.load_cert_chain('devcert.crt', 'devcert.key')
class RequestHandler(LDAPRequestHandler):
ssl_context = ssl_context
def do_bind(self, name, password):
if not name and not password:
return None
if not isinstance(self.request, SSLSocket):
raise LDAPConfidentialityRequired()
try:
password = password.decode()
except UnicodeDecodeError:
raise LDAPInvalidCredentials()
try:
evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes,
objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
dn_base=User.ldap_dn_base)
user = session.query(User).filter(usereval.filter_dn(name, SearchScope.baseObject)).one_or_none()
except ValueError:
raise LDAPInvalidCredentials()
user = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
if user is None or not user.check_password(password):
raise LDAPInvalidCredentials()
return user
def do_search(self, baseobj, scope, filter):
# User
ldap_attributes = {
'cn': 'displayname',
'displayname': 'displayname',
'gidnumber': 'ldap_gid',
'givenname': 'displayname',
'homedirectory': 'homedirectory',
'mail': 'email',
'sn': 'ldap_sn',
'uid': 'loginname',
'uidnumber': 'id',
}
ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount']
ldap_rdn_attribute = 'uid'
ldap_dn_base = 'ou=users,dc=example,dc=com'
evaluator = SQLSearchEvaluator(User, session, attributes=ldap_attributes,
objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
dn_base=ldap_dn_base)
yield from evaluator(baseobj, scope, filter)
# Group
ldap_attributes = {
'cn': 'name',
'description': 'description',
'gidnumber': 'id',
}
ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames']
ldap_rdn_attribute = 'cn'
ldap_dn_base = 'ou=groups,dc=example,dc=com'
evaluator = SQLSearchEvaluator(Group, session, attributes=ldap_attributes,
objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
dn_base=ldap_dn_base)
yield from evaluator(baseobj, scope, filter)
#if self.bind_object is None:
# raise LDAPInsufficientAccessRights()
yield from staticobjs(baseobj, scope, filter)
yield from usereval(baseobj, scope, filter)
yield from groupeval(baseobj, scope, filter)
ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment