From 0cb7c3f18cfb209f9c86baececd8db54d3708a53 Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Sat, 6 Mar 2021 23:56:21 +0100 Subject: [PATCH] Restructured server code --- db.py | 116 +++++++++++++++++++++++++++--------------------------- ldap.py | 14 +++++++ server.py | 113 +++++++++++++++++++++++++++++++++++----------------- 3 files changed, 148 insertions(+), 95 deletions(-) diff --git a/db.py b/db.py index c6a5024..36d131f 100644 --- a/db.py +++ b/db.py @@ -5,7 +5,8 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent -from server import Server as LDAPServer +from server import LDAPRequestHandler +from socketserver import ForkingTCPServer from dn import parse_dn, build_dn Base = declarative_base() @@ -179,35 +180,8 @@ engine = create_engine('sqlite:///db.sqlite', echo=True) Session = sessionmaker(bind=engine) session = Session() -class LDAPViewMixin: - ldap_attributes = {} - ldap_objectclasses = [b'top'] - ldap_rdn_attribute = 'uid' - ldap_dn_base = '' - - @classmethod - def ldap_search(cls, base, scope, filter_expr, conn): - evaluator = SQLSearchEvaluator(cls, session, attributes=cls.ldap_attributes, - objectclasses=cls.ldap_objectclasses, rdn_attr=cls.ldap_rdn_attribute, - dn_base=cls.ldap_dn_base) - return evaluator(base, scope, filter_expr) - -class User(Base, LDAPViewMixin): +class User(Base): __tablename__ = 'users' - ldap_attributes = { - 'cn': 'displayname', - 'displayname': 'displayname', - 'gidnumber': 'ldap_gid', - 'givenname': 'displayname', - 'homedirectory': 'homedirectory', - 'mail': 'email', - 'sn': 'ldap_sn', - 'uid': 'loginname', - 'uidnumber': 'id', - } - ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'] - ldap_rdn_attribute = 'uid' - ldap_dn_base = 'ou=users,dc=example,dc=com' id = Column(Integer, primary_key=True) loginname = Column(String, unique=True, nullable=False) @@ -230,16 +204,8 @@ class User(Base, LDAPViewMixin): def check_password(self, password): return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash -class Group(Base, LDAPViewMixin): +class Group(Base): __tablename__ = 'groups' - ldap_attributes = { - 'cn': 'name', - 'description': 'description', - 'gidnumber': 'id', - } - ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames'] - ldap_rdn_attribute = 'cn' - ldap_dn_base = 'ou=groups,dc=example,dc=com' id = Column(Integer, primary_key=True) name = Column(String, unique=True, nullable=False) @@ -247,23 +213,57 @@ class Group(Base, LDAPViewMixin): Base.metadata.create_all(engine) -ldap_server = LDAPServer() -ldap_server.search_handler(User.ldap_search) -ldap_server.search_handler(Group.ldap_search) - -@ldap_server.bind_handler -def ldap_bind(name, password, conn): - try: - password = password.decode() - except UnicodeDecodeError: - return False - evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, - objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute, - dn_base=User.ldap_dn_base) - res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none() - if res: - return res.check_password(password) - return False - -ldap_server.run('127.0.0.1', 1337) - +class RequestHandler(LDAPRequestHandler): + def do_bind(self, name, password): + if not name and not password: + return None + try: + password = password.decode() + except UnicodeDecodeError: + raise LDAPInvalidCredentials() + try: + evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes, + objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute, + dn_base=User.ldap_dn_base) + except ValueError: + raise LDAPInvalidCredentials() + user = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none() + if user is None or not user.check_password(password): + raise LDAPInvalidCredentials() + return user + + def do_search(self, baseobj, scope, filter): + # User + ldap_attributes = { + 'cn': 'displayname', + 'displayname': 'displayname', + 'gidnumber': 'ldap_gid', + 'givenname': 'displayname', + 'homedirectory': 'homedirectory', + 'mail': 'email', + 'sn': 'ldap_sn', + 'uid': 'loginname', + 'uidnumber': 'id', + } + ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'] + ldap_rdn_attribute = 'uid' + ldap_dn_base = 'ou=users,dc=example,dc=com' + evaluator = SQLSearchEvaluator(User, session, attributes=ldap_attributes, + objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute, + dn_base=ldap_dn_base) + yield from evaluator(baseobj, scope, filter) + # Group + ldap_attributes = { + 'cn': 'name', + 'description': 'description', + 'gidnumber': 'id', + } + ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames'] + ldap_rdn_attribute = 'cn' + ldap_dn_base = 'ou=groups,dc=example,dc=com' + evaluator = SQLSearchEvaluator(Group, session, attributes=ldap_attributes, + objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute, + dn_base=ldap_dn_base) + yield from evaluator(baseobj, scope, filter) + +ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever() diff --git a/ldap.py b/ldap.py index a5145f7..eadc75e 100644 --- a/ldap.py +++ b/ldap.py @@ -636,6 +636,20 @@ class ShallowLDAPMessage(Sequence): (ShallowProtocolOp, 'protocolOp', None, False) ] +# Extended Operation Values + +class PasswdModifyRequestValue(Sequence): + sequence_fields = [ + (retag(LDAPString, (2, False, 0)), 'userIdentity', None, True), + (retag(OctetString, (2, False, 1)), 'oldPasswd', None, True), + (retag(OctetString, (2, False, 2)), 'newPasswd', None, True), + ] + +class PasswdModifyResponseValue(Sequence): + sequence_fields = [ + (retag(OctetString, (2, False, 0)), 'genPasswd', None, True), + ] + bind1 = b'0\x0c\x02\x01\x01`\x07\x02\x01\x03\x04\x00\x80\x00' bind2 = b'0\x1c\x02\x01\x01`\x17\x02\x01\x03\x04\nuid=foobar\x80\x06foobar' search1 = b'0?\x02\x01\x02c:\x04\x1aou=users,dc=example,dc=com\n\x01\x01\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\x87\x0bobjectclass0\x00' diff --git a/server.py b/server.py index 69af473..00a2762 100644 --- a/server.py +++ b/server.py @@ -1,40 +1,88 @@ import traceback -from socketserver import ForkingTCPServer, BaseRequestHandler +from socketserver import BaseRequestHandler -from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse +from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue -class Handler(BaseRequestHandler): - ldap_server = None +class LDAPError(Exception): + def __init__(self, code=LDAPResultCode.other, message=''): + self.code = code + self.message = message - def setup(self): - self.bind_dn = b'' - self.keep_running = True +class LDAPOperationsError(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.operationsError, message) + +class LDAPProtocolError(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.protocolError, message) + +class LDAPInvalidCredentials(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.invalidCredentials, message) + +class LDAPInsufficientAccessRights(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.insufficientAccessRights, message) +class LDAPUnwillingToPerform(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.unwillingToPerform, message) + +class LDAPAuthMethodNotSupported(LDAPError): + def __init__(self, message=''): + super().__init__(LDAPResultCode.authMethodNotSupported, message) + +class LDAPRequestHandler(BaseRequestHandler): def handle_bind(self, req): - if not isinstance(req.protocolOp.authentication, SimpleAuthentication): - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.authMethodNotSupported))) - name = req.protocolOp.name - password = req.protocolOp.authentication.password - for func in self.ldap_server.bind_handlers: - if func(name, password, self): - self.bind_dn = name - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) - return - self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials))) + op = req.protocolOp + if op.version != 3: + raise LDAPProtocolError('Unsupported protocol version') + if isinstance(op.authentication, SimpleAuthentication): + self.bind_object = self.do_bind(op.name, op.authentication.password) + else: + raise LDAPAuthMethodNotSupported() + self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success))) + + def do_bind(self, user, password): + raise LDAPInvalidCredentials() def handle_search(self, req): search = req.protocolOp - entries = [] - for func in self.ldap_server.search_handlers: - entries += func(search.baseObject, search.scope, search.filter, self) - for dn, attributes in entries: + for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter): attributes = [PartialAttribute(name, values) for name, values in attributes.items()] self.send_msg(LDAPMessage(req.messageID, SearchResultEntry(dn, attributes))) self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success))) + def do_search(self, baseobj, scope, filter): + yield from [] + def handle_unbind(self, req): self.keep_running = False + def handle_extended(self, req): + op = req.protocolOp + if op.requestName == '1.3.6.1.4.1.1466.20037': + # StartTLS (RFC 4511) + raise LDAPProtocolError() + elif op.requestName == '1.3.6.1.4.1.4203.1.11.1': + # Password Modify Extended Operation (RFC 3062) + newpw = None + if op.requestValue is None: + newpw = self.do_passwd() + else: + decoded, _ = PasswdModifyRequestValue.from_ber(op.requestValue) + newpw = self.do_passwd(decoded.userIdentity, decoded.oldPasswd, decoded.newPasswd) + if newpw is None: + self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success))) + else: + encoded = PasswdModifyResponseValue.to_ber(PasswdModifyResponseValue(newpw)) + self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseValue=encoded))) + else: + raise LDAPProtocolError() + + def do_passwd(self, user=None, oldpasswd=None, newpasswd=None): + raise LDAPUnwillingToPerform('Password change is not supported') + def handle_message(self, data): handlers = { BindRequest: (self.handle_bind, BindResponse), @@ -46,7 +94,7 @@ class Handler(BaseRequestHandler): ModifyDNRequest: (None, ModifyDNResponse), CompareRequest: (None, CompareResponse), AbandonRequest: (None, None), - ExtendedRequest: (None, ExtendedResponse), # TODO + ExtendedRequest: (self.handle_extended, ExtendedResponse), } shallowmsg, rest = ShallowLDAPMessage.from_ber(data) if shallowmsg.protocolOp is None: @@ -67,6 +115,10 @@ class Handler(BaseRequestHandler): func(msg) elif errfunc: self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights))) + except LDAPError as e: + if errfunc: + self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(e.code, diagnosticMessage=e.message))) + return rest except Exception as e: if errfunc: self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other))) @@ -79,6 +131,8 @@ class Handler(BaseRequestHandler): self.request.sendall(LDAPMessage.to_ber(msg)) def handle(self): + self.keep_running = True + self.bind_object = None data = b'' while self.keep_running: try: @@ -89,18 +143,3 @@ class Handler(BaseRequestHandler): return data += chunk -class Server: - def __init__(self): - self.bind_handlers = [] - self.search_handlers = [] - - def bind_handler(self, func): - self.bind_handlers.append(func) - - def search_handler(self, func): - self.search_handlers.append(func) - - def run(self, host='127.0.0.1', port=1337): - class BoundHandler(Handler): - ldap_server = self - ForkingTCPServer((host, port), BoundHandler).serve_forever() -- GitLab