Skip to content
Snippets Groups Projects
test_ratelimit.py 2.04 KiB
Newer Older
  • Learn to ignore specific revisions
  • Julian's avatar
    Julian committed
    import time
    
    from flask import Flask, Blueprint, session, url_for
    
    from uffd.ratelimit import get_addrkey, format_delay, Ratelimit, RatelimitEvent
    
    from utils import UffdTestCase
    
    class TestRatelimit(UffdTestCase):
    	def test_limiting(self):
    		cases = [
    			(1*60, 3),
    			(1*60*60, 3),
    			(1*60*60, 25),
    		]
    		for index, case in enumerate(cases):
    			interval, limit = case
    			key = str(index)
    			ratelimit = Ratelimit('test', interval, limit)
    			for i in range(limit):
    				ratelimit.log(key)
    			self.assertLessEqual(ratelimit.get_delay(key), interval)
    			ratelimit.log(key)
    			self.assertGreater(ratelimit.get_delay(key), interval)
    
    	def test_addrkey(self):
    		self.assertEqual(get_addrkey('192.168.0.1'), get_addrkey('192.168.0.99'))
    		self.assertNotEqual(get_addrkey('192.168.0.1'), get_addrkey('192.168.1.1'))
    		self.assertEqual(get_addrkey('fdee:707a:f38a:c369::'), get_addrkey('fdee:707a:f38a:ffff::'))
    		self.assertNotEqual(get_addrkey('fdee:707a:f38a:c369::'), get_addrkey('fdee:707a:f38b:c369::'))
    		cases = [
    			'',
    			'192.168.0.',
    			':',
    			'::',
    			'192.168.0.1/24',
    			'192.168.0.1/24',
    			'host.example.com',
    		]
    		for case in cases:
    			self.assertIsInstance(get_addrkey(case), str)
    
    	def test_format_delay(self):
    		self.assertIsInstance(format_delay(0), str)
    		self.assertIsInstance(format_delay(1), str)
    		self.assertIsInstance(format_delay(30), str)
    		self.assertIsInstance(format_delay(60), str)
    		self.assertIsInstance(format_delay(120), str)
    		self.assertIsInstance(format_delay(3600), str)
    		self.assertIsInstance(format_delay(4000), str)
    
    	def test_cleanup(self):
    		ratelimit = Ratelimit('test', 1, 1)
    		ratelimit.log('')
    		ratelimit.log('1')
    		ratelimit.log('2')
    		ratelimit.log('3')
    		ratelimit.log('4')
    		time.sleep(1)
    		ratelimit.log('5')
    		self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 6)
    		ratelimit.cleanup()
    		self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 1)
    		time.sleep(1)
    		ratelimit.cleanup()
    		self.assertEqual(RatelimitEvent.query.filter(RatelimitEvent.name == 'test').count(), 0)