import enum

from . import ldap
from .dn import DN

class FilterResult(enum.Enum):
	TRUE = enum.auto()
	FALSE = enum.auto()
	UNDEFINED = enum.auto()
	MAYBE_TRUE = enum.auto() # used by ObjectTemplate

def match_to_filter_result(match_result):
	if match_result is True:
		return FilterResult.TRUE
	if match_result is False:
		return FilterResult.FALSE
	return FilterResult.UNDEFINED

def any_3value(iterable):
	'''Extended three-valued logic equivalent of any builtin

	If all items are TRUE, return TRUE. Otherwise if any item is MAYBE_TRUE,
	return MAYBE_TRUE. If neither TRUE nor MAYBE_TRUE are in items, but any
	item is UNDEFINED, return UNDEFINED. Otherwise (all items are FALSE),
	return FALSE.'''
	result = FilterResult.FALSE
	for item in iterable:
		if item == FilterResult.TRUE:
			return FilterResult.TRUE
		elif item == FilterResult.MAYBE_TRUE:
			result = FilterResult.MAYBE_TRUE
		elif item == FilterResult.UNDEFINED and result == FilterResult.FALSE:
			result = FilterResult.UNDEFINED
	return result

def all_3value(iterable):
	'''Extended three-valued logic equivalent of all builtin

	If all items are TRUE, return TRUE. If any item is FALSE, return FALSE.
	If no item is FALSE and any item is UNDEFINED, return UNDEFINED.
	Otherwise (not item is FALSE or UNDEFINED and not all items are TRUE,
	so at least one item is MAYBE_TRUE), return MAYBE_TRUE.'''
	result = FilterResult.TRUE
	for item in iterable:
		if item == FilterResult.FALSE:
			return FilterResult.FALSE
		elif item == FilterResult.UNDEFINED:
			result = FilterResult.UNDEFINED
		elif item == FilterResult.MAYBE_TRUE and result == FilterResult.TRUE:
			result = FilterResult.MAYBE_TRUE
	return result

