From 49cd235f1d10bdd991c0f743db42bf1992776262 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Sat, 6 Mar 2021 21:58:54 +0100
Subject: [PATCH] Completed ldap message definitions

---
 ldap.py | 195 +++++++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 143 insertions(+), 52 deletions(-)

diff --git a/ldap.py b/ldap.py
index a60561d..a5145f7 100644
--- a/ldap.py
+++ b/ldap.py
@@ -136,6 +136,9 @@ class LDAPString(OctetString):
 			raise TypeError()
 		return super().to_ber(obj.encode())
 
+class LDAPOID(LDAPString):
+	pass
+
 class Set(BERType):
 	ber_tag = (0, True, 17)
 	set_type = OctetString
@@ -165,12 +168,12 @@ class SequenceOf(Set):
 class Sequence(BERType):
 	ber_tag = (0, True, 16)
 	sequence_fields = [
-		#(Type, attr_name, default_value),
+		#(Type, attr_name, default_value, optional?),
 	]
 
 	def __init__(self, *args, **kwargs):
 		for index, spec in enumerate(type(self).sequence_fields):
-			field_type, name, default = spec
+			field_type, name, default, optional = spec
 			if index < len(args):
 				value = args[index]
 			elif name in kwargs:
@@ -181,7 +184,7 @@ class Sequence(BERType):
 
 	def __repr__(self):
 		args = []
-		for field_type, name, default in type(self).sequence_fields:
+		for field_type, name, default, optional in type(self).sequence_fields:
 			args.append('%s=%s'%(name, repr(getattr(self, name))))
 		return '<%s(%s)>'%(type(self).__name__, ', '.join(args))
 
@@ -192,9 +195,14 @@ class Sequence(BERType):
 			raise ValueError()
 		args = []
 		data = seqobj.content
-		for field_type, name, default in cls.sequence_fields:
-			obj, data = field_type.from_ber(data)
-			args.append(obj)
+		for field_type, name, default, optional in cls.sequence_fields:
+			try:
+				obj, data = field_type.from_ber(data)
+				args.append(obj)
+			except ValueError as e:
+				if not optional:
+					raise e
+				args.append(None)
 		return cls(*args), rest
 
 	@classmethod
@@ -202,8 +210,9 @@ class Sequence(BERType):
 		if not isinstance(obj, cls):
 			raise TypeError()
 		content = b''
-		for field_type, name, default in cls.sequence_fields:
-			content += field_type.to_ber(getattr(obj, name))
+		for field_type, name, default, optional in cls.sequence_fields:
+			if not optional or getattr(obj, name) is not None:
+				content += field_type.to_ber(getattr(obj, name))
 		return encode_ber(BERObject(cls.ber_tag, content))
 
 class Choice(BERType):
@@ -263,6 +272,17 @@ class Wrapper(BERType):
 			raise TypeError()
 		return WrappedType.to_ber(getattr(obj, cls.wrapped_attribute))
 
+def retag(cls, tag):
+	class Overwritten(cls):
+		ber_tag = tag
+	return Overwritten
+
+class AttributeValueAssertion(Sequence):
+	sequence_fields = [
+		(LDAPString, 'attributeDesc', None, False),
+		(OctetString, 'assertionValue', None, False),
+	]
+
 class Filter(Choice):
 	pass
 
@@ -281,22 +301,40 @@ class FilterOr(Wrapper, Filter):
 class FilterNot(Sequence, Filter):
 	ber_tag = (2, True, 2)
 	sequence_fields = [
-		(Filter, 'filter', None)
+		(Filter, 'filter', None, False)
 	]
 
 class FilterEqual(Sequence, Filter):
 	ber_tag = (2, True, 3)
 	sequence_fields = [
-		(LDAPString, 'attribute', None),
-		(OctetString, 'value', None)
+		(LDAPString, 'attribute', None, False),
+		(OctetString, 'value', None, False)
 	]
 
+class FilterGreaterOrEqual(AttributeValueAssertion, Filter):
+	ber_tag = (2, True, 5)
+
+class FilterLessOrEqual(AttributeValueAssertion, Filter):
+	ber_tag = (2, True, 6)
+
 class FilterPresent(Wrapper, Filter):
 	ber_tag = (2, False, 7)
 	wrapped_attribute = 'attribute'
 	wrapped_type = LDAPString
 	wrapped_default = None
 
+class FilterApproxMatch(AttributeValueAssertion, Filter):
+	ber_tag = (2, True, 8)
+
+class FilterExtensibleMatch(Sequence, Filter):
+	ber_tag = (2, True, 9)
+	sequence_fields = [
+		(retag(LDAPString, (2, False, 1)), 'matchingRule', None, True),
+		(retag(LDAPString, (2, False, 2)), 'type', None, True),
+		(retag(OctetString, (2, False, 3)), 'matchValue', None, False),
+		(retag(Boolean, (2, False, 4)), 'dnAttributes', None, True),
+	]
+
 class Enum(BERType):
 	ber_tag = (0, False, 10)
 	enum_type = None
