Newer
Older
import time
import unittest
from flask import url_for, request
from uffd.models import DeviceLoginConfirmation, Service, OAuth2Client, OAuth2DeviceLoginInitiation, User, RecoveryCodeMethod, TOTPMethod
from uffd.models.mfa import _hotp
class TestSession(UffdTestCase):
def setUpApp(self):
self.app.config['SESSION_LIFETIME_SECONDS'] = 2
@self.app.route('/test_login_required')
@login_required()
def test_login_required():
@self.app.route('/test_group_required1')
@login_required(lambda: request.user.is_in_group('users'))
def test_group_required1():
return 'SUCCESS', 200
@self.app.route('/test_group_required2')
@login_required(lambda: request.user.is_in_group('notagroup'))
def test_group_required2():
return 'SUCCESS', 200
def setUp(self):
super().setUp()
self.assertIsNone(request.user)
def login(self):
self.assertIsNotNone(request.user)
self.assertEqual(self.client.get(path=url_for('test_login_required'), follow_redirects=True).data, b'SUCCESS testuser')
self.assertNotIn(b'SUCCESS', self.client.get(path=url_for('test_login_required'), follow_redirects=True).data)
def test_login(self):
r = self.client.get(path=url_for('session.login'), follow_redirects=True)
dump('login', r)
self.assertEqual(r.status_code, 200)
dump('login_post', r)
self.assertEqual(r.status_code, 200)
def test_login_password_rehash(self):
self.get_user().password = PlaintextPasswordHash.from_password('userpassword')
db.session.commit()
self.assertIsInstance(self.get_user().password, PlaintextPasswordHash)
db_flush()
r = self.login_as('user')
self.assertEqual(r.status_code, 200)
self.assertLoggedIn()
self.assertIsInstance(self.get_user().password, User.password.method_cls)
self.assertTrue(self.get_user().password.verify('userpassword'))
def test_titlecase_password(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': self.get_user().loginname.title(), 'password': 'userpassword'}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertLoggedIn()
def test_redirect(self):
r = self.login_as('user', ref=url_for('test_login_required'))
self.assertEqual(r.status_code, 200)
def test_wrong_password(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': self.get_user().loginname, 'password': 'wrongpassword'},
dump('login_wrong_password', r)
self.assertEqual(r.status_code, 200)
def test_empty_password(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': self.get_user().loginname, 'password': ''}, follow_redirects=True)
dump('login_empty_password', r)
self.assertEqual(r.status_code, 200)
# Regression test for #100 (uncatched LDAPSASLPrepError)
def test_saslprep_invalid_password(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'testuser', 'password': 'wrongpassword\n'}, follow_redirects=True)
dump('login_saslprep_invalid_password', r)
self.assertEqual(r.status_code, 200)
self.assertLoggedOut()
def test_wrong_user(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'nouser', 'password': 'userpassword'},
dump('login_wrong_user', r)
self.assertEqual(r.status_code, 200)
def test_empty_user(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': '', 'password': 'userpassword'}, follow_redirects=True)
dump('login_empty_user', r)
self.assertEqual(r.status_code, 200)
def test_no_access(self):
r = self.client.post(path=url_for('session.login'),
data={'loginname': 'testservice', 'password': 'servicepassword'}, follow_redirects=True)
dump('login_no_access', r)
self.assertEqual(r.status_code, 200)
def test_deactivated(self):
self.get_user().is_deactivated = True
db.session.commit()
r = self.login_as('user')
dump('login_deactivated', r)
self.assertEqual(r.status_code, 200)
self.assertLoggedOut()
def test_deactivated_after_login(self):
self.login_as('user')
self.get_user().is_deactivated = True
db.session.commit()
self.assertLoggedOut()
def test_group_required(self):
self.login()
self.assertEqual(self.client.get(path=url_for('test_group_required1'),
self.assertNotEqual(self.client.get(path=url_for('test_group_required2'),
def test_logout(self):
self.login()
r = self.client.get(path=url_for('session.logout'), follow_redirects=True)
dump('logout', r)
self.assertEqual(r.status_code, 200)
def test_timeout(self):
self.login()
time.sleep(3)
def test_ratelimit(self):
for i in range(20):
self.client.post(path=url_for('session.login'),
data={'loginname': self.get_user().loginname,
'password': 'wrongpassword_%i'%i}, follow_redirects=True)
r = self.login_as('user')
dump('login_ratelimit', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
oauth2_client = OAuth2Client(service=Service(name='test', limit_access=False), client_id='test', client_secret='testsecret', redirect_uris=['http://localhost:5009/callback', 'http://localhost:5009/callback2'])
initiation = OAuth2DeviceLoginInitiation(client=oauth2_client)
db.session.add(initiation)
db.session.commit()
code = initiation.code
self.login()
r = self.client.get(path=url_for('session.deviceauth'), follow_redirects=True)
dump('deviceauth', r)
self.assertEqual(r.status_code, 200)
r = self.client.get(path=url_for('session.deviceauth', **{'initiation-code': code}), follow_redirects=True)
dump('deviceauth_check', r)
self.assertEqual(r.status_code, 200)
self.assertIn(b'test', r.data)
r = self.client.post(path=url_for('session.deviceauth_submit'), data={'initiation-code': code}, follow_redirects=True)
dump('deviceauth_submit', r)
self.assertEqual(r.status_code, 200)
initiation = OAuth2DeviceLoginInitiation.query.filter_by(code=code).one()
self.assertEqual(len(initiation.confirmations), 1)
self.assertEqual(initiation.confirmations[0].session.user.loginname, 'testuser')
self.assertIn(initiation.confirmations[0].code.encode(), r.data)
r = self.client.get(path=url_for('session.deviceauth_finish'), follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertEqual(DeviceLoginConfirmation.query.all(), [])
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class TestMfaViews(UffdTestCase):
def add_recovery_codes(self, count=10):
user = self.get_user()
for _ in range(count):
db.session.add(RecoveryCodeMethod(user=user))
db.session.commit()
def add_totp(self):
db.session.add(TOTPMethod(user=self.get_user(), name='My phone'))
db.session.commit()
def test_auth_integration(self):
self.add_recovery_codes()
self.add_totp()
db.session.commit()
self.assertIsNone(request.user)
r = self.login_as('user')
dump('mfa_auth_redirected', r)
self.assertEqual(r.status_code, 200)
self.assertIn(b'/mfa/auth', r.data)
self.assertIsNone(request.user)
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
dump('mfa_auth', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
def test_auth_disabled(self):
self.assertIsNone(request.user)
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth', ref='/redirecttarget'), follow_redirects=False)
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertIsNotNone(request.user)
def test_auth_recovery_only(self):
self.add_recovery_codes()
self.assertIsNone(request.user)
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth', ref='/redirecttarget'), follow_redirects=False)
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertIsNotNone(request.user)
def test_auth_recovery_code(self):
self.add_recovery_codes()
self.add_totp()
method = RecoveryCodeMethod(user=self.get_user())
db.session.add(method)
db.session.commit()
method_id = method.id
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
dump('mfa_auth_recovery_code', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
r = self.client.post(path=url_for('session.mfa_auth_finish', ref='/redirecttarget'), data={'code': method.code_value})
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
self.assertEqual(r.status_code, 302)
self.assertTrue(r.location.endswith('/redirecttarget'))
self.assertIsNotNone(request.user)
self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=method_id).all()), 0)
def test_auth_totp_code(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
dump('mfa_auth_totp_code', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
code = _hotp(int(time.time()/30), raw_key)
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_totp_code_submit', r)
self.assertEqual(r.status_code, 200)
self.assertIsNotNone(request.user)
def test_auth_totp_code_reuse(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
code = _hotp(int(time.time()/30), raw_key)
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': code}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertIsNotNone(request.user)
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': code}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
def test_auth_empty_code(self):
self.add_recovery_codes()
self.add_totp()
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': ''}, follow_redirects=True)
dump('mfa_auth_empty_code', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
def test_auth_invalid_code(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login_as('user')
r = self.client.get(path=url_for('session.mfa_auth'), follow_redirects=False)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
code = _hotp(int(time.time()/30), raw_key)
code = str(int(code[0])+1)[-1] + code[1:]
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_invalid_code', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
def test_auth_ratelimit(self):
self.add_recovery_codes()
self.add_totp()
method = TOTPMethod(user=self.get_user(), name='testname')
raw_key = method.raw_key
db.session.add(method)
db.session.commit()
self.login_as('user')
self.assertIsNone(request.user)
code = _hotp(int(time.time()/30), raw_key)
inv_code = str(int(code[0])+1)[-1] + code[1:]
for i in range(20):
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': inv_code}, follow_redirects=True)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
r = self.client.post(path=url_for('session.mfa_auth_finish'), data={'code': code}, follow_redirects=True)
dump('mfa_auth_ratelimit', r)
self.assertEqual(r.status_code, 200)
self.assertIsNone(request.user)
# TODO: webauthn auth tests