import unittest
import datetime
import time

from flask import url_for, session, request

# These imports are required, because otherwise we get circular imports?!
from uffd import user

from uffd.user.models import User
from uffd.role.models import Role, RoleGroup
from uffd.mfa.models import MFAMethod, MFAType, RecoveryCodeMethod, TOTPMethod, WebauthnMethod, _hotp
from uffd import create_app, db

from utils import dump, UffdTestCase, db_flush

class TestMfaPrimitives(unittest.TestCase):
	def test_hotp(self):
		self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q'), '458290')
		self.assertEqual(_hotp(5555555, b'\xae\xa3T\x05\x89\xd6\xb76\xf61r\x92\xcc\xb5WZ\xe6)\x05q', digits=8), '20458290')
		for digits in range(1, 10):
			self.assertEqual(len(_hotp(1, b'abcd', digits=digits)), digits)
		self.assertEqual(_hotp(1234, b''), '161024')
		self.assertEqual(_hotp(0, b'\x04\x8fM\xcc\x7f\x82\x9c$a\x1b\xb3'), '279354')
		self.assertEqual(_hotp(2**64-1, b'abcde'), '899292')

def get_fido2_test_cred(self):
	try:
		from fido2.ctap2 import AttestedCredentialData
	except ImportError:
		self.skipTest('fido2 could not be imported')
	# Example public key from webauthn spec 6.5.1.1
	return AttestedCredentialData(bytes.fromhex('00000000000000000000000000000000'+'0040'+'053cbcc9d37a61d3bac87cdcc77ee326256def08ab15775d3a720332e4101d14fae95aeee3bc9698781812e143c0597dc6e180595683d501891e9dd030454c0a'+'A501020326200121582065eda5a12577c2bae829437fe338701a10aaa375e1bb5b5de108de439c08551d2258201e52ed75701163f7f9e40ddf9f341b3dc9ba860af7e0ca7ca7e9eecd0084d19c'))

class TestMfaMethodModels(UffdTestCase):
	def test_common_attributes(self):
		method = TOTPMethod(user=self.get_user(), name='testname')
		self.assertTrue(method.created <= datetime.datetime.now())
		self.assertEqual(method.name, 'testname')
		self.assertEqual(method.user.loginname, 'testuser')
		method.user = self.get_admin()
		self.assertEqual(method.user.loginname, 'testadmin')

	def test_recovery_code_method(self):
		method = RecoveryCodeMethod(user=self.get_user())
		db.session.add(method)
		db.session.commit()
		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
		_method = RecoveryCodeMethod.query.get(method.id)
		self.assertFalse(hasattr(_method, 'code'))
		self.assertFalse(_method.verify(''))
		self.assertFalse(_method.verify('A'*8))
		self.assertTrue(_method.verify(method.code))

	def test_totp_method_attributes(self):
		method = TOTPMethod(user=self.get_user(), name='testname')
		raw_key = method.raw_key
		issuer = method.issuer
		accountname = method.accountname
		key_uri = method.key_uri
		self.assertEqual(method.name, 'testname')
		# Restore method with key parameter
		_method = TOTPMethod(user=self.get_user(), key=method.key, name='testname')
		self.assertEqual(_method.name, 'testname')
		self.assertEqual(_method.raw_key, raw_key)
		self.assertEqual(_method.issuer, issuer)
		self.assertEqual(_method.accountname, accountname)
		self.assertEqual(_method.key_uri, key_uri)
		db.session.add(method)
		db.session.commit()
		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
		# Restore method from db
		_method = TOTPMethod.query.get(method.id)
		self.assertEqual(_method.name, 'testname')
		self.assertEqual(_method.raw_key, raw_key)
		self.assertEqual(_method.issuer, issuer)
		self.assertEqual(_method.accountname, accountname)
		self.assertEqual(_method.key_uri, key_uri)

	def test_totp_method_verify(self):
		method = TOTPMethod(user=self.get_user())
		counter = int(time.time()/30)
		self.assertFalse(method.verify(''))
		self.assertFalse(method.verify(_hotp(counter-2, method.raw_key)))
		self.assertTrue(method.verify(_hotp(counter, method.raw_key)))
		self.assertFalse(method.verify(_hotp(counter+2, method.raw_key)))

	def test_webauthn_method(self):
		data = get_fido2_test_cred(self)
		method = WebauthnMethod(user=self.get_user(), cred=data, name='testname')
		self.assertEqual(method.name, 'testname')
		db.session.add(method)
		db.session.commit()
		db.session = db.create_scoped_session() # Ensure the next query does not return the cached method object
		_method = WebauthnMethod.query.get(method.id)
		self.assertEqual(_method.name, 'testname')
		self.assertEqual(bytes(method.cred), bytes(_method.cred))
		self.assertEqual(data.credential_id, _method.cred.credential_id)
		self.assertEqual(data.public_key, _method.cred.public_key)
		# We only test (de-)serialization here, as everything else is currently implemented in the views

