From c8bab6e70337ccde280a1ed911661de6b6d9d3dd Mon Sep 17 00:00:00 2001 From: Julian Rother <julian@jrother.eu> Date: Wed, 17 Nov 2021 14:26:06 +0100 Subject: [PATCH] Proper CLI and update to ldapserver 0.0.1.dev5 --- requirements.txt | 2 +- server.py | 208 +++++++++++++++++++++++++++++------------------ 2 files changed, 130 insertions(+), 80 deletions(-) diff --git a/requirements.txt b/requirements.txt index 13331b7..3021879 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --extra-index-url https://git.cccv.de/api/v4/projects/220/packages/pypi/simple -ldapserver==0.0.1.dev4 +ldapserver==0.0.1.dev5 requests==2.* CacheControl diff --git a/server.py b/server.py index 5fccb24..b6ff7ac 100644 --- a/server.py +++ b/server.py @@ -1,27 +1,25 @@ +import os import sys -import json import socketserver +import logging +import socket +import click import requests from cachecontrol import CacheControl from cachecontrol.heuristics import ExpiresAfter -import ldapserver from ldapserver import LDAPRequestHandler, rfc4518_stringprep -from ldapserver.dn import DN from ldapserver.exceptions import LDAPInvalidCredentials -from ldapserver.schema import RFC2307BIS_SUBSCHEMA, RFC2798_SUBSCHEMA, WILDCARD_VALUE +from ldapserver.schema import RFC2307BIS_SCHEMA, RFC2798_SCHEMA +from ldapserver.objects import SubschemaSubentry, WILDCARD_VALUE -memberOf = ldapserver.schema.AttributeType( - '1.2.840.113556.1.2.102', - name='memberOf', - desc='Group that the entry belongs to', - equality=ldapserver.schema.rfc4517.matching_rules.distinguishedNameMatch, - syntax=ldapserver.schema.rfc4517.syntaxes.DN(), - usage=ldapserver.schema.AttributeTypeUsage.dSAOperation -) +logger = logging.getLogger(__name__) -CUSTOM_SUBSCHEMA = RFC2307BIS_SUBSCHEMA.extend(RFC2798_SUBSCHEMA, attribute_types=[memberOf]) +CUSTOM_SCHEMA = (RFC2307BIS_SCHEMA|RFC2798_SCHEMA).extend(attribute_type_definitions=[ + # pylint: disable=line-too-long + "( 1.2.840.113556.1.2.102 NAME 'memberOf' DESC 'Group that the entry belongs to' EQUALITY distinguishedNameMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.12 USAGE dSAOperation )" +]) class UffdAPI: def __init__(self, baseurl, key, cache_ttl=60): @@ -31,23 +29,24 @@ class UffdAPI: self.session.headers['Authorization'] = 'Bearer '+self.key def get(self, endpoint, **kwargs): - resp = self.session.get(self.baseurl + '/' + endpoint, params=kwargs) - assert(resp.ok) + resp = self.session.get(self.baseurl + endpoint, params=kwargs) + resp.raise_for_status() return resp.json() def post(self, endpoint, **kwargs): - resp = self.session.post(self.baseurl + '/' + endpoint, data=kwargs) - assert(resp.ok) + resp = self.session.post(self.baseurl + endpoint, data=kwargs) + resp.raise_for_status() return resp.json() + # pylint: disable=invalid-name,redefined-builtin def get_users(self, id=None, loginname=None, group=None): - return self.get('getusers', id=id, loginname=loginname, group=group) + return self.get('/api/v1/getusers', id=id, loginname=loginname, group=group) def get_groups(self, id=None, name=None, member=None): - return self.get('getgroups', id=id, name=name, member=member) + return self.get('/api/v1/getgroups', id=id, name=name, member=member) def check_password(self, loginname, password): - return self.api.post('checkpassword', loginname=loginname, password=password) + return self.post('/api/v1/checkpassword', loginname=loginname, password=password) def normalize_user_loginname(loginname): # The equality matching rule for uid is caseIgnoreMatch. It prepares @@ -80,23 +79,19 @@ def normalize_group_name(name): # See https://git.cccv.de/uffd/uffd/-/issues/127 return normalize_user_loginname(name) -class RequestHandler(LDAPRequestHandler): - subschema = CUSTOM_SUBSCHEMA +class UffdLDAPRequestHandler(LDAPRequestHandler): + subschema = SubschemaSubentry(CUSTOM_SCHEMA, 'cn=Subschema') # Overwritten before use api = None dn_base = None - bind_dn = None - bind_password = None - - def setup(self): - super().setup() + bind_password = None # if None anonymous reads are allowed def do_bind_simple_authenticated(self, dn, password): - dn = DN.from_str(dn) - if dn == self.bind_dn and password == self.bind_password: + dn = self.subschema.DN.from_str(dn) + if dn == self.subschema.DN('cn=service,ou=system') + self.dn_base and password == self.bind_password: return True - if not dn.is_direct_child_of(DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid': + if not dn.is_direct_child_of(self.subschema.DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid': raise LDAPInvalidCredentials() if self.api.check_password(loginname=dn[0][0].value, password=password): return True @@ -112,44 +107,44 @@ class RequestHandler(LDAPRequestHandler): raise LDAPInvalidCredentials() return user - def do_search(self, baseobj, scope, filter): - yield from super().do_search(baseobj, scope, filter) - if self.bind_object: + def do_search(self, baseobj, scope, filterobj): + yield from super().do_search(baseobj, scope, filterobj) + if self.bind_object or self.bind_password is None: yield from self.do_search_static() - yield from self.do_search_users(baseobj, scope, filter) - yield from self.do_search_groups(baseobj, scope, filter) + yield from self.do_search_users(baseobj, scope, filterobj) + yield from self.do_search_groups(baseobj, scope, filterobj) def do_search_static(self): base_attrs = { 'objectClass': ['top', 'dcObject', 'organization'], 'structuralObjectClass': ['organization'], } - for rdnassertion in self.dn_base[0]: + for rdnassertion in self.dn_base[0]: # pylint: disable=unsubscriptable-object base_attrs[rdnassertion.attribute] = [rdnassertion.value] yield self.subschema.Object(self.dn_base, **base_attrs) - yield self.subschema.Object(DN('ou=users') + self.dn_base, + yield self.subschema.Object(self.subschema.DN('ou=users') + self.dn_base, ou=['users'], objectClass=['top', 'organizationalUnit'], structuralObjectClass=['organizationalUnit'], ) - yield self.subschema.Object(DN('ou=groups') + self.dn_base, + yield self.subschema.Object(self.subschema.DN('ou=groups') + self.dn_base, ou=['groups'], objectClass=['top', 'organizationalUnit'], structuralObjectClass=['organizationalUnit'], ) - yield self.subschema.Object(DN('ou=system') + self.dn_base, + yield self.subschema.Object(self.subschema.DN('ou=system') + self.dn_base, ou=['system'], objectClass=['top', 'organizationalUnit'], structuralObjectClass=['organizationalUnit'], ) - yield self.subschema.Object(DN('cn=service,ou=system') + self.dn_base, + yield self.subschema.Object(self.subschema.DN('cn=service,ou=system') + self.dn_base, cn=['service'], objectClass=['top', 'organizationalRole', 'simpleSecurityObject'], structuralObjectClass=['organizationalRole'], ) - def do_search_users(self, baseobj, scope, filter): - template = self.subschema.ObjectTemplate(DN(self.dn_base, ou='users'), 'uid', + def do_search_users(self, baseobj, scope, filterobj): + template = self.subschema.ObjectTemplate(self.subschema.DN(self.dn_base, ou='users'), 'uid', structuralObjectClass=['inetorgperson'], objectClass=['top', 'inetorgperson', 'organizationalperson', 'person', 'posixaccount'], cn=[WILDCARD_VALUE], @@ -162,9 +157,9 @@ class RequestHandler(LDAPRequestHandler): uidNumber=[WILDCARD_VALUE], memberOf=[WILDCARD_VALUE], ) - if not template.match_search(baseobj, scope, filter): + if not template.match_search(baseobj, scope, filterobj): return - constraints = template.extract_search_constraints(baseobj, scope, filter) + constraints = template.extract_search_constraints(baseobj, scope, filterobj) request_params = {} if 'uid' in constraints: request_params = {'loginname': normalize_user_loginname(constraints['uid'][0])} @@ -172,7 +167,7 @@ class RequestHandler(LDAPRequestHandler): request_params = {'id': constraints['uidnumber'][0]} elif 'memberof' in constraints: for value in constraints['memberof']: - if value.is_direct_child_of(DN(self.dn_base, ou='groups')) and value.object_attribute == 'cn': + if value.is_direct_child_of(self.subschema.DN(self.dn_base, ou='groups')) and value.object_attribute == 'cn': request_params = {'group': normalize_group_name(value.object_value)} break for user in self.api.get_users(**request_params): @@ -184,11 +179,11 @@ class RequestHandler(LDAPRequestHandler): mail=[user['email']], uid=[user['loginname']], uidNumber=[user['id']], - memberOf=[DN(DN(self.dn_base, ou='groups'), cn=group) for group in user['groups']], + memberOf=[self.subschema.DN(self.subschema.DN(self.dn_base, ou='groups'), cn=group) for group in user['groups']], ) - def do_search_groups(self, baseobj, scope, filter): - template = self.subschema.ObjectTemplate(DN(self.dn_base, ou='groups'), 'cn', + def do_search_groups(self, baseobj, scope, filterobj): + template = self.subschema.ObjectTemplate(self.subschema.DN(self.dn_base, ou='groups'), 'cn', structuralObjectClass=['groupOfUniqueNames'], objectClass=['top', 'groupOfUniqueNames', 'posixGroup'], cn=[WILDCARD_VALUE], @@ -196,9 +191,9 @@ class RequestHandler(LDAPRequestHandler): gidNumber=[WILDCARD_VALUE], uniqueMember=[WILDCARD_VALUE], ) - if not template.match_search(baseobj, scope, filter): + if not template.match_search(baseobj, scope, filterobj): return - constraints = template.extract_search_constraints(baseobj, scope, filter) + constraints = template.extract_search_constraints(baseobj, scope, filterobj) request_params = {} if 'cn' in constraints: request_params = {'name': normalize_group_name(constraints['cn'][0])} @@ -206,47 +201,102 @@ class RequestHandler(LDAPRequestHandler): request_params = {'id': constraints['gidnumber'][0]} elif 'uniquemember' in constraints: for value in constraints['uniquemember']: - if value.is_direct_child_of(DN(self.dn_base, ou='users')) and value.object_attribute == 'uid': + if value.is_direct_child_of(self.subschema.DN(self.dn_base, ou='users')) and value.object_attribute == 'uid': request_params = {'member': normalize_user_loginname(value.object_value)} break for group in self.api.get_groups(**request_params): yield template.create_object(group['name'], cn=[group['name']], gidNumber=[group['id']], - uniqueMember=[DN(DN(self.dn_base, ou='users'), uid=user) for user in group['members']], + uniqueMember=[self.subschema.DN(self.subschema.DN(self.dn_base, ou='users'), uid=user) for user in group['members']], ) -def main(config): - dn_base = DN.from_str(config['dn_base']) - api = UffdAPI(config['api_baseurl'], config['api_key'], config.get('cache_ttl', 60)) +def make_requesthandler(api, dn_base, bind_password=None): + class RequestHandler(UffdLDAPRequestHandler): + pass + dn_base = RequestHandler.subschema.DN.from_str(dn_base) + RequestHandler.api = api + RequestHandler.dn_base = dn_base + RequestHandler.bind_password = bind_password.encode() if bind_password else None + return RequestHandler - subschema = RFC2307BIS_SUBSCHEMA +class FilenoUnixStreamServer(socketserver.UnixStreamServer): + def __init__(self, fd, RequestHandlerClass, bind_and_activate=True): + self.server_fd = fd + super().__init__(None, RequestHandlerClass, bind_and_activate=bind_and_activate) - class CustomRequestHandler(RequestHandler): - pass + def server_bind(self): + self.socket.close() # UnixStreamServer.__init__ creates an unbound socket + self.socket = socket.fromfd(self.server_fd, socket.AF_UNIX, socket.SOCK_STREAM) + self.server_address = self.socket.getsockname() + +class ThreadingFilenoUnixStreamServer(socketserver.ThreadingMixIn, FilenoUnixStreamServer): + pass + +def cleanup_unix_socket(path): + if not os.path.exists(path): + return + conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + conn.connect(path) + except ConnectionRefusedError: + os.remove(path) + conn.close() + +def parse_network_address(addr): + port = '389' + if addr.startswith('['): + addr, remainder = addr[1:].split(']', 1) + if remainder.startswith(':'): + port = remainder[1:] + elif ':' in addr: + addr, port = addr.split(':') + return addr, port + +class StdoutFilter(logging.Filter): + def filter(self, record): + return record.levelno <= logging.INFO + +# pylint: disable=line-too-long +@click.command(help='LDAP proxy for integrating LDAP service with uffd SSO. Supports user and group searches and as well as binds with user passwords.') +@click.option('--socket-address', help='Host and port "ip:port" to listen on') +@click.option('--socket-path', type=click.Path(), help='Path for UNIX domain socket') +@click.option('--socket-fd', type=int, help='Use fd number as server socket (alternative to --socket-path)') +@click.option('--api-url', required=True, help='Uffd base URL without API prefix or trailing slash (e.g. https://example.com)') +@click.option('--api-key', required=True, help='API secret, do not set this on the command-line, use environment variable SERVER_API_KEY instead') +@click.option('--cache-ttl', default=60, help='Time-to-live for API response caching in seconds') +@click.option('--base-dn', required=True, help='Base DN for user, group and system objects. E.g. "dc=example,dc=com"') +@click.option('--bind-password', help='Authentication password for the service connection to LDAP. Bind DN is always "cn=service,ou=system,BASEDN". If set, anonymous access is disabled.') +def main(socket_address, socket_path, socket_fd, api_url, api_key, cache_ttl, base_dn, bind_password): + # pylint: disable=too-many-locals + if (socket_address is not None) \ + + (socket_path is not None) \ + + (socket_fd is not None) != 1: + raise click.ClickException('Either --socket-address, --socket-path or --socket-fd must be specified') - CustomRequestHandler.api = api - CustomRequestHandler.dn_base = dn_base - CustomRequestHandler.bind_dn = DN('cn=service,ou=system') + dn_base - CustomRequestHandler.bind_password = config['bind_password'].encode() + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + stdout_handler.addFilter(StdoutFilter()) + stderr_handler = logging.StreamHandler(sys.stderr) + stderr_handler.setLevel(logging.WARNING) + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + root_logger.addHandler(stdout_handler) + root_logger.addHandler(stderr_handler) - if config['listen_addr'].startswith('unix:'): - socketserver.ThreadingUnixStreamServer(config['listen_addr'][5:], CustomRequestHandler).serve_forever() + api = UffdAPI(api_url, api_key, cache_ttl) + RequestHandler = make_requesthandler(api, base_dn, bind_password) + if socket_address is not None: + host, port = parse_network_address(socket_address) + server = socketserver.ThreadingTCPServer((host, int(port)), RequestHandler) + elif socket_path is not None: + cleanup_unix_socket(socket_path) + server = socketserver.ThreadingUnixStreamServer(socket_path, RequestHandler) else: - addr = config['listen_addr'] - port = '389' - if addr.startswith('['): - addr, remainder = addr[1:].split(']', 1) - if remainder.startswith(':'): - port = remainder[1:] - elif ':' in addr: - addr, port = addr.split(':') - socketserver.ThreadingTCPServer((addr, int(port)), CustomRequestHandler).serve_forever() + server = ThreadingFilenoUnixStreamServer(socket_fd, RequestHandler) + server.serve_forever() if __name__ == '__main__': - if len(sys.argv) != 2: - print('usage: server.py CONFIG_PATH') - exit(1) - with open(sys.argv[1], 'r') as f: - config = json.load(f) - main(config) + # Pylint does not seem to understand the click's decorators + # pylint: disable=unexpected-keyword-arg,no-value-for-parameter + main(auto_envvar_prefix='SERVER') -- GitLab