Select Git revision
-
Julian Rother authoredJulian Rother authored
test.py 9.12 KiB
import socket
from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute
from ldap3.utils.asn1 import decoder, encode
from ldap_decode import decode_message_fast, compute_ldap_message_size
class LDAPObject:
@classmethod
def _ldap_filter(cls, operator, *args):
if operator == 'and':
return cls._ldap_filter_and(*args)
elif operator == 'or':
return cls._ldap_filter_or(*args)
elif operator == 'not':
return cls._ldap_filter_not(*args)
elif operator == 'equal':
return cls._ldap_filter_equal(*args)
elif operator == 'present':
return cls._ldap_filter_present(*args)
else:
return None
@classmethod
def _ldap_filter_and(cls, *components):
results = []
for component in components:
subres = cls._ldap_filter(*component)
if subres is True:
continue
elif subres is False:
return False
elif subres is None:
return None
else:
results.append(subres)
return cls.ldap_filter_and(*results)
@classmethod
def _ldap_filter_or(cls, *components):
results = []
for component in components:
subres = cls._ldap_filter(*component)
if subres is True:
return True
elif subres is False:
continue
elif subres is None:
return None
else:
results.append(subres)
return cls.ldap_filter_or(*results)
@classmethod
def _ldap_filter_not(cls, subfilter):
subres = cls._ldap_filter(*subfilter)
if subres is True:
return False
elif subres is False:
return True
elif subres is None:
return None
else:
return cls.ldap_filter_not(subres)
@classmethod
def _ldap_filter_present(cls, name):
return cls.ldap_filter_present(name.decode().lower())
@classmethod
def _ldap_filter_equal(cls, name, value):
return cls.ldap_filter_equal(name.decode().lower(), bytes(value))
@classmethod
def ldap_filter_present(cls, name):
# return True if always true, False if always false, None if undefined, custom object else
return None
@classmethod
def ldap_filter_equal(cls, name, value):
return None
@classmethod
def ldap_filter_and(cls, *results):
return None
@classmethod
def ldap_filter_or(cls, *results):
return None
@classmethod
def ldap_filter_not(cls, result):
return None
@classmethod
def ldap_search(cls, dn_base, filter_res):
return []
def split_dn(dn):
if not dn:
return []
return dn.split(',')
class StaticLDAPObject(LDAPObject):
present_map = {} # name -> set of objs
value_map = {} # (name, value) -> set of objs
all_objects = set() # set of all objs
dn_map = {tuple(): set()} # (dn part tuples) -> set of objs
def __init__(self, dn, **kwargs):
self.ldap_dn = dn
cls = type(self)
cls.dn_map[tuple()].add(self)
dn_path = []
for part in reversed(split_dn(dn)):
dn_path.insert(0, part)
key = tuple(dn_path)
cls.dn_map[key] = cls.dn_map.get(key, set())
cls.dn_map[key].add(self)
cls.all_objects.add(self)
self.ldap_attributes = {}
for name, _values in kwargs.items():
name = name.lower()
if not isinstance(_values, list):
_values = [_values]
values = []
for value in _values:
if value is None or value == '' or value == b'':
continue
if isinstance(value, int):
value = str(value)
if isinstance(value, str):
value = value.encode()
values.append(value)
if not values:
continue
self.ldap_attributes[name] = values
cls.present_map[name] = cls.present_map.get(name, set())
cls.present_map[name].add(self)
for value in values:
key = (name, value)
cls.value_map[key] = cls.value_map.get(key, set())
cls.value_map[key].add(self)
@classmethod
def ldap_filter_present(cls, name):
return cls.present_map.get(name, set())
@classmethod
def ldap_filter_equal(cls, name, value):
return cls.value_map.get((name, value), set())
@classmethod
def ldap_filter_and(cls, *results):
objs = results[0]
for subres in results[1:]:
objs = objs.intersection(subres)
return objs
@classmethod
def ldap_filter_or(cls, *results):
objs = results[0]
for subres in results[1:]:
objs = objs.union(subres)
return objs
@classmethod
def ldap_filter_not(cls, result):
return cls.all_objects.difference(result)
@classmethod
def ldap_search(cls, dn_base, filter_res):
key = tuple(split_dn(dn_base))
objs = cls.dn_map.get(key, set())
if filter_res is None or filter_res is False:
return []
if filter_res is not True:
objs = objs.intersection(filter_res)
return objs
StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser')
StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin')
class LDAPHandler:
def __init__(self, conn, models=[]):
self.models = models
self.conn = conn
self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn
self.buf = b''
def handle_bind(self, msgid, op):
self.user = op['name']
resp = BindResponse()
resp['resultCode'] = ResultCode('success')
resp['matchedDN'] = op['name']
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'bindResponse', resp)
def handle_search(self, msgid, op):
objs = []
for model in self.models:
filter_obj = model._ldap_filter(*op['filter'])
objs += model.ldap_search(op['baseObject'].decode(), filter_obj)
for obj in objs:
resp = SearchResultEntry()
resp['object'] = obj.ldap_dn
for name, values in obj.ldap_attributes.items():
attr = PartialAttribute()
attr['type'] = name
for value in values:
attr['vals'].append(value)
resp['attributes'].append(attr)
self.send_response(msgid, 'searchResEntry', resp)
resp = SearchResultDone()
resp['resultCode'] = ResultCode('success')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'searchResDone', resp)
def handle_modify(self, msgid, op):
resp = ModifyResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'modifyResponse', resp)
def handle_add(self, msgid, op):
resp = AddResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'addResponse', resp)
def handle_del(self, msgid, op):
resp = DelResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'delResponse', resp)
def handle_moddn(self, msgid, op):
resp = ModifyDNResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'modDNResponse', resp)
def handle_compare(self, msgid, op):
resp = CompareResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'compareResponse', resp)
def handle_extended(self, msgid, op):
resp = ExtendedResponse()
resp['resultCode'] = ResultCode('protocolError')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'extendedResp', resp)
def recv_request(self):
while True:
size = compute_ldap_message_size(self.buf)
if size != -1 and len(self.buf) >= size:
break
chunk = self.conn.recv(5)
if not chunk:
return None
self.buf += chunk
print('received ', self.buf)
req = decode_message_fast(self.buf)
print(req)
self.buf = self.buf[size:]
return req
def send_response(self, message_id, response_type, response):
msg = LDAPMessage()
msg['messageID'] = message_id
msg['protocolOp'] = ProtocolOp().setComponentByName(response_type, response)
print('sent', msg)
self.conn.sendall(encode(msg))
def run(self):
while True:
msg = self.recv_request()
if msg is None:
break
if msg['protocolOp'] == 0: # bindRequest
self.handle_bind(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 2: # unbindRequest
break
elif msg['protocolOp'] == 3: # searchRequest
self.handle_search(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 6: # modifyRequest
self.handle_modify(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 8: # addRequest
self.handle_add(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 10: # delRequest
self.handle_del(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 12: # modDNRequest
self.handle_moddn(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 14: # compareRequest
self.handle_compare(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 16: # abandonRequest
pass
elif msg['protocolOp'] == 23: # extendedReq
self.handle_extended(msg['messageID'], msg['payload'])
else:
raise Exception()
self.conn.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.bind(('127.0.0.1', 1337))
sock.listen(1)
while True:
conn, addr = sock.accept()
print('accepted connection from', addr)
LDAPHandler(conn, models=[StaticLDAPObject]).run()
print('connection closed')