From 0d6c5e146402abd03b170f1fe56f013205a41bdb Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Thu, 2 Dec 2021 16:49:48 +0100
Subject: [PATCH] Started decorator-based api

---
 examples/new_api.py                 |  81 +++++
 examples/passwd.py                  |  71 +++--
 ldapserver/future.py                | 478 ++++++++++++++++++++++++++++
 ldapserver/schema/definitions.py    |  47 ++-
 ldapserver/schema/matching_rules.py |   5 +
 ldapserver/schema/syntaxes.py       |  15 +-
 ldapserver/value_templates.py       | 244 ++++++++++++++
 7 files changed, 909 insertions(+), 32 deletions(-)
 create mode 100644 examples/new_api.py
 create mode 100644 ldapserver/future.py
 create mode 100644 ldapserver/value_templates.py

diff --git a/examples/new_api.py b/examples/new_api.py
new file mode 100644
index 0000000..776e97a
--- /dev/null
+++ b/examples/new_api.py
@@ -0,0 +1,81 @@
+from ldapserver import Subtree, Server
+from ldapserver.exceptions import LDAPInvalidCredentials
+
+api = None # Overwritten later
+bind_password = None # if None anonymous reads are allowed
+
+server = Server(schema)
+
+service = server.add_entry('cn=service,ou=system,dc=example,dc=com',
+		objectclass=['top', 'organizationalRole', 'simpleSecurityObject'],
+		structuralObjectClass=['organizationalRole'])
+
+@service.bind
+def service_bind(password):
+	if password == bind_password:
+		return
+	raise LDAPInvalidCredentials()
+
+@server.template('uid={loginname},ou=users,dc=example,dc=com',
+		structuralObjectClass=['inetorgperson'],
+		objectClass=['top', 'inetorgperson', 'organizationalperson', 'person', 'posixaccount'],
+		sn=[' '],
+		cn=['{displayname}'],
+		displayname=['{displayname}'],
+		givenname=['{displayname}'],
+		homeDirectory=['/home/{loginname}'],
+		mail=['{email}'],
+		uid=['{loginname}'],
+		uidNumber=['{uid}'],
+		memberOf=['cn={*group_names},ou=groups,dc=example,dc=com'])
+def users(loginname=None, uid=None, group_names=None, **kwargs):
+	if not connection.bind_state:
+		return
+	request_params = {}
+	if not request_params and loginname is not None:
+		request_params = {'loginname': normalize_user_loginname(loginname)}
+	if not request_params and uid is not None:
+		request_params = {'id': uid}
+	if not request_params and group_names:
+		request_params = {'group': normalize_group_name(group_names[0])}
+	for user in api.get_users(**request_params):
+		yield dict(loginname=user['loginname'], displayname=user['displayname'],
+		           email=user['email'], uid=user['id'], groups=user['groups'])
+
+@users.bind
+def users_bind(entry, password):
+	if api.check_password(loginname=entry['uid'][0], password=password):
+		return
+	raise LDAPInvalidCredentials()
+
+@server.bind_sasl_plain
+def bind_sasl_plain(identity, password, authzid=None):
+	if authzid is not None and identity != authzid:
+		raise LDAPInvalidCredentials()
+	user = api.check_password(loginname=identity, password=password)
+	if user is None:
+		raise LDAPInvalidCredentials()
+	return user
+
+@server.template('cn={name},ou=groups',
+		structuralObjectClass=['groupOfUniqueNames'],
+		objectClass=['top', 'groupOfUniqueNames', 'posixGroup'],
+		description=[' '],
+		cn=['{name}'],
+		gidNumber=['{gid}'],
+		uniqueMember=['uid={*member_names},ou=users,dc=example,dc=com'])
+def groups(name=None, gid=None, member_names=None):
+	if not connection.bind_state:
+		return
+	request_params = {}
+	if not request_params and name is not None:
+		request_params = {'name': normalize_group_name(name)}
+	if not request_params and gid is not None:
+		request_params = {'id': gid}
+	if not request_params and member_names:
+		request_params = {'member': normalize_user_loginname(member_names[0])}
+	for group in api.get_groups(**request_params):
+		yield groups.create(name=group['name'], gid=group['id'], member_names=group['members'])
+
+if __name__ == '__main__':
+	server.run()
diff --git a/examples/passwd.py b/examples/passwd.py
index ce024a3..1d4590f 100644
--- a/examples/passwd.py
+++ b/examples/passwd.py
@@ -4,38 +4,49 @@ import pwd
 import grp
 
 import ldapserver
+import ldapserver.future
 
-logging.basicConfig(level=logging.INFO)
+server = ldapserver.future.Server(ldapserver.schema.RFC2307BIS_SCHEMA)
+server.add_entry('dc=example,dc=com',
+	objectClass=['top', 'dcObject', 'organization'],
+	structuralObjectClass=['organization'],
+)
+server.add_entry('ou=users,dc=example,dc=com',
+	objectClass=['top', 'organizationalUnit'],
+	structuralObjectClass=['organizationalUnit'],
+	ou=['users'],
+)
+server.add_entry('ou=groups,dc=example,dc=com',
+	objectClass=['top', 'organizationalUnit'],
+	structuralObjectClass=['organizationalUnit'],
+	ou=['groups'],
+)
 
