import socket from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute from ldap3.utils.asn1 import decoder, encode from ldap_decode import decode_message_fast, compute_ldap_message_size class LDAPObject: @classmethod def _ldap_filter(cls, operator, *args): if operator == 'and': return cls._ldap_filter_and(*args) elif operator == 'or': return cls._ldap_filter_or(*args) elif operator == 'not': return cls._ldap_filter_not(*args) elif operator == 'equal': return cls._ldap_filter_equal(*args) elif operator == 'present': return cls._ldap_filter_present(*args) else: return None @classmethod def _ldap_filter_and(cls, *components): results = [] for component in components: subres = cls._ldap_filter(*component) if subres is True: continue elif subres is False: return False elif subres is None: return None else: results.append(subres) return cls.ldap_filter_and(*results) @classmethod def _ldap_filter_or(cls, *components): results = [] for component in components: subres = cls._ldap_filter(*component) if subres is True: return True elif subres is False: continue elif subres is None: return None else: results.append(subres) return cls.ldap_filter_or(*results) @classmethod def _ldap_filter_not(cls, subfilter): subres = cls._ldap_filter(*subfilter) if subres is True: return False elif subres is False: return True elif subres is None: return None else: return cls.ldap_filter_not(subres) @classmethod def _ldap_filter_present(cls, name): return cls.ldap_filter_present(name.decode().lower()) @classmethod def _ldap_filter_equal(cls, name, value): return cls.ldap_filter_equal(name.decode().lower(), bytes(value)) @classmethod def ldap_filter_present(cls, name): # return True if always true, False if always false, None if undefined, custom object else return None @classmethod def ldap_filter_equal(cls, name, value): return None @classmethod def ldap_filter_and(cls, *results): return None @classmethod def ldap_filter_or(cls, *results): return None @classmethod def ldap_filter_not(cls, result): return None @classmethod def ldap_search(cls, dn_base, filter_res): return [] def split_dn(dn): if not dn: return [] return dn.split(',') class StaticLDAPObject(LDAPObject): present_map = {} # name -> set of objs value_map = {} # (name, value) -> set of objs all_objects = set() # set of all objs dn_map = {tuple(): set()} # (dn part tuples) -> set of objs def __init__(self, dn, **kwargs): self.ldap_dn = dn cls = type(self) cls.dn_map[tuple()].add(self) dn_path = [] for part in reversed(split_dn(dn)): dn_path.insert(0, part) key = tuple(dn_path) cls.dn_map[key] = cls.dn_map.get(key, set()) cls.dn_map[key].add(self) cls.all_objects.add(self) self.ldap_attributes = {} for name, _values in kwargs.items(): name = name.lower() if not isinstance(_values, list): _values = [_values] values = [] for value in _values: if value is None or value == '' or value == b'': continue if isinstance(value, int): value = str(value) if isinstance(value, str): value = value.encode() values.append(value) if not values: continue self.ldap_attributes[name] = values cls.present_map[name] = cls.present_map.get(name, set()) cls.present_map[name].add(self) for value in values: key = (name, value) cls.value_map[key] = cls.value_map.get(key, set()) cls.value_map[key].add(self) @classmethod def ldap_filter_present(cls, name): return cls.present_map.get(name, set()) @classmethod def ldap_filter_equal(cls, name, value): return cls.value_map.get((name, value), set()) @classmethod def ldap_filter_and(cls, *results): objs = results[0] for subres in results[1:]: objs = objs.intersection(subres) return objs @classmethod def ldap_filter_or(cls, *results): objs = results[0] for subres in results[1:]: objs = objs.union(subres) return objs @classmethod def ldap_filter_not(cls, result): return cls.all_objects.difference(result) @classmethod def ldap_search(cls, dn_base, filter_res): key = tuple(split_dn(dn_base)) objs = cls.dn_map.get(key, set()) if filter_res is None or filter_res is False: return [] if filter_res is not True: objs = objs.intersection(filter_res) return objs StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser') StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin') class LDAPHandler: def __init__(self, conn, models=[]): self.models = models self.conn = conn self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn self.buf = b'' def handle_bind(self, msgid, op): self.user = op['name'] resp = BindResponse() resp['resultCode'] = ResultCode('success') resp['matchedDN'] = op['name'] resp['diagnosticMessage'] = '' self.send_response(msgid, 'bindResponse', resp) def handle_search(self, msgid, op): objs = [] for model in self.models: filter_obj = model._ldap_filter(*op['filter']) objs += model.ldap_search(op['baseObject'].decode(), filter_obj) for obj in objs: resp = SearchResultEntry() resp['object'] = obj.ldap_dn for name, values in obj.ldap_attributes.items(): attr = PartialAttribute() attr['type'] = name for value in values: attr['vals'].append(value) resp['attributes'].append(attr) self.send_response(msgid, 'searchResEntry', resp) resp = SearchResultDone() resp['resultCode'] = ResultCode('success') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'searchResDone', resp) def handle_modify(self, msgid, op): resp = ModifyResponse() resp['resultCode'] = ResultCode('other') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'modifyResponse', resp) def handle_add(self, msgid, op): resp = AddResponse() resp['resultCode'] = ResultCode('other') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'addResponse', resp) def handle_del(self, msgid, op): resp = DelResponse() resp['resultCode'] = ResultCode('other') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'delResponse', resp) def handle_moddn(self, msgid, op): resp = ModifyDNResponse() resp['resultCode'] = ResultCode('other') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'modDNResponse', resp) def handle_compare(self, msgid, op): resp = CompareResponse() resp['resultCode'] = ResultCode('other') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'compareResponse', resp) def handle_extended(self, msgid, op): resp = ExtendedResponse() resp['resultCode'] = ResultCode('protocolError') resp['matchedDN'] = '' resp['diagnosticMessage'] = '' self.send_response(msgid, 'extendedResp', resp) def recv_request(self): while True: size = compute_ldap_message_size(self.buf) if size != -1 and len(self.buf) >= size: break chunk = self.conn.recv(5) if not chunk: return None self.buf += chunk print('received ', self.buf) req = decode_message_fast(self.buf) print(req) self.buf = self.buf[size:] return req def send_response(self, message_id, response_type, response): msg = LDAPMessage() msg['messageID'] = message_id msg['protocolOp'] = ProtocolOp().setComponentByName(response_type, response) print('sent', msg) self.conn.sendall(encode(msg)) def run(self): while True: msg = self.recv_request() if msg is None: break if msg['protocolOp'] == 0: # bindRequest self.handle_bind(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 2: # unbindRequest break elif msg['protocolOp'] == 3: # searchRequest self.handle_search(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 6: # modifyRequest self.handle_modify(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 8: # addRequest self.handle_add(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 10: # delRequest self.handle_del(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 12: # modDNRequest self.handle_moddn(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 14: # compareRequest self.handle_compare(msg['messageID'], msg['payload']) elif msg['protocolOp'] == 16: # abandonRequest pass elif msg['protocolOp'] == 23: # extendedReq self.handle_extended(msg['messageID'], msg['payload']) else: raise Exception() self.conn.close() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.bind(('127.0.0.1', 1337)) sock.listen(1) while True: conn, addr = sock.accept() print('accepted connection from', addr) LDAPHandler(conn, models=[StaticLDAPObject]).run() print('connection closed')