Skip to content
Snippets Groups Projects
Select Git revision
  • 98a6ec9924191f63b95f949adfd28ec89f4f3cd3
  • master default protected
  • decorator-interface
  • v0.1.2 protected
  • v0.1.1 protected
  • v0.1.0 protected
  • v0.0.1.dev6 protected
  • v0.0.1.dev5 protected
  • v0.0.1.dev4 protected
  • v0.0.1.dev3 protected
  • v0.0.1.dev2 protected
  • v0.0.1.dev1 protected
  • v0.0.1.dev0 protected
13 results

test.py

Blame
  • test.py 9.12 KiB
    import socket
    from ldap3.protocol.rfc4511 import LDAPMessage, ResultCode, ProtocolOp, BindResponse, SearchResultDone, ModifyResponse, AddResponse, DelResponse, ModifyDNResponse, CompareResponse, ExtendedResponse, SearchResultEntry, PartialAttribute
    from ldap3.utils.asn1 import decoder, encode
    from ldap_decode import decode_message_fast, compute_ldap_message_size
    
    class LDAPObject:
    	@classmethod
    	def _ldap_filter(cls, operator, *args):
    		if operator == 'and':
    			return cls._ldap_filter_and(*args)
    		elif operator == 'or':
    			return cls._ldap_filter_or(*args)
    		elif operator == 'not':
    			return cls._ldap_filter_not(*args)
    		elif operator == 'equal':
    			return cls._ldap_filter_equal(*args)
    		elif operator == 'present':
    			return cls._ldap_filter_present(*args)
    		else:
    			return None
    
    	@classmethod
    	def _ldap_filter_and(cls, *components):
    		results = []
    		for component in components:
    			subres = cls._ldap_filter(*component)
    			if subres is True:
    				continue
    			elif subres is False:
    				return False
    			elif subres is None:
    				return None
    			else:
    				results.append(subres)
    		return cls.ldap_filter_and(*results)
    
    	@classmethod
    	def _ldap_filter_or(cls, *components):
    		results = []
    		for component in components:
    			subres = cls._ldap_filter(*component)
    			if subres is True:
    				return True
    			elif subres is False:
    				continue
    			elif subres is None:
    				return None
    			else:
    				results.append(subres)
    		return cls.ldap_filter_or(*results)
    
    	@classmethod
    	def _ldap_filter_not(cls, subfilter):
    		subres = cls._ldap_filter(*subfilter)
    		if subres is True:
    			return False
    		elif subres is False:
    			return True
    		elif subres is None:
    			return None
    		else:
    			return cls.ldap_filter_not(subres)
    
    	@classmethod
    	def _ldap_filter_present(cls, name):
    		return cls.ldap_filter_present(name.decode().lower())
    
    	@classmethod
    	def _ldap_filter_equal(cls, name, value):
    		return cls.ldap_filter_equal(name.decode().lower(), bytes(value))
    
    	@classmethod
    	def ldap_filter_present(cls, name):
    		# return True if always true, False if always false, None if undefined, custom object else
    		return None
    
    	@classmethod
    	def ldap_filter_equal(cls, name, value):
    		return None
    
    	@classmethod
    	def ldap_filter_and(cls, *results):
    		return None
    
    	@classmethod
    	def ldap_filter_or(cls, *results):
    		return None
    		
    	@classmethod
    	def ldap_filter_not(cls, result):
    		return None
    
    	@classmethod
    	def ldap_search(cls, dn_base, filter_res):
    		return []
    
    def split_dn(dn):
    	if not dn:
    		return []
    	return dn.split(',')
    
    class StaticLDAPObject(LDAPObject):
    	present_map = {} # name -> set of objs
    	value_map = {} # (name, value) -> set of objs
    	all_objects = set() # set of all objs
    	dn_map = {tuple(): set()} # (dn part tuples) -> set of objs
    
    	def __init__(self, dn, **kwargs):
    		self.ldap_dn = dn
    		cls = type(self)
    		cls.dn_map[tuple()].add(self)
    		dn_path = []
    		for part in reversed(split_dn(dn)):
    			dn_path.insert(0, part)
    			key = tuple(dn_path)
    			cls.dn_map[key] = cls.dn_map.get(key, set())
    			cls.dn_map[key].add(self)
    		cls.all_objects.add(self)
    		self.ldap_attributes = {}
    		for name, _values in kwargs.items():
    			name = name.lower()
    			if not isinstance(_values, list):
    				_values = [_values]
    			values = []
    			for value in _values:
    				if value is None or value == '' or value == b'':
    					continue
    				if isinstance(value, int):
    					value = str(value)
    				if isinstance(value, str):
    					value = value.encode()
    				values.append(value)
    			if not values:
    				continue
    			self.ldap_attributes[name] = values
    			cls.present_map[name] = cls.present_map.get(name, set())
    			cls.present_map[name].add(self)
    			for value in values:
    				key = (name, value)
    				cls.value_map[key] = cls.value_map.get(key, set())
    				cls.value_map[key].add(self)
    
    	@classmethod
    	def ldap_filter_present(cls, name):
    		return cls.present_map.get(name, set())
    
    	@classmethod
    	def ldap_filter_equal(cls, name, value):
    		return cls.value_map.get((name, value), set())
    
    	@classmethod
    	def ldap_filter_and(cls, *results):
    		objs = results[0]
    		for subres in results[1:]:
    			objs = objs.intersection(subres)
    		return objs
    
    	@classmethod
    	def ldap_filter_or(cls, *results):
    		objs = results[0]
    		for subres in results[1:]:
    			objs = objs.union(subres)
    		return objs
    		
    	@classmethod
    	def ldap_filter_not(cls, result):
    		return cls.all_objects.difference(result)
    
    	@classmethod
    	def ldap_search(cls, dn_base, filter_res):
    		key = tuple(split_dn(dn_base))
    		objs = cls.dn_map.get(key, set())
    		if filter_res is None or filter_res is False:
    			return []
    		if filter_res is not True:
    			objs = objs.intersection(filter_res)
    		return objs
    
    StaticLDAPObject('uid=testuser,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test User', mail='testuser@example.com', uid='testuser')
    StaticLDAPObject('uid=testadmin,ou=users,dc=example,dc=com', objectClass=['top', 'person'], givenName='Test Admin', mail='testadmin@example.com', uid='testadmin')
    
    class LDAPHandler:
    	def __init__(self, conn, models=[]):
    		self.models = models
    		self.conn = conn
    		self.user = None # None -> unbound, b'' -> anonymous bind, else -> self.user == dn
    		self.buf = b''
    
    	def handle_bind(self, msgid, op):
    		self.user = op['name']
    		resp = BindResponse()
    		resp['resultCode'] = ResultCode('success')
    		resp['matchedDN'] = op['name']
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'bindResponse', resp)
    
    	def handle_search(self, msgid, op):
    		objs = []
    		for model in self.models:
    			filter_obj = model._ldap_filter(*op['filter'])
    			objs += model.ldap_search(op['baseObject'].decode(), filter_obj)
    		for obj in objs:
    			resp = SearchResultEntry()
    			resp['object'] = obj.ldap_dn
    			for name, values in obj.ldap_attributes.items():
    				attr = PartialAttribute()
    				attr['type'] = name
    				for value in values:
    					attr['vals'].append(value)
    				resp['attributes'].append(attr)
    			self.send_response(msgid, 'searchResEntry', resp)
    		resp = SearchResultDone()
    		resp['resultCode'] = ResultCode('success')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'searchResDone', resp)
    
    	def handle_modify(self, msgid, op):
    		resp = ModifyResponse()
    		resp['resultCode'] = ResultCode('other')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'modifyResponse', resp)
    
    	def handle_add(self, msgid, op):
    		resp = AddResponse()
    		resp['resultCode'] = ResultCode('other')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'addResponse', resp)
    
    	def handle_del(self, msgid, op):
    		resp = DelResponse()
    		resp['resultCode'] = ResultCode('other')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'delResponse', resp)
    
    	def handle_moddn(self, msgid, op):
    		resp = ModifyDNResponse()
    		resp['resultCode'] = ResultCode('other')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'modDNResponse', resp)
    
    	def handle_compare(self, msgid, op):
    		resp = CompareResponse()
    		resp['resultCode'] = ResultCode('other')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'compareResponse', resp)
    
    	def handle_extended(self, msgid, op):
    		resp = ExtendedResponse()
    		resp['resultCode'] = ResultCode('protocolError')
    		resp['matchedDN'] = ''
    		resp['diagnosticMessage'] = ''
    		self.send_response(msgid, 'extendedResp', resp)
    
    	def recv_request(self):
    		while True:
    			size = compute_ldap_message_size(self.buf)
    			if size != -1 and len(self.buf) >= size:
    				break
    			chunk = self.conn.recv(5)
    			if not chunk:
    				return None
    			self.buf += chunk
    		print('received ', self.buf)
    		req = decode_message_fast(self.buf)
    		print(req)
    		self.buf = self.buf[size:]
    		return req
    
    	def send_response(self, message_id, response_type, response):
    		msg = LDAPMessage()
    		msg['messageID'] = message_id
    		msg['protocolOp'] = ProtocolOp().setComponentByName(response_type, response)
    		print('sent', msg)
    		self.conn.sendall(encode(msg))
    
    	def run(self):
    		while True:
    			msg = self.recv_request()
    			if msg is None:
    				break
    			if msg['protocolOp'] == 0: # bindRequest
    				self.handle_bind(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 2: # unbindRequest
    				break
    			elif msg['protocolOp'] == 3: # searchRequest
    				self.handle_search(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 6: # modifyRequest
    				self.handle_modify(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 8: # addRequest
    				self.handle_add(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 10: # delRequest
    				self.handle_del(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 12: # modDNRequest
    				self.handle_moddn(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 14: # compareRequest
    				self.handle_compare(msg['messageID'], msg['payload'])
    			elif msg['protocolOp'] == 16: # abandonRequest
    				pass
    			elif msg['protocolOp'] == 23: # extendedReq
    				self.handle_extended(msg['messageID'], msg['payload'])
    			else:
    				raise Exception()
    		self.conn.close()
    
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    sock.bind(('127.0.0.1', 1337))
    sock.listen(1)
    while True:
    	conn, addr = sock.accept()
    	print('accepted connection from', addr)
    	LDAPHandler(conn, models=[StaticLDAPObject]).run()
    	print('connection closed')