class AttributeDict(dict):
	'''Special dictionary holding LDAP attribute values

	Attribute values can be set and accessed by their attribute type's numeric
	OID or short descriptive name. Attribute types must be defined within the
	schema to be used. Attribute values are always lists. Accessing an unset
	attribute behaves the same as accessing an empty attribute. List items must
	conform to the attribute's syntax.'''
	def __init__(self, schema, **attributes):
		super().__init__()
		self.schema = schema
		for key, value in attributes.items():
			self[key] = value

	def __contains__(self, key):
		try:
			return super().__contains__(self.schema.get_attribute_type(key).name)
		except KeyError:
			return False

	def __setitem__(self, key, values):
		super().__setitem__(self.schema.get_attribute_type(key).name, values)

	def __getitem__(self, key):
		canonical_oid = self.schema.get_attribute_type(key).name
		if not super().__contains__(canonical_oid):
			super().__setitem__(canonical_oid, [])
		result = super().__getitem__(canonical_oid)
		if callable(result):
			return result()
		return result

	def setdefault(self, key, default=None):
		canonical_oid = self.schema.get_attribute_type(key).name
		return super().setdefault(canonical_oid, default)

	def get(self, key, default=None):
		return self[key] or default

	def get_with_subtypes(self, key):
		result = []
		for attribute_type in self.schema.get_attribute_type_and_subtypes(key):
			result += self[attribute_type.name]
		return result

	def match_present(self, key):
		attribute_type = self.schema.get_attribute_type(key)
		if attribute_type is None:
			return FilterResult.UNDEFINED
		if self[attribute_type] != []:
			return FilterResult.TRUE
		else:
			return FilterResult.FALSE

	def match_equal(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(self.schema, attrval, assertion_value)), self.get_with_subtypes(key)))

	def match_substr(self, key, inital_substring, any_substrings, final_substring):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None or attribute_type.substr is None:
			return FilterResult.UNDEFINED
		if inital_substring:
			inital_substring = attribute_type.equality.syntax.decode(self.schema, inital_substring)
			if inital_substring is None:
				return FilterResult.UNDEFINED
		any_substrings = [attribute_type.equality.syntax.decode(self.schema, substring) for substring in any_substrings]
		if None in any_substrings:
			return FilterResult.UNDEFINED
		if final_substring:
			final_substring = attribute_type.equality.syntax.decode(self.schema, final_substring)
			if final_substring is None:
				return FilterResult.UNDEFINED
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(self.schema, attrval, inital_substring, any_substrings, final_substring)), self.get_with_subtypes(key)))

	def match_approx(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(self.schema, attrval, assertion_value)), self.get_with_subtypes(key)))

	def match_greater_or_equal(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.ordering is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(self.schema, attrval, assertion_value)), self.get_with_subtypes(key)))

	def match_less(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.ordering is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(self.schema, attrval, assertion_value)), self.get_with_subtypes(key)))

	def match_less_or_equal(self, key, assertion_value):
		return any_3value((self.match_equal(key, assertion_value),
		                   self.match_less(key, assertion_value)))

	def match_filter(self, filter_obj):
		if isinstance(filter_obj, ldap.FilterAnd):
			return all_3value(map(self.match_filter, filter_obj.filters))
		elif isinstance(filter_obj, ldap.FilterOr):
			return any_3value(map(self.match_filter, filter_obj.filters))
		elif isinstance(filter_obj, ldap.FilterNot):
			subresult = self.match_filter(filter_obj.filter)
			if subresult == FilterResult.TRUE:
				return FilterResult.FALSE
			elif subresult == FilterResult.FALSE:
				return FilterResult.TRUE
			else:
				return subresult
		elif isinstance(filter_obj, ldap.FilterPresent):
			return self.match_present(filter_obj.attribute)
		elif isinstance(filter_obj, ldap.FilterEqual):
			return self.match_equal(filter_obj.attribute, filter_obj.value)
		elif isinstance(filter_obj, ldap.FilterSubstrings):
			return self.match_substr(filter_obj.attribute, filter_obj.initial_substring,
			                         filter_obj.any_substrings, filter_obj.final_substring)
		elif isinstance(filter_obj, ldap.FilterApproxMatch):
			return self.match_approx(filter_obj.attribute, filter_obj.value)
		elif isinstance(filter_obj, ldap.FilterGreaterOrEqual):
			return self.match_greater_or_equal(filter_obj.attribute, filter_obj.value)
		elif isinstance(filter_obj, ldap.FilterLessOrEqual):
			return self.match_less_or_equal(filter_obj.attribute, filter_obj.value)
		else:
			return FilterResult.UNDEFINED

class Object(AttributeDict):
	def __init__(self, schema, dn, **attributes):
		super().__init__(schema, **attributes)
		self.dn = DN(dn)

	def match_dn(self, basedn, scope):
		if scope == ldap.SearchScope.baseObject:
			return self.dn == basedn
		elif scope == ldap.SearchScope.singleLevel:
			return self.dn.is_direct_child_of(basedn)
		elif scope == ldap.SearchScope.wholeSubtree:
			return self.dn.in_subtree_of(basedn)
		else:
			return False

	def match_search(self, base_obj, scope, filter_obj):
		return self.match_dn(DN.from_str(base_obj), scope) and self.match_filter(filter_obj) == FilterResult.TRUE

	def get_search_result_entry(self, attributes=None, types_only=False):
		selected_attributes = set()
		for selector in attributes or ['*']:
			if selector == '*':
				selected_attributes |= set(self.schema.user_attribute_types)
			elif selector == '1.1':
				continue
			else:
				try:
					selected_attributes.add(self.schema.get_attribute_type(selector))
				except KeyError:
					pass
		partial_attributes = []
		for attribute_type in selected_attributes:
			values = self[attribute_type.name]
			if values != []:
				if types_only:
					values = []
				partial_attributes.append(ldap.PartialAttribute(attribute_type.name, [attribute_type.syntax.encode(self.schema, value) for value in values]))
		return ldap.SearchResultEntry(str(self.dn), partial_attributes)

