From 2cfda245c32ed047251b8081daaa89ba3085bd8a Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Thu, 11 Mar 2021 19:22:21 +0100
Subject: [PATCH] Code cleanup and restructuring

---
 db.py              |  51 +++++----
 exceptions.py      |  42 +++++++
 ldap.py            |   4 +
 sasl.py            |  48 --------
 sasl/__init__.py   |   0
 sasl/digest_md5.py | 140 +++++++++++++++++++++++
 server.py          | 270 ++++++++++++++++++++++-----------------------
 7 files changed, 343 insertions(+), 212 deletions(-)
 create mode 100644 exceptions.py
 delete mode 100644 sasl.py
 create mode 100644 sasl/__init__.py
 create mode 100644 sasl/digest_md5.py

diff --git a/db.py b/db.py
index ec49d48..c1f3176 100644
--- a/db.py
+++ b/db.py
@@ -8,9 +8,10 @@ from sqlalchemy.orm import sessionmaker, relationship, RelationshipProperty, ali
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.ext.hybrid import hybrid_property
 from ldap import SearchScope, FilterAnd, FilterOr, FilterNot, FilterEqual, FilterPresent
-from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired, LDAPNoSuchObject, encode_attribute
+from server import LDAPRequestHandler, LDAPInvalidCredentials, LDAPInsufficientAccessRights, LDAPConfidentialityRequired, LDAPNoSuchObject, encode_attribute, CaseInsensitiveDict
 import socketserver
 from dn import parse_dn, build_dn
+import sasl.digest_md5
 
 Base = declarative_base()
 
@@ -131,12 +132,10 @@ class Group(Base):
 Base.metadata.create_all(engine)
 
 class SQLModelWrapper(BaseSearchEvaluator):