@@ -315,23 +353,22 @@ class Enum(BERType):
 			raise TypeError()
 		return encode_ber(BERObject(cls.ber_tag, encode_ber_integer(obj.value)))
 
+def wrapenum(enumtype):
+	class WrappedEnum(Enum):
+		enum_type = enumtype
+	return WrappedEnum
+
 class SearchScope(enum.Enum):
 	baseObject = 0 # The scope is constrained to the entry named by baseObject.
 	singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject.
 	wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates.
 
-class SearchScopeEnum(Enum):
-	enum_type = SearchScope
-
 class DerefAliases(enum.Enum):
 	neverDerefAliases = 0
 	derefInSearching = 1
 	derefFindingBaseObj = 2
 	derefAlways = 3
 
-class DerefAliasesEnum(Enum):
-	enum_type = DerefAliases
-
 class LDAPResultCode(enum.Enum):
 	success                      = 0
 	operationsError              = 1
@@ -380,15 +417,12 @@ class LDAPResultCode(enum.Enum):
 	# -- 72-79 unused --
 	other                        = 80
 
-class LDAPResultCodeEnum(Enum):
-	enum_type = LDAPResultCode
-
 class LDAPResult(Sequence):
 	ber_tag = (5, True, 1)
 	sequence_fields = [
-		(LDAPResultCodeEnum, 'resultCode', None),
-		(LDAPString, 'matchedDN', ''),
-		(LDAPString, 'diagnosticMessage', ''),
+		(wrapenum(LDAPResultCode), 'resultCode', None, False),
+		(LDAPString, 'matchedDN', '', False),
+		(LDAPString, 'diagnosticMessage', '', False),
 	]
 
 class AttributeSelection(SequenceOf):
@@ -413,22 +447,32 @@ class AttributeValueSet(Set):
 
 class PartialAttribute(Sequence):
 	sequence_fields = [
-		(LDAPString, 'type', None),
-		(AttributeValueSet, 'vals', lambda: []),
+		(LDAPString, 'type', None, False),
+		(AttributeValueSet, 'vals', lambda: [], False),
 	]
 
 class PartialAttributeList(SequenceOf):
 	set_type = PartialAttribute
 
+class Attribute(Sequence):
+	# Constrain: vals must not be empty
+	sequence_fields = [
+		(LDAPString, 'type', None, False),
+		(AttributeValueSet, 'vals', lambda: [], False),
+	]
+
+class AttributeList(SequenceOf):
+	set_type = Attribute
+
 class ProtocolOp(Choice):
 	pass
 
 class BindRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 0)
 	sequence_fields = [
-		(Integer, 'version', 3),
-		(LDAPString, 'name', ''),
-		(AuthenticationChoice, 'authentication', lambda: SimpleAuthentication())
+		(Integer, 'version', 3, False),
+		(LDAPString, 'name', '', False),
+		(AuthenticationChoice, 'authentication', lambda: SimpleAuthentication(), False)
 	]
 
 class BindResponse(LDAPResult, ProtocolOp):
@@ -440,14 +484,14 @@ class UnbindRequest(Sequence, ProtocolOp):
 class SearchRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 3)
 	sequence_fields = [
-		(LDAPString, 'baseObject', ''),
-		(SearchScopeEnum, 'scope', SearchScope.wholeSubtree),
-		(DerefAliasesEnum, 'derefAliases', DerefAliases.neverDerefAliases),
-		(Integer, 'sizeLimit', 0),
-		(Integer, 'timeLimit', 0),
-		(Boolean, 'typesOnly', False),
-		(Filter, 'filter', lambda: FilterPresent('objectClass')),
-		(AttributeSelection, 'attributes', lambda: [])
+		(LDAPString, 'baseObject', '', False),
+		(wrapenum(SearchScope), 'scope', SearchScope.wholeSubtree, False),
+		(wrapenum(DerefAliases), 'derefAliases', DerefAliases.neverDerefAliases, False),
+		(Integer, 'sizeLimit', 0, False),
+		(Integer, 'timeLimit', 0, False),
+		(Boolean, 'typesOnly', False, False),
+		(Filter, 'filter', lambda: FilterPresent('objectClass'), False),
+		(AttributeSelection, 'attributes', lambda: [], False)
 	]
 
 	@classmethod
