diff --git a/ldap.py b/ldap.py
index eadc75e10a4e9a9b812466089ac6766563c2152d..a72fefacb62048b5dcc2afafa3e201694bade145 100644
--- a/ldap.py
+++ b/ldap.py
@@ -442,6 +442,13 @@ class SimpleAuthentication(Wrapper, AuthenticationChoice):
 			return '<%s(EMPTY PASSWORD)>'%(type(self).__name__)
 		return '<%s(PASSWORD HIDDEN)>'%(type(self).__name__)
 
+class SaslCredentials(Sequence, AuthenticationChoice):
+	ber_tag = (2, True, 3)
+	sequence_fields = [
+		(LDAPString, 'mechanism', None, False),
+		(OctetString, 'credentials', None, True),
+	]
+
 class AttributeValueSet(Set):
 	set_type = OctetString
 
@@ -475,8 +482,14 @@ class BindRequest(Sequence, ProtocolOp):
 		(AuthenticationChoice, 'authentication', lambda: SimpleAuthentication(), False)
 	]
 
-class BindResponse(LDAPResult, ProtocolOp):
+class BindResponse(Sequence, ProtocolOp):
 	ber_tag = (1, True, 1)
+	sequence_fields = [
+		(wrapenum(LDAPResultCode), 'resultCode', None, False),
+		(LDAPString, 'matchedDN', '', False),
+		(LDAPString, 'diagnosticMessage', '', False),
+		(retag(OctetString, (2, False, 7)), 'serverSaslCreds', None, True)
+	]
 
 class UnbindRequest(Sequence, ProtocolOp):
 	ber_tag = (1, False, 2)
@@ -621,20 +634,36 @@ class LDAPMessage(Sequence):
 		(Controls, 'controls', None, True)
 	]
 
-class ShallowProtocolOp:
+class ShallowLDAPMessage(BERType):
+	ber_tag = (0, True, 16)
+
+	def __init__(self, messageID=None, protocolOpType=None, data=None):
+		self.messageID = messageID
+		self.protocolOpType = protocolOpType
+		self.data = data
+
+	def decode(self):
+		return LDAPMessage.from_ber(self.data)
+
 	@classmethod
 	def from_ber(cls, data):
-		obj, rest = decode_ber(data)
+		seq, rest = decode_ber(data)
+		data = data[:len(data)-len(rest)]
+		if seq.tag != cls.ber_tag:
+			raise ValueError()
+		content = seq.content
+		messageID, content = Integer.from_ber(content)
+		op, content = decode_ber(content)
 		for subcls in ProtocolOp.__subclasses__():
-			if subcls.ber_tag == obj.tag:
-				return subcls, rest
-		return None, rest
+			if subcls.ber_tag == op.tag:
+				return cls(messageID, subcls, data), rest
+		return cls(messageID, None, data), rest
 
-class ShallowLDAPMessage(Sequence):
-	sequence_fields = [
-		(Integer, 'messageID', None, False),
-		(ShallowProtocolOp, 'protocolOp', None, False)
-	]
+	@classmethod
+	def to_ber(cls, obj):
+		if not isinstance(obj, cls):
+			raise TypeError()
+		return obj.data
 
 # Extended Operation Values
 
diff --git a/sasl.py b/sasl.py
new file mode 100644
index 0000000000000000000000000000000000000000..48febc7c6300d7c148bcfc6e04a447dcddbc2a0a
--- /dev/null
+++ b/sasl.py
@@ -0,0 +1,48 @@
+SEP = [b'(', b')', b'<', b'>', b'@', b',', b';', b':', b'\\', b'\'', b'/', b'[', b']', b'?', b'=', b'{', b'}', b' ', b'\t']
+CTL = [bytes([c]) for c in range(0, 31)] + [b'127']
+
+def parse_token(s):
+	for index in range(len(s)):
+		c = bytes([s[index]])
+		if c in SEP + CTL:
+			return bytes(s[:index]), bytes(s[index:])
+	return s, b''
+
+def parse_qstr(s):
+	if s[0] != b'"'[0]:
+		raise ValueError()
+	res = b''
+	escaped = False
+	for index in range(1, len(s)):
+		c = bytes([s[index]])
+		if escaped:
+			res += c
+			escaped = False
+		elif c == b'\\':
+			escaped = True
+		elif c == b'"':
+			return res, bytes(s[index+1:])
+		else:
+			res += c
+	raise ValueError()
+
+def parse_token_qstr(s):
+	if s[0] == b'"'[0]:
+		return parse_qstr(s)
+	return parse_token(s)
+
+def parse_kwargs(s):
+	res = []
+	while True:
+		key, s = parse_token(s)
+		if s[0] != b'='[0]:
+			raise ValueError()
+		value, s = parse_token_qstr(bytes(s[1:]))
+		res.append((key, value))
+		if not s:
+			return res
+		if s[0] != b','[0]:
+			raise ValueError()
+		s = bytes(s[1:])
+	return res
+
diff --git a/server.py b/server.py
index 00a27628efb521c8da7ef739bd83b4d2aaa9f001..2a65f509b19bc783feccd3e0e4aa041df4656b13 100644
--- a/server.py
+++ b/server.py
@@ -1,7 +1,10 @@
 import traceback
