Skip to content
Snippets Groups Projects
Select Git revision
  • 697165fe7b7b349dd108be9ac31e9eb0513097fc
  • master default protected
  • igor-compat
3 results

auth.php

Blame
  • dbutils.py 3.52 KiB
    from collections.abc import MutableSet
    
    from .model import add_to_session
    
    class DBRelationshipSet(MutableSet):
    	def __init__(self, dbobj, relattr, ldapcls, mapcls):
    		self.__dbobj = dbobj
    		self.__relattr = relattr
    		self.__ldapcls = ldapcls
    		self.__mapcls = mapcls
    
    	def __get_dns(self):
    		return [mapobj.dn for mapobj in getattr(self.__dbobj, self.__relattr)]
    
    	def __repr__(self):
    		return repr(set(self))
    
    	def __contains__(self, value):
    		if value is None or not isinstance(value, self.__ldapcls):
    			return False
    		return value.ldap_object.dn in self.__get_dns()
    
    	def __iter__(self):
    		return iter(filter(lambda obj: obj is not None, [self.__ldapcls.query.get(dn) for dn in self.__get_dns()]))
    
    	def __len__(self):
    		return len(set(self))
    
    	def add(self, value):
    		if not isinstance(value, self.__ldapcls):
    			raise TypeError()
    		if value.ldap_object.session is None:
    			add_to_session(value, self.__ldapcls.ldap_mapper.session.ldap_session)
    		if value.ldap_object.dn not in self.__get_dns():
    			getattr(self.__dbobj, self.__relattr).append(self.__mapcls(dn=value.ldap_object.dn))
    
    	def discard(self, value):
    		if not isinstance(value, self.__ldapcls):
    			raise TypeError()
    		rel = getattr(self.__dbobj, self.__relattr)
    		for mapobj in list(rel):
    			if mapobj.dn == value.ldap_object.dn:
    				rel.remove(mapobj)
    
    class DBRelationship:
    	def __init__(self, relattr, ldapcls, mapcls, backref=None, backattr=None):
    		self.relattr = relattr
    		self.ldapcls = ldapcls
    		self.mapcls = mapcls
    		self.backref = backref
    		self.backattr = backattr
    
    	def __set_name__(self, cls, name):
    		if self.backref:
    			assert self.backattr
    			setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr))
    
    	def __get__(self, obj, objtype=None):
    		if obj is None:
    			return self
    		return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
    
    	def __set__(self, obj, values):
    		tmp = self.__get__(obj)
    		tmp.clear()
    		for value in values:
    			tmp.add(value)
    
    class DBBackreferenceSet(MutableSet):
    	def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr):
    		self.__ldapobj = ldapobj
    		self.__dbcls = dbcls
    		self.__relattr = relattr
    		self.__mapcls = mapcls
    		self.__backattr = backattr
    
    	@property
    	def __dn(self):
    		return self.__ldapobj.ldap_object.dn
    
    	def __get(self):
    		return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
    
    	def __repr__(self):
    		return repr(self.__get())
    
    	def __contains__(self, value):
    		return value in self.__get()
    
    	def __iter__(self):
    		return iter(self.__get())
    
    	def __len__(self):
    		return len(self.__get())
    
    	def add(self, value):
    		# TODO: add value to db session if necessary
    		assert self.__ldapobj.ldap_object.session is not None
    		if not isinstance(value, self.__dbcls):
    			raise TypeError()
    		rel = getattr(value, self.__relattr)
    		if self.__dn not in {mapobj.dn for mapobj in rel}:
    			rel.append(self.__mapcls(dn=self.__dn))
    
    	def discard(self, value):
    		if not isinstance(value, self.__dbcls):
    			raise TypeError()
    		rel = getattr(value, self.__relattr)
    		for mapobj in list(rel):
    			if mapobj.dn == self.__dn:
    				rel.remove(mapobj)
    
    class DBBackreference:
    	def __init__(self, dbcls, relattr, mapcls, backattr):
    		self.dbcls = dbcls
    		self.relattr = relattr
    		self.mapcls = mapcls
    		self.backattr = backattr
    
    	def __get__(self, obj, objtype=None):
    		if obj is None:
    			return self
    		return DBBackreferenceSet(obj, self.dbcls, self.relattr, self.mapcls, self.backattr)
    
    	def __set__(self, obj, values):
    		tmp = self.__get__(obj)
    		tmp.clear()
    		for value in values:
    			tmp.add(value)