@@ -457,23 +501,43 @@ class SearchRequest(Sequence, ProtocolOp):
 class SearchResultEntry(Sequence, ProtocolOp):
 	ber_tag = (1, True, 4)
 	sequence_fields = [
-		(LDAPString, 'objectName', ''),
-		(PartialAttributeList, 'attributes', lambda: []),
+		(LDAPString, 'objectName', '', False),
+		(PartialAttributeList, 'attributes', lambda: [], False),
 	]
 
 class SearchResultDone(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 5)
 
+class ModifyOperation(enum.Enum):
+	add = 0
+	delete = 1
+	replace = 2
+
+class ModifyChange(Sequence):
+	sequence_fields = [
+		(wrapenum(ModifyOperation), 'operation', None, False),
+		(PartialAttribute, 'modification', None, False),
+	]
+
+class ModifyChanges(SequenceOf):
+	set_type = ModifyChange
+
 class ModifyRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 6)
-	# stub
+	sequence_fields = [
+		(LDAPString, 'object', None, False),
+		(ModifyChanges, 'changes', None, False),
+	]
 
 class ModifyResponse(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 7)
 
 class AddRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 8)
-	# stub
+	sequence_fields = [
+		(LDAPString, 'entry', None, False),
+		(AttributeList, 'attributes', None, False),
+	]
 
 class AddResponse(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 9)
@@ -489,14 +553,22 @@ class DelResponse(LDAPResult, ProtocolOp):
 
 class ModifyDNRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 12)
-	# stub
+	sequence_fields = [
+		(LDAPString, 'entry', None, False),
+		(LDAPString, 'newrdn', None, False),
+		(Boolean, 'deleteoldrdn', None, False),
+		(retag(LDAPString, (2, False, 0)), 'newSuperior', None, True),
+	]
 
 class ModifyDNResponse(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 13)
 
 class CompareRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 14)
-	# stub
+	sequence_fields = [
+		(LDAPString, 'entry', None, False),
+		(AttributeValueAssertion, 'ava', None, False),
+	]
 
 class CompareResponse(LDAPResult, ProtocolOp):
 	ber_tag = (1, True, 15)
@@ -509,25 +581,44 @@ class AbandonRequest(Wrapper, ProtocolOp):
 
 class ExtendedRequest(Sequence, ProtocolOp):
 	ber_tag = (1, True, 23)
-	# stub
+	sequence_fields = [
+		(retag(LDAPOID, (2, False, 0)), 'requestName', None, True),
+		(retag(OctetString, (2, False, 1)), 'requestValue', None, True),
+	]
 
 class ExtendedResponse(Sequence, ProtocolOp):
 	ber_tag = (1, True, 24)
 	sequence_fields = [
-		(LDAPResultCodeEnum, 'resultCode', None),
-		(LDAPString, 'matchedDN', ''),
-		(LDAPString, 'diagnosticMessage', ''),
+		(wrapenum(LDAPResultCode), 'resultCode', None, False),
+		(LDAPString, 'matchedDN', '', False),
+		(LDAPString, 'diagnosticMessage', '', False),
+		(retag(LDAPOID, (2, False, 10)), 'responseName', None, True),
+		(retag(OctetString, (2, False, 11)), 'responseValue', None, True),
 	]
-	# stub
 
 class IntermediateResponse(Sequence, ProtocolOp):
 	ber_tag = (1, True, 25)
-	# stub
+	sequence_fields = [
+		(retag(LDAPOID, (2, False, 0)), 'responseName', None, True),
+		(retag(OctetString, (2, False, 1)), 'responseValue', None, True),
+	]
+
+class Control(Sequence):
+	sequence_fields = [
+		(LDAPOID, 'controlType', None, False),
+		(Boolean, 'criticality', None, True),
+		(OctetString, 'controlValue', None, True),
+	]
+
+class Controls(SequenceOf):
+	ber_tag = (2, True, 0)
+	set_type = Control
 
 class LDAPMessage(Sequence):
 	sequence_fields = [
-		(Integer, 'messageID', None),
-		(ProtocolOp, 'protocolOp', None)
+		(Integer, 'messageID', None, False),
+		(ProtocolOp, 'protocolOp', None, False),
+		(Controls, 'controls', None, True)
 	]
 
 class ShallowProtocolOp:
@@ -541,8 +632,8 @@ class ShallowProtocolOp:
 
 class ShallowLDAPMessage(Sequence):
 	sequence_fields = [
-		(Integer, 'messageID', None),
-		(ShallowProtocolOp, 'protocolOp', None)
+		(Integer, 'messageID', None, False),
+		(ShallowProtocolOp, 'protocolOp', None, False)
 	]
 
 bind1 = b'0\x0c\x02\x01\x01`\x07\x02\x01\x03\x04\x00\x80\x00'
-- 
GitLab