+import hashlib
+import secrets
 from socketserver import BaseRequestHandler
 
-from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue
+import sasl
+from ldap import LDAPMessage, ShallowLDAPMessage, BindRequest, BindResponse, SearchRequest, SearchResultEntry, PartialAttribute, SearchResultDone, UnbindRequest, LDAPResultCode, IncompleteBERError, SimpleAuthentication, ModifyRequest, ModifyResponse, AddRequest, AddResponse, DelRequest, DelResponse, ModifyDNRequest, ModifyDNResponse, CompareRequest, CompareResponse, AbandonRequest, ExtendedRequest, ExtendedResponse, PasswdModifyRequestValue, PasswdModifyResponseValue, SaslCredentials
 
 class LDAPError(Exception):
 	def __init__(self, code=LDAPResultCode.other, message=''):
@@ -32,20 +35,129 @@ class LDAPAuthMethodNotSupported(LDAPError):
 	def __init__(self, message=''):
 		super().__init__(LDAPResultCode.authMethodNotSupported, message)
 
+class LDAPConfidentialityRequired(LDAPError):
+	def __init__(self, message=''):
+		super().__init__(LDAPResultCode.confidentialityRequired, message)
+
+def decode_msg(shallowmsg):
+	try:
+		return shallowmsg.decode()[0]
+	except:
+		traceback.print_exc()
+		raise LDAPProtocolError()
+
 class LDAPRequestHandler(BaseRequestHandler):
+	ssl_context = None
+
+	def setup(self):
+		super().setup()
+		self.keep_running = True
+		self.bind_object = None
+		self.bind_sasl_state = None
+
 	def handle_bind(self, req):
 		op = req.protocolOp
 		if op.version != 3:
 			raise LDAPProtocolError('Unsupported protocol version')
-		if isinstance(op.authentication, SimpleAuthentication):
-			self.bind_object = self.do_bind(op.name, op.authentication.password)
+		auth = op.authentication
+		# Resume ongoing SASL dialog
+		if self.bind_sasl_state and isinstance(auth, SaslCredentials) \
+				and auth.mechanism == self.bind_sasl_state[0]:
+			iterator = self.bind_sasl_state[1]
+			resp_code = LDAPResultCode.saslBindInProgress
+			try:
+				resp = iterator.send(auth.credentials)
+			except StopIteration as e:
+				resp_code = LDAPResultCode.success
+				self.bind_sasl_state = None
+				self.bind_object, resp = e.value
+			self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp)))
+			return
+		# If auth type or SASL method changed, abort SASL dialog
+		self.bind_sasl_state = None
+		if isinstance(auth, SimpleAuthentication):
+			self.bind_object = self.do_bind(op.name, auth.password)
+			self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
+		elif isinstance(auth, SaslCredentials):
+			ret = self.do_bind_sasl(auth.mechanism, auth.credentials)
+			if isinstance(ret, tuple):
+				self.bind_object, resp = ret
+				self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success, serverSaslCreds=resp)))
+				return
+			iterator = iter(ret)
+			resp_code = LDAPResultCode.saslBindInProgress
+			try:
+				resp = next(iterator)
+			except StopIteration as e:
+				resp_code = LDAPResultCode.success
+				self.bind_sasl_state = None
+				self.bind_object, resp = e.value
+			self.send_msg(LDAPMessage(req.messageID, BindResponse(resp_code, serverSaslCreds=resp)))
+			self.bind_sasl_state = (auth.mechanism, iterator)
 		else:
 			raise LDAPAuthMethodNotSupported()
