From 39f0d293187293a64ef11bcd26f17335bab947c9 Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Fri, 5 Mar 2021 21:02:35 +0100
Subject: [PATCH] Implemented custom asn.1/ldap parser

---
 dn.py     |   3 +
 ldap.py   | 488 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 server.py |  65 ++++++++
 3 files changed, 556 insertions(+)
 create mode 100644 ldap.py
 create mode 100644 server.py

diff --git a/dn.py b/dn.py
index 1ce21a3..020e7d0 100644
--- a/dn.py
+++ b/dn.py
@@ -119,3 +119,6 @@ class DNScope(Enum):
 	singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject.
 	wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates.
 
+	@classmethod
+	def from_bytes(cls, data):
+		return self(data)
diff --git a/ldap.py b/ldap.py
new file mode 100644
index 0000000..48d3091
--- /dev/null
+++ b/ldap.py
@@ -0,0 +1,488 @@
+from collections import namedtuple
+import enum
+
+BERObject = namedtuple('BERObject', ['tag', 'content'])
+
+class IncompleteBERError(ValueError):
+	def __init__(self, expected_length=-1):
+		super().__init__()
+		self.expected_length = expected_length
+
+def decode_ber(data):
+	index = 0
+	if len(data) < 2:
+		raise IncompleteBERError(2)
+	identifier = data[index]
+	ber_class = identifier >> 6
+	ber_constructed = bool(identifier & 0x20)
+	ber_type = identifier & 0x1f
+	index += 1
+	if not data[index] & 0x80:
+		length = data[index]
+	elif data[index] == 0x80:
+		raise ValueError('Indefinite form not implemented')
+	elif data[index] == 0xff:
+		return ValueError('BER length invalid')
+	else:
+		num = data[index] & ~0x80
+		index += 1
+		if len(data) < index + num:
+			raise IncompleteBERError(index + num)
+		length = 0
+		for octet in data[index:index + num]:
+			length = length << 8 | octet
+		index += num
+	if len(data) < index + length + 1:
+		raise IncompleteBERError(index + length + 1)
+	ber_content = data[index + 1: index + length + 1]
+	rest = data[index + length + 1:]
+	return BERObject((ber_class, ber_constructed, ber_type), ber_content), rest
+
+def decode_ber_integer(data):
+	if not data:
+		return 0
+	value = -1 if data[0] & 0x80 else 0
+	for octet in data:
+		value = value << 8 | octet
+	return value
+
+def encode_ber(obj):
+	tag = (obj.tag[0] & 0b11) << 6 | (obj.tag[1] & 1) << 5 | (obj.tag[2] & 0b11111)
+	length = len(obj.content)
+	if length >= 127:
+		raise NotImplementedError('Long form length encoding not implemented')
+	return bytes([tag, length]) + obj.content
+
+def encode_ber_integer(value):
+	if value < 0 or value > 255:
+		raise NotImplementedError('Encoding of integers greater than 255 is not implemented')
+	return bytes([value])
+
+class BERType:
+	@classmethod
+	def from_ber(cls, data):
+		raise NotImplementedError()
+
+	@classmethod
+	def to_ber(cls, obj):
+		raise NotImplementedError()
+
+	def __bytes__(self):
+		return type(self).to_ber(self)
+
+class OctetString(BERType):
+	ber_tag = (0, False, 4)
+
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		if obj.tag != cls.ber_tag:
+			raise ValueError('Expected tag %s but found %s'%(cls.ber_tag, obj.tag))
+		return obj.content, rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, bytes):
+			raise TypeError()
+		return encode_ber(BERObject(cls.ber_tag, obj))
+
+class Integer(BERType):
+	ber_tag = (0, False, 2)
+
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		if obj.tag != cls.ber_tag:
+			raise ValueError()
+		return decode_ber_integer(obj.content), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, int):
+			raise TypeError()
+		return encode_ber(BERObject(cls.ber_tag, encode_ber_integer(obj)))
+
+class Boolean(BERType):
+	ber_tag = (0, False, 1)
+
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		if obj.tag != cls.ber_tag:
+			raise ValueError()
+		return bool(decode_ber_integer(obj.content)), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, bool):
+			raise TypeError()
+		content = b'\xff' if obj else b'\x00'
+		return encode_ber(BERObject(cls.ber_tag, content))
+
+class LDAPString(OctetString):
+	@classmethod
+	def from_ber(cls, data):
+		raw, rest = super().from_ber(data)
+		return raw.decode(), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, str):
+			raise TypeError()
+		return super().to_ber(obj.encode())
+
+class Set(BERType):
+	ber_tag = (0, True, 17)
+	set_type = OctetString
+
+	@classmethod
+	def from_ber(cls, data):
+		setobj, rest = decode_ber(data)
+		if setobj.tag != cls.ber_tag:
+			raise ValueError()
+		objs = []
+		data = setobj.content
+		while data:
+			obj, data = cls.set_type.from_ber(data)
+			objs.append(obj)
+		return list(objs), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		content = b''
+		for item in obj:
+			content += cls.set_type.to_ber(item)
+		return encode_ber(BERObject(cls.ber_tag, content))
+
+class SequenceOf(Set):
+	ber_tag = (0, True, 16)
+
+class Sequence(BERType):
+	ber_tag = (0, True, 16)
+	sequence_fields = [
+		#(Type, attr_name, default_value),
+	]
+
+	def __init__(self, *args, **kwargs):
+		for index, spec in enumerate(type(self).sequence_fields):
+			field_type, name, default = spec
+			if index < len(args):
+				value = args[index]
+			elif name in kwargs:
+				value = kwargs[name]
+			else:
+				value = default() if callable(default) else default
+			setattr(self, name, value)
+
+	def __repr__(self):
+		args = []
+		for field_type, name, default in type(self).sequence_fields:
+			args.append('%s=%s'%(name, repr(getattr(self, name))))
+		return '<%s(%s)>'%(type(self).__name__, ', '.join(args))
+
+	@classmethod
+	def from_ber(cls, data):
+		seqobj, rest = decode_ber(data)
+		if seqobj.tag != cls.ber_tag:
+			raise ValueError()
+		args = []
+		data = seqobj.content
+		for field_type, name, default in cls.sequence_fields:
+			obj, data = field_type.from_ber(data)
+			args.append(obj)
+		return cls(*args), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, cls):
+			raise TypeError()
+		content = b''
+		for field_type, name, default in cls.sequence_fields:
+			content += field_type.to_ber(getattr(obj, name))
+		return encode_ber(BERObject(cls.ber_tag, content))
+
+class Choice(BERType):
+	ber_tag = None
+
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		for subcls in cls.__subclasses__():
+			if subcls.ber_tag == obj.tag:
+				return subcls.from_ber(data)
+		return None, rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		for subcls in cls.__subclasses__():
+			if isinstance(obj, subcls):
+				return subcls.to_ber(obj)
+		raise TypeError()
+
+class Wrapper(BERType):
+	ber_tag = None
+	wrapped_attribute = None
+	wrapped_type = None
+	wrapped_default = None
+	wrapped_clsattrs = {}
+
+	def __init__(self, *args, **kwargs):
+		cls = type(self)
+		attribute = cls.wrapped_attribute
+		if args:
+			setattr(self, attribute, args[0])
+		elif kwargs:
+			setattr(self, attribute, kwargs[attribute])
+		else:
+			setattr(self, attribute, cls.wrapped_default() if callable(cls.wrapped_default) else cls.wrapped_default)
+
+	def __repr__(self):
+		return '<%s(%s)>'%(type(self).__name__, repr(getattr(self, type(self).wrapped_attribute)))
+
+	@classmethod
+	def from_ber(cls, data):
+		class WrappedType(cls.wrapped_type):
+			ber_tag = cls.ber_tag
+		for key, value in cls.wrapped_clsattrs.items():
+			setattr(WrappedType, key, value)
+		value, rest = WrappedType.from_ber(data)
+		return cls(value), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		class WrappedType(cls.wrapped_type):
+			ber_tag = cls.ber_tag
+		for key, value in cls.wrapped_clsattrs.items():
+			setattr(WrappedType, key, value)
+		if not isinstance(obj, cls):
+			raise TypeError()
+		return WrappedType.to_ber(getattr(obj, cls.wrapped_attribute))
+
+class Filter(Choice):
+	pass
+
+class FilterAnd(Wrapper, Filter):
+	ber_tag = (2, True, 0)
+	wrapped_attribute = 'filters'
+	wrapped_type = Set
+	wrapped_clsattrs = {'set_type': Filter}
+
+class FilterOr(Wrapper, Filter):
+	ber_tag = (2, True, 1)
+	wrapped_attribute = 'filters'
+	wrapped_type = Set
+	wrapped_clsattrs = {'set_type': Filter}
+
+class FilterNot(Sequence, Filter):
+	ber_tag = (2, True, 2)
+	sequence_fields = [
+		(Filter, 'filter', None)
+	]
+
+class FilterEqual(Sequence, Filter):
+	ber_tag = (2, True, 3)
+	sequence_fields = [
+		(LDAPString, 'attribute', None),
+		(OctetString, 'value', None)
+	]
+
+class FilterPresent(Wrapper, Filter):
+	ber_tag = (2, False, 7)
+	wrapped_attribute = 'attribute'
+	wrapped_type = LDAPString
+	wrapped_default = None
+
+class Enum(BERType):
+	ber_tag = (0, False, 10)
+	enum_type = None
+
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		if obj.tag != cls.ber_tag:
+			raise ValueError()
+		value = decode_ber_integer(obj.content)
+		return cls.enum_type(value), rest
+
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, cls.enum_type):
+			raise TypeError()
+		return encode_ber(BERObject(cls.ber_tag, encode_ber_integer(obj.value)))
+
+class SearchScope(enum.Enum):
+	baseObject = 0 # The scope is constrained to the entry named by baseObject.
+	singleLevel = 1 # The scope is constrained to the immediate subordinates of the entry named by baseObject.
+	wholeSubtree = 2 # The scope is constrained to the entry named by baseObject and to all its subordinates.
+
+class SearchScopeEnum(Enum):
+	enum_type = SearchScope
+
+class DerefAliases(enum.Enum):
+	neverDerefAliases = 0
+	derefInSearching = 1
+	derefFindingBaseObj = 2
+	derefAlways = 3
+
+class DerefAliasesEnum(Enum):
+	enum_type = DerefAliases
+
+class LDAPResultCode(enum.Enum):
+	success                      = 0
+	operationsError              = 1
+	protocolError                = 2
+	timeLimitExceeded            = 3
+	sizeLimitExceeded            = 4
+	compareFalse                 = 5
+	compareTrue                  = 6
+	authMethodNotSupported       = 7
+	strongerAuthRequired         = 8
+	# -- 9 reserved --
+	referral                     = 10
+	adminLimitExceeded           = 11
+	unavailableCriticalExtension = 12
+	confidentialityRequired      = 13
+	saslBindInProgress           = 14
+	noSuchAttribute              = 16
+	undefinedAttributeType       = 17
+	inappropriateMatching        = 18
+	constraintViolation          = 19
+	attributeOrValueExists       = 20
+	invalidAttributeSyntax       = 21
+	# -- 22-31 unused --
+	noSuchObject                 = 32
+	aliasProblem                 = 33
+	invalidDNSyntax              = 34
+	# -- 35 reserved for undefined isLeaf --
+	aliasDereferencingProblem    = 36
+	# -- 37-47 unused --
+	inappropriateAuthentication  = 48
+	invalidCredentials           = 49
+	insufficientAccessRights     = 50
+	busy                         = 51
+	unavailable                  = 52
+	unwillingToPerform           = 53
+	loopDetect                   = 54
+	# -- 55-63 unused --
+	namingViolation              = 64
+	objectClassViolation         = 65
+	notAllowedOnNonLeaf          = 66
+	notAllowedOnRDN              = 67
+	entryAlreadyExists           = 68
+	objectClassModsProhibited    = 69
+	# -- 70 reserved for CLDAP --
+	affectsMultipleDSAs          = 71
+	# -- 72-79 unused --
+	other                        = 80
+
+class LDAPResultCodeEnum(Enum):
+	enum_type = LDAPResultCode
+
+class LDAPResult(Sequence):
+	ber_tag = (5, True, 1)
+	sequence_fields = [
+		(LDAPResultCodeEnum, 'resultCode', None),
+		(LDAPString, 'matchedDN', ''),
+		(LDAPString, 'diagnosticMessage', ''),
+	]
+
+class AttributeSelection(SequenceOf):
+	set_type = LDAPString
+
+class AuthenticationChoice(Choice):
+	pass
+
+class SimpleAuthentication(Wrapper, AuthenticationChoice):
+	ber_tag = (2, False, 0)
+	wrapped_attribute = 'password'
+	wrapped_type = OctetString
+	wrapped_default = b''
+
+	def __repr__(self):
+		if not self.password:
+			return '<%s(EMPTY PASSWORD)>'%(type(self).__name__)
+		return '<%s(PASSWORD HIDDEN)>'%(type(self).__name__)
+
+class AttributeValueSet(Set):
+	set_type = OctetString
+
+class PartialAttribute(Sequence):
+	sequence_fields = [
+		(LDAPString, 'type', None),
+		(AttributeValueSet, 'vals', lambda: []),
+	]
+
+class PartialAttributeList(SequenceOf):
+	set_type = PartialAttribute
+
+class ProtocolOp(Choice):
+	pass
+
+class BindRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 0)
+	sequence_fields = [
+		(Integer, 'version', 3),
+		(LDAPString, 'name', ''),
+		(AuthenticationChoice, 'authentication', lambda: SimpleAuthentication())
+	]
+
+class BindResponse(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 1)
+
+class SearchRequest(Sequence, ProtocolOp):
+	ber_tag = (1, True, 3)
+	sequence_fields = [
+		(LDAPString, 'baseObject', ''),
+		(SearchScopeEnum, 'scope', SearchScope.wholeSubtree),
+		(DerefAliasesEnum, 'derefAliases', DerefAliases.neverDerefAliases),
+		(Integer, 'sizeLimit', 0),
+		(Integer, 'timeLimit', 0),
+		(Boolean, 'typesOnly', False),
+		(Filter, 'filter', lambda: FilterPresent('objectClass')),
+		(AttributeSelection, 'attributes', lambda: [])
+	]
+
+	@classmethod
+	def from_ber(cls, data):
+		return super().from_ber(data)
+
+class SearchResultEntry(Sequence, ProtocolOp):
+	ber_tag = (1, True, 4)
+	sequence_fields = [
+		(LDAPString, 'objectName', ''),
+		(PartialAttributeList, 'attributes', lambda: []),
+	]
+
+class SearchResultDone(LDAPResult, ProtocolOp):
+	ber_tag = (1, True, 5)
+
+class UnbindRequest(Sequence, ProtocolOp):
+	ber_tag = (1, False, 2)
+
+class LDAPMessage(Sequence):
+	sequence_fields = [
+		(Integer, 'messageID', None),
+		(ProtocolOp, 'protocolOp', None)
+	]
+
+class ShallowProtocolOp:
+	@classmethod
+	def from_ber(cls, data):
+		obj, rest = decode_ber(data)
+		for subcls in ProtocolOp.__subclasses__():
+			if subcls.ber_tag == obj.tag:
+				return subcls, rest
+		return None, rest
+
+class ShallowLDAPMessage(Sequence):
+	sequence_fields = [
+		(Integer, 'messageID', None),
+		(ShallowProtocolOp, 'protocolOp', None)
+	]
+
+bind1 = b'0\x0c\x02\x01\x01`\x07\x02\x01\x03\x04\x00\x80\x00'
+bind2 = b'0\x1c\x02\x01\x01`\x17\x02\x01\x03\x04\nuid=foobar\x80\x06foobar'
+search1 = b'0?\x02\x01\x02c:\x04\x1aou=users,dc=example,dc=com\n\x01\x01\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\x87\x0bobjectclass0\x00'
+search2 = b'0@\x02\x01\x02c;\x04\x00\n\x01\x02\n\x01\x00\x02\x01\x00\x02\x01\x00\x01\x01\x00\xa0&\xa3\x12\x04\x0bobjectClass\x04\x03top\xa3\x10\x04\x03uid\x04\ttestadmin0\x00'
+unbind1 = b'0\x05\x02\x01\x03B\x00'
+
diff --git a/server.py b/server.py
new file mode 100644
index 0000000..10abdbe
--- /dev/null
+++ b/server.py
@@ -0,0 +1,65 @@
+import traceback
+from socketserver import ForkingTCPServer, BaseRequestHandler
+
+from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError
+
+class Handler(BaseRequestHandler):
+	def setup(self):
+		self.bind_dn = b''
+		self.keep_running = True
+
+	def handle_bind(self, req):
+		self.bind_dn = req.protocolOp.name
+		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
+
+	def handle_search(self, req):
+		self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success)))
+
+	def handle_unbind(self, req):
+		self.keep_running = False
+
+	def handle_message(self, data):
+		handlers = {
+			BindRequest: (self.handle_bind, BindResponse),
+			SearchRequest: (self.handle_search, SearchResultDone),
+			UnbindRequest: (self.handle_unbind, None),
+		}
+		shallowmsg, rest = ShallowLDAPMessage.from_ber(data)
+		if shallowmsg.protocolOp is None:
+			print('Ignoring unknown message')
+			return rest
+		func, errfunc = handlers[shallowmsg.protocolOp]
+		msg = None
+		try:
+			msg, _ = LDAPMessage.from_ber(data)
+		except Exception as e:
+			if errfunc:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError)))
+			traceback.print_exc()
+			return rest
+		print(msg)
+		try:
+			if func:
+				func(msg)
+		except Exception as e:
+			if errfunc:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))
+			traceback.print_exc()
+			return rest
+		return rest
+
+	def send_msg(self, msg):
+		self.request.sendall(LDAPMessage.to_ber(msg))
+
+	def handle(self):
+		data = b''
+		while self.keep_running:
+			try:
+				data = self.handle_message(data)
+			except IncompleteBERError:
+				chunk = self.request.recv(5)
+				if not chunk:
+					return
+				data += chunk
+
+ForkingTCPServer(('127.0.0.1', 1338), Handler).serve_forever()
-- 
GitLab