From bb029c06c2de17b7331e1ac73e882a45312f69c1 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Sat, 23 Oct 2021 12:49:15 +0200
Subject: [PATCH] Implemented COMPARE in request handler

The default implementation relies on do_search to get the requested object.
---
 ldapserver/objects.py | 25 ++++++++++++++++++++-----
 ldapserver/server.py  | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/ldapserver/objects.py b/ldapserver/objects.py
index ae267be..da758b6 100644
--- a/ldapserver/objects.py
+++ b/ldapserver/objects.py
@@ -1,6 +1,6 @@
 import enum
 
-from . import ldap
+from . import ldap, exceptions
 from .dn import DN
 
 class FilterResult(enum.Enum):
@@ -99,7 +99,7 @@ class AttributeDict(dict):
 		attribute_type = self.schema.get_attribute_type(key)
 		if attribute_type is None:
 			return FilterResult.UNDEFINED
-		if self[attribute_type] != []:
+		if self[attribute_type.name] != []:
 			return FilterResult.TRUE
 		else:
 			return FilterResult.FALSE
@@ -223,6 +223,21 @@ class Object(AttributeDict):
 	def match_search(self, base_obj, scope, filter_obj):
 		return self.match_dn(DN.from_str(base_obj), scope) and self.match_filter(filter_obj) == FilterResult.TRUE
 
+	def match_compare(self, attribute, value):
+		try:
+			attribute_type = self.schema.get_attribute_type(attribute)
+		except KeyError as exc:
+			raise exceptions.LDAPUndefinedAttributeType() from exc
+		if attribute_type.equality is None:
+			raise exceptions.LDAPInappropriateMatching()
+		value = attribute_type.equality.syntax.decode(self.schema, value)
+		if value is None:
+			raise exceptions.LDAPInvalidAttributeSyntax()
+		for attrval in self.get_with_subtypes(attribute):
+			if attribute_type.equality.match_equal(self.schema, attrval, value):
+				return True
+		return False
+
 	def get_search_result_entry(self, attributes=None, types_only=False):
 		selected_attributes = set()
 		for selector in attributes or ['*']:
@@ -437,9 +452,9 @@ class SubschemaSubentry(Object):
 		self['attributeTypes'] = [item.to_definition() for item in schema.attribute_types]
 		# pylint: disable=invalid-name
 		self.AttributeDict = lambda **attributes: AttributeDict(schema, **attributes)
-		self.Object = lambda *args, **attributes: Object(schema, dn, subschemaSubentry=[dn], **attributes)
-		self.RootDSE = lambda **attributes: RootDSE(schema, subschemaSubentry=[dn], **attributes)
-		self.ObjectTemplate = lambda *args, **kwargs: ObjectTemplate(schema, *args, subschemaSubentry=[dn], **kwargs)
+		self.Object = lambda *args, **attributes: Object(schema, *args, subschemaSubentry=[self.dn], **attributes)
+		self.RootDSE = lambda **attributes: RootDSE(schema, subschemaSubentry=[self.dn], **attributes)
+		self.ObjectTemplate = lambda *args, **kwargs: ObjectTemplate(schema, *args, subschemaSubentry=[self.dn], **kwargs)
 
 	def match_search(self, base_obj, scope, filter_obj):
 		return DN.from_str(base_obj) == self.dn and  \
diff --git a/ldapserver/server.py b/ldapserver/server.py
index d554fb2..bbe8dbf 100644
--- a/ldapserver/server.py
+++ b/ldapserver/server.py
@@ -4,6 +4,7 @@ import socketserver
 import typing
 
 from . import asn1, exceptions, ldap, schema, objects
+from .dn import DN
 
 def reject_critical_controls(controls=None):
 	for control in controls or []:
@@ -421,6 +422,38 @@ class LDAPRequestHandler(BaseLDAPRequestHandler):
 		yield self.subschema
 		yield from self.static_objects
 
+	def handle_compare(self, op, controls=None):
+		obj = self.do_compare(op.entry, op.ava.attributeDesc, op.ava.assertionValue)
+		if obj is None:
+			raise exceptions.LDAPNoSuchObject()
+		if obj.match_compare(op.ava.attributeDesc, op.ava.assertionValue):
+			return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)]
+		else:
+			return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)]
+
+	def do_compare(self, dn, attribute, value):
+		'''Lookup object for COMPARE operation
+
+		:param dn: Distinguished name of the LDAP entry
+		:type dn: str
+		:param attribute: Attribute type
+		:type attribute: str
+		:param value: Attribute value
+		:type value: bytes
+
+		:raises exceptions.LDAPError: on error
+
+		:returns: `Object` or None
+
+		The default implementation calls `do_search` and returns the first object
+		with the right DN.'''
+		objs = self.do_search(dn, ldap.SearchScope.baseObject, ldap.FilterPresent(attribute='objectClass'))
+		dn = DN.from_str(dn)
+		for obj in objs:
+			if obj.dn == dn:
+				return obj
+		return None
+
 	def handle_unbind(self, op, controls=None):
 		reject_critical_controls(controls)
 		self.keep_running = False
-- 
GitLab