-class RequestHandler(ldapserver.LDAPRequestHandler):
-	subschema = ldapserver.SubschemaSubentry(ldapserver.schema.RFC2307BIS_SCHEMA, 'cn=Subschema')
+@server.template('uid={name},ou=users,dc=example,dc=com',
+	objectClass=['top', 'organizationalperson', 'person', 'posixaccount'],
+	structuralObjectClass=['organizationalperson'],
+	uid=['{name}'],
+	uidNumber=['{uid}'],
+	gidNumber=['{gid}'],
+	cn=['{gecos}'],
+	homeDirectory=['{home}']
+)
+def users(name=None, gecos=None, uid=None, gid=None, home=None):
+	print('users', repr(name), repr(gecos), repr(uid), repr(gid), repr(home))
+	for user in pwd.getpwall():
+		yield {'name': user.pw_name, 'gecos': user.pw_gecos or ' ', 'uid': user.pw_uid, 'gid': user.pw_gid, 'home': user.pw_dir}
 
-	def do_search(self, basedn, scope, filterobj):
-		yield from super().do_search(basedn, scope, filterobj)
-		yield self.subschema.ObjectEntry('dc=example,dc=com', **{
-			'objectClass': ['top', 'dcObject', 'organization'],
-			'structuralObjectClass': ['organization'],
-		})
-		user_gids = {}
-		for user in pwd.getpwall():
-			user_gids[user.pw_gid] = user_gids.get(user.pw_gid, set()) | {user.pw_name}
-			yield self.subschema.ObjectEntry(self.subschema.DN('ou=users,dc=example,dc=com', uid=user.pw_name), **{
-				'objectClass': ['top', 'organizationalperson', 'person', 'posixaccount'],
-				'structuralObjectClass': ['organizationalperson'],
-				'uid': [user.pw_name],
-				'uidNumber': [user.pw_uid],
-				'gidNumber': [user.pw_gid],
-				'cn': [user.pw_gecos],
-			})
-		for group in grp.getgrall():
-			members = set(group.gr_mem) | user_gids.get(group.gr_gid, set())
-			yield self.subschema.ObjectEntry(self.subschema.DN('ou=groups,dc=example,dc=com', cn=group.gr_name), **{
-				'objectClass': ['top', 'groupOfUniqueNames', 'posixGroup'],
-				'structuralObjectClass': ['groupOfUniqueNames'],
-				'cn': [group.gr_name],
-				'gidNumber': [group.gr_gid],
-				'uniqueMember': [self.subschema.DN('ou=user,dc=example,dc=com', uid=name) for name in members],
-			})
+@server.template('cn={name},ou=groups,dc=example,dc=com',
+	objectClass=['top', 'groupOfUniqueNames', 'posixGroup'],
+	structuralObjectClass=['groupOfUniqueNames'],
+	cn=['{name}'],
+	gidNumber=['{gid}'],
+	uniqueMember=['uid={*member_names},ou=users,dc=example,dc=com'],
+)
+def groups(name=None, gid=None, member_names=None):
+	print('groups', repr(name), repr(gid), repr(member_names))
+	for group in grp.getgrall():
+		yield {'name': group.gr_name, 'gid': group.gr_gid, 'member_names': group.gr_mem}
 
 if __name__ == '__main__':
