From 0cb7c3f18cfb209f9c86baececd8db54d3708a53 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Sat, 6 Mar 2021 23:56:21 +0100
Subject: [PATCH] Restructured server code

---
 db.py     | 116 +++++++++++++++++++++++++++---------------------------
 ldap.py   |  14 +++++++
 server.py | 113 +++++++++++++++++++++++++++++++++++-----------------
 3 files changed, 148 insertions(+), 95 deletions(-)

diff --git a/db.py b/db.py
index c6a5024..36d131f 100644
--- a/db.py
+++ b/db.py
@@ -5,7 +5,8 @@ from sqlalchemy.orm import sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.ext.hybrid import hybrid_property
 from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent
-from server import Server as LDAPServer
+from server import LDAPRequestHandler
+from socketserver import ForkingTCPServer
 from dn import parse_dn, build_dn
 
 Base = declarative_base()
@@ -179,35 +180,8 @@ engine = create_engine('sqlite:///db.sqlite', echo=True)
 Session = sessionmaker(bind=engine)
 session = Session()
 
-class LDAPViewMixin:
-	ldap_attributes = {}
-	ldap_objectclasses = [b'top']
-	ldap_rdn_attribute = 'uid'
-	ldap_dn_base = ''
-
-	@classmethod
-	def ldap_search(cls, base, scope, filter_expr, conn):
-		evaluator = SQLSearchEvaluator(cls, session, attributes=cls.ldap_attributes,
-			objectclasses=cls.ldap_objectclasses, rdn_attr=cls.ldap_rdn_attribute,
-			dn_base=cls.ldap_dn_base)
-		return evaluator(base, scope, filter_expr)
-
-class User(Base, LDAPViewMixin):
+class User(Base):
 	__tablename__ = 'users'
-	ldap_attributes = {
-		'cn': 'displayname',
-		'displayname': 'displayname',
-		'gidnumber': 'ldap_gid',
-		'givenname': 'displayname',
-		'homedirectory': 'homedirectory',
-		'mail': 'email',
-		'sn': 'ldap_sn',
-		'uid': 'loginname',
-		'uidnumber': 'id',
-	}
-	ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount']
-	ldap_rdn_attribute = 'uid'
-	ldap_dn_base = 'ou=users,dc=example,dc=com'
 
 	id = Column(Integer, primary_key=True)
 	loginname = Column(String, unique=True, nullable=False)
@@ -230,16 +204,8 @@ class User(Base, LDAPViewMixin):
 	def check_password(self, password):
 		return self.pwhash is not None and crypt(password, self.pwhash) == self.pwhash
 
-class Group(Base, LDAPViewMixin):
+class Group(Base):
 	__tablename__ = 'groups'
-	ldap_attributes = {
-		'cn': 'name',
-		'description': 'description',
-		'gidnumber': 'id',
-	}
-	ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames']
-	ldap_rdn_attribute = 'cn'
-	ldap_dn_base = 'ou=groups,dc=example,dc=com'
 
 	id = Column(Integer, primary_key=True)
 	name = Column(String, unique=True, nullable=False)
@@ -247,23 +213,57 @@ class Group(Base, LDAPViewMixin):
 
 Base.metadata.create_all(engine)
 
