From 380af003cd225be83a83eeb45c0196ef6f2efbbb Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Fri, 29 Jan 2021 18:45:50 +0100
Subject: [PATCH] Moved oauth/services permission checking into
 User.has_permission

---
 tests/test_oauth2.py   | 31 +------------------------------
 tests/test_user.py     | 25 +++++++++++++++++++++++++
 uffd/oauth2/models.py  | 13 +------------
 uffd/services/views.py | 10 +++++-----
 uffd/user/models.py    | 15 +++++++++++++++
 5 files changed, 47 insertions(+), 47 deletions(-)

diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py
index 1d9751b7..ab02d4a5 100644
--- a/tests/test_oauth2.py
+++ b/tests/test_oauth2.py
@@ -41,39 +41,10 @@ class TestOAuth2Client(UffdTestCase):
 	def test_access_allowed(self):
 		user = get_user() # has 'users' and 'uffd_access' group
 		admin = get_admin() # has 'users', 'uffd_access' and 'uffd_admin' group
-		client = OAuth2Client('test', '', [''], None)
-		self.assertTrue(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], 'users')
-		self.assertTrue(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], 'notagroup')
-		self.assertFalse(client.access_allowed(user))
-		self.assertFalse(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], 'uffd_admin')
-		self.assertFalse(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], ['uffd_admin'])
-		self.assertFalse(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], ['uffd_admin', 'notagroup'])
-		self.assertFalse(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], ['notagroup', 'uffd_admin' ])
-		self.assertFalse(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], ['uffd_admin', 'users'])
-		self.assertTrue(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], ['uffd_admin', 'users'])
-		self.assertTrue(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
-		client = OAuth2Client('test', '', [''], [['uffd_admin', 'users'], ['users', 'uffd_access']])
-		self.assertTrue(client.access_allowed(user))
-		self.assertTrue(client.access_allowed(admin))
 		client = OAuth2Client('test', '', [''], ['uffd_admin', ['users', 'notagroup']])
 		self.assertFalse(client.access_allowed(user))
 		self.assertTrue(client.access_allowed(admin))
+		# More required_group values are tested by TestUserModel.test_has_permission
 
 class TestViews(UffdTestCase):
 	def setUpApp(self):
diff --git a/tests/test_user.py b/tests/test_user.py
index 3b095fd6..cca33eff 100644
--- a/tests/test_user.py
+++ b/tests/test_user.py
@@ -26,6 +26,31 @@ def get_user_password():
 def get_admin():
 	return User.from_ldap_dn('uid=testadmin,ou=users,dc=example,dc=com')
 
+class TestUserModel(UffdTestCase):
+	def test_has_permission(self):
+		user = get_user() # has 'users' and 'uffd_access' group
+		admin = get_admin() # has 'users', 'uffd_access' and 'uffd_admin' group
+		self.assertTrue(user.has_permission(None))
+		self.assertTrue(admin.has_permission(None))
+		self.assertTrue(user.has_permission('users'))
+		self.assertTrue(admin.has_permission('users'))
+		self.assertFalse(user.has_permission('notagroup'))
+		self.assertFalse(admin.has_permission('notagroup'))
+		self.assertFalse(user.has_permission('uffd_admin'))
+		self.assertTrue(admin.has_permission('uffd_admin'))
+		self.assertFalse(user.has_permission(['uffd_admin']))
+		self.assertTrue(admin.has_permission(['uffd_admin']))
+		self.assertFalse(user.has_permission(['uffd_admin', 'notagroup']))
+		self.assertTrue(admin.has_permission(['uffd_admin', 'notagroup']))
+		self.assertFalse(user.has_permission(['notagroup', 'uffd_admin']))
+		self.assertTrue(admin.has_permission(['notagroup', 'uffd_admin']))
+		self.assertTrue(user.has_permission(['uffd_admin', 'users']))
+		self.assertTrue(admin.has_permission(['uffd_admin', 'users']))
+		self.assertTrue(user.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
+		self.assertTrue(admin.has_permission([['uffd_admin', 'users'], ['users', 'uffd_access']]))
+		self.assertFalse(user.has_permission(['uffd_admin', ['users', 'notagroup']]))
+		self.assertTrue(admin.has_permission(['uffd_admin', ['users', 'notagroup']]))
+
 class TestUserViews(UffdTestCase):
 	def setUp(self):
 		super().setUp()
diff --git a/uffd/oauth2/models.py b/uffd/oauth2/models.py
index cdc3b9f4..261a8cf1 100644
--- a/uffd/oauth2/models.py
+++ b/uffd/oauth2/models.py
@@ -29,18 +29,7 @@ class OAuth2Client:
 		return self.redirect_uris[0]
 
 	def access_allowed(self, user):
-		if not self.required_group:
-			return True
-		user_groups = {group.name for group in user.get_groups()}
-		group_sets = self.required_group
-		if isinstance(group_sets, str):
-			group_sets = [group_sets]
-		for group_set in group_sets:
-			if isinstance(group_set, str):
-				group_set = [group_set]
-			if set(group_set) - user_groups == set():
-				return True
-		return False
+		return user.has_permission(self.required_group)
 
 class OAuth2Grant(db.Model):
 	__tablename__ = 'oauth2grant'
diff --git a/uffd/services/views.py b/uffd/services/views.py
index 48fcf795..4868eafc 100644
--- a/uffd/services/views.py
+++ b/uffd/services/views.py
@@ -26,11 +26,11 @@ def get_services(user=None):
 			'links': [],
 		}
 		if service_data.get('required_group'):
-			if not user or not user.is_in_group(service_data['required_group']):
+			if not user or not user.has_permission(service_data['required_group']):
 				service['has_access'] = False
 		for permission_data in service_data.get('permission_levels', []):
 			if permission_data.get('required_group'):
-				if not user or not user.is_in_group(permission_data['required_group']):
+				if not user or not user.has_permission(permission_data['required_group']):
 					continue
 			if not permission_data.get('name'):
 				continue
@@ -40,14 +40,14 @@ def get_services(user=None):
 			continue
 		for group_data in service_data.get('groups', []):
 			if group_data.get('required_group'):
-				if not user or not user.is_in_group(group_data['required_group']):
+				if not user or not user.has_permission(group_data['required_group']):
 					continue
 			if not group_data.get('name'):
 				continue
 			service['groups'].append(group_data)
 		for info_data in service_data.get('infos', []):
 			if info_data.get('required_group'):
-				if not user or not user.is_in_group(info_data['required_group']):
+				if not user or not user.has_permission(info_data['required_group']):
 					continue
 			if not info_data.get('title') or not info_data.get('html'):
 				continue
@@ -59,7 +59,7 @@ def get_services(user=None):
 			service['infos'].append(info)
 		for link_data in service_data.get('links', []):
 			if link_data.get('required_group'):
-				if not user or not user.is_in_group(link_data['required_group']):
+				if not user or not user.has_permission(link_data['required_group']):
 					continue
 			if not link_data.get('url') or not link_data.get('title'):
 				continue
diff --git a/uffd/user/models.py b/uffd/user/models.py
index 5158b323..3cf84583 100644
--- a/uffd/user/models.py
+++ b/uffd/user/models.py
@@ -91,6 +91,7 @@ class User():
 				groups.append(newgroup)
 		self._groups = groups
 		return groups
+
 	def replace_group_dns(self, values):
 		self._groups = None
 		self.groups_ldap = values
@@ -105,6 +106,20 @@ class User():
 				return True
 		return False
 
+	def has_permission(self, required_group=None):
+		if not required_group:
+			return True
+		group_names = {group.name for group in self.get_groups()}
+		group_sets = required_group
+		if isinstance(group_sets, str):
+			group_sets = [group_sets]
+		for group_set in group_sets:
+			if isinstance(group_set, str):
+				group_set = [group_set]
+			if set(group_set) - group_names == set():
+				return True
+		return False
+
 	def set_loginname(self, value):
 		if not ldap.loginname_is_safe(value):
 			return False
-- 
GitLab