-	socketserver.ThreadingTCPServer(('127.0.0.1', 3890), RequestHandler).serve_forever()
+	socketserver.ThreadingTCPServer(('127.0.0.1', 3890), server.socketserver_handler).serve_forever()
diff --git a/ldapserver/future.py b/ldapserver/future.py
new file mode 100644
index 0000000..93c6bb3
--- /dev/null
+++ b/ldapserver/future.py
@@ -0,0 +1,478 @@
+import traceback
+import re
+import ssl
+import socketserver
+import typing
+import logging
+import time
+import random
+import string
+import itertools
+import collections.abc
+import enum
+import contextvars
+
+from . import ldap, exceptions, entries, asn1, schema, value_templates
+from .dn import DN, RDN, RDNAssertion
+
+'''
+server-level
+- simple anonymous bind
+- SASL PLAIN bind
+- SASL EXTERNAL bind
+- SASL ANONYMOUS bind
+- other SASL bind
+- unbind
+- add
+- abandon
+- extended
+
+special
+- search
+
+object-level
+- simple authenticated bind (with dn)
+- simple unauthenticated bind (with dn)
+- modify
+- modifydn
+- compare
+- del
+'''
+
+class Server:
+	def __init__(self, schema):
+		self.schema = schema
+		self.lookup_handlers = []
+		self.search_handlers = []
+		self.sasl_handlers = {} # 'METHOD': func
+		self.extended_handlers = {} # 'OID': func
+
+		class ServerRequestHandler(RequestHandler):
+			ldap_server = self
+
+		self.socketserver_handler = ServerRequestHandler
+
+	def sasl_handler(self, method):
+		def decorator(func):
+			assert method not in self.sasl_bind_handler
+			self.sasl_bind_handler[method] = func
+			return func
+		return decorator
+
+	def sasl_plain_handler(self, func):
+		@self.bind_sasl('PLAIN')
+		def wrapper(credentials, dn):
+			if credentials is None:
+				raise exceptions.LDAPProtocolError('Unsupported protocol version')
+			authzid, authcid, password = credentials.split(b'\0', 2)
+			return func(authcid.decode(), password.decode(), authzid.decode() or None), None
+		return func
+
+	def extended_handler(self, oid):
+		def decorator(func):
+			assert oid not in self.extended_handlers
+			self.extended_handlers[oid] = func
+			return func
+		return decorator
+
+	def add_entry(self, dn, **attributes):
+		entry = Entry(self.schema, dn, attributes)
+		self.lookup_handlers.append(entry.lookup)
+		self.search_handlers.append(entry.search)
+		return entry
+
+	def template(self, dn_template, **attribute_templates):
+		def decorator(func):
+			templ = Template(self.schema, dn_template, func, attribute_templates)
+			self.lookup_handlers.append(templ.lookup)
+			self.search_handlers.append(templ.search)
+			return templ
+		return decorator
+
+	def lookup_entry(self, dn):
+		for func in self.lookup_handlers:
+			entry = func(dn)
+			if entry:
+				return entry
+		raise exceptions.LDAPNoSuchObject()
+
+	def process_connect(self):
+		connection.bind_object = None
+
+	def process_message(self, shallowmsg: ldap.ShallowLDAPMessage) -> typing.Iterable[ldap.LDAPMessage]:
+		msgtypes = {
+			ldap.BindRequest: (self.process_bind, ldap.BindResponse),
+			#ldap.UnbindRequest: (self.process_unbind, None),
+			ldap.SearchRequest: (self.process_search, ldap.SearchResultDone),
+			#ldap.ModifyRequest: (self.process_modify, ldap.ModifyResponse),
+			#ldap.AddRequest: (self.process_add,  ldap.AddResponse),
+			#ldap.DelRequest: (self.process_delete, ldap.DelResponse),
+			#ldap.ModifyDNRequest: (self.process_modifydn, ldap.ModifyDNResponse),
+			ldap.CompareRequest: (self.process_compare, ldap.CompareResponse),
+			#ldap.AbandonRequest: (self.process_abandon, None),
+			ldap.ExtendedRequest: (self.process_extended, ldap.ExtendedResponse),
+		}
+		handler, response_type = msgtypes.get(shallowmsg.protocolOpType, (None, None))
+		try:
+			if handler is None:
+				raise exceptions.LDAPProtocolError()
+			try:
+				msg = shallowmsg.decode()[0]
+			except ValueError as e:
+				raise exceptions.LDAPProtocolError() from e
+			for args in handler(msg.protocolOp, msg.controls):
+				response, controls = args if isinstance(args, tuple) else (args, None)
+				yield ldap.LDAPMessage(shallowmsg.messageID, response, controls)
+		except exceptions.LDAPError as e:
+			if response_type is not None:
+				respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(e.code, diagnosticMessage=e.message))
+				yield respmsg
+		except Exception as e: # pylint: disable=broad-except
+			traceback.print_exc()
+			if response_type is not None:
+				respmsg = ldap.LDAPMessage(shallowmsg.messageID, response_type(ldap.LDAPResultCode.other))
+				yield respmsg
+
+	def process_bind(self, op, controls=None):
+		if op.version != 3:
+			raise exceptions.LDAPProtocolError('Unsupported protocol version')
+		auth = op.authentication
+		if isinstance(auth, ldap.SimpleAuthentication):
+			if not op.name:
+				connection.bind_object = None
+			else:
+				try:
+					entry = self.lookup_entry(op.name)
+				except exceptions.LDAPNoSuchObject as exc:
+					raise exceptions.LDAPInvalidCredentials() from exc
+				if auth.password:
+					connection.bind_object = entry.bind(auth.password)
+				else:
+					connection.bind_object = entry.bind_unauthenticated()
+			yield ldap.BindResponse(ldap.LDAPResultCode.success)
+		elif isinstance(auth, ldap.SaslCredentials):
+			if auth.mechanism not in self.sasl_handlers:
+				raise exceptions.LDAPAuthMethodNotSupported()
+			connection.bind_object, resp = self.sasl_handlers[auth.mechanism](auth.credentials)
+			yield ldap.BindResponse(ldap.LDAPResultCode.success, serverSaslCreds=resp)
+		else:
+			raise exceptions.LDAPAuthMethodNotSupported()
+
+	def process_search(self, op, controls=None):
+		for func in self.search_handlers:
+			yield from func(op.baseObject, op.scope, op.filter, op.attributes, op.typesOnly)
+		yield ldap.SearchResultDone(ldap.LDAPResultCode.success)
+
+	def process_compare(self, op, controls=None):
+		if self.lookup_entry(op.entry).compare(op.ava.attributeDesc, op.ava.assertionValue):
+			return [ldap.CompareResponse(ldap.LDAPResultCode.compareTrue)]
+		else:
+			return [ldap.CompareResponse(ldap.LDAPResultCode.compareFalse)]
+
+	def process_extended(self, op, controls=None):
+		if op.requestName not in self.extended_handlers:
+			exceptions.LDAPProtocolError('Unknown extended request')
+		resp = self.extended_handlers[op.requestName](op.requestValue)
+		yield ldap.ExtendedResponse(ldap.LDAPResultCode.success, responseValue=resp)
+
+class FilterResult(enum.Enum):
+	TRUE = enum.auto()
+	FALSE = enum.auto()
+	UNDEFINED = enum.auto()
+
+class Entry(entries.AttributeDict):
+	def __init__(self, schema, dn, attributes):
+		super().__init__(schema, **attributes)
+		#: Entry's distinguished name (:class:`DN`)
+		self.dn = DN(schema, dn)
+
+	def __search_match_dn(self, basedn, scope):
+		if scope == ldap.SearchScope.baseObject:
+			return self.dn == basedn
+		elif scope == ldap.SearchScope.singleLevel:
+			return self.dn.is_direct_child_of(basedn)
+		elif scope == ldap.SearchScope.wholeSubtree:
+			return self.dn.in_subtree_of(basedn)
+		else:
+			return False
+
+	def __search_match_filter(self, filter_obj):
+		# pylint: disable=too-many-branches,too-many-return-statements,too-many-statements,too-many-nested-blocks
+		if isinstance(filter_obj, ldap.FilterAnd):
+			result = FilterResult.TRUE
+			for subfilter in filter_obj.filters:
+				subresult = self.__search_match_filter(subfilter)
+				if subresult == FilterResult.FALSE:
+					return FilterResult.FALSE
+				elif subresult == FilterResult.UNDEFINED:
+					result = FilterResult.UNDEFINED
+			return result
+		elif isinstance(filter_obj, ldap.FilterOr):
+			result = FilterResult.FALSE
+			for subfilter in filter_obj.filters:
+				subresult = self.__search_match_filter(subfilter)
+				if subresult == FilterResult.TRUE:
+					return FilterResult.TRUE
+				elif subresult == FilterResult.UNDEFINED:
+					result = FilterResult.UNDEFINED
+			return result
+		elif isinstance(filter_obj, ldap.FilterNot):
+			subresult = self.__search_match_filter(filter_obj.filter)
+			if subresult == FilterResult.TRUE:
+				return FilterResult.FALSE
+			elif subresult == FilterResult.FALSE:
+				return FilterResult.TRUE
+			else:
+				return subresult
+		elif isinstance(filter_obj, (ldap.FilterPresent,
+		                             ldap.FilterEqual,
+		                             ldap.FilterSubstrings,
+		                             ldap.FilterApproxMatch,
+		                             ldap.FilterGreaterOrEqual,
+		                             ldap.FilterLessOrEqual)):
+			try:
+				attribute_type = self.schema.attribute_types[filter_obj.attribute]
+			except KeyError:
+				return FilterResult.UNDEFINED
+			values = self.get(filter_obj.attribute, subtypes=True)
+			try:
+				if isinstance(filter_obj, ldap.FilterPresent):
+					result = values != []
+				elif isinstance(filter_obj, ldap.FilterEqual):
+					result = attribute_type.match_equal(values, filter_obj.value)
+				elif isinstance(filter_obj, ldap.FilterSubstrings):
+					result = attribute_type.match_substr(values,
+							filter_obj.initial_substring, filter_obj.any_substrings, filter_obj.final_substring)
+				elif isinstance(filter_obj, ldap.FilterApproxMatch):
+					result = attribute_type.match_approx(values, filter_obj.value)
+				elif isinstance(filter_obj, ldap.FilterGreaterOrEqual):
+					result = attribute_type.match_greater_or_equal(values, filter_obj.value)
+				elif isinstance(filter_obj, ldap.FilterLessOrEqual):
+					result = attribute_type.match_less_or_equal(values, filter_obj.value)
+				else:
+					return FilterResult.UNDEFINED
+				return FilterResult.TRUE if result else FilterResult.FALSE
+			except exceptions.LDAPError:
+				return FilterResult.UNDEFINED
+		elif isinstance(filter_obj, ldap.FilterExtensibleMatch):
+			attribute_types = []
+			matching_rule = None
+			try:
+				if filter_obj.type is not None and filter_obj.matchingRule is not None:
+					attribute_types = [self.schema.attribute_types[filter_obj.type]]
+					matching_rule = self.schema.matching_rules[filter_obj.matchingRule]
+				elif filter_obj.type is not None:
+					attribute_types = [self.schema.attribute_types[filter_obj.type]]
+				elif filter_obj.matchingRule is not None:
+					matching_rule = self.schema.matching_rules[filter_obj.matchingRule]
+					attribute_types = matching_rule.compatible_attribute_types
+			except KeyError:
+				pass
+			result = FilterResult.FALSE
+			for attribute_type in attribute_types:
+				values = self.get(attribute_type.oid, subtypes=True)
+				if filter_obj.dnAttributes:
+					for rdn in self.dn:
+						for assertion in rdn:
+							if assertion.attribute.lower() == attribute_type.ref.lower():
+								values.append(assertion.value)
+				try:
+					if attribute_type.match_extensible(values, filter_obj.matchValue, matching_rule):
+						return FilterResult.TRUE
+				except exceptions.LDAPError:
+					result = FilterResult.UNDEFINED
+			return result
+		else:
+			return FilterResult.UNDEFINED
+
+	def search(self, base_obj, scope, filter_obj, attributes, types_only):
+		if not self.__search_match_dn(DN.from_str(self.schema, base_obj), scope):
+			return []
+		if not self.__search_match_filter(filter_obj) == FilterResult.TRUE:
+			return []
+		selected_attributes = set()
+		for selector in attributes or ['*']:
+			if selector == '*':
+				selected_attributes |= self.schema.user_attribute_types
+			elif selector == '1.1':
+				continue
+			elif selector in self.schema.attribute_types:
+				selected_attributes.add(self.schema.attribute_types[selector])
+		partial_attributes = []
+		for attribute_type, values in self.items(types=True):
+			if attribute_type in selected_attributes:
+				if types_only:
+					values = []
+				encoded_values = [attribute_type.encode(value) for value in values]
+				partial_attributes.append(ldap.PartialAttribute(attribute_type.ref, encoded_values))
+		return [ldap.SearchResultEntry(str(self.dn), partial_attributes)]
+
+	def lookup(self, dn):
+		return self if self.dn == DN.from_str(self.schema, dn) else None
+
+	def bind(self, password):
+		raise exceptions.LDAPInvalidCredentials()
+
+	def bind_handler(self, func):
+		self.bind = func
+		return func
+
+	def compare(self, attribute, value):
+		try:
+			attribute_type = self.schema.attribute_types[attribute]
+		except KeyError as exc:
+			raise exceptions.LDAPUndefinedAttributeType() from exc
+		return attribute_type.match_equal(self.get(attribute_type, subtypes=True), value)
+
+class AttributeTemplate:
+	def __init__(self, attribute_type, exprs):
+		self.attribute_type = attribute_type
+		self.value_templates = []
+		for expr in exprs:
+			value_template = attribute_type.syntax.definition.parse_template(attribute_type.schema, expr)
+			self.value_templates.append(value_template)
+
+	def render(self, params):
+		results = []
+		for template in self.value_templates:
+			if any([params.get(param) is None for param in template.optional_params]):
+				continue
+			if not template.starred_param:
+				results.append(template.render(params))
+				continue
+			for value in params.get(template.starred_param, []):
+				results.append(template.render(params | {template.starred_param: value}))
+		return results
+
+	def match_equal(self, assertion_value):
+		if not self.attribute_type.equality:
+			return {}
+		return self.attribute_type.equality.definition.match_equal_template(self.attribute_type.schema, self.value_templates, assertion_value)
+
+class Template:
+	def __init__(self, schema, dn_template, lookup_func, attribute_templates):
+		self.schema = schema
+		self.dn_template = value_templates.DNTemplate.from_str(schema, dn_template)
+		if self.dn_template.starred_param is not None:
+			raise ValueError()
+		if self.dn_template.optional_params:
+			raise ValueError()
+		self.lookup_func = lookup_func
+		self.attribute_templates = {schema.attribute_types[key].ref: AttributeTemplate(schema.attribute_types[key], values) for key, values in attribute_templates.items()}
+		self.bind_handler_func = None
+
+	def bind_handler(self, func):
+		self.bind_handler_func = func
+
+	def __lookup_entries(self, input_params):
+		for output_params in self.lookup_func(**input_params):
+			entry = Entry(self.schema, self.dn_template.render(output_params),
+			              {key: value.render(output_params) for key, value in self.attribute_templates.items()})
+			if self.bind_handler_func:
+				entry.bind = lambda password: self.bind_handler_func(entry, password)
+			yield entry
+
+	def __extract_filter_constraints(self, filter_obj):
+		if isinstance(filter_obj, ldap.FilterEqual):
+			try:
+				attribute_type = self.schema.attribute_types[filter_obj.attribute]
+			except KeyError:
+				return entries.AttributeDict(self.schema)
+			if attribute_type.equality is None:
+				return entries.AttributeDict(self.schema)
+			assertion_value = attribute_type.equality.syntax.decode(filter_obj.value)
+			if assertion_value is None:
+				return entries.AttributeDict(self.schema)
+			return entries.AttributeDict(self.schema, **{filter_obj.attribute: [assertion_value]})
+		if isinstance(filter_obj, ldap.FilterAnd):
+			result = entries.AttributeDict(self.schema)
+			for subfilter in filter_obj.filters:
+				for name, values in self.__extract_filter_constraints(subfilter).items():
+					result[name] += values
+			return result
+		return entries.AttributeDict(self.schema)
+
+	def search(self, base_obj, scope, filter_obj, attributes, types_only):
+		constraints = self.dn_template.match(DN.from_str(self.schema, base_obj), scope=scope)
+		if constraints is False:
+			return
+		if constraints is True:
+			constraints = {}
+		constraints = {key: {value} for key, value in constraints.items()}
+		filter_constraints = self.__extract_filter_constraints(filter_obj)
+		for attribute_type, values in filter_constraints.items(types=True):
+			if attribute_type.ref not in self.attribute_templates:
+				continue
+			templates = self.attribute_templates[attribute_type.ref].value_templates
+			for value in values:
+				subresult = attribute_type.equality.definition.match_equal_template(self.schema, templates, value)
+				if subresult is False:
+					return
+				if subresult is True:
+					continue
+				for key, value in subresult.items():
+					constraints.setdefault(key, set()).add(value)
+		constraints = {key: list(values)[0] for key, values in constraints.items() if len(values) == 1}
+		for entry in self.__lookup_entries(constraints):
+			yield from entry.search(base_obj, scope, filter_obj, attributes, types_only)
+
+	def lookup(self, dn):
+		dn = DN.from_str(self.schema, dn)
+		constraints = self.dn_template.match(dn)
+		if constraints is False:
+			return None
+		if constraints is True:
+			constraints = {}
+		for entry in self.__lookup_entries(constraints):
+			if entry.dn == dn:
+				return entry
+
+class Connection:
+	pass
+
+_connection: contextvars.ContextVar[Connection] = contextvars.ContextVar('connection')
+
+internal_conn = _connection
+
+class Proxy:
+	def __init__(self, var):
+		super().__setattr__('__var', var)
+
+	def __getattr__(self, key):
+		return getattr(super().__getattribute__('__var').get(), key)
+
+	def __setattr__(self, key, value):
+		return setattr(super().__getattribute__('__var').get(), key, value)
+
+	def __delattr__(self, key):
+		return delattr(super().__getattribute__('__var').get(), key)
+
+connection = Proxy(_connection)
+
+class RequestHandler(socketserver.BaseRequestHandler):
+	ldap_server = None
+
+	def setup(self):
+		super().setup()
+		self.keep_running = True
+		conn = Connection()
+		self.ctxtoken = _connection.set(conn)
+	
+	def handle(self):
+		buf = b''
+		while self.keep_running:
+			try:
+				shallowmsg, buf = ldap.ShallowLDAPMessage.from_ber(buf)
+				for respmsg in self.ldap_server.process_message(shallowmsg):
+					self.request.sendall(ldap.LDAPMessage.to_ber(respmsg))
+			except asn1.IncompleteBERError:
+				chunk = self.request.recv(4096)
+				if not chunk:
+					self.keep_running = False
+					self.request.close()
+				else:
+					buf += chunk
+		self.request.close()
+
+	def finish(self):
+		_connection.reset(self.ctxtoken)
diff --git a/ldapserver/schema/definitions.py b/ldapserver/schema/definitions.py
index 3943cda..abda201 100644
--- a/ldapserver/schema/definitions.py
+++ b/ldapserver/schema/definitions.py
@@ -1,7 +1,7 @@
 import enum
 import re
 