class RootDSE(Object):
	def __init__(self, schema, *args, **kwargs):
		super().__init__(schema, DN(), *args, **kwargs)

	def match_search(self, base_obj, scope, filter_obj):
		return not base_obj and scope == ldap.SearchScope.baseObject and \
		       isinstance(filter_obj, ldap.FilterPresent) and \
		       filter_obj.attribute.lower() == 'objectclass'

class WildcardValue:
	pass

WILDCARD_VALUE = WildcardValue()

class ObjectTemplate(AttributeDict):
	def __init__(self, schema, parent_dn, rdn_attribute, **attributes):
		super().__init__(schema, **attributes)
		self.parent_dn = parent_dn
		self.rdn_attribute = rdn_attribute

	def match_present(self, key):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		values = self[attribute_type]
		if values == []:
			return FilterResult.FALSE
		elif WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		else:
			return FilterResult.TRUE

	def match_equal(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		values = self.get_with_subtypes(key)
		if WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_equal(self.schema, attrval, assertion_value)), values))

	def match_substr(self, key, inital_substring, any_substrings, final_substring):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None or attribute_type.substr is None:
			return FilterResult.UNDEFINED
		if inital_substring:
			inital_substring = attribute_type.equality.syntax.decode(self.schema, inital_substring)
			if inital_substring is None:
				return FilterResult.UNDEFINED
		any_substrings = [attribute_type.equality.syntax.decode(self.schema, substring) for substring in any_substrings]
		if None in any_substrings:
			return FilterResult.UNDEFINED
		if final_substring:
			final_substring = attribute_type.equality.syntax.decode(self.schema, final_substring)
			if final_substring is None:
				return FilterResult.UNDEFINED
		values = self.get_with_subtypes(key)
		if WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.substr.match_substr(self.schema, attrval, inital_substring, any_substrings, final_substring)), values))

	def match_approx(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.equality is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.equality.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		values = self.get_with_subtypes(key)
		if WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.equality.match_approx(self.schema, attrval, assertion_value)), values))

	def match_greater_or_equal(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.ordering is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		values = self.get_with_subtypes(key)
		if WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_greater_or_equal(self.schema, attrval, assertion_value)), values))

	def match_less(self, key, assertion_value):
		try:
			attribute_type = self.schema.get_attribute_type(key)
		except KeyError:
			return FilterResult.UNDEFINED
		if attribute_type.ordering is None:
			return FilterResult.UNDEFINED
		assertion_value = attribute_type.ordering.syntax.decode(self.schema, assertion_value)
		if assertion_value is None:
			return FilterResult.UNDEFINED
		values = self.get_with_subtypes(key)
		if WILDCARD_VALUE in values:
			return FilterResult.MAYBE_TRUE
		return any_3value(map(lambda attrval: match_to_filter_result(attribute_type.ordering.match_less(self.schema, attrval, assertion_value)), values))

	def __extract_dn_constraints(self, basedn, scope):
		if scope == ldap.SearchScope.baseObject:
			if basedn[1:] != self.parent_dn or basedn.object_attribute != self.rdn_attribute:
				return False, AttributeDict(self.schema)
			return True, AttributeDict(self.schema, **{self.rdn_attribute: [basedn.object_value]})
		elif scope == ldap.SearchScope.singleLevel:
			return basedn == self.parent_dn, AttributeDict(self.schema)
		elif scope == ldap.SearchScope.wholeSubtree:
			if self.parent_dn.in_subtree_of(basedn):
				return True, AttributeDict(self.schema)
			if basedn[1:] != self.parent_dn or basedn.object_attribute != self.rdn_attribute:
				return False, AttributeDict(self.schema)
			return True, AttributeDict(self.schema, **{self.rdn_attribute: [basedn.object_value]})
		else:
			return False, AttributeDict(self.schema)

	def match_dn(self, basedn, scope):
		'''Return whether objects from this template might match the provided parameters'''
		return self.__extract_dn_constraints(basedn, scope)[0]

	def extract_dn_constraints(self, basedn, scope):
		return self.__extract_dn_constraints(basedn, scope)[1]

	def extract_filter_constraints(self, filter_obj):
		if isinstance(filter_obj, ldap.FilterEqual):
			try:
				attribute_type = self.schema.get_attribute_type(filter_obj.attribute)
			except KeyError:
				return AttributeDict(self.schema)
			if attribute_type.equality is None:
				return AttributeDict(self.schema)
			assertion_value = attribute_type.equality.syntax.decode(self.schema, filter_obj.value)
			if assertion_value is None:
				return AttributeDict(self.schema)
			return AttributeDict(self.schema, **{filter_obj.attribute: [assertion_value]})
		if isinstance(filter_obj, ldap.FilterAnd):
			result = AttributeDict(self.schema)
			for subfilter in filter_obj.filters:
				for name, values in self.extract_filter_constraints(subfilter).items():
					result[name] += values
			return result
		return AttributeDict(self.schema)

	def match_search(self, base_obj, scope, filter_obj):
		'''Return whether objects based on this template might match the search parameters'''
		return self.match_dn(DN.from_str(base_obj), scope) and self.match_filter(filter_obj) in (FilterResult.TRUE, FilterResult.MAYBE_TRUE)

	def extract_search_constraints(self, base_obj, scope, filter_obj):
		constraints = self.extract_filter_constraints(filter_obj)
		for key, values in self.extract_dn_constraints(DN.from_str(base_obj), scope).items():
			constraints[key] += values
		return constraints

	def create_object(self, rdn_value, **attributes):
		obj = Object(self.schema, DN(self.parent_dn, **{self.rdn_attribute: rdn_value}))
		for key, values in attributes.items():
			if WILDCARD_VALUE not in self[key]:
				raise ValueError(f'Cannot set attribute "{key}" that is not set to [WILDCARD_VALUE] in the template')
			obj[key] = values
		for attribute_type, values in self.items():
			if WILDCARD_VALUE not in values:
				obj[attribute_type] = values
		return obj