-	def __init__(self, store, model, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
+	def __init__(self, store, model, attribute_map=None, objectclasses=None, rdn_attr='uid', dn_base=''):
 		self.store = store
 		self.model = model
-		self.attributes = {}
-		for ldap_name, attr_name in (attributes or {}).items():
-			self.attributes[ldap_name.lower()] = attr_name
+		self.attribute_map = CaseInsensitiveDict(attribute_map or {})
 		self.objectclasses = []
 		for value in (objectclasses or []):
 			value = value.lower()
@@ -144,15 +143,14 @@ class SQLModelWrapper(BaseSearchEvaluator):
 				value = value.encode()
 			self.objectclasses.append(value)
 		self.rdn_attr = rdn_attr
-		self.dn_base = dn_base
 		self.dn_base_path = parse_dn(dn_base)
 
 	def filter_present(self, name):
 		if name == 'objectclass':
 			return True
-		if name not in self.attributes:
+		if name not in self.attribute_map:
 			return False
-		attr = getattr(self.model, self.attributes[name])
+		attr = getattr(self.model, self.attribute_map[name])
 		if hasattr(attr, 'prop') and isinstance(attr.prop, RelationshipProperty):
 			return attr.any()
 		return attr.isnot(None)
@@ -160,9 +158,9 @@ class SQLModelWrapper(BaseSearchEvaluator):
 	def filter_equal(self, name, value):
 		if name == 'objectclass':
 			return value in self.objectclasses
-		if name not in self.attributes:
+		if name not in self.attribute_map:
 			return False
-		attr = getattr(self.model, self.attributes[name])
+		attr = getattr(self.model, self.attribute_map[name])
 		if isinstance(attr, str):
 			value = value.decode()
 		elif isinstance(attr, int):
@@ -208,7 +206,7 @@ class SQLModelWrapper(BaseSearchEvaluator):
 			return False
 
 	def get_dn(self, obj):
-		attr_name = self.attributes[self.rdn_attr]
+		attr_name = self.attribute_map[self.rdn_attr]
 		rdn_value = encode_attribute(getattr(obj, attr_name))
 		dn_parts = (((self.rdn_attr, rdn_value),),) + self.dn_base_path
 		return build_dn(dn_parts)
@@ -222,7 +220,7 @@ class SQLModelWrapper(BaseSearchEvaluator):
 			objs = self.store.session.query(self.model).filter(filter_obj)
 		for obj in objs:
 			attrs = {}
-			for ldap_name, attr_name in self.attributes.items():
+			for ldap_name, attr_name in self.attribute_map.items():
 				values = getattr(obj, attr_name)
 				if values is None:
 					continue
@@ -242,8 +240,8 @@ class SQLObjectStore:
 		self.session = session
 		self.models = {}
 
-	def register_model(self, model, attributes=None, objectclasses=None, rdn_attr='uid', dn_base=''):
-		self.models[model] = SQLModelWrapper(self, model, attributes, objectclasses, rdn_attr, dn_base)
+	def register_model(self, model, attribute_map=None, objectclasses=None, rdn_attr='uid', dn_base=''):
+		self.models[model] = SQLModelWrapper(self, model, attribute_map, objectclasses, rdn_attr, dn_base)
 
 	def search(self, baseobj, scope, filter):
 		for model, wrapper in self.models.items():
@@ -252,17 +250,17 @@ class SQLObjectStore:
 sqlstore = SQLObjectStore(session)
 sqlstore.register_model(
 	model=User,
-	attributes={
+	attribute_map={
 		'cn': 'displayname',
 		'displayname': 'displayname',
-		'gidnumber': 'ldap_gid',
+		'gidNumber': 'ldap_gid',
 		'givenname': 'displayname',
-		'homedirectory': 'homedirectory',
+		'homeDirectory': 'homedirectory',
 		'mail': 'email',
 		'sn': 'ldap_sn',
 		'uid': 'loginname',
-		'uidnumber': 'id',
-		'memberof': 'groups',
+		'uidNumber': 'id',
+		'memberOf': 'groups',
 	},
 	objectclasses=[b'top', b'inetOrgPerson', b'organizationalPerson', b'person', b'posixAccount'],
 	rdn_attr='uid',
@@ -271,10 +269,10 @@ sqlstore.register_model(
 
 sqlstore.register_model(
 	model=Group,
-	attributes={
+	attribute_map={
 		'cn': 'name',
 		'description': 'description',
-		'gidnumber': 'id',
+		'gidNumber': 'id',
 		'uniqueMember': 'users',
 	},
 	objectclasses=[b'top', b'posixGroup', b'groupOfUniqueNames'],
@@ -288,7 +286,12 @@ ssl_context.load_cert_chain('devcert.crt', 'devcert.key')
 class RequestHandler(LDAPRequestHandler):
 	ssl_context = ssl_context
 
-	sasl_enable_plain = True
+	supports_sasl_digest_md5 = True
+
+	def do_bind_sasl_digest_md5(self, username, realm, host, serv_name=None, authzid=None, charset='utf-8'):
+		return [(None, sasl.digest_md5.credential_digest(username, realm, pw)) for pw in ['foobar', 'abcdef']]
+
+	supports_sasl_plain = True
 
 	def do_bind_sasl_plain(self, identity, password, authzid=None):
 		if not isinstance(self.request, SSLSocket):
@@ -318,7 +321,7 @@ class RequestHandler(LDAPRequestHandler):
 		return user
 
 	@property
-	def sasl_enable_external(self):
+	def supports_sasl_external(self):
 		return self.request.family == socket.AF_UNIX
 
 	def do_bind_sasl_external(self, authzid=None):
@@ -334,6 +337,8 @@ class RequestHandler(LDAPRequestHandler):
 			raise LDAPNoSuchObject()
 		return user
 
+	supports_whoami = True
+
 	def do_whoami(self):
 		if self.bind_object is None:
 			return ''
diff --git a/exceptions.py b/exceptions.py
new file mode 100644
index 0000000..7263a09
--- /dev/null
+++ b/exceptions.py
@@ -0,0 +1,42 @@
+from ldap import LDAPResultCode
+
+class LDAPError(Exception):
+	def __init__(self, code=LDAPResultCode.other, message=''):
+		self.code = code
+		self.message = message
+
+class LDAPOperationsError(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.operationsError, message)
+
+class LDAPProtocolError(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.protocolError, message)
+
+class LDAPInvalidCredentials(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.invalidCredentials, message)
+
+class LDAPInsufficientAccessRights(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.insufficientAccessRights, message)
+
+class LDAPUnwillingToPerform(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.unwillingToPerform, message)
+
+class LDAPAuthMethodNotSupported(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.authMethodNotSupported, message)
+
+class LDAPConfidentialityRequired(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.confidentialityRequired, message)
+
+class LDAPNoSuchObject(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.noSuchObject, message)
+
+class LDAPUnavailableCriticalExtension(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.unavailableCriticalExtension, message)
diff --git a/ldap.py b/ldap.py
index a72fefa..64a9143 100644
--- a/ldap.py
+++ b/ldap.py
@@ -667,6 +667,10 @@ class ShallowLDAPMessage(BERType):
 
 # Extended Operation Values
 
+EXT_STARTTLS_OID = '1.3.6.1.4.1.1466.20037'
+EXT_WHOAMI_OID = '1.3.6.1.4.1.4203.1.11.3'
+EXT_PASSWORD_MODIFY_OID = '1.3.6.1.4.1.4203.1.11.1'
+
 class PasswdModifyRequestValue(Sequence):
 	sequence_fields = [
 		(retag(LDAPString, (2, False, 0)), 'userIdentity', None, True),
diff --git a/sasl.py b/sasl.py
deleted file mode 100644
index 48febc7..0000000
--- a/sasl.py
+++ /dev/null
@@ -1,48 +0,0 @@
-SEP = [b'(', b')', b'<', b'>', b'@', b',', b';', b':', b'\\', b'\'', b'/', b'[', b']', b'?', b'=', b'{', b'}', b' ', b'\t']
-CTL = [bytes([c]) for c in range(0, 31)] + [b'127']
-
-def parse_token(s):
-	for index in range(len(s)):
-		c = bytes([s[index]])
-		if c in SEP + CTL:
-			return bytes(s[:index]), bytes(s[index:])
-	return s, b''
-
-def parse_qstr(s):
-	if s[0] != b'"'[0]:
-		raise ValueError()
-	res = b''
-	escaped = False
-	for index in range(1, len(s)):
-		c = bytes([s[index]])
-		if escaped:
-			res += c
-			escaped = False
-		elif c == b'\\':
-			escaped = True
-		elif c == b'"':
-			return res, bytes(s[index+1:])
-		else:
-			res += c
-	raise ValueError()
-
-def parse_token_qstr(s):
-	if s[0] == b'"'[0]:
-		return parse_qstr(s)
-	return parse_token(s)
-
-def parse_kwargs(s):
-	res = []
-	while True:
-		key, s = parse_token(s)
-		if s[0] != b'='[0]:
-			raise ValueError()
-		value, s = parse_token_qstr(bytes(s[1:]))
-		res.append((key, value))
-		if not s:
-			return res
-		if s[0] != b','[0]:
-			raise ValueError()
-		s = bytes(s[1:])
-	return res
-
diff --git a/sasl/__init__.py b/sasl/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/sasl/digest_md5.py b/sasl/digest_md5.py
new file mode 100644
index 0000000..fd81af1
--- /dev/null
+++ b/sasl/digest_md5.py
@@ -0,0 +1,140 @@
+import hashlib
+import secrets
+
+from exceptions import *
+
+def _parse_token(s):
+	SEP = [b'(', b')', b'<', b'>', b'@', b',', b';', b':', b'\\', b'\'', b'/',
+	       b'[', b']', b'?', b'=', b'{', b'}', b' ', b'\t']
+	CTL = [bytes([c]) for c in range(0, 31)] + [b'127']
+	for index in range(len(s)):
+		c = bytes([s[index]])
+		if c in SEP + CTL:
+			return bytes(s[:index]), bytes(s[index:])
+	return s, b''
+
+def _parse_qstr(s):
+	if s[0] != b'"'[0]:
+		raise ValueError()
+	res = b''
+	escaped = False
+	for index in range(1, len(s)):
+		c = bytes([s[index]])
+		if escaped:
+			res += c
+			escaped = False
+		elif c == b'\\':
+			escaped = True
+		elif c == b'"':
+			return res, bytes(s[index+1:])
+		else:
+			res += c
+	raise ValueError()
+
+def _parse_token_qstr(s):
+	if s[0] == b'"'[0]:
+		return _parse_qstr(s)
+	return _parse_token(s)
+
+def _parse_kwargs(s):
+	res = []
+	while True:
+		key, s = _parse_token(s)
+		if s[0] != b'='[0]:
+			raise ValueError()
+		value, s = _parse_token_qstr(bytes(s[1:]))
+		res.append((key, value))
+		if not s:
+			return res
+		if s[0] != b','[0]:
+			raise ValueError()
+		s = bytes(s[1:])
+	return res
+
+def _generate_nonce():
+	return secrets.token_urlsafe(1024).encode()
+
+def _hexdigest(data):
+	ctx = hashlib.md5()
+	ctx.update(data)
+	return ctx.hexdigest().lower().encode()
+
+def _handle_ldap_bind(get_credentials, get_nonce=_generate_nonce, initial_response=None):
+	# Defined by RFC2831 and RFC2829, obsoleted by RFC6331
+	nonce = get_nonce()
+	challenge = b'nonce="%s",charset="utf-8",algorithm="md5-sess"'%(nonce)
+	resp = yield challenge
+	args = {key: value for key, value in _parse_kwargs(resp)}
+	if args[b'nonce'] != nonce:
+		raise LDAPProtocolError()
+	try:
+		charset = 'utf-8' if args.get(b'charset', b'utf-8') == b'utf-8' else 'latin_1'
+		username = args[b'username']
+		realm = args.get(b'realm', b'')
+		cnonce = args[b'cnonce']
+		nc = args.get(b'nc', b'00000001')
+		qop = args.get(b'qop', b'auth')
+		digest_uri = args[b'digest-uri']
+		parts = digest_uri.decode(charset).split('/', 2)
+		serv_type = parts[0]
+		host = parts[1]
+		serv_name = parts[2] if len(parts) == 3 else None
+		authzid = args[b'authzid'].decode(charset) if b'authzid' in args else None
+		response = args[b'response']
+	except (KeyError, IndexError):
+		raise LDAPProtocolError()
+	except UnicodeError:
+		raise LDAPProtocolError()
+	if serv_type != 'ldap':
+		raise LDAPInvalidCredentials()
+	valid_credentials = get_credentials(username, realm, host, serv_name=None, authzid=None, charset=charset)
+	a2 = b'AUTHENTICATE:' + digest_uri
+	data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + _hexdigest(a2)
+	for bind_obj, pwdigest in valid_credentials:
+		a1 = pwdigest + b':' + nonce + b':' + cnonce
+		key = _hexdigest(a1)
+		expected_response = _hexdigest(key + b':' + data)
+		if expected_response != response:
+			continue
+		# We don't support subsequent authentication so according to RFC 2829 the
+		# serverSaslCreds field in our response should be absent and we should
+		# return (bind_obj, None). But this seems to confuse some clients (e.g.
+		# openldap's ldapsearch) so we return serverSaslCreds with rspauth instead.
+		a2 = b':' + digest_uri
+		data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + _hexdigest(a2)
+		response = b'rspauth=%s'%_hexdigest(key + b':' + data)
+		return bind_obj, response
+	raise LDAPInvalidCredentials()
+
+def _encode_latin1_or_utf8(s):
+	try:
+		return s.encode('latin_1')
+	except UnicodeEncodeError:
+		pass
+	return s.encode('utf-8')
+
+def credential_digest(username, realm, password):
+	'''Compute DIGEST-MD5-specific credential digest
+
+	:param username: Name of the user account
+	:type username: bytes or str
+	:param realm: Realm containing the user account
+	:type realm: bytes or str
+	:param password: Password for the user account
+	:type password: bytes or str
+
+	:returns: DIGEST-MD5-specific credential digest (16 bytes)
+	:rtype: bytes
+
+	Parameters passed as strings are encoded according to the special DIGEST-MD5
+	encoding rules (latin_1 whenever all characters can be encoded wit it, utf-8
+	otherwise).'''
+	if isinstance(username, str):
+		username = _encode_latin1_or_utf8(username)
+	if isinstance(realm, str):
+		realm = _encode_latin1_or_utf8(realm)
+	if isinstance(password, str):
+		password = _encode_latin1_or_utf8(password)
+	ctx = hashlib.md5()
+	ctx.update(username + b':' + realm + b':' + password)
+	return ctx.digest()
diff --git a/server.py b/server.py
index 5e203e8..17a964e 100644
--- a/server.py
+++ b/server.py
@@ -5,49 +5,9 @@ import socket
 import ssl
 from socketserver import BaseRequestHandler
 
-import sasl
-from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue, SaslCredentials, SearchScope, FilterPresent
-
-class LDAPError(Exception):
-	def __init__(self, code=LDAPResultCode.other, message=''):
-		self.code = code
-		self.message = message
-
-class LDAPOperationsError(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.operationsError, message)
-
-class LDAPProtocolError(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.protocolError, message)
-
-class LDAPInvalidCredentials(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.invalidCredentials, message)
-
-class LDAPInsufficientAccessRights(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.insufficientAccessRights, message)
-
-class LDAPUnwillingToPerform(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.unwillingToPerform, message)
-
-class LDAPAuthMethodNotSupported(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.authMethodNotSupported, message)
-
-class LDAPConfidentialityRequired(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.confidentialityRequired, message)
-
-class LDAPNoSuchObject(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.noSuchObject, message)
-
-class LDAPUnavailableCriticalExtension(LDAPError):
-	def __init__(self, message=''):
-		super().__init__(LDAPResultCode.unavailableCriticalExtension, message)
+from ldap import *
+import sasl.digest_md5
+from exceptions import *
 
 def decode_msg(shallowmsg):
 	try:
@@ -63,33 +23,37 @@ def encode_attribute(value):
 		value = value.encode()
 	return value
 
-class AttributeKey(str):
+class CaseInsensitiveKey(str):
 	def __hash__(self):
 		return hash(self.lower())
 
 	def __eq__(self, value):
 		return self.lower() == value.lower()
 
-class AttributeDict(dict):
+class CaseInsensitiveDict(dict):
 	def __init__(self, *args, **kwargs):
 		if len(args) == 1 and isinstance(args[0], dict):
-			kwargs = {AttributeKey(k): v for k, v in args[0].items()}
+			kwargs = {CaseInsensitiveKey(k): v for k, v in args[0].items()}
 			args = []
 		else:
-			kwargs = {AttributeKey(k): v for k, v in kwargs.items()}
-			args = [(AttributeKey(k), v) for k, v in args]
+			kwargs = {CaseInsensitiveKey(k): v for k, v in kwargs.items()}
+			args = [(CaseInsensitiveKey(k), v) for k, v in args]
 		super().__init__(*args, **kwargs)
 
 	def __contains__(self, key):
-		return super().__contains__(AttributeKey(key))
+		return super().__contains__(CaseInsensitiveKey(key))
 
 	def __setitem__(self, key, value):
-		super().__setitem__(AttributeKey(key), value)
+		super().__setitem__(CaseInsensitiveKey(key), value)
+
+	def __getitem__(self, key):
+		return super().__getitem__(CaseInsensitiveKey(key))
 
+class AttributeDict(CaseInsensitiveDict):
 	def __getitem__(self, key):
 		if key not in self:
 			self[key] = []
-		return super().__getitem__(AttributeKey(key))
+		return super().__getitem__()
 
 class RootDSE(AttributeDict):
 	def search(self, baseobj, scope, filter):
@@ -117,11 +81,6 @@ def check_controls(controls=None, supported_oids=[]):
 class LDAPRequestHandler(BaseRequestHandler):
 	ssl_context = None
 
-	sasl_enable_anonymous = False
-	sasl_enable_plain = False
-	sasl_enable_external = False
-	sasl_enable_digest_md5 = False
-
 	def setup(self):
 		super().setup()
 		self.rootdse = RootDSE()
@@ -134,20 +93,38 @@ class LDAPRequestHandler(BaseRequestHandler):
 		self.bind_sasl_state = None
 
 	def get_extentions(self):
+		'''Get supported LDAP extentions
+
+		:returns: OIDs of supported LDAP extentions
+		:rtype: list of bytes objects
+
+		Called whenever the root DSE attribute "supportedExtension" is queried.'''
 		res = []
-		if self.ssl_context is not None and not isinstance(self.request, ssl.SSLSocket):
-			res.append(b'1.3.6.1.4.1.1466.20037')
+		if self.supports_starttls:
+			res.append(EXT_STARTTLS_OID.encode())
+		if self.supports_whoami:
+			res.append(EXT_WHOAMI_OID.encode())
+		if self.supports_password_modify:
+			res.append(EXT_PASSWORD_MODIFY_OID.encode())
 		return res
 
 	def get_sasl_mechanisms(self):
+		'''Get supported SASL mechanisms
+
+		:returns: Names of supported SASL mechanisms
+		:rtype: list of bytes objects
+
+		SASL mechanism name are typically all-caps, like "EXTERNAL".
+		
+		Called whenever the root DSE attribute "supportedSASLMechanisms" is queried.'''
 		res = []
-		if self.sasl_enable_anonymous:
+		if self.supports_sasl_anonymous:
 			res.append(b'ANONYMOUS')
-		if self.sasl_enable_plain:
+		if self.supports_sasl_plain:
 			res.append(b'PLAIN')
-		if self.sasl_enable_external:
+		if self.supports_sasl_external:
 			res.append(b'EXTERNAL')
-		if self.sasl_enable_digest_md5:
+		if self.supports_sasl_digest_md5:
 			res.append(b'DIGEST-MD5')
 		return res
 
@@ -254,7 +231,6 @@ class LDAPRequestHandler(BaseRequestHandler):
 
 		Calld by `do_bind_simple()`. The default implementation always raises an
 		`LDAPInvalidCredentials` exception.'''
-
 		raise LDAPInvalidCredentials()
 
 	def do_bind_sasl(self, mechanism, credentials=None, dn=None):
@@ -289,23 +265,25 @@ class LDAPRequestHandler(BaseRequestHandler):
 		if not mechanism:
 			# Request to abort current negotiation (RFC4513 5.2.1.2)
 			raise LDAPAuthMethodNotSupported()
-		if mechanism == 'ANONYMOUS' and self.sasl_enable_anonymous:
+		if mechanism == 'ANONYMOUS' and self.supports_sasl_anonymous:
 			if credentials is not None:
 				credentials = credentials.decode()
 			return self.do_bind_sasl_anonymous(trace_info=credentials), None
-		if mechanism == 'PLAIN' and self.sasl_enable_plain:
+		if mechanism == 'PLAIN' and self.supports_sasl_plain:
 			if credentials is None:
 				raise LDAPProtocolError('Unsupported protocol version')
 			authzid, authcid, password = credentials.split(b'\0', 2)
 			return self.do_bind_sasl_plain(authcid.decode(), password.decode(), authzid.decode() or None), None
-		if mechanism == 'EXTERNAL' and self.sasl_enable_external:
+		if mechanism == 'EXTERNAL' and self.supports_sasl_external:
 			if credentials is not None:
 				credentials = credentials.decode()
 			return self.do_bind_sasl_external(authzid=credentials), None
-		if mechanism == 'DIGEST-MD5' and self.sasl_enable_digest_md5:
-			return self.do_bind_sasl_digest_md5(credentials)
+		if mechanism == 'DIGEST-MD5' and self.supports_sasl_digest_md5:
+			return sasl.digest_md5._handle_ldap_bind(self.do_bind_sasl_digest_md5, initial_response=credentials)
 		raise LDAPAuthMethodNotSupported()
 
+	supports_sasl_anonymous = False
+
 	def do_bind_sasl_anonymous(self, trace_info=None):
 		'''Do LDAP BIND with SASL "ANONYMOUS" mechanism (RFC 4505)
 
@@ -322,6 +300,8 @@ class LDAPRequestHandler(BaseRequestHandler):
 		`LDAPAuthMethodNotSupported` exception.'''
 		raise LDAPAuthMethodNotSupported()
 
+	supports_sasl_plain = False
+
 	def do_bind_sasl_plain(self, identity, password, authzid=None):
 		'''Do LDAP BIND with SASL "PLAIN" mechanism (RFC 4616)
 
@@ -341,6 +321,8 @@ class LDAPRequestHandler(BaseRequestHandler):
 		`LDAPAuthMethodNotSupported` exception.'''
 		raise LDAPAuthMethodNotSupported()
 
+	supports_sasl_external = False
+
 	def do_bind_sasl_external(self, authzid=None):
 		'''Do LDAP BIND with SASL "EXTERNAL" mechanism (RFC 4422 and 4513)
 
@@ -359,58 +341,42 @@ class LDAPRequestHandler(BaseRequestHandler):
 		`LDAPAuthMethodNotSupported` exception.'''
 		raise LDAPAuthMethodNotSupported()
 
-	def do_bind_sasl_digest_md5_password(self, username, realm, authzid=None):
-		# should return (bind_obj, password string) for username
-		raise LDAPAuthMethodNotSupported()
+	supports_sasl_digest_md5 = False
 
-	def do_bind_sasl_digest_md5_pwdigest(self, username, realm, charset='utf-8', authzid=None):
-		# charset is either 'utf-8' or 'latin_1', it should only affect username and password
-		bind_obj, password = self.do_bind_sasl_digest_md5_password(username, realm, authzid=authzid)
-		ctx = hashlib.md5()
-		ctx.update(username.encode(charset) + b':' + realm + b':' + password.encode(charset))
-		return bind_obj, ctx.digest()
-
-	def do_bind_sasl_digest_md5(self, credentials=None):
-		# Defined by RFC2831 and RFC2829, obsoleted by RFC6331
-		nonce = secrets.token_urlsafe(1024).encode()
-		challenge = b'nonce="%s",charset="utf-8",algorithm="md5-sess"'%(nonce)
-		resp = yield challenge
-		args = {key: value for key, value in sasl.parse_kwargs(resp)}
-		if args[b'nonce'] != nonce:
-			raise LDAPProtocolError()
-		try:
-			charset = 'utf-8' if args.get(b'charset', b'utf-8') == b'utf-8' else 'latin_1'
-			username = args[b'username'].decode(charset)
-			realm = args.get(b'realm', b'')
-			cnonce = args[b'cnonce']
-			nc = args.get(b'nc', b'00000001')
-			qop = args.get(b'qop', b'auth')
-			digest_uri = args[b'digest-uri']
-			response = args[b'response']
-		except KeyError:
-			raise LDAPProtocolError()
-		except UnicodeError:
-			raise LDAPProtocolError()
-		bind_obj, pwdigest = self.do_bind_sasl_digest_md5_pwdigest(username, realm, charset)
-		def md5digest(data):
-			ctx = hashlib.md5()
-			ctx.update(data)
-			return ctx.hexdigest().lower().encode()
-		a1 = pwdigest + b':' + nonce + b':' + cnonce
-		a2 = b'AUTHENTICATE:' + digest_uri
-		key = md5digest(a1)
-		data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2)
-		expected_response = md5digest(key + b':' + data)
-		if expected_response != response:
-			raise LDAPInvalidCredentials()
-		# We don't support subsequent authentication so according to RFC2829 the
-		# serverSaslCreds field in our response should be absent and we should
-		# return (bind_obj, None). But this seems to confuse some clients (e.g.
-		# openldap's ldapsearch) so we return serverSaslCreds with rspauth instead.
-		a2 = b':' + digest_uri
-		data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2)
-		response = b'rspauth=%s'%md5digest(key + b':' + data)
-		return bind_obj, response
+	def do_bind_sasl_digest_md5(self, username, realm, host, serv_name=None, authzid=None, charset='utf-8'):
+		'''Do LDAP BIND with SASL "DIGEST-MD5" mechanism (RFC 2829)
+
+		:param username: Name of the user account
+		:type username: bytes
+		:param realm: Realm containing the user account
+		:type realm: bytes
+		:param host: DNS host name or IP address for the requested service, should be verified
+		:type host: str
+		:param serv_name: Name of the service if it is replicated, see RFC 2829 for details
+		:type serv_name: str, optional
+		:param authzid: Authorization identity
+		:type authzid: str, optional
+		:param charset: Charset ("utf-8" or "latin_1") that username and realm are encoded with
+		:type charset: str
+
+		:returns: Pairs of bind objects and credential digests that are acceptable for username and realm
+		:rtype: [(obj, bytes), ...]
+
+		WARNING: "DIGEST-MD5" is insecure and was obsoleted by RFC 6331. It is only
+		implemented for completeness and widespread client support.
+
+		To implement this mechanism, passwords must be stored either unencrypted or
+		as mechanism-specific credential digests. Use `sasl.md5_digest.credential_digest` to
+		generate the digest from username, realm and password.
+
+		Note that username and realm are passed as bytes instead of strings to
+		reduce the risk of encoding/normalization-related problems. Decode them with
+		`username.decode(charset)`. Make sure to pass the values as bytes to
+		`sasl.md5_digest.credential_digest` if possible.
+
+		Called by `do_bind_sasl()`. The default implementation raises an
+		`LDAPInvalidCredentials` exception.'''
+		raise LDAPInvalidCredentials()
 
 	def handle_search(self, req):
 		check_controls(req.controls)
@@ -445,20 +411,19 @@ class LDAPRequestHandler(BaseRequestHandler):
 	def handle_extended(self, req):
 		check_controls(req.controls)
 		op = req.protocolOp
-		if op.requestName == '1.3.6.1.4.1.1466.20037':
+		if op.requestName == EXT_STARTTLS_OID and self.supports_starttls:
 			# StartTLS (RFC 4511)
-			sent_response = False
-			for _ in self.do_starttls():
-				if not sent_response:
-					self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseName='1.3.6.1.4.1.1466.20037')))
-				sent_response = True
-			if not sent_response:
-				raise LDAPProtocolError()
-		elif op.requestName == '1.3.6.1.4.1.4203.1.11.3':
+			self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseName=EXT_STARTTLS_OID)))
+			try:
+				self.do_starttls()
+			except Exception as e:
+				traceback.print_exc()
+				self.keep_running = False
+		elif op.requestName == EXT_WHOAMI_OID and self.supports_whoami:
 			# "Who am I?" Operation (RFC 4532)
 			identity = (self.do_whoami() or '').encode()
 			self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseValue=identity)))
-		elif op.requestName == '1.3.6.1.4.1.4203.1.11.1':
+		elif op.requestName == EXT_PASSWORD_MODIFY_OID and self.supports_password_modify:
 			# Password Modify Extended Operation (RFC 3062)
 			newpw = None
 			if op.requestValue is None:
@@ -474,27 +439,50 @@ class LDAPRequestHandler(BaseRequestHandler):
 		else:
 			raise LDAPProtocolError()
 
+	@property
+	def supports_starttls(self):
+		return self.ssl_context is not None and not isinstance(self.request, ssl.SSLSocket)
+
 	def do_starttls(self):
-		if self.ssl_context is None or isinstance(self.request, ssl.SSLSocket):
-			raise LDAPProtocolError()
-		yield None
-		try:
-			self.request = self.ssl_context.wrap_socket(self.request, server_side=True)
-		except Exception as e:
-			traceback.print_exc()
-			self.keep_running = False
+		'''Do StartTLS extended operation (RFC 4511)
+
+		Called by `handle_extended()` if `supports_starttls` is True. The default
+		implementation uses `ssl_context`.
+
+		Note that the (success) response to the request is sent before this method
+		is called. If a call to this method fails, the LDAP connection is
+		immediately terminated.'''
+		self.request = self.ssl_context.wrap_socket(self.request, server_side=True)
+
+	supports_whoami = False
 
 	def do_whoami(self):
-		'''Do "Who am I" operation (RFC 4532)
+		'''Do "Who am I?" extended operation (RFC 4532)
 
 		:returns: Current authorization identity (authzid) or empty string for anonymous sessions
 		:rtype: str
 
-		The default implementation always returns an empty string.'''
+		Called by `handle_extended()` if `supports_whoami` is True. The default
+		implementation always returns an empty string.'''
 		return ''
 
-	def do_passwd(self, user=None, oldpasswd=None, newpasswd=None):
-		raise LDAPUnwillingToPerform('Password change is not supported')
+	supports_password_modify = False
+
+	def do_password_modify(self, user=None, old_password=None, new_password=None):
+		'''Do password modify extended operation (RFC 3062)
+		
+		:param user: User the request relates to, may or may not be a
+		             distinguished name. If absent, the request relates to the
+		             user currently associated with the LDAP connection
+		:type user: str, optional
+		:param old_password: Current password of user
+		:type old_password: bytes, optional
+		:param new_password: Desired password for user
+		:type new_password: bytes, optional
+		
+		Called by `handle_extended()` if `supports_password_modify` is True. The
+		default implementation always raises an LDAPUnwillingToPerform error.'''
+		raise LDAPUnwillingToPerform()
 
 	def handle_modify(self, req):
 		check_controls(req.controls)
-- 
GitLab