From c8bab6e70337ccde280a1ed911661de6b6d9d3dd Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Wed, 17 Nov 2021 14:26:06 +0100
Subject: [PATCH] Proper CLI and update to ldapserver 0.0.1.dev5

---
 requirements.txt |   2 +-
 server.py        | 208 +++++++++++++++++++++++++++++------------------
 2 files changed, 130 insertions(+), 80 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 13331b7..3021879 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
 --extra-index-url https://git.cccv.de/api/v4/projects/220/packages/pypi/simple
-ldapserver==0.0.1.dev4
+ldapserver==0.0.1.dev5
 
 requests==2.*
 CacheControl
diff --git a/server.py b/server.py
index 5fccb24..b6ff7ac 100644
--- a/server.py
+++ b/server.py
@@ -1,27 +1,25 @@
+import os
 import sys
-import json
 import socketserver
+import logging
+import socket
 
+import click
 import requests
 from cachecontrol import CacheControl
 from cachecontrol.heuristics import ExpiresAfter
 
-import ldapserver
 from ldapserver import LDAPRequestHandler, rfc4518_stringprep
-from ldapserver.dn import DN
 from ldapserver.exceptions import LDAPInvalidCredentials
-from ldapserver.schema import RFC2307BIS_SUBSCHEMA, RFC2798_SUBSCHEMA, WILDCARD_VALUE
+from ldapserver.schema import RFC2307BIS_SCHEMA, RFC2798_SCHEMA
+from ldapserver.objects import SubschemaSubentry, WILDCARD_VALUE
 
-memberOf = ldapserver.schema.AttributeType(
-	'1.2.840.113556.1.2.102',
-	name='memberOf',
-	desc='Group that the entry belongs to',
-	equality=ldapserver.schema.rfc4517.matching_rules.distinguishedNameMatch,
-	syntax=ldapserver.schema.rfc4517.syntaxes.DN(),
-	usage=ldapserver.schema.AttributeTypeUsage.dSAOperation
-)
+logger = logging.getLogger(__name__)
 
-CUSTOM_SUBSCHEMA = RFC2307BIS_SUBSCHEMA.extend(RFC2798_SUBSCHEMA, attribute_types=[memberOf])
+CUSTOM_SCHEMA = (RFC2307BIS_SCHEMA|RFC2798_SCHEMA).extend(attribute_type_definitions=[
+	# pylint: disable=line-too-long
+	"( 1.2.840.113556.1.2.102 NAME 'memberOf' DESC 'Group that the entry belongs to' EQUALITY distinguishedNameMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.12 USAGE dSAOperation )"
+])
 
 class UffdAPI:
 	def __init__(self, baseurl, key, cache_ttl=60):
@@ -31,23 +29,24 @@ class UffdAPI:
 		self.session.headers['Authorization'] = 'Bearer '+self.key
 
 	def get(self, endpoint, **kwargs):
-		resp = self.session.get(self.baseurl + '/' + endpoint, params=kwargs)
-		assert(resp.ok)
+		resp = self.session.get(self.baseurl + endpoint, params=kwargs)
+		resp.raise_for_status()
 		return resp.json()
 
 	def post(self, endpoint, **kwargs):
-		resp = self.session.post(self.baseurl + '/' + endpoint, data=kwargs)
-		assert(resp.ok)
+		resp = self.session.post(self.baseurl + endpoint, data=kwargs)
+		resp.raise_for_status()
 		return resp.json()
 
+	# pylint: disable=invalid-name,redefined-builtin
 	def get_users(self, id=None, loginname=None, group=None):
-		return self.get('getusers', id=id, loginname=loginname, group=group)
+		return self.get('/api/v1/getusers', id=id, loginname=loginname, group=group)
 
 	def get_groups(self, id=None, name=None, member=None):
-		return self.get('getgroups', id=id, name=name, member=member)
+		return self.get('/api/v1/getgroups', id=id, name=name, member=member)
 
 	def check_password(self, loginname, password):
-		return self.api.post('checkpassword', loginname=loginname, password=password)
+		return self.post('/api/v1/checkpassword', loginname=loginname, password=password)
 
 def normalize_user_loginname(loginname):
 	# The equality matching rule for uid is caseIgnoreMatch. It prepares
@@ -80,23 +79,19 @@ def normalize_group_name(name):
 	# See https://git.cccv.de/uffd/uffd/-/issues/127
 	return normalize_user_loginname(name)
 
