From 1fab0a6aee4e483511672d61d32a154a614e19ea Mon Sep 17 00:00:00 2001
From: Julian Rother <julian@jrother.eu>
Date: Sat, 23 Oct 2021 16:56:13 +0200
Subject: [PATCH] Refactored api wrapper

---
 server.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/server.py b/server.py
index 06681dd..5fccb24 100644
--- a/server.py
+++ b/server.py
@@ -1,6 +1,7 @@
 import sys
 import json
 import socketserver
+
 import requests
 from cachecontrol import CacheControl
 from cachecontrol.heuristics import ExpiresAfter
@@ -39,6 +40,15 @@ class UffdAPI:
 		assert(resp.ok)
 		return resp.json()
 
+	def get_users(self, id=None, loginname=None, group=None):
+		return self.get('getusers', id=id, loginname=loginname, group=group)
+
+	def get_groups(self, id=None, name=None, member=None):
+		return self.get('getgroups', id=id, name=name, member=member)
+
+	def check_password(self, loginname, password):
+		return self.api.post('checkpassword', loginname=loginname, password=password)
+
 def normalize_user_loginname(loginname):
 	# The equality matching rule for uid is caseIgnoreMatch. It prepares
 	# attribute and assertion value according to LDAP stringprep with
@@ -88,7 +98,7 @@ class RequestHandler(LDAPRequestHandler):
 			return True
 		if not dn.is_direct_child_of(DN('ou=users') + self.dn_base) or len(dn[0]) != 1 or dn[0][0].attribute != 'uid':
 			raise LDAPInvalidCredentials()
-		if self.api.post('checkpassword', loginname=dn[0][0].value, password=password):
+		if self.api.check_password(loginname=dn[0][0].value, password=password):
 			return True
 		raise LDAPInvalidCredentials()
 
@@ -97,7 +107,7 @@ class RequestHandler(LDAPRequestHandler):
 	def do_bind_sasl_plain(self, identity, password, authzid=None):
 		if authzid is not None and identity != authzid:
 			raise LDAPInvalidCredentials()
-		user = self.api.post('checkpassword', loginname=identity, password=password)
+		user = self.api.check_password(loginname=identity, password=password)
 		if user is None:
 			raise LDAPInvalidCredentials()
 		return user
@@ -165,7 +175,7 @@ class RequestHandler(LDAPRequestHandler):
 				if value.is_direct_child_of(DN(self.dn_base, ou='groups')) and value.object_attribute == 'cn':
 					request_params = {'group': normalize_group_name(value.object_value)}
 					break
-		for user in self.api.get('getusers', **request_params):
+		for user in self.api.get_users(**request_params):
 			yield template.create_object(user['loginname'],
 				cn=[user['displayname']],
 				displayname=[user['displayname']],
@@ -199,7 +209,7 @@ class RequestHandler(LDAPRequestHandler):
 				if value.is_direct_child_of(DN(self.dn_base, ou='users')) and value.object_attribute == 'uid':
 					request_params = {'member': normalize_user_loginname(value.object_value)}
 					break
-		for group in self.api.get('getgroups', **request_params):
+		for group in self.api.get_groups(**request_params):
 			yield template.create_object(group['name'],
 				cn=[group['name']],
 				gidNumber=[group['id']],
-- 
GitLab