class SubschemaSubentry(Object):
	'''Special :any:`Object` providing information on a Schema'''
	def __init__(self, schema, dn, **attributes):
		super().__init__(schema, dn, **attributes)
		self['subschemaSubentry'] = [self.dn]
		self['structuralObjectClass'] = ['subtree']
		self['objectClass'] = ['top', 'subtree', 'subschema']
		self['objectClasses'] = [item.to_definition() for item in schema.object_classes]
		self['ldapSyntaxes'] = [item.to_definition() for item in schema.syntaxes]
		self['matchingRules'] = [item.to_definition() for item in schema.matching_rules]
		self['attributeTypes'] = [item.to_definition() for item in schema.attribute_types]
		# pylint: disable=invalid-name
		self.AttributeDict = lambda **attributes: AttributeDict(schema, **attributes)
		self.Object = lambda *args, **attributes: Object(schema, dn, subschemaSubentry=[dn], **attributes)
		self.RootDSE = lambda **attributes: RootDSE(schema, subschemaSubentry=[dn], **attributes)
		self.ObjectTemplate = lambda *args, **kwargs: ObjectTemplate(schema, *args, subschemaSubentry=[dn], **kwargs)

	def match_search(self, base_obj, scope, filter_obj):
		return DN.from_str(base_obj) == self.dn and  \
		       scope == ldap.SearchScope.baseObject and \
		       isinstance(filter_obj, ldap.FilterEqual) and \
		       filter_obj.attribute.lower() == 'objectclass' and \
		       filter_obj.value.lower() == b'subschema'