From 551a997f52ee075e909ccb80d439b2eef3e32298 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Tue, 22 Feb 2022 16:44:15 +0100
Subject: [PATCH] Fix COMPARE request handling

The code for COMPRAE requests in LDAPRequestHandler was not updated when the
interface of Entry changed. For this reason, all COMPARE requests failed with
a TypeError exception and consequently result code "other".
---
 ldapserver/server.py | 16 +++++-----
 tests/test_server.py | 70 +++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 76 insertions(+), 10 deletions(-)

diff --git a/ldapserver/server.py b/ldapserver/server.py
index 98b2917..4157b0a 100644
--- a/ldapserver/server.py
+++ b/ldapserver/server.py
@@ -523,14 +523,12 @@ class LDAPRequestHandler(BaseLDAPRequestHandler):
 	def handle_compare(self, op, controls=None):
 		self.logger.info('COMPRAE request "%s" %s=%s', op.entry, op.ava.attributeDesc, repr(op.ava.assertionValue))
 		obj = self.do_compare(op.entry, op.ava.attributeDesc, op.ava.assertionValue)
-		if obj is None:
-			raise exceptions.LDAPNoSuchObject()
-		if obj.match_compare(op.ava.attributeDesc, op.ava.assertionValue):
-			self.logger.info('COMPRAE result TRUE')
-			return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)]
-		else:
-			self.logger.info('COMPRAE result FALSE')
-			return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)]
+		if obj is not None:
+			if obj.compare(op.entry, op.ava.attributeDesc, op.ava.assertionValue):
+				return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)]
+			else:
+				return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)]
+		raise exceptions.LDAPNoSuchObject()
 
 	def do_compare(self, dn, attribute, value):
 		'''Lookup object for COMPARE operation
@@ -550,7 +548,7 @@ class LDAPRequestHandler(BaseLDAPRequestHandler):
 		objs = self.do_search(dn, ldap.SearchScope.baseObject, ldap.FilterPresent(attribute='objectClass'))
 		for obj in objs:
 			try:
-				obj.compare(dn)
+				obj.compare(dn, attribute, value)
 				return obj
 			except exceptions.LDAPNoSuchObject:
 				pass
diff --git a/tests/test_server.py b/tests/test_server.py
index a471fcf..b7a908c 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -1,6 +1,6 @@
 import unittest
 
-from ldapserver import BaseLDAPRequestHandler, LDAPRequestHandler, ldap
+from ldapserver import BaseLDAPRequestHandler, LDAPRequestHandler, ldap, exceptions
 
 class MockConnection:
 	def __init__(self, data, chunksize):
@@ -169,3 +169,71 @@ class TestLDAPRequestHandler(unittest.TestCase):
 		self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.SearchResultEntry('cn=Test1,dc=example,dc=com')))
 		self.assertEqual(ldap.ProtocolOp.to_ber(resps[1]), ldap.ProtocolOp.to_ber(ldap.SearchResultEntry('cn=Test2,dc=example,dc=com')))
 		self.assertEqual(ldap.ProtocolOp.to_ber(resps[2]), ldap.ProtocolOp.to_ber(ldap.SearchResultDone()))
+
+	def test_compare(self):
+		class MockObject:
+			def __init__(_self, result=None):
+				_self.result = result
+
+			def compare(_self, dn, attribute, value):
+				self.assertEqual(dn, 'cn=Test,dc=example,dc=com')
+				self.assertEqual(attribute, 'foo')
+				self.assertEqual(value, b'bar')
+				if isinstance(_self.result, Exception):
+					raise _self.result
+				return _self.result
+
+		class RequestHandler(LDAPRequestHandler):
+			def handle(self):
+				pass
+
+			def do_search(_self, base_obj, scope, filter_obj):
+				self.assertEqual(base_obj, 'cn=Test,dc=example,dc=com')
+				if _self.mode == 'true':
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(True)
+					yield MockObject(exceptions.LDAPNoSuchObject())
+				elif _self.mode == 'false':
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(False)
+					yield MockObject(exceptions.LDAPNoSuchObject())
+				elif _self.mode == 'empty':
+					pass
+				elif _self.mode == 'notfound':
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(exceptions.LDAPNoSuchObject())
+				elif _self.mode == 'error':
+					yield MockObject(exceptions.LDAPNoSuchObject())
+					yield MockObject(exceptions.LDAPUndefinedAttributeType())
+					yield MockObject(exceptions.LDAPNoSuchObject())
+
+		handler = RequestHandler(None, None, None)
+		handler.mode = 'true'
+		resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar'))))
+		self.assertEqual(len(resps), 1)
+		self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)))
+
+		handler.mode = 'false'
+		resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar'))))
+		self.assertEqual(len(resps), 1)
+		self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)))
+
+		handler.mode = 'empty'
+		with self.assertRaises(exceptions.LDAPNoSuchObject):
+			resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar'))))
+			self.assertEqual(len(resps), 1)
+			self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.noSuchObject)))
+
+		handler.mode = 'notfound'
+		with self.assertRaises(exceptions.LDAPNoSuchObject):
+			resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar'))))
+			self.assertEqual(len(resps), 1)
+			self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.noSuchObject)))
+
+		handler.mode = 'error'
+		with self.assertRaises(exceptions.LDAPUndefinedAttributeType):
+			resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar'))))
+			self.assertEqual(len(resps), 1)
+			self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.undefinedAttributeType)))
-- 
GitLab