-class RequestHandler(LDAPRequestHandler):
-	subschema = CUSTOM_SUBSCHEMA
+class UffdLDAPRequestHandler(LDAPRequestHandler):
+	subschema = SubschemaSubentry(CUSTOM_SCHEMA, 'cn=Subschema')
 
 	# Overwritten before use
 	api = None
 	dn_base = None
-	bind_dn = None
-	bind_password = None
-
-	def setup(self):
-		super().setup()
+	bind_password = None # if None anonymous reads are allowed
 
 	def do_bind_simple_authenticated(self, dn, password):
-		dn = DN.from_str(dn)
-		if dn == self.bind_dn and password == self.bind_password:
+		dn = self.subschema.DN.from_str(dn)
+		if dn == self.subschema.DN('cn=service,ou=system') + self.dn_base and password == self.bind_password:
 			return True
-		if not dn.is_direct_child_of(DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid':
+		if not dn.is_direct_child_of(self.subschema.DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid':
 			raise LDAPInvalidCredentials()
 		if self.api.check_password(loginname=dn[0][0].value, password=password):
 			return True
@@ -112,44 +107,44 @@ class RequestHandler(LDAPRequestHandler):
 			raise LDAPInvalidCredentials()
 		return user
 
-	def do_search(self, baseobj, scope, filter):
-		yield from super().do_search(baseobj, scope, filter)
-		if self.bind_object:
+	def do_search(self, baseobj, scope, filterobj):
+		yield from super().do_search(baseobj, scope, filterobj)
+		if self.bind_object or self.bind_password is None:
 			yield from self.do_search_static()
-			yield from self.do_search_users(baseobj, scope, filter)
-			yield from self.do_search_groups(baseobj, scope, filter)
+			yield from self.do_search_users(baseobj, scope, filterobj)
+			yield from self.do_search_groups(baseobj, scope, filterobj)
 
 	def do_search_static(self):
 		base_attrs = {
 			'objectClass': ['top', 'dcObject', 'organization'],
 			'structuralObjectClass': ['organization'],
 		}
-		for rdnassertion in self.dn_base[0]:
+		for rdnassertion in self.dn_base[0]: # pylint: disable=unsubscriptable-object
 			base_attrs[rdnassertion.attribute] = [rdnassertion.value]
 		yield self.subschema.Object(self.dn_base, **base_attrs)
-		yield self.subschema.Object(DN('ou=users') + self.dn_base,
+		yield self.subschema.Object(self.subschema.DN('ou=users') + self.dn_base,
 			ou=['users'],
 			objectClass=['top', 'organizationalUnit'],
 			structuralObjectClass=['organizationalUnit'],
 		)
-		yield self.subschema.Object(DN('ou=groups') + self.dn_base,
+		yield self.subschema.Object(self.subschema.DN('ou=groups') + self.dn_base,
 			ou=['groups'],
 			objectClass=['top', 'organizationalUnit'],
 			structuralObjectClass=['organizationalUnit'],
 		)
-		yield self.subschema.Object(DN('ou=system') + self.dn_base,
+		yield self.subschema.Object(self.subschema.DN('ou=system') + self.dn_base,
 			ou=['system'],
 			objectClass=['top', 'organizationalUnit'],
 			structuralObjectClass=['organizationalUnit'],
 		)
-		yield self.subschema.Object(DN('cn=service,ou=system') + self.dn_base,
+		yield self.subschema.Object(self.subschema.DN('cn=service,ou=system') + self.dn_base,
 			cn=['service'],
 			objectClass=['top', 'organizationalRole', 'simpleSecurityObject'],
 			structuralObjectClass=['organizationalRole'],
 		)
 
-	def do_search_users(self, baseobj, scope, filter):
-		template = self.subschema.ObjectTemplate(DN(self.dn_base, ou='users'), 'uid',
+	def do_search_users(self, baseobj, scope, filterobj):
+		template = self.subschema.ObjectTemplate(self.subschema.DN(self.dn_base, ou='users'), 'uid',
 			structuralObjectClass=['inetorgperson'],
 			objectClass=['top', 'inetorgperson', 'organizationalperson', 'person', 'posixaccount'],
 			cn=[WILDCARD_VALUE],
@@ -162,9 +157,9 @@ class RequestHandler(LDAPRequestHandler):
 			uidNumber=[WILDCARD_VALUE],
 			memberOf=[WILDCARD_VALUE],
 		)
-		if not template.match_search(baseobj, scope, filter):
+		if not template.match_search(baseobj, scope, filterobj):
 			return
-		constraints = template.extract_search_constraints(baseobj, scope, filter)
+		constraints = template.extract_search_constraints(baseobj, scope, filterobj)
 		request_params = {}
 		if 'uid' in constraints:
 			request_params = {'loginname': normalize_user_loginname(constraints['uid'][0])}
@@ -172,7 +167,7 @@ class RequestHandler(LDAPRequestHandler):
 			request_params = {'id': constraints['uidnumber'][0]}
 		elif 'memberof' in constraints:
 			for value in constraints['memberof']:
-				if value.is_direct_child_of(DN(self.dn_base, ou='groups')) and value.object_attribute == 'cn':
+				if value.is_direct_child_of(self.subschema.DN(self.dn_base, ou='groups')) and value.object_attribute == 'cn':
 					request_params = {'group': normalize_group_name(value.object_value)}
 					break
 		for user in self.api.get_users(**request_params):
@@ -184,11 +179,11 @@ class RequestHandler(LDAPRequestHandler):
 				mail=[user['email']],
 				uid=[user['loginname']],
 				uidNumber=[user['id']],
-				memberOf=[DN(DN(self.dn_base, ou='groups'), cn=group) for group in user['groups']],
+				memberOf=[self.subschema.DN(self.subschema.DN(self.dn_base, ou='groups'), cn=group) for group in user['groups']],
 			)
 
-	def do_search_groups(self, baseobj, scope, filter):
-		template = self.subschema.ObjectTemplate(DN(self.dn_base, ou='groups'), 'cn',
+	def do_search_groups(self, baseobj, scope, filterobj):
+		template = self.subschema.ObjectTemplate(self.subschema.DN(self.dn_base, ou='groups'), 'cn',
 			structuralObjectClass=['groupOfUniqueNames'],
 			objectClass=['top', 'groupOfUniqueNames', 'posixGroup'],
 			cn=[WILDCARD_VALUE],
@@ -196,9 +191,9 @@ class RequestHandler(LDAPRequestHandler):
 			gidNumber=[WILDCARD_VALUE],
 			uniqueMember=[WILDCARD_VALUE],
 		)
-		if not template.match_search(baseobj, scope, filter):
+		if not template.match_search(baseobj, scope, filterobj):
 			return
-		constraints = template.extract_search_constraints(baseobj, scope, filter)
+		constraints = template.extract_search_constraints(baseobj, scope, filterobj)
 		request_params = {}
 		if 'cn' in constraints:
 			request_params = {'name': normalize_group_name(constraints['cn'][0])}
@@ -206,47 +201,102 @@ class RequestHandler(LDAPRequestHandler):
 			request_params = {'id': constraints['gidnumber'][0]}
 		elif 'uniquemember' in constraints:
 			for value in constraints['uniquemember']:
-				if value.is_direct_child_of(DN(self.dn_base, ou='users')) and value.object_attribute == 'uid':
+				if value.is_direct_child_of(self.subschema.DN(self.dn_base, ou='users')) and value.object_attribute == 'uid':
 					request_params = {'member': normalize_user_loginname(value.object_value)}
 					break
 		for group in self.api.get_groups(**request_params):
 			yield template.create_object(group['name'],
 				cn=[group['name']],
 				gidNumber=[group['id']],
-				uniqueMember=[DN(DN(self.dn_base, ou='users'), uid=user) for user in group['members']],
+				uniqueMember=[self.subschema.DN(self.subschema.DN(self.dn_base, ou='users'), uid=user) for user in group['members']],
 			)
 
-def main(config):
-	dn_base = DN.from_str(config['dn_base'])
-	api = UffdAPI(config['api_baseurl'], config['api_key'], config.get('cache_ttl', 60))
+def make_requesthandler(api, dn_base, bind_password=None):
+	class RequestHandler(UffdLDAPRequestHandler):
+		pass
+	dn_base = RequestHandler.subschema.DN.from_str(dn_base)
+	RequestHandler.api = api
+	RequestHandler.dn_base = dn_base
+	RequestHandler.bind_password = bind_password.encode() if bind_password else None
+	return RequestHandler
 
-	subschema = RFC2307BIS_SUBSCHEMA
+class FilenoUnixStreamServer(socketserver.UnixStreamServer):
+	def __init__(self, fd, RequestHandlerClass, bind_and_activate=True):
+		self.server_fd = fd
+		super().__init__(None, RequestHandlerClass, bind_and_activate=bind_and_activate)
 
-	class CustomRequestHandler(RequestHandler):
-		pass
+	def server_bind(self):
+		self.socket.close() # UnixStreamServer.__init__ creates an unbound socket
+		self.socket = socket.fromfd(self.server_fd, socket.AF_UNIX, socket.SOCK_STREAM)
+		self.server_address = self.socket.getsockname()
+
+class ThreadingFilenoUnixStreamServer(socketserver.ThreadingMixIn, FilenoUnixStreamServer):
+	pass
+
+def cleanup_unix_socket(path):
+	if not os.path.exists(path):
+		return
+	conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+	try:
+		conn.connect(path)
+	except ConnectionRefusedError:
+		os.remove(path)
+	conn.close()
+
+def parse_network_address(addr):
+	port = '389'
+	if addr.startswith('['):
+		addr, remainder = addr[1:].split(']', 1)
+		if remainder.startswith(':'):
+			port = remainder[1:]
+	elif ':' in addr:
+		addr, port = addr.split(':')
+	return addr, port
+
+class StdoutFilter(logging.Filter):
+	def filter(self, record):
+		return record.levelno <= logging.INFO
+
+# pylint: disable=line-too-long
+@click.command(help='LDAP proxy for integrating LDAP service with uffd SSO. Supports user and group searches and as well as binds with user passwords.')
+@click.option('--socket-address', help='Host and port "ip:port" to listen on')
+@click.option('--socket-path', type=click.Path(), help='Path for UNIX domain socket')
+@click.option('--socket-fd', type=int, help='Use fd number as server socket (alternative to --socket-path)')
+@click.option('--api-url', required=True, help='Uffd base URL without API prefix or trailing slash (e.g. https://example.com)')
+@click.option('--api-key', required=True, help='API secret, do not set this on the command-line, use environment variable SERVER_API_KEY instead')
+@click.option('--cache-ttl', default=60, help='Time-to-live for API response caching in seconds')
+@click.option('--base-dn', required=True, help='Base DN for user, group and system objects. E.g. "dc=example,dc=com"')
+@click.option('--bind-password', help='Authentication password for the service connection to LDAP. Bind DN is always "cn=service,ou=system,BASEDN". If set, anonymous access is disabled.')
+def main(socket_address, socket_path, socket_fd, api_url, api_key, cache_ttl, base_dn, bind_password):
+	# pylint: disable=too-many-locals
+	if (socket_address is not None) \
+	   + (socket_path is not None) \
+	   + (socket_fd is not None) != 1:
+		raise click.ClickException('Either --socket-address, --socket-path or --socket-fd must be specified')
 
-	CustomRequestHandler.api = api
-	CustomRequestHandler.dn_base = dn_base
-	CustomRequestHandler.bind_dn = DN('cn=service,ou=system') + dn_base
-	CustomRequestHandler.bind_password = config['bind_password'].encode()
+	stdout_handler = logging.StreamHandler(sys.stdout)
+	stdout_handler.setLevel(logging.INFO)
+	stdout_handler.addFilter(StdoutFilter())
+	stderr_handler = logging.StreamHandler(sys.stderr)
+	stderr_handler.setLevel(logging.WARNING)
+	root_logger = logging.getLogger()
+	root_logger.setLevel(logging.INFO)
+	root_logger.addHandler(stdout_handler)
+	root_logger.addHandler(stderr_handler)
 
-	if config['listen_addr'].startswith('unix:'):
-		socketserver.ThreadingUnixStreamServer(config['listen_addr'][5:], CustomRequestHandler).serve_forever()
+	api = UffdAPI(api_url, api_key, cache_ttl)
+	RequestHandler = make_requesthandler(api, base_dn, bind_password)
+	if socket_address is not None:
+		host, port = parse_network_address(socket_address)
+		server = socketserver.ThreadingTCPServer((host, int(port)), RequestHandler)
+	elif socket_path is not None:
+		cleanup_unix_socket(socket_path)
+		server = socketserver.ThreadingUnixStreamServer(socket_path, RequestHandler)
 	else:
-		addr = config['listen_addr']
-		port = '389'
-		if addr.startswith('['):
-			addr, remainder = addr[1:].split(']', 1)
-			if remainder.startswith(':'):
-				port = remainder[1:]
-		elif ':' in addr:
-			addr, port = addr.split(':')
-		socketserver.ThreadingTCPServer((addr, int(port)), CustomRequestHandler).serve_forever()
+		server = ThreadingFilenoUnixStreamServer(socket_fd, RequestHandler)
+	server.serve_forever()
 
 if __name__ == '__main__':
-	if len(sys.argv) != 2:
-		print('usage: server.py CONFIG_PATH')
-		exit(1)
-	with open(sys.argv[1], 'r') as f:
-		config = json.load(f)
-	main(config)
+	# Pylint does not seem to understand the click's decorators
+	# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
+	main(auto_envvar_prefix='SERVER')
-- 
GitLab