class TestMfaViews(UffdTestCase):
	def setUp(self):
		super().setUp()
		db.session.add(RecoveryCodeMethod(user=self.get_admin()))
		db.session.add(TOTPMethod(user=self.get_admin(), name='Admin Phone'))
		# We don't want to skip all tests only because fido2 is not installed!
		#db.session.add(WebauthnMethod(user=get_testadmin(), cred=get_fido2_test_cred(self), name='Admin FIDO2 dongle'))
		db.session.commit()

	def add_recovery_codes(self, count=10):
		user = self.get_user()
		for _ in range(count):
			db.session.add(RecoveryCodeMethod(user=user))
		db.session.commit()

	def add_totp(self):
		db.session.add(TOTPMethod(user=self.get_user(), name='My phone'))
		db.session.commit()

	def add_webauthn(self):
		db.session.add(WebauthnMethod(user=self.get_user(), cred=get_fido2_test_cred(self), name='My FIDO2 dongle'))
		db.session.commit()

	def test_setup_disabled(self):
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
		dump('mfa_setup_disabled', r)
		self.assertEqual(r.status_code, 200)

	def test_setup_recovery_codes(self):
		self.login_as('user')
		self.add_recovery_codes()
		r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
		dump('mfa_setup_only_recovery_codes', r)
		self.assertEqual(r.status_code, 200)

	def test_setup_enabled(self):
		self.login_as('user')
		self.add_recovery_codes()
		self.add_totp()
		self.add_webauthn()
		r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
		dump('mfa_setup_enabled', r)
		self.assertEqual(r.status_code, 200)

	def test_setup_few_recovery_codes(self):
		self.login_as('user')
		self.add_totp()
		self.add_recovery_codes(1)
		r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
		dump('mfa_setup_few_recovery_codes', r)
		self.assertEqual(r.status_code, 200)

	def test_setup_no_recovery_codes(self):
		self.login_as('user')
		self.add_totp()
		r = self.client.get(path=url_for('mfa.setup'), follow_redirects=True)
		dump('mfa_setup_no_recovery_codes', r)
		self.assertEqual(r.status_code, 200)

	def test_disable(self):
		baserole = Role(name='baserole', is_default=True)
		db.session.add(baserole)
		baserole.groups[self.get_access_group()] = RoleGroup()
		db.session.commit()
		self.login_as('user')
		self.add_recovery_codes()
		self.add_totp()
		admin_methods = len(MFAMethod.query.filter_by(user=self.get_admin()).all())
		r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True)
		dump('mfa_disable', r)
		self.assertEqual(r.status_code, 200)
		r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
		dump('mfa_disable_submit', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(MFAMethod.query.filter_by(user=request.user).all()), 0)
		self.assertEqual(len(MFAMethod.query.filter_by(user=self.get_admin()).all()), admin_methods)

	def test_disable_recovery_only(self):
		baserole = Role(name='baserole', is_default=True)
		db.session.add(baserole)
		baserole.groups[self.get_access_group()] = RoleGroup()
		db.session.commit()
		self.login_as('user')
		self.add_recovery_codes()
		admin_methods = len(MFAMethod.query.filter_by(user=self.get_admin()).all())
		self.assertNotEqual(len(MFAMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.get(path=url_for('mfa.disable'), follow_redirects=True)
		dump('mfa_disable_recovery_only', r)
		self.assertEqual(r.status_code, 200)
		r = self.client.post(path=url_for('mfa.disable_confirm'), follow_redirects=True)
		dump('mfa_disable_recovery_only_submit', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(MFAMethod.query.filter_by(user=request.user).all()), 0)
		self.assertEqual(len(MFAMethod.query.filter_by(user=self.get_admin()).all()), admin_methods)

	def test_admin_disable(self):
		for method in MFAMethod.query.filter_by(user=self.get_admin()).all():
			if not isinstance(method, RecoveryCodeMethod):
				db.session.delete(method)
		db.session.commit()
		self.add_recovery_codes()
		self.add_totp()
		self.login_as('admin')
		self.assertIsNotNone(request.user)
		admin_methods = len(MFAMethod.query.filter_by(user=self.get_admin()).all())
		r = self.client.get(path=url_for('mfa.admin_disable', id=self.get_user().id), follow_redirects=True)
		dump('mfa_admin_disable', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(MFAMethod.query.filter_by(user=self.get_user()).all()), 0)
		self.assertEqual(len(MFAMethod.query.filter_by(user=self.get_admin()).all()), admin_methods)

	def test_setup_recovery(self):
		self.login_as('user')
		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
		dump('mfa_setup_recovery', r)
		self.assertEqual(r.status_code, 200)
		methods = RecoveryCodeMethod.query.filter_by(user=request.user).all()
		self.assertNotEqual(len(methods), 0)
		r = self.client.post(path=url_for('mfa.setup_recovery'), follow_redirects=True)
		dump('mfa_setup_recovery_reset', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=methods[0].id).all()), 0)
		self.assertNotEqual(len(methods), 0)

	def test_setup_totp(self):
		self.login_as('user')
		self.add_recovery_codes()
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		dump('mfa_setup_totp', r)
		self.assertEqual(r.status_code, 200)
		self.assertNotEqual(len(session.get('mfa_totp_key', '')), 0)

	def test_setup_totp_without_recovery(self):
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		dump('mfa_setup_totp_without_recovery', r)
		self.assertEqual(r.status_code, 200)

	def test_setup_totp_finish(self):
		baserole = Role(name='baserole', is_default=True)
		db.session.add(baserole)
		baserole.groups[self.get_access_group()] = RoleGroup()
		db.session.commit()
		self.login_as('user')
		self.add_recovery_codes()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
		code = _hotp(int(time.time()/30), method.raw_key)
		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
		dump('mfa_setup_totp_finish', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 1)

	def test_setup_totp_finish_without_recovery(self):
		self.login_as('user')
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
		code = _hotp(int(time.time()/30), method.raw_key)
		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
		dump('mfa_setup_totp_finish_without_recovery', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)

	def test_setup_totp_finish_wrong_code(self):
		self.login_as('user')
		self.add_recovery_codes()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		method = TOTPMethod(request.user, key=session.get('mfa_totp_key', ''))
		code = _hotp(int(time.time()/30), method.raw_key)
		code = str(int(code[0])+1)[-1] + code[1:]
		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': code}, follow_redirects=True)
		dump('mfa_setup_totp_finish_wrong_code', r)
		self.assertEqual(r.status_code, 200)
		db_flush()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)

	def test_setup_totp_finish_empty_code(self):
		self.login_as('user')
		self.add_recovery_codes()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)
		r = self.client.get(path=url_for('mfa.setup_totp', name='My TOTP Authenticator'), follow_redirects=True)
		r = self.client.post(path=url_for('mfa.setup_totp_finish', name='My TOTP Authenticator'), data={'code': ''}, follow_redirects=True)
		dump('mfa_setup_totp_finish_empty_code', r)
		self.assertEqual(r.status_code, 200)
		db_flush()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 0)

	def test_delete_totp(self):
		baserole = Role(name='baserole', is_default=True)
		db.session.add(baserole)
		baserole.groups[self.get_access_group()] = RoleGroup()
		db.session.commit()
		self.login_as('user')
		self.add_recovery_codes()
		self.add_totp()
		method = TOTPMethod(request.user, name='test')
		db.session.add(method)
		db.session.commit()
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 2)
		r = self.client.get(path=url_for('mfa.delete_totp', id=method.id), follow_redirects=True)
		dump('mfa_delete_totp', r)
		self.assertEqual(r.status_code, 200)
		self.assertEqual(len(TOTPMethod.query.filter_by(id=method.id).all()), 0)
		self.assertEqual(len(TOTPMethod.query.filter_by(user=request.user).all()), 1)

	# TODO: webauthn setup tests

	def test_auth_integration(self):
		self.add_recovery_codes()
		self.add_totp()
		db.session.commit()
		self.assertIsNone(request.user)
		r = self.login_as('user')
		dump('mfa_auth_redirected', r)
		self.assertEqual(r.status_code, 200)
		self.assertIn(b'/mfa/auth', r.data)
		self.assertIsNone(request.user)
		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
		dump('mfa_auth', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)

	def test_auth_disabled(self):
		self.assertIsNone(request.user)
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
		self.assertEqual(r.status_code, 302)
		self.assertTrue(r.location.endswith('/redirecttarget'))
		self.assertIsNotNone(request.user)

	def test_auth_recovery_only(self):
		self.add_recovery_codes()
		self.assertIsNone(request.user)
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth', ref='/redirecttarget'), follow_redirects=False)
		self.assertEqual(r.status_code, 302)
		self.assertTrue(r.location.endswith('/redirecttarget'))
		self.assertIsNotNone(request.user)

	def test_auth_recovery_code(self):
		self.add_recovery_codes()
		self.add_totp()
		method = RecoveryCodeMethod(user=self.get_user())
		db.session.add(method)
		db.session.commit()
		method_id = method.id
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
		dump('mfa_auth_recovery_code', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)
		r = self.client.post(path=url_for('mfa.auth_finish', ref='/redirecttarget'), data={'code': method.code})
		self.assertEqual(r.status_code, 302)
		self.assertTrue(r.location.endswith('/redirecttarget'))
		self.assertIsNotNone(request.user)
		self.assertEqual(len(RecoveryCodeMethod.query.filter_by(id=method_id).all()), 0)

	def test_auth_totp_code(self):
		self.add_recovery_codes()
		self.add_totp()
		method = TOTPMethod(user=self.get_user(), name='testname')
		raw_key = method.raw_key
		db.session.add(method)
		db.session.commit()
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
		dump('mfa_auth_totp_code', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)
		code = _hotp(int(time.time()/30), raw_key)
		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
		dump('mfa_auth_totp_code_submit', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNotNone(request.user)

	def test_auth_empty_code(self):
		self.add_recovery_codes()
		self.add_totp()
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)
		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': ''}, follow_redirects=True)
		dump('mfa_auth_empty_code', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)

	def test_auth_invalid_code(self):
		self.add_recovery_codes()
		self.add_totp()
		method = TOTPMethod(user=self.get_user(), name='testname')
		raw_key = method.raw_key
		db.session.add(method)
		db.session.commit()
		self.login_as('user')
		r = self.client.get(path=url_for('mfa.auth'), follow_redirects=False)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)
		code = _hotp(int(time.time()/30), raw_key)
		code = str(int(code[0])+1)[-1] + code[1:]
		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
		dump('mfa_auth_invalid_code', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)

	def test_auth_ratelimit(self):
		self.add_recovery_codes()
		self.add_totp()
		method = TOTPMethod(user=self.get_user(), name='testname')
		raw_key = method.raw_key
		db.session.add(method)
		db.session.commit()
		self.login_as('user')
		self.assertIsNone(request.user)
		code = _hotp(int(time.time()/30), raw_key)
		inv_code = str(int(code[0])+1)[-1] + code[1:]
		for i in range(20):
			r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': inv_code}, follow_redirects=True)
			self.assertEqual(r.status_code, 200)
			self.assertIsNone(request.user)
		r = self.client.post(path=url_for('mfa.auth_finish'), data={'code': code}, follow_redirects=True)
		dump('mfa_auth_ratelimit', r)
		self.assertEqual(r.status_code, 200)
		self.assertIsNone(request.user)

	# TODO: webauthn auth tests