From 7cdfc82e623262573f8a88132325360915286dd4 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Fri, 5 Mar 2021 23:52:43 +0100 Subject: [PATCH] Adapted db code to new server implementation --- db.py | 118 ++++++++++++++++++++++++++++++++++++++++++------------ dn.py | 11 ----- ldap.py | 10 +++-- server.py | 41 ++++++++++++++++--- 4 files changed, 136 insertions(+), 44 deletions(-) diff --git a/db.py b/db.py index ae5bf07..8c2791b 100644 --- a/db.py +++ b/db.py @@ -3,7 +3,10 @@ from crypt import crypt from sqlalchemy import create_engine, or_, and_, Column, Integer, String from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base -from dn import parse_dn, build_dn, DNScope +from sqlalchemy.ext.hybrid import hybrid_property +from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent +from server import Server as LDAPServer +from dn import parse_dn, build_dn Base = declarative_base() @@ -14,17 +17,16 @@ class BaseSearchEvaluator: return self.query(self.filter_and(dn_res, filter_res)) def filter_expr(self, expr): - operator, *args = expr - if operator == 'and': - return self.filter_and(*[self.filter_expr(subexpr) for subexpr in args]) - elif operator == 'or': - return self.filter_or(*[self.filter_expr(subexpr) for subexpr in args]) - elif operator == 'not': - return self.filter_not(self.filter_expr(args[0])) - elif operator == 'equal': - return self.filter_equal(args[0].lower(), args[1]) - elif operator == 'present': - return self.filter_present(args[0].lower()) + if isinstance(expr, FilterAnd): + return self.filter_and(*[self.filter_expr(subexpr) for subexpr in expr.filters]) + elif isinstance(expr, FilterOr): + return self.filter_or(*[self.filter_expr(subexpr) for subexpr in expr.filters]) + elif isinstance(expr, FilterNot): + return self.filter_not(self.filter_expr(expr.filter)) + elif isinstance(expr, FilterEqual): + return self.filter_equal(expr.attribute.lower(), expr.value) + elif isinstance(expr, FilterPresent): + return self.filter_present(expr.attribute.lower()) else: return False @@ -84,8 +86,15 @@ class SQLSearchEvaluator(BaseSearchEvaluator): def __init__(self, model, session, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''): self.model = model self.session = session - self.attributes = attributes or {} - self.objectclasses = objectclasses or [] + self.attributes = {} + for ldap_name, attr_name in (attributes or {}).items(): + self.attributes[ldap_name.lower()] = attr_name + self.objectclasses = [] + for value in (objectclasses or []): + value = value.lower() + if isinstance(value, str): + value = value.encode() + self.objectclasses.append(value) self.rdn_attr = rdn_attr self.dn_base_path = parse_dn(dn_base) @@ -94,7 +103,7 @@ class SQLSearchEvaluator(BaseSearchEvaluator): return True if name not in self.attributes: return False - return getattr(self.model, self.attributes[name]).is_not(None) + return getattr(self.model, self.attributes[name]).isnot(None) def filter_equal(self, name, value): if name == 'objectclass': @@ -102,9 +111,13 @@ class SQLSearchEvaluator(BaseSearchEvaluator): if name not in self.attributes: return False attr = getattr(self.model, self.attributes[name]) - if isinstance(attr.type, String): + if hasattr(attr, 'type') and isinstance(attr.type, String): value = value.decode() - elif isinstance(attr.type, Integer): + elif hasattr(attr, 'type') and isinstance(attr.type, Integer): + value = int(value) + elif isinstance(attr, str): + value = value.decode() + elif isinstance(attr, int): value = int(value) return attr == value @@ -123,13 +136,13 @@ class SQLSearchEvaluator(BaseSearchEvaluator): while search_path and base_path: if search_path.pop() != base_path.pop(): return False - if scope == DNScope.baseObject: + if scope == SearchScope.baseObject: if base_path or len(search_path) != 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr: return False return self.filter_equal(self.rdn_attr, search_path[0][0][1]) - elif scope == DNScope.singleLevel: + elif scope == SearchScope.singleLevel: return not search_path and not base_path - elif scope == DNScope.wholeSubtree: + elif scope == SearchScope.wholeSubtree: if not search_path: return True if len(search_path) > 1 or len(search_path[0]) != 1 or search_path[0][0][0] != self.rdn_attr: @@ -149,9 +162,16 @@ class SQLSearchEvaluator(BaseSearchEvaluator): for obj in objs: attrs = {} for ldap_name, attr_name in self.attributes.items(): - attrs [ldap_name] = getattr(obj, attr_name) + value = getattr(obj, attr_name) + if value is None: + continue + if isinstance(value, int): + value = str(value) + if isinstance(value, str): + value = value.encode() + attrs[ldap_name] = [value] attrs['objectClass'] = self.objectclasses - dn_parts = (((self.rdn_attr, attrs[self.rdn_attr]),),) + self.dn_base_path + dn_parts = (((self.rdn_attr, attrs[self.rdn_attr][0]),),) + self.dn_base_path results.append((build_dn(dn_parts), attrs)) return results @@ -175,12 +195,17 @@ class LDAPViewMixin: class User(Base, LDAPViewMixin): __tablename__ = 'users' ldap_attributes = { + 'cn': 'displayname', + 'displayname': 'displayname', + 'gidnumber': 'ldap_gid', 'givenname': 'displayname', + 'homedirectory': 'homedirectory', 'mail': 'email', + 'sn': 'ldap_sn', 'uid': 'loginname', - 'uidnumeric': 'id', + 'uidnumber': 'id', } - ldap_objectclasses = [b'top', b'person'] + ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'] ldap_rdn_attribute = 'uid' ldap_dn_base = 'ou=users,dc=example,dc=com' @@ -190,12 +215,55 @@ class User(Base, LDAPViewMixin): email = Column(String) pwhash = Column(String) + ldap_gid = 1 + ldap_sn = ' ' + + @hybrid_property + def homedirectory(self): + return '/home/' + self.loginname + # Write-only property def password(self, value): self.pwhash = crypt(value) password = property(fset=password) def check_password(self, password): - return self.pwhash is not None and crypt(value, self.pwhash) == self.pwhash + return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash + +class Group(Base, LDAPViewMixin): + __tablename__ = 'groups' + ldap_attributes = { + 'cn': 'name', + 'description': 'description', + 'gidnumber': 'id', + } + ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames'] + ldap_rdn_attribute = 'cn' + ldap_dn_base = 'ou=groups,dc=example,dc=com' + + id = Column(Integer, primary_key=True) + name = Column(String, unique=True, nullable=False) + description = Column(String, nullable=False, default='') Base.metadata.create_all(engine) + +ldap_server = LDAPServer() +ldap_server.search_handler(User.ldap_search) +ldap_server.search_handler(Group.ldap_search) + +@ldap_server.bind_handler +def ldap_bind(name, password, conn): + try: + password = password.decode() + except UnicodeDecodeError: + return False + evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, + objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute, + dn_base=User.ldap_dn_base) + res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one() + if res: + return res.check_password(password) + return False + +ldap_server.run('127.0.0.1', 1337) + diff --git a/dn.py b/dn.py index 020e7d0..f6927b8 100644 --- a/dn.py +++ b/dn.py @@ -111,14 +111,3 @@ def build_rdn(assertions): def build_dn(rdns): return ','.join(map(build_rdn, rdns)) - -from enum import Enum - -class DNScope(Enum): - baseObject = 0 # The scope is constrained to the entry named by baseObject. - singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject. - wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates. - - @classmethod - def from_bytes(cls, data): - return self(data) diff --git a/ldap.py b/ldap.py index 48d3091..5aceb6d 100644 --- a/ldap.py +++ b/ldap.py @@ -49,9 +49,13 @@ def decode_ber_integer(data): def encode_ber(obj): tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111) length = len(obj.content) - if length >= 127: - raise NotImplementedError('Long form length encoding not implemented') - return bytes([tag, length]) + obj.content + if length < 127: + return bytes([tag, length]) + obj.content + octets = [] + while length: + octets.append(length & 0xff) + length = length >> 8 + return bytes([tag, 0x80 | len(octets)]) + bytes(reversed(octets)) + obj.content def encode_ber_integer(value): if value < 0 or value > 255: diff --git a/server.py b/server.py index 10abdbe..96c52d0 100644 --- a/server.py +++ b/server.py @@ -1,18 +1,34 @@ import traceback from socketserver import ForkingTCPServer, BaseRequestHandler -from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError +from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication class Handler(BaseRequestHandler): + ldap_server = None + def setup(self): self.bind_dn = b'' self.keep_running = True def handle_bind(self, req): - self.bind_dn = req.protocolOp.name - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + if not isinstance(req.protocolOp.authentication, SimpleAuthentication): + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.authMethodNotSupported))) + name = req.protocolOp.name + password = req.protocolOp.authentication.password + for func in self.ldap_server.bind_handlers: + if func(name, password, self): + self.bind_dn = name + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials))) def handle_search(self, req): + search = req.protocolOp + entries = [] + for func in self.ldap_server.search_handlers: + entries += func(search.baseObject, search.scope, search.filter, self) + for dn, attributes in entries: + attributes = [PartialAttribute(name, values) for name, values in attributes.items()] + self.send_msg(LDAPMessage(req.messageID, SearchResultEntry(dn, attributes))) self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success))) def handle_unbind(self, req): @@ -37,7 +53,7 @@ class Handler(BaseRequestHandler): self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError))) traceback.print_exc() return rest - print(msg) + print('received', msg) try: if func: func(msg) @@ -49,6 +65,7 @@ class Handler(BaseRequestHandler): return rest def send_msg(self, msg): + print('sending', msg) self.request.sendall(LDAPMessage.to_ber(msg)) def handle(self): @@ -62,4 +79,18 @@ class Handler(BaseRequestHandler): return data += chunk -ForkingTCPServer(('127.0.0.1', 1338), Handler).serve_forever() +class Server: + def __init__(self): + self.bind_handlers = [] + self.search_handlers = [] + + def bind_handler(self, func): + self.bind_handlers.append(func) + + def search_handler(self, func): + self.search_handlers.append(func) + + def run(self, host='127.0.0.1', port=1337): + class BoundHandler(Handler): + ldap_server = self + ForkingTCPServer((host, port), BoundHandler).serve_forever() -- GitLab