-		self.send_msg(LDAPMessage(req.messageID, BindResponse(LDAPResultCode.success)))
 
 	def do_bind(self, user, password):
 		raise LDAPInvalidCredentials()
 
+	def do_bind_sasl(self, mechanism, credentials):
+		if mechanism == 'DIGEST-MD5':
+			return self.do_bind_sasl_digest_md5(mechanism, credentials)
+		raise LDAPAuthMethodNotSupported()
+
+	def do_bind_sasl_digest_md5_password(self, username, realm):
+		return None, 'foo'
+		# should return (bind_obj, password string) for username
+		raise LDAPAuthMethodNotSupported()
+
+	def do_bind_sasl_digest_md5_pwdigest(self, username, realm, charset):
+		# charset is either 'utf-8' or 'latin_1', it should only affect username and password
+		bind_obj, password = self.do_bind_sasl_digest_md5_password(username, realm)
+		ctx = hashlib.md5()
+		ctx.update(username.encode(charset) + b':' + realm + b':' + password.encode(charset))
+		return bind_obj, ctx.digest()
+
+	def do_bind_sasl_digest_md5(self, mechanism, credentials):
+		nonce = secrets.token_urlsafe(1024).encode()
+		challenge = b'nonce="%s",charset="utf-8",algorithm="md5-sess"'%(nonce)
+		resp = yield challenge
+		args = {key: value for key, value in sasl.parse_kwargs(resp)}
+		if args[b'nonce'] != nonce:
+			raise LDAPProtocolError()
+		try:
+			charset = 'utf-8' if args.get(b'charset', b'utf-8') == b'utf-8' else 'latin_1'
+			username = args[b'username'].decode(charset)
+			realm = args.get(b'realm', b'')
+			cnonce = args[b'cnonce']
+			nc = args.get(b'nc', b'00000001')
+			qop = args.get(b'qop', b'auth')
+			digest_uri = args[b'digest-uri']
+			response = args[b'response']
+		except KeyError:
+			raise LDAPProtocolError()
+		except UnicodeError:
+			raise LDAPProtocolError()
+		bind_obj, pwdigest = self.do_bind_sasl_digest_md5_pwdigest(username, realm, charset)
+		def md5digest(data):
+			ctx = hashlib.md5()
+			ctx.update(data)
+			return ctx.hexdigest().lower().encode()
+		a1 = pwdigest + b':' + nonce + b':' + cnonce
+		a2 = b'AUTHENTICATE:' + digest_uri
+		key = md5digest(a1)
+		data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2)
+		expected_response = md5digest(key + b':' + data)
+		if expected_response != response:
+			raise LDAPInvalidCredentials()
+		# We don't support subsequent authentication so according to RFC 2829 the
+		# serverSaslCreds field in our response should be absent and we should
+		# return (bind_obj, None). But this seems to confuse some clients (e.g.
+		# openldap's ldapsearch) so we return serverSaslCreds with rspauth instead.
+		a2 = b':' + digest_uri
+		data = nonce + b':' + nc + b':' + cnonce + b':' + qop + b':' + md5digest(a2)
+		response = b'rspauth=%s'%md5digest(key + b':' + data)
+		return bind_obj, response
+
 	def handle_search(self, req):
 		search = req.protocolOp
 		for dn, attributes in self.do_search(search.baseObject, search.scope, search.filter):
@@ -54,7 +166,7 @@ class LDAPRequestHandler(BaseRequestHandler):
 		self.send_msg(LDAPMessage(req.messageID, SearchResultDone(LDAPResultCode.success)))
 
 	def do_search(self, baseobj, scope, filter):
-		yield from []
+		return []
 
 	def handle_unbind(self, req):
 		self.keep_running = False
@@ -63,7 +175,13 @@ class LDAPRequestHandler(BaseRequestHandler):
 		op = req.protocolOp
 		if op.requestName == '1.3.6.1.4.1.1466.20037':
 			# StartTLS (RFC 4511)
-			raise LDAPProtocolError()
+			sent_response = False
+			for _ in self.do_starttls():
+				if not sent_response:
+					self.send_msg(LDAPMessage(req.messageID, ExtendedResponse(LDAPResultCode.success, responseName='1.3.6.1.4.1.1466.20037')))
+				sent_response = True
+			if not sent_response:
+				raise LDAPProtocolError()
 		elif op.requestName == '1.3.6.1.4.1.4203.1.11.1':
 			# Password Modify Extended Operation (RFC 3062)
 			newpw = None
