Select Git revision
cccv-archive-key.gpg
Forked from
uffd / uffd
Source project has a limited visibility.
test.py 9.12 KiB
import socket
from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute
from ldap3.utils.asn1 import decoder, encode
from ldap_decode import decode_message_fast, compute_ldap_message_size
class LDAPObject:
@classmethod
def _ldap_filter(cls, operator, *args):
if operator == 'and':
return cls._ldap_filter_and(*args)
elif operator == 'or':
return cls._ldap_filter_or(*args)
elif operator == 'not':
return cls._ldap_filter_not(*args)
elif operator == 'equal':
return cls._ldap_filter_equal(*args)
elif operator == 'present':
return cls._ldap_filter_present(*args)
else:
return None
@classmethod
def _ldap_filter_and(cls, *components):
results = []
for component in components:
subres = cls._ldap_filter(*component)
if subres is True:
continue
elif subres is False:
return False
elif subres is None:
return None
else:
results.append(subres)
return cls.ldap_filter_and(*results)
@classmethod
def _ldap_filter_or(cls, *components):
results = []
for component in components:
subres = cls._ldap_filter(*component)
if subres is True:
return True
elif subres is False:
continue
elif subres is None:
return None
else:
results.append(subres)
return cls.ldap_filter_or(*results)
@classmethod
def _ldap_filter_not(cls, subfilter):
subres = cls._ldap_filter(*subfilter)
if subres is True:
return False
elif subres is False:
return True
elif subres is None:
return None
else:
return cls.ldap_filter_not(subres)
@classmethod
def _ldap_filter_present(cls, name):
return cls.ldap_filter_present(name.decode().lower())
@classmethod
def _ldap_filter_equal(cls, name, value):
return cls.ldap_filter_equal(name.decode().lower(), bytes(value))
@classmethod
def ldap_filter_present(cls, name):
# return True if always true, False if always false, None if undefined, custom object else
return None
@classmethod
def ldap_filter_equal(cls, name, value):
return None
@classmethod
def ldap_filter_and(cls, *results):
return None
@classmethod
def ldap_filter_or(cls, *results):
return None
@classmethod
def ldap_filter_not(cls, result):
return None
@classmethod
def ldap_search(cls, dn_base, filter_res):
return []
def split_dn(dn):
if not dn:
return []
return dn.split(',')
class StaticLDAPObject(LDAPObject):
present_map = {} # name -> set of objs
value_map = {} # (name, value) -> set of objs
all_objects = set() # set of all objs
dn_map = {tuple(): set()} # (dn part tuples) -> set of objs
def __init__(self, dn, **kwargs):
self.ldap_dn = dn
cls = type(self)
cls.dn_map[tuple()].add(self)
dn_path = []
for part in reversed(split_dn(dn)):
dn_path.insert(0, part)
key = tuple(dn_path)
cls.dn_map[key] = cls.dn_map.get(key, set())
cls.dn_map[key].add(self)
cls.all_objects.add(self)
self.ldap_attributes = {}
for name, _values in kwargs.items():
name = name.lower()
if not isinstance(_values, list):
_values = [_values]
values = []
for value in _values:
if value is None or value == '' or value == b'':
continue
if isinstance(value, int):
value = str(value)
if isinstance(value, str):
value = value.encode()
values.append(value)
if not values:
continue
self.ldap_attributes[name] = values
cls.present_map[name] = cls.present_map.get(name, set())
cls.present_map[name].add(self)
for value in values:
key = (name, value)
cls.value_map[key] = cls.value_map.get(key, set())
cls.value_map[key].add(self)
@classmethod
def ldap_filter_present(cls, name):
return cls.present_map.get(name, set())
@classmethod
def ldap_filter_equal(cls, name, value):
return cls.value_map.get((name, value), set())
@classmethod
def ldap_filter_and(cls, *results):
objs = results[0]
for subres in results[1:]:
objs = objs.intersection(subres)
return objs
@classmethod
def ldap_filter_or(cls, *results):
objs = results[0]
for subres in results[1:]:
objs = objs.union(subres)
return objs
@classmethod
def ldap_filter_not(cls, result):
return cls.all_objects.difference(result)
@classmethod
def ldap_search(cls, dn_base, filter_res):
key = tuple(split_dn(dn_base))
objs = cls.dn_map.get(key, set())
if filter_res is None or filter_res is False:
return []
if filter_res is not True:
objs = objs.intersection(filter_res)
return objs
StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser')
StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin')
class LDAPHandler:
def __init__(self, conn, models=[]):
self.models = models
self.conn = conn
self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn
self.buf = b''
def handle_bind(self, msgid, op):
self.user = op['name']
resp = BindResponse()
resp['resultCode'] = ResultCode('success')
resp['matchedDN'] = op['name']
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'bindResponse', resp)
def handle_search(self, msgid, op):
objs = []
for model in self.models:
filter_obj = model._ldap_filter(*op['filter'])
objs += model.ldap_search(op['baseObject'].decode(), filter_obj)
for obj in objs:
resp = SearchResultEntry()
resp['object'] = obj.ldap_dn
for name, values in obj.ldap_attributes.items():
attr = PartialAttribute()
attr['type'] = name
for value in values:
attr['vals'].append(value)
resp['attributes'].append(attr)
self.send_response(msgid, 'searchResEntry', resp)
resp = SearchResultDone()
resp['resultCode'] = ResultCode('success')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'searchResDone', resp)
def handle_modify(self, msgid, op):
resp = ModifyResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'modifyResponse', resp)
def handle_add(self, msgid, op):
resp = AddResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'addResponse', resp)
def handle_del(self, msgid, op):
resp = DelResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'delResponse', resp)
def handle_moddn(self, msgid, op):
resp = ModifyDNResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'modDNResponse', resp)
def handle_compare(self, msgid, op):
resp = CompareResponse()
resp['resultCode'] = ResultCode('other')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'compareResponse', resp)
def handle_extended(self, msgid, op):
resp = ExtendedResponse()
resp['resultCode'] = ResultCode('protocolError')
resp['matchedDN'] = ''
resp['diagnosticMessage'] = ''
self.send_response(msgid, 'extendedResp', resp)
def recv_request(self):
while True:
size = compute_ldap_message_size(self.buf)
if size != -1 and len(self.buf) >= size:
break
chunk = self.conn.recv(5)
if not chunk:
return None
self.buf += chunk
print('received ', self.buf)
req = decode_message_fast(self.buf)
print(req)
self.buf = self.buf[size:]
return req
def send_response(self, message_id, response_type, response):
msg = LDAPMessage()
msg['messageID'] = message_id
msg['protocolOp'] = ProtocolOp().setComponentByName(response_type, response)
print('sent', msg)
self.conn.sendall(encode(msg))
def run(self):
while True:
msg = self.recv_request()
if msg is None:
break
if msg['protocolOp'] == 0: # bindRequest
self.handle_bind(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 2: # unbindRequest
break
elif msg['protocolOp'] == 3: # searchRequest
self.handle_search(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 6: # modifyRequest
self.handle_modify(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 8: # addRequest
self.handle_add(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 10: # delRequest
self.handle_del(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 12: # modDNRequest
self.handle_moddn(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 14: # compareRequest
self.handle_compare(msg['messageID'], msg['payload'])
elif msg['protocolOp'] == 16: # abandonRequest
pass
elif msg['protocolOp'] == 23: # extendedReq
self.handle_extended(msg['messageID'], msg['payload'])
else:
raise Exception()
self.conn.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.bind(('127.0.0.1', 1337))
sock.listen(1)
while True:
conn, addr = sock.accept()
print('accepted connection from', addr)
LDAPHandler(conn, models=[StaticLDAPObject]).run()
print('connection closed')