-ldap_server = LDAPServer()
-ldap_server.search_handler(User.ldap_search)
-ldap_server.search_handler(Group.ldap_search)
-
-@ldap_server.bind_handler
-def ldap_bind(name, password, conn):
-	try:
-		password = password.decode()
-	except UnicodeDecodeError:
-		return False
-	evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes,
-		objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
-		dn_base=User.ldap_dn_base)
-	res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
-	if res:
-		return res.check_password(password)
-	return False
-
-ldap_server.run('127.0.0.1', 1337)
-
+class RequestHandler(LDAPRequestHandler):
+	def do_bind(self, name, password):
+		if not name and not password:
+			return None
+		try:
+			password = password.decode()
+		except UnicodeDecodeError:
+			raise LDAPInvalidCredentials()
+		try:
+			evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes,
+				objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
+				dn_base=User.ldap_dn_base)
+		except ValueError:
+			raise LDAPInvalidCredentials()
+		user = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
+		if user is None or not user.check_password(password):
+			raise LDAPInvalidCredentials()
+		return user
+
+	def do_search(self, baseobj, scope, filter):
+		# User
+		ldap_attributes = {
+			'cn': 'displayname',
+			'displayname': 'displayname',
+			'gidnumber': 'ldap_gid',
+			'givenname': 'displayname',
+			'homedirectory': 'homedirectory',
+			'mail': 'email',
+			'sn': 'ldap_sn',
+			'uid': 'loginname',
+			'uidnumber': 'id',
+		}
+		ldap_objectclasses = [b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount']
+		ldap_rdn_attribute = 'uid'
+		ldap_dn_base = 'ou=users,dc=example,dc=com'
+		evaluator = SQLSearchEvaluator(User, session, attributes=ldap_attributes,
+			objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
+			dn_base=ldap_dn_base)
+		yield from evaluator(baseobj, scope, filter)
+		# Group
+		ldap_attributes = {
+			'cn': 'name',
+			'description': 'description',
+			'gidnumber': 'id',
+		}
+		ldap_objectclasses = [b'top', b'posixGroup', b'groupOfUniqueNames']
+		ldap_rdn_attribute = 'cn'
+		ldap_dn_base = 'ou=groups,dc=example,dc=com'
+		evaluator = SQLSearchEvaluator(Group, session, attributes=ldap_attributes,
+			objectclasses=ldap_objectclasses, rdn_attr=ldap_rdn_attribute,
+			dn_base=ldap_dn_base)
+		yield from evaluator(baseobj, scope, filter)
+
+ForkingTCPServer(('127.0.0.1', 1337), RequestHandler).serve_forever()
diff --git a/ldap.py b/ldap.py
index a5145f7..eadc75e 100644
--- a/ldap.py
+++ b/ldap.py
@@ -636,6 +636,20 @@ class ShallowLDAPMessage(Sequence):
 		(ShallowProtocolOp, 'protocolOp', None, False)
 	]
 
+# Extended Operation Values
+
+class PasswdModifyRequestValue(Sequence):
+	sequence_fields = [
+		(retag(LDAPString, (2, False, 0)), 'userIdentity', None, True),
+		(retag(OctetString, (2, False, 1)), 'oldPasswd', None, True),
+		(retag(OctetString, (2, False, 2)), 'newPasswd', None, True),
+	]
+
+class PasswdModifyResponseValue(Sequence):
+	sequence_fields = [
+		(retag(OctetString, (2, False, 0)), 'genPasswd', None, True),
+	]
+
 bind1 = b'0\x0c\x02\x01\x01`\x07\x02\x01\x03\x04\x00\x80\x00'
 bind2 = b'0\x1c\x02\x01\x01`\x17\x02\x01\x03\x04\nuid=foobar\x80\x06foobar'
 search1 = b'0?\x02\x01\x02c:\x04\x1aou=users,dc=example,dc=com\n\x01\x01\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\x87\x0bobjectclass0\x00'
diff --git a/server.py b/server.py
index 69af473..00a2762 100644
--- a/server.py
+++ b/server.py
@@ -1,40 +1,88 @@
 import traceback
-from socketserver import ForkingTCPServer, BaseRequestHandler
+from socketserver import BaseRequestHandler
 
-from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse
+from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue
 
-class Handler(BaseRequestHandler):
-	ldap_server = None
+class LDAPError(Exception):
+	def __init__(self, code=LDAPResultCode.other, message=''):
+		self.code = code
+		self.message = message
 
-	def setup(self):
-		self.bind_dn = b''
-		self.keep_running = True
+class LDAPOperationsError(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.operationsError, message)
+
+class LDAPProtocolError(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.protocolError, message)
+
+class LDAPInvalidCredentials(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.invalidCredentials, message)
+
+class LDAPInsufficientAccessRights(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.insufficientAccessRights, message)
 
+class LDAPUnwillingToPerform(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.unwillingToPerform, message)
+
+class LDAPAuthMethodNotSupported(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.authMethodNotSupported, message)
+
+class LDAPRequestHandler(BaseRequestHandler):
 	def handle_bind(self, req):