@@ -80,11 +198,39 @@ class LDAPRequestHandler(BaseRequestHandler):
 		else:
 			raise LDAPProtocolError()
 
+	def do_starttls(self):
+		if self.ssl_context is None:
+			raise LDAPProtocolError()
+		yield None
+		try:
+			self.request = self.ssl_context.wrap_socket(self.request, server_side=True)
+		except Exception as e:
+			traceback.print_exc()
+			self.keep_running = False
+
 	def do_passwd(self, user=None, oldpasswd=None, newpasswd=None):
 		raise LDAPUnwillingToPerform('Password change is not supported')
 
-	def handle_message(self, data):
-		handlers = {
+	def handle_modify(self, req):
+		raise LDAPInsufficientAccessRights()
+
+	def handle_add(self, req):
+		raise LDAPInsufficientAccessRights()
+
+	def handle_delete(self, req):
+		raise LDAPInsufficientAccessRights()
+
+	def handle_modifydn(self, req):
+		raise LDAPInsufficientAccessRights()
+
+	def handle_compare(self, req):
+		raise LDAPInsufficientAccessRights()
+
+	def handle_abandon(self, req):
+		pass
+
+	def handle_message(self, shallowmsg):
+		msgtypes = {
 			BindRequest: (self.handle_bind, BindResponse),
 			UnbindRequest: (self.handle_unbind, None),
 			SearchRequest: (self.handle_search, SearchResultDone),
@@ -96,50 +242,49 @@ class LDAPRequestHandler(BaseRequestHandler):
 			AbandonRequest: (None, None),
 			ExtendedRequest: (self.handle_extended, ExtendedResponse),
 		}
-		shallowmsg, rest = ShallowLDAPMessage.from_ber(data)
-		if shallowmsg.protocolOp is None:
-			print('Ignoring unknown message')
-			return rest
-		func, errfunc = handlers[shallowmsg.protocolOp]
-		msg = None
-		try:
-			msg, _ = LDAPMessage.from_ber(data)
-		except Exception as e:
-			if errfunc:
-				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.protocolError)))
-			traceback.print_exc()
-			return rest
-		print('received', msg)
+
+		responses = {
+			BindRequest: BindResponse,
+			SearchRequest: SearchResultDone,
+			ModifyRequest: ModifyResponse,
+			AddRequest: AddResponse,
+			DelRequest: DelResponse,
+			ModifyDNRequest: ModifyDNResponse,
+			CompareRequest: CompareResponse,
+			ExtendedRequest: ExtendedResponse,
+		}
+		handler, response = msgtypes.get(shallowmsg.protocolOpType, (None, None))
 		try:
-			if func:
-				func(msg)
-			elif errfunc:
-				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.insufficientAccessRights)))
+			if handler is None:
+				raise LDAPProtocolError()
+			try:
+				msg = decode_msg(shallowmsg)
+			except ValueError:
+				raise LDAPProtocolError()
+			print('recved', msg)
+			handler(msg)
 		except LDAPError as e:
-			if errfunc:
-				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(e.code, diagnosticMessage=e.message)))
-			return rest
+			if response is not None:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, response(e.code, diagnosticMessage=e.message)))
 		except Exception as e:
-			if errfunc:
-				self.send_msg(LDAPMessage(shallowmsg.messageID, errfunc(LDAPResultCode.other)))
+			if response is not None:
+				self.send_msg(LDAPMessage(shallowmsg.messageID, response(LDAPResultCode.other)))
 			traceback.print_exc()
-			return rest
-		return rest
 
 	def send_msg(self, msg):
 		print('sending', msg)
 		self.request.sendall(LDAPMessage.to_ber(msg))
 
 	def handle(self):
-		self.keep_running = True
-		self.bind_object = None
-		data = b''
+		buf = b''
 		while self.keep_running:
 			try:
-				data = self.handle_message(data)
+				shallowmsg, buf = ShallowLDAPMessage.from_ber(buf)
+				self.handle_message(shallowmsg)
 			except IncompleteBERError:
 				chunk = self.request.recv(5)
 				if not chunk:
-					return
-				data += chunk
+					self.keep_running = False
+					return None
+				buf += chunk