diff --git a/ldapserver/server.py b/ldapserver/server.py index 98b29177622c53df5594e80c3f14df48cb2ff1a6..4157b0a44e076569ff302f37d917fe4dd0c30bb2 100644 --- a/ldapserver/server.py +++ b/ldapserver/server.py @@ -523,14 +523,12 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): def handle_compare(self, op, controls=None): self.logger.info('COMPRAE request "%s" %s=%s', op.entry, op.ava.attributeDesc, repr(op.ava.assertionValue)) obj = self.do_compare(op.entry, op.ava.attributeDesc, op.ava.assertionValue) - if obj is None: - raise exceptions.LDAPNoSuchObject() - if obj.match_compare(op.ava.attributeDesc, op.ava.assertionValue): - self.logger.info('COMPRAE result TRUE') - return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)] - else: - self.logger.info('COMPRAE result FALSE') - return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)] + if obj is not None: + if obj.compare(op.entry, op.ava.attributeDesc, op.ava.assertionValue): + return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)] + else: + return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)] + raise exceptions.LDAPNoSuchObject() def do_compare(self, dn, attribute, value): '''Lookup object for COMPARE operation @@ -550,7 +548,7 @@ class LDAPRequestHandler(BaseLDAPRequestHandler): objs = self.do_search(dn, ldap.SearchScope.baseObject, ldap.FilterPresent(attribute='objectClass')) for obj in objs: try: - obj.compare(dn) + obj.compare(dn, attribute, value) return obj except exceptions.LDAPNoSuchObject: pass diff --git a/tests/test_server.py b/tests/test_server.py index a471fcf0759a9bc4b7644fcdb6cfc947b7f35ab6..b7a908c3c460d972790bf58e798a8cd9e1fa2445 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,6 @@ import unittest -from ldapserver import BaseLDAPRequestHandler, LDAPRequestHandler, ldap +from ldapserver import BaseLDAPRequestHandler, LDAPRequestHandler, ldap, exceptions class MockConnection: def __init__(self, data, chunksize): @@ -169,3 +169,71 @@ class TestLDAPRequestHandler(unittest.TestCase): self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.SearchResultEntry('cn=Test1,dc=example,dc=com'))) self.assertEqual(ldap.ProtocolOp.to_ber(resps[1]), ldap.ProtocolOp.to_ber(ldap.SearchResultEntry('cn=Test2,dc=example,dc=com'))) self.assertEqual(ldap.ProtocolOp.to_ber(resps[2]), ldap.ProtocolOp.to_ber(ldap.SearchResultDone())) + + def test_compare(self): + class MockObject: + def __init__(_self, result=None): + _self.result = result + + def compare(_self, dn, attribute, value): + self.assertEqual(dn, 'cn=Test,dc=example,dc=com') + self.assertEqual(attribute, 'foo') + self.assertEqual(value, b'bar') + if isinstance(_self.result, Exception): + raise _self.result + return _self.result + + class RequestHandler(LDAPRequestHandler): + def handle(self): + pass + + def do_search(_self, base_obj, scope, filter_obj): + self.assertEqual(base_obj, 'cn=Test,dc=example,dc=com') + if _self.mode == 'true': + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(True) + yield MockObject(exceptions.LDAPNoSuchObject()) + elif _self.mode == 'false': + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(False) + yield MockObject(exceptions.LDAPNoSuchObject()) + elif _self.mode == 'empty': + pass + elif _self.mode == 'notfound': + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(exceptions.LDAPNoSuchObject()) + elif _self.mode == 'error': + yield MockObject(exceptions.LDAPNoSuchObject()) + yield MockObject(exceptions.LDAPUndefinedAttributeType()) + yield MockObject(exceptions.LDAPNoSuchObject()) + + handler = RequestHandler(None, None, None) + handler.mode = 'true' + resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar')))) + self.assertEqual(len(resps), 1) + self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.compareTrue))) + + handler.mode = 'false' + resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar')))) + self.assertEqual(len(resps), 1) + self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.compareFalse))) + + handler.mode = 'empty' + with self.assertRaises(exceptions.LDAPNoSuchObject): + resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar')))) + self.assertEqual(len(resps), 1) + self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.noSuchObject))) + + handler.mode = 'notfound' + with self.assertRaises(exceptions.LDAPNoSuchObject): + resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar')))) + self.assertEqual(len(resps), 1) + self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.noSuchObject))) + + handler.mode = 'error' + with self.assertRaises(exceptions.LDAPUndefinedAttributeType): + resps = list(handler.handle_compare(ldap.CompareRequest('cn=Test,dc=example,dc=com', ava=ldap.AttributeValueAssertion('foo', b'bar')))) + self.assertEqual(len(resps), 1) + self.assertEqual(ldap.ProtocolOp.to_ber(resps[0]), ldap.ProtocolOp.to_ber(ldap.CompareResponse(ldap.LDAPResultCode.undefinedAttributeType)))