From 76eadd597bec20bd53866474a5d9b5b3d826d5a3 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Sat, 6 Mar 2021 18:20:13 +0100
Subject: [PATCH] Implemented stubs for all remaining standard operations

---
 db.py     |  2 +-
 ldap.py   | 75 +++++++++++++++++++++++++++++++++++++++++++++++++------
 server.py | 14 +++++++++--
 3 files changed, 81 insertions(+), 10 deletions(-)

diff --git a/db.py b/db.py
index 8c2791b..c6a5024 100644
--- a/db.py
+++ b/db.py
@@ -260,7 +260,7 @@ def ldap_bind(name, password, conn):
 	evaluator = SQLSearchEvaluator(User, session, attributes=User.ldap_attributes,
 		objectclasses=User.ldap_objectclasses, rdn_attr=User.ldap_rdn_attribute,
 		dn_base=User.ldap_dn_base)
-	res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one()
+	res = session.query(User).filter(evaluator.filter_dn(name, SearchScope.baseObject)).one_or_none()
 	if res:
 		return res.check_password(password)
 	return False
diff --git a/ldap.py b/ldap.py
index 5aceb6d..a60561d 100644
--- a/ldap.py
+++ b/ldap.py
@@ -19,10 +19,11 @@ def decode_ber(data):
 	index += 1
 	if not data[index] & 0x80:
 		length = data[index]
+		index += 1
 	elif data[index] == 0x80:
 		raise ValueError('Indefinite form not implemented')
 	elif data[index] == 0xff:
-		return ValueError('BER length invalid')
+		raise ValueError('BER length invalid')
 	else:
 		num = data[index] & ~0x80
 		index += 1
@@ -32,10 +33,10 @@ def decode_ber(data):
 		for octet in data[index:index + num]:
 			length = length << 8 | octet
 		index += num
-	if len(data) < index + length + 1:
-		raise IncompleteBERError(index + length + 1)
-	ber_content = data[index + 1: index + length + 1]
-	rest = data[index + length + 1:]
+	if len(data) < index + length:
+		raise IncompleteBERError(index + length)
+	ber_content = data[index: index + length]
+	rest = data[index + length:]
 	return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest
 
 def decode_ber_integer(data):
@@ -433,6 +434,9 @@ class BindRequest(Sequence, ProtocolOp):
 class BindResponse(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 1)
 
+class UnbindRequest(Sequence, ProtocolOp):
+	ber_tag = (1, False, 2)
+
 class SearchRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 3)
 	sequence_fields = [
@@ -460,8 +464,65 @@ class SearchResultEntry(Sequence, ProtocolOp):
 class SearchResultDone(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 5)
 
-class UnbindRequest(Sequence, ProtocolOp):
-	ber_tag = (1, False, 2)
+class ModifyRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 6)
+	# stub
+
+class ModifyResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 7)
+
+class AddRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 8)
+	# stub
+
+class AddResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 9)
+
+class DelRequest(Wrapper, ProtocolOp):
+	ber_tag = (1, False, 10)
+	wrapped_attribute = 'dn'
+	wrapped_type = LDAPString
+	wrapped_default = None
+
+class DelResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 11)
+
+class ModifyDNRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 12)
+	# stub
+
+class ModifyDNResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 13)
+
+class CompareRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 14)
+	# stub
+
+class CompareResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 15)
+
+class AbandonRequest(Wrapper, ProtocolOp):
+	ber_tag = (1, False, 16)
+	wrapped_attribute = 'messageID'
+	wrapped_type = Integer
+	wrapped_default = None
+
+class ExtendedRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 23)
+	# stub
+
+class ExtendedResponse(Sequence, ProtocolOp):
+	ber_tag = (1, True, 24)
+	sequence_fields = [
+		(LDAPResultCodeEnum, 'resultCode', None),
+		(LDAPString, 'matchedDN', ''),
+		(LDAPString, 'diagnosticMessage', ''),
+	]
+	# stub
+
+class IntermediateResponse(Sequence, ProtocolOp):
+	ber_tag = (1, True, 25)
+	# stub
 
 class LDAPMessage(Sequence):
 	sequence_fields = [
diff --git a/server.py b/server.py
index 96c52d0..69af473 100644
--- a/server.py
+++ b/server.py
@@ -1,7 +1,7 @@
 import traceback
 from socketserver import ForkingTCPServer, BaseRequestHandler
 
-from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication
+from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse
 
 class Handler(BaseRequestHandler):
 	ldap_server = None
@@ -19,6 +19,7 @@ class Handler(BaseRequestHandler):
 			if func(name, password, self):
 				self.bind_dn = name
 				self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
+				return
 		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.invalidCredentials)))
 
 	def handle_search(self, req):
@@ -37,8 +38,15 @@ class Handler(BaseRequestHandler):
 	def handle_message(self, data):
 		handlers = {
 			BindRequest: (self.handle_bind, BindResponse),
-			SearchRequest: (self.handle_search, SearchResultDone),
 			UnbindRequest: (self.handle_unbind, None),
+			SearchRequest: (self.handle_search, SearchResultDone),
+			ModifyRequest: (None, ModifyResponse),
+			AddRequest: (None,  AddResponse),
+			DelRequest: (None, DelResponse),
+			ModifyDNRequest: (None, ModifyDNResponse),
+			CompareRequest: (None, CompareResponse),
+			AbandonRequest: (None, None),
+			ExtendedRequest: (None, ExtendedResponse), # TODO
 		}
 		shallowmsg, rest = ShallowLDAPMessage.from_ber(data)
 		if shallowmsg.protocolOp is None:
@@ -57,6 +65,8 @@ class Handler(BaseRequestHandler):
 		try:
 			if func:
 				func(msg)
+			elif errfunc:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights)))
 		except Exception as e:
 			if errfunc:
 				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))
-- 
GitLab