-		if not isinstance(req.protocolOp.authentication, SimpleAuthentication):
-			self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.authMethodNotSupported)))
-		name = req.protocolOp.name
-		password = req.protocolOp.authentication.password
-		for func in self.ldap_server.bind_handlers:
-			if func(name, password, self):
-				self.bind_dn = name
-				self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
-				return
-		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials)))
+		op = req.protocolOp
+		if op.version != 3:
+			raise LDAPProtocolError('Unsupported protocol version')
+		if isinstance(op.authentication, SimpleAuthentication):
+			self.bind_object = self.do_bind(op.name, op.authentication.password)
+		else:
+			raise LDAPAuthMethodNotSupported()
+		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
+
+	def do_bind(self, user, password):
+		raise LDAPInvalidCredentials()
 
 	def handle_search(self, req):
 		search = req.protocolOp
-		entries = []
-		for func in self.ldap_server.search_handlers:
-			entries += func(search.baseObject, search.scope, search.filter, self)
-		for dn, attributes in entries:
+		for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter):
 			attributes = [PartialAttribute(name, values) for name, values in attributes.items()]
 			self.send_msg(LDAPMessage(req.messageID, SearchResultEntry(dn, attributes)))
 		self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success)))
 
+	def do_search(self, baseobj, scope, filter):
+		yield from []
+
 	def handle_unbind(self, req):
 		self.keep_running = False
 
+	def handle_extended(self, req):
+		op = req.protocolOp
+		if op.requestName == '1.3.6.1.4.1.1466.20037':
+			# StartTLS (RFC 4511)
+			raise LDAPProtocolError()
+		elif op.requestName == '1.3.6.1.4.1.4203.1.11.1':
+			# Password Modify Extended Operation (RFC 3062)
+			newpw = None
+			if op.requestValue is None:
+				newpw = self.do_passwd()
+			else:
+				decoded, _ = PasswdModifyRequestValue.from_ber(op.requestValue)
+				newpw = self.do_passwd(decoded.userIdentity, decoded.oldPasswd, decoded.newPasswd)
+			if newpw is None:
+				self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success)))
+			else:
+				encoded = PasswdModifyResponseValue.to_ber(PasswdModifyResponseValue(newpw))
+				self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseValue=encoded)))
+		else:
+			raise LDAPProtocolError()
+
+	def do_passwd(self, user=None, oldpasswd=None, newpasswd=None):
+		raise LDAPUnwillingToPerform('Password change is not supported')
+
 	def handle_message(self, data):
 		handlers = {
 			BindRequest: (self.handle_bind, BindResponse),
@@ -46,7 +94,7 @@ class Handler(BaseRequestHandler):
 			ModifyDNRequest: (None, ModifyDNResponse),
 			CompareRequest: (None, CompareResponse),
 			AbandonRequest: (None, None),
-			ExtendedRequest: (None, ExtendedResponse), # TODO
+			ExtendedRequest: (self.handle_extended, ExtendedResponse),
 		}
 		shallowmsg, rest = ShallowLDAPMessage.from_ber(data)
 		if shallowmsg.protocolOp is None:
@@ -67,6 +115,10 @@ class Handler(BaseRequestHandler):
 				func(msg)
 			elif errfunc:
 				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights)))
+		except LDAPError as e:
+			if errfunc:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(e.code, diagnosticMessage=e.message)))
+			return rest
 		except Exception as e:
 			if errfunc:
 				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))
@@ -79,6 +131,8 @@ class Handler(BaseRequestHandler):
 		self.request.sendall(LDAPMessage.to_ber(msg))
 
 	def handle(self):
+		self.keep_running = True
+		self.bind_object = None
 		data = b''
 		while self.keep_running:
 			try:
@@ -89,18 +143,3 @@ class Handler(BaseRequestHandler):
 					return
 				data += chunk
 
-class Server:
-	def __init__(self):
-		self.bind_handlers = []
-		self.search_handlers = []
-
-	def bind_handler(self, func):
-		self.bind_handlers.append(func)
-
-	def search_handler(self, func):
-		self.search_handlers.append(func)
-
-	def run(self, host='127.0.0.1', port=1337):
-		class BoundHandler(Handler):
-			ldap_server = self
-		ForkingTCPServer((host, port), BoundHandler).serve_forever()
-- 
GitLab