-from .. import exceptions
+from .. import exceptions, value_templates
 
 __all__ = [
 	'SyntaxDefinition',
@@ -207,6 +207,9 @@ class SyntaxDefinition:
 		:raises exceptions.LDAPError: if raw_value is invalid'''
 		raise exceptions.LDAPInvalidAttributeSyntax()
 
+	def parse_template(self, schema, expr):
+		return value_templates.SimpleValueTemplate(expr, decode=lambda expr: self.decode(schema, expr.encode()))
+
 class MatchingRuleKind(enum.Enum):
 	'''Values for :any:`MatchingRuleDefinition.kind`'''
 	#:
@@ -329,6 +332,48 @@ class MatchingRuleDefinition:
 		:raises exceptions.LDAPError: if the result is undefined'''
 		raise exceptions.LDAPInappropriateMatching()
 
+	def match_equal_template(self, schema, templates, assertion_value):
+		static_values = []
+		params = set()
+		starred_params = set()
+		for template in templates:
+			if isinstance(template, value_templates.SimpleValueTemplate):
+				if not template.param:
+					static_values.append(template.static_value)
+				else:
+					params.add(template.param)
+					if template.starred_param:
+						starred_params.add(template.starred_param)
+			elif isinstance(template, value_templates.SubstrValueTemplate):
+				if len(template.tokens) != 1:
+					return {}
+				text, key = template.tokens[0]
+				if not text and key:
+					params.add(key)
+					if template.starred_param:
+						starred_params.add(template.starred_param)
+				elif text and not key:
+					static_values.append(text)
+				else:
+					return {}
+			else:
+				return {}
+		try:
+			result = self.match_equal(schema, static_values, assertion_value)
+			if result:
+				return True
+		except exceptions.LDAPInappropriateMatching:
+			pass
+		if starred_params:
+			if len(starred_params) == 1:
+				return {params.pop(): [assertion_value]}
+			return {}
+		elif params:
+			if len(params) == 1:
+				return {params.pop(): assertion_value}
+		else:
+			return False
+
 class MatchingRuleUseDefinition:
 	def __init__(self, oid, name=None, desc='', obsolete=False, applies=None, extensions=None):
 		#: Numeric OID (string)
diff --git a/ldapserver/schema/matching_rules.py b/ldapserver/schema/matching_rules.py
index 851eb6f..7716061 100644
--- a/ldapserver/schema/matching_rules.py
+++ b/ldapserver/schema/matching_rules.py
@@ -23,6 +23,11 @@ class GenericMatchingRuleDefinition(MatchingRuleDefinition):
 				return True
 		return False
 
+class DistinguishedNameMatchingRuleDefinition(GenericMatchingRuleDefinition):
+	def match_equal_template(self, schema, templates, assertion_value):
+		for template in templates:
+			pass
+
 def _substr_match(attribute_value, inital_substring, any_substrings, final_substring):
 	if inital_substring:
 		if not attribute_value.startswith(inital_substring):
diff --git a/ldapserver/schema/syntaxes.py b/ldapserver/schema/syntaxes.py
index 2660a03..6eae0af 100644
--- a/ldapserver/schema/syntaxes.py
+++ b/ldapserver/schema/syntaxes.py
@@ -2,7 +2,7 @@ import re
 import datetime
 
 from .definitions import SyntaxDefinition
-from .. import dn, exceptions
+from .. import dn, exceptions, value_templates
 
 class BytesSyntaxDefinition(SyntaxDefinition):
 	def encode(self, schema, value):
@@ -29,6 +29,9 @@ class StringSyntaxDefinition(SyntaxDefinition):
 			raise exceptions.LDAPInvalidAttributeSyntax()
 		return value
 
+	def parse_template(self, schema, expr):
+		return value_templates.SubstrValueTemplate(expr)
+
 class IntegerSyntaxDefinition(StringSyntaxDefinition):
 	def __init__(self, oid, **kwargs):
 		super().__init__(oid, encoding='ascii', re_pattern='([0-9]|-?[1-9][0-9]+)', **kwargs)
@@ -42,6 +45,10 @@ class IntegerSyntaxDefinition(StringSyntaxDefinition):
 		except ValueError as exc:
 			raise exceptions.LDAPInvalidAttributeSyntax() from exc
 
+	def parse_template(self, schema, expr):
+		return value_templates.SimpleValueTemplate(expr, decode=lambda expr: self.decode(schema, expr.encode()))
+
+
 class SchemaElementSyntaxDefinition(SyntaxDefinition):
 	def encode(self, schema, value):
 		return str(value).encode('utf8')
@@ -68,6 +75,9 @@ class DNSyntaxDefinition(SyntaxDefinition):
 		except (UnicodeDecodeError, TypeError, ValueError) as exc:
 			raise exceptions.LDAPInvalidAttributeSyntax() from exc
 
+	def parse_template(self, schema, expr):
+		return value_templates.DNTemplate.from_str(schema, expr)
+
 class NameAndOptionalUIDSyntaxDefinition(StringSyntaxDefinition):
 	def encode(self, schema, value):
 		return str(value).encode('utf8')
@@ -78,6 +88,9 @@ class NameAndOptionalUIDSyntaxDefinition(StringSyntaxDefinition):
 		except (UnicodeDecodeError, TypeError, ValueError) as exc:
 			raise exceptions.LDAPInvalidAttributeSyntax() from exc
 
+	def parse_template(self, schema, expr):
+		return value_templates.DNTemplate.from_str(schema, expr)
+
 class GeneralizedTimeSyntaxDefinition(SyntaxDefinition):
 	def encode(self, schema, value):
 		if value.microsecond:
diff --git a/ldapserver/value_templates.py b/ldapserver/value_templates.py
new file mode 100644
index 0000000..159fb79
--- /dev/null
+++ b/ldapserver/value_templates.py
@@ -0,0 +1,244 @@
+import re
+
+from . import ldap
+from .dn import DN, RDN, RDNAssertion
+
+class SimpleValueTemplate:
+	def __init__(self, expr, decode):
+		self.param = None
+		self.starred_param = None
+		self.optional_params = set()
+		self.static_value = None
+		if not isinstance(expr, str):
+			self.static_value = decode(expr)
+			return
+		match = re.fullmatch(r'{([*?]?)([A-Za-z0-9_]+)}', expr)
+		if not match:
+			self.static_value = decode(expr.replace('{{', '{').replace('}}', '}'))
+			return
+		option, key = match.groups()
+		option = option or None
+		self.param = key
+		if option is None:
+			pass
+		elif option == '*':
+			self.starred_param = key
+		elif option == '?':
+			self.optional_params.add(key)
+		else:
+			raise ValueError(f'Invalid template option {option!r}')
+
+	def render(self, params):
+		if self.param is None:
+			return self.static_value
+		return params[self.param]
+
+class SubstrValueTemplate:
+	def __init__(self, expr):
+		self.tokens = []
+		self.starred_param = None
+		self.optional_params = set()
+		while expr:
+			match = re.match(r'^(([^{}]|{{|}})*)({([*?]?)([A-Za-z0-9_]+)})?', expr)
+			if not match or not match.end():
+				raise ValueError('Invalid token')
+			text, _, _, option, key, = match.groups()
+			text = text.replace('{{', '{').replace('}}', '}')
+			option = option or None
+			expr = expr[match.end():]
+			self.tokens.append((text, key))
+			if key is None:
+				pass
+			elif option is None:
+				pass
+			elif option == '?':
+				self.optional_params.add(key)
+			elif option == '*':
+				if self.starred_param is not None and self.starred_param != key:
+					raise ValueError('Only one parameter may be starred')
+				self.starred_param = key
+			else:
+				raise ValueError(f'Invalid template option {option!r}')
+
+	def render(self, params):
+		result = ''
+		for text, key in self.tokens:
+			result += text
+			if key is not None:
+				value = params[key]
+				result += str(value)
+		return result
+
+class DNTemplate(tuple):
+	def __new__(cls, schema, *args):
+		for rdn in args:
+			if not isinstance(rdn, RDNTemplate):
+				raise TypeError(f'Argument {repr(rdn)} is of type {repr(type(rdn))}, expected ldapserver.RDNTemplate object')
+		dn = super().__new__(cls, args)
+		dn.schema = schema
+		dn.starred_param = None
+		dn.optional_params = set()
+		for rdn in args:
+			dn.optional_params |= rdn.optional_params
+			if rdn.starred_param:
+				if dn.starred_param and rdn.starred_param != dn.starred_param:
+					raise ValueError()
+				dn.starred_param = rdn.starred_param
+		return dn
+
+	@classmethod
+	def from_str(cls, schema, expr):
+		escaped = False
+		rdns = []
+		token = ''
+		for char in expr:
+			if escaped:
+				escaped = False
+				token += char
+			elif char == ',':
+				rdns.append(RDNTemplate.from_str(schema, token))
+				token = ''
+			else:
+				if char == '\\':
+					escaped = True
+				token += char
+		if token:
+			rdns.append(RDNTemplate.from_str(schema, token))
+		return cls(schema, *rdns)
+
+	def render(self, params):
+		return DN(self.schema, *[rdn.render(params) for rdn in self])
+
+	def match(self, dn, scope=ldap.SearchScope.baseObject):
+		if scope == ldap.SearchScope.baseObject:
+			if len(self) != len(dn):
+				return False
+		elif scope == ldap.SearchScope.singleLevel:
+			if len(self) != len(dn) + 1:
+				return False
+		elif scope == ldap.SearchScope.wholeSubtree:
+			if len(self) < len(dn):
+				return False
+		result = True
+		for rdn_template, rdn in zip(self[::-1], dn[::-1]):
+			subresult = rdn_template.match(rdn)
+			if subresult is False:
+				return False
+			if subresult is True:
+				continue
+			if result is True:
+				result = {}
+			for key, value in subresult.items():
+				result.setdefault(key, []).append(value)
+		if result is True:
+			return True
+		return {key: values[0] for key, values in result.items() if len(values) == 1}
+
+class RDNTemplate(tuple):
+	def __new__(cls, schema, *assertions):
+		for assertion in assertions:
+			if not isinstance(assertion, RDNAssertionTemplate):
+				raise TypeError(f'Argument {repr(assertion)} is of type {repr(type(assertion))}, expected ldapserver.RDNAssertionTemplate')
+			if assertion.attribute_type.schema is not schema:
+				raise ValueError('RDNAssertion has different schema')
+		assertions = list(assertions)
+		if not assertions:
+			raise ValueError('RDN must have at least one assertion')
+		rdn = super().__new__(cls, assertions)
+		rdn.schema = schema
+		rdn.starred_param = None
+		rdn.optional_params = set()
+		for assertion in assertions:
+			rdn.optional_params |= assertion.optional_params
+			if assertion.starred_param:
+				if rdn.starred_param and assertion.starred_param != rdn.starred_param:
+					raise ValueError()
+				rdn.starred_param = assertion.starred_param
+		return rdn
+
+	@classmethod
+	def from_str(cls, schema, expr):
+		escaped = False
+		assertions = []
+		token = ''
+		for char in expr:
+			if escaped:
+				escaped = False
+				token += char
+			elif char == '+':
+				assertions.append(RDNAssertionTemplate.from_str(schema, token))
+				token = ''
+			else:
+				if char == '\\':
+					escaped = True
+				token += char
+		if token:
+			assertions.append(RDNAssertionTemplate.from_str(schema, token))
+		return cls(schema, *assertions)
+
+	def render(self, params):
+		return RDN(self.schema, *[assertion.render(params) for assertion in self])
+
+	def match(self, rdn):
+		if len(self) != len(rdn):
+			return False
+		if len(self) != 1:
+			return {}
+		return self[0].match(rdn[0])
+
+class RDNAssertionTemplate:
+	def __init__(self, schema, attribute, value_template):
+		self.schema = schema
+		try:
+			self.attribute_type = schema.attribute_types[attribute]
+		except KeyError as exc:
+			raise ValueError('Invalid RDN attribute type: Attribute type undefined in schema') from exc
+		self.attribute = attribute
+		if not self.attribute_type.equality:
+			raise ValueError('Invalid RDN attribute type: Attribute type has no EQUALITY matching rule')
+		self.value_template = value_template
+		self.starred_param = value_template.starred_param
+		self.optional_params = value_template.optional_params
+
+	@classmethod
+	def from_str(cls, schema, expr):
+		attribute, escaped_value = expr.split('=', 1)
+		if escaped_value.startswith('#'):
+			# The "#..." form is used for unknown attribute types and those without
+			# an LDAP string encoding. Supporting it would require us to somehow
+			# handle the hex-encoded BER encoding of the data. We'll stay away from
+			# this mess for now.
+			raise ValueError('Hex-encoded RDN assertion values are not supported')
+		escaped = False
+		hexdigit = None
+		encoded_value = b''
+		for char in escaped_value:
+			if hexdigit is not None:
+				encoded_value += bytes.fromhex('%s%s'%(hexdigit, char))
+				hexdigit = None
+			elif escaped:
+				if ord(char) in DN_SPECIAL + (b'\\'[0],):
+					encoded_value += char.encode('utf8')
+				elif ord(char) in HEXDIGITS:
+					hexdigit = char
+				else:
+					raise ValueError('Invalid escape: \\%s'%char)
+				escaped = False
+			elif char == '\\':
+				escaped = True
+			else:
+				encoded_value += char.encode('utf8')
+		try:
+			attribute_type = schema.attribute_types[attribute]
+		except KeyError as exc:
+			raise ValueError('Invalid RDN attribute type: Attribute type undefined in schema') from exc
+		value_template = attribute_type.syntax.definition.parse_template(schema, encoded_value.decode())
+		return cls(schema, attribute, value_template)
+
+	def render(self, params):
+		return RDNAssertion(self.schema, self.attribute, self.value_template.render(params))
+
+	def match(self, rdnassertion):
+		if self.attribute_type != rdnassertion.attribute_type:
+			return False
+		return self.attribute_type.equality.definition.match_equal_template(self.schema, [self.value_template], rdnassertion.value)
-- 
GitLab