Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • uffd/uffd
  • rixx/uffd
  • thies/uffd
  • leona/uffd
  • enbewe/uffd
  • strifel/uffd
  • thies/uffd-2
7 results
Show changes
Showing
with 263 additions and 1414 deletions
from flask import current_app
from flask.cli import AppGroup
from sqlalchemy.exc import IntegrityError
import click
from uffd.database import db
from uffd.models import User, Role
user_command = AppGroup('user', help='Manage users')
# pylint: disable=too-many-arguments
def update_attrs(user, mail=None, displayname=None, password=None,
prompt_password=False, clear_roles=False,
add_role=tuple(), remove_role=tuple(), deactivate=None):
if password is None and prompt_password:
password = click.prompt('Password', hide_input=True, confirmation_prompt='Confirm password')
if mail is not None and not user.set_primary_email_address(mail):
raise click.ClickException('Invalid mail address')
if displayname is not None and not user.set_displayname(displayname):
raise click.ClickException('Invalid displayname')
if password is not None and not user.set_password(password):
raise click.ClickException('Invalid password')
if deactivate is not None:
user.is_deactivated = deactivate
if clear_roles:
user.roles.clear()
for role_name in add_role:
role = Role.query.filter_by(name=role_name).one_or_none()
if role is None:
raise click.ClickException(f'Role {role_name} not found')
role.members.append(user)
for role_name in remove_role:
role = Role.query.filter_by(name=role_name).one_or_none()
if role is None:
raise click.ClickException(f'Role {role_name} not found')
role.members.remove(user)
user.update_groups()
@user_command.command(help='List login names of all users')
def list():
with current_app.test_request_context():
for user in User.query:
click.echo(user.loginname)
@user_command.command(help='Show details of user')
@click.argument('loginname')
def show(loginname):
with current_app.test_request_context():
user = User.query.filter_by(loginname=loginname).one_or_none()
if user is None:
raise click.ClickException(f'User {loginname} not found')
click.echo(f'Loginname: {user.loginname}')
click.echo(f'Deactivated: {user.is_deactivated}')
click.echo(f'Displayname: {user.displayname}')
click.echo(f'Mail: {user.primary_email.address}')
click.echo(f'Service User: {user.is_service_user}')
click.echo(f'Roles: {", ".join([role.name for role in user.roles])}')
click.echo(f'Groups: {", ".join([group.name for group in user.groups])}')
@user_command.command(help='Create new user')
@click.argument('loginname')
@click.option('--mail', required=True, metavar='EMAIL_ADDRESS', help='E-Mail address')
@click.option('--displayname', help='Set display name. Defaults to login name.')
@click.option('--service/--no-service', default=False, help='Create service or regular (default) user. '+\
'Regular users automatically have roles marked as default. '+\
'Service users do not.')
@click.option('--password', help='Password for SSO login. Login disabled if unset.')
@click.option('--prompt-password', is_flag=True, flag_value=True, default=False, help='Read password interactively from terminal.')
@click.option('--add-role', multiple=True, help='Add role to user. Repeat to add multiple roles.', metavar='ROLE_NAME')
@click.option('--deactivate', is_flag=True, flag_value=True, default=None, help='Deactivate account.')
def create(loginname, mail, displayname, service, password, prompt_password, add_role, deactivate):
with current_app.test_request_context():
if displayname is None:
displayname = loginname
user = User(is_service_user=service)
if not user.set_loginname(loginname, ignore_blocklist=True):
raise click.ClickException('Invalid loginname')
try:
db.session.add(user)
update_attrs(user, mail, displayname, password, prompt_password, add_role=add_role, deactivate=deactivate)
db.session.commit()
except IntegrityError:
# pylint: disable=raise-missing-from
raise click.ClickException('Login name or e-mail address is already in use')
@user_command.command(help='Update user attributes and roles')
@click.argument('loginname')
@click.option('--mail', metavar='EMAIL_ADDRESS', help='Set e-mail address.')
@click.option('--displayname', help='Set display name.')
@click.option('--password', help='Set password for SSO login.')
@click.option('--prompt-password', is_flag=True, flag_value=True, default=False, help='Set password by reading it interactivly from terminal.')
@click.option('--clear-roles', is_flag=True, flag_value=True, default=False, help='Remove all roles from user. Executed before --add-role.')
@click.option('--add-role', multiple=True, help='Add role to user. Repeat to add multiple roles.')
@click.option('--remove-role', multiple=True, help='Remove role from user. Repeat to remove multiple roles.')
@click.option('--deactivate/--activate', default=None, help='Deactivate or reactivate account.')
def update(loginname, mail, displayname, password, prompt_password, clear_roles, add_role, remove_role, deactivate):
with current_app.test_request_context():
user = User.query.filter_by(loginname=loginname).one_or_none()
if user is None:
raise click.ClickException(f'User {loginname} not found')
try:
update_attrs(user, mail, displayname, password, prompt_password, clear_roles, add_role, remove_role, deactivate)
db.session.commit()
except IntegrityError:
# pylint: disable=raise-missing-from
raise click.ClickException('E-mail address is already in use')
@user_command.command(help='Delete user')
@click.argument('loginname')
def delete(loginname):
with current_app.test_request_context():
user = User.query.filter_by(loginname=loginname).one_or_none()
if user is None:
raise click.ClickException(f'User {loginname} not found')
db.session.delete(user)
db.session.commit()
......@@ -4,9 +4,7 @@ from flask import Blueprint, request, session
bp = Blueprint("csrf", __name__)
# pylint: disable=invalid-name
csrfEndpoints = []
# pylint: enable=invalid-name
csrf_endpoints = []
def csrf_protect(blueprint=None, endpoint=None):
def wraper(func):
......@@ -15,7 +13,7 @@ def csrf_protect(blueprint=None, endpoint=None):
urlendpoint = "{}.{}".format(blueprint.name, func.__name__)
else:
urlendpoint = func.__name__
csrfEndpoints.append(urlendpoint)
csrf_endpoints.append(urlendpoint)
@wraps(func)
def decorator(*args, **kwargs):
if '_csrf_token' in request.values:
......@@ -32,6 +30,6 @@ def csrf_protect(blueprint=None, endpoint=None):
@bp.app_url_defaults
def csrf_inject(endpoint, values):
if endpoint not in csrfEndpoints or not session.get('_csrf_token'):
if endpoint not in csrf_endpoints or not session.get('_csrf_token'):
return
values['_csrf_token'] = session['_csrf_token']
from .csrf import bp as csrf_bp, csrf_protect
bp = [csrf_bp]
from collections import OrderedDict
from sqlalchemy import MetaData
from sqlalchemy import MetaData, event
from sqlalchemy.types import TypeDecorator, Text
from sqlalchemy.ext.mutable import MutableList
from flask_sqlalchemy import SQLAlchemy
from flask.json import JSONEncoder
convention = {
'ix': 'ix_%(column_0_label)s',
......@@ -13,15 +12,56 @@ convention = {
}
metadata = MetaData(naming_convention=convention)
# pylint: disable=C0103
db = SQLAlchemy(metadata=metadata)
# pylint: enable=C0103
class SQLAlchemyJSON(JSONEncoder):
def default(self, o):
if isinstance(o, db.Model):
result = OrderedDict()
for key in o.__mapper__.c.keys():
result[key] = getattr(o, key)
return result
return JSONEncoder.default(self, o)
def enable_sqlite_foreign_key_support(dbapi_connection, connection_record):
# pylint: disable=unused-argument
cursor = dbapi_connection.cursor()
cursor.execute('PRAGMA foreign_keys=ON')
cursor.close()
# We want to enable SQLite foreign key support for app and test code, but not
# for migrations.
# The common way to add the handler to the Engine class (so it applies to all
# instances) would also affect the migrations. With flask_sqlalchemy v2.4 and
# newer we could overwrite SQLAlchemy.create_engine and add our handler there.
# However Debian Buster and Bullseye ship v2.1, so we do this here and call
# this function in create_app.
def customize_db_engine(engine):
if engine.name == 'sqlite':
event.listen(engine, 'connect', enable_sqlite_foreign_key_support)
elif engine.name in ('mysql', 'mariadb'):
@event.listens_for(engine, 'connect')
def receive_connect(dbapi_connection, connection_record): # pylint: disable=unused-argument
cursor = dbapi_connection.cursor()
cursor.execute('SHOW VARIABLES LIKE "character_set_connection"')
character_set_connection = cursor.fetchone()[1]
if character_set_connection != 'utf8mb4':
raise Exception(f'Unsupported connection charset "{character_set_connection}". Make sure to add "?charset=utf8mb4" to SQLALCHEMY_DATABASE_URI!')
cursor.execute('SHOW VARIABLES LIKE "collation_database"')
collation_database = cursor.fetchone()[1]
if collation_database != 'utf8mb4_nopad_bin':
raise Exception(f'Unsupported database collation "{collation_database}". Create the database with "CHARACTER SET utf8mb4 COLLATE utf8mb4_nopad_bin"!')
cursor.execute('SET NAMES utf8mb4 COLLATE utf8mb4_nopad_bin')
cursor.close()
class CommaSeparatedList(TypeDecorator):
# For some reason TypeDecorator.process_literal_param and
# TypeEngine.python_type are abstract but not actually required
# pylint: disable=abstract-method
impl = Text
cache_ok = True
def process_bind_param(self, value, dialect):
if value is None:
return None
for item in value:
if ',' in item:
raise ValueError('Items of comma-separated list must not contain commas')
return ','.join(value)
def process_result_value(self, value, dialect):
if value is None:
return None
return MutableList(value.split(','))
LDAP_USER_SEARCH_BASE="ou=users,dc=example,dc=com"
LDAP_USER_SEARCH_FILTER=[("objectClass", "person")]
LDAP_USER_OBJECTCLASSES=["top", "inetOrgPerson", "organizationalPerson", "person", "posixAccount"]
LDAP_USER_MIN_UID=10000
LDAP_USER_MAX_UID=18999
LDAP_USER_SERVICE_MIN_UID=19000
LDAP_USER_SERVICE_MAX_UID=19999
LDAP_USER_GID=20001
LDAP_USER_DN_ATTRIBUTE="uid"
LDAP_USER_UID_ATTRIBUTE="uidNumber"
LDAP_USER_UID_ALIASES=[]
LDAP_USER_LOGINNAME_ATTRIBUTE="uid"
LDAP_USER_LOGINNAME_ALIASES=[]
LDAP_USER_DISPLAYNAME_ATTRIBUTE="cn"
LDAP_USER_DISPLAYNAME_ALIASES=["givenName", "displayName"]
LDAP_USER_MAIL_ATTRIBUTE="mail"
LDAP_USER_MAIL_ALIASES=[]
LDAP_USER_DEFAULT_ATTRIBUTES={
"sn": " ",
# All string values are subject to python str.format-style format expansion. To insert literal braces use "{{" and "}}".
# Variables: uid, loginname, displayname, mail and possibly other attributes of the User class
"homeDirectory": "/home/{loginname}",
"gidNumber": LDAP_USER_GID,
# "multiValueAttribute": ["value1", "value2"],
}
USER_GID=20001
# Service and non-service users must either have the same UID range or must not overlap
USER_MIN_UID=10000
USER_MAX_UID=18999
USER_SERVICE_MIN_UID=19000
USER_SERVICE_MAX_UID=19999
LDAP_GROUP_SEARCH_BASE="ou=groups,dc=example,dc=com"
LDAP_GROUP_SEARCH_FILTER=[("objectClass","groupOfUniqueNames")]
LDAP_GROUP_GID_ATTRIBUTE="gidNumber"
LDAP_GROUP_NAME_ATTRIBUTE="cn"
LDAP_GROUP_DESCRIPTION_ATTRIBUTE="description"
LDAP_GROUP_MEMBER_ATTRIBUTE="uniqueMember"
LDAP_MAIL_SEARCH_BASE="ou=postfix,dc=example,dc=com"
LDAP_MAIL_SEARCH_FILTER=[("objectClass","postfixVirtual")]
LDAP_MAIL_OBJECTCLASSES=["top", "postfixVirtual"]
LDAP_MAIL_DN_ATTRIBUTE="uid"
LDAP_MAIL_UID_ATTRIBUTE="uid"
LDAP_MAIL_RECEIVERS_ATTRIBUTE="mailacceptinggeneralid"
LDAP_MAIL_DESTINATIONS_ATTRIBUTE="maildrop"
LDAP_SERVICE_URL="ldapi:///"
LDAP_SERVICE_USE_STARTTLS=True
LDAP_SERVICE_BIND_DN=""
LDAP_SERVICE_BIND_PASSWORD=""
# Connections use LDAP_SERVICE_BIND_DN if LDAP_SERVICE_USER_BIND=False, otherwise they use the users credentials.
# When using a user connection, some features are not available, since they require a service connection
LDAP_SERVICE_USER_BIND=False
GROUP_MIN_GID=20000
GROUP_MAX_GID=49999
# The period of time that a login lasts for.
SESSION_LIFETIME_SECONDS=3600
# The period of time that the session cookie lasts for. This is refreshed on each page load.
PERMANENT_SESSION_LIFETIME=2678400
# CSRF protection
SESSION_COOKIE_SECURE=True
SESSION_COOKIE_HTTPONLY=True
......@@ -60,27 +27,42 @@ LANGUAGES={
}
ACL_ADMIN_GROUP="uffd_admin"
# Group required to access selfservice functions (view selfservice, change profile/password/roles)
ACL_SELFSERVICE_GROUP="uffd_access"
# Group required to login
ACL_ACCESS_GROUP="uffd_access"
# Members can create invite links for signup
ACL_SIGNUP_GROUP="uffd_signup"
MAIL_SERVER='' # e.g. example.com
MAIL_PORT=465
MAIL_USERNAME='yourId@example.com'
MAIL_USERNAME='yourId@example.com' # set to empty string to disable authentication
MAIL_PASSWORD='*****'
MAIL_USE_STARTTLS=True
MAIL_FROM_ADDRESS='foo@bar.com'
# The following settings are not available when using a user connection
ENABLE_INVITE=True
ENABLE_PASSWORDRESET=True
ENABLE_ROLESELFSERVICE=True
# Set to a domain name (e.g. "remailer.example.com") to enable remailer.
# Requires special mail server configuration (see uffd-socketmapd). Can be
# enabled/disabled per-service in the service settings. If enabled, services
# no longer get real user mail addresses but instead special autogenerated
# addresses that are replaced with the real mail addresses by the mail server.
REMAILER_DOMAIN = ''
REMAILER_OLD_DOMAINS = []
# Secret used for construction and verification of remailer addresses.
# If None, the value of SECRET_KEY is used.
REMAILER_SECRET_KEY = None
# Set to list of user loginnames to limit remailer to a small list of users.
# Useful for debugging. If None remailer is active for all users (if
# configured and enabled for a service). This option is deprecated. Use the
# per-service setting in the web interface instead.
REMAILER_LIMIT_TO_USERS = None
# Do not enable this on a public service! There is no spam protection implemented at the moment.
SELF_SIGNUP=False
INVITE_MAX_VALID_DAYS=21
LOGINNAME_BLACKLIST=['^admin$', '^root$']
LOGINNAME_BLOCKLIST=['^admin$', '^root$']
#MFA_ICON_URL = 'https://example.com/logo.png'
#MFA_RP_ID = 'example.com' # If unset, hostname from current request is used
......@@ -90,16 +72,9 @@ SQLALCHEMY_TRACK_MODIFICATIONS=False
FOOTER_LINKS=[{"url": "https://example.com", "title": "example"}]
OAUTH2_CLIENTS={
#'test_client_id' : {'client_secret': 'random_secret', 'redirect_uris': ['https://example.com/oauth']},
# You can optionally restrict access to users with a certain group. Set 'required_group' to the name of an LDAP group name or a list of groups.
# ... 'required_group': 'test_access_group' ... only allows users with group "test_access_group" access
# ... 'required_group': ['groupa', ['groupb', 'groupc']] ... allows users with group "groupa" as well as users with both "groupb" and "groupc" access
}
API_CLIENTS={
#'token': {'scopes': ['checkpassword']}
}
# The default page after login or clicking the top left home button is the self-service
# page. If you would like it to be the services list instead, set this to True.
DEFAULT_PAGE_SERVICES=False
# Service overview page (disabled if empty)
SERVICES=[
......@@ -155,7 +130,11 @@ SERVICES_BANNER_PUBLIC=True
# Enable the service overview page for users who are not logged in
SERVICES_PUBLIC=True
# An optional banner that will be displayed above the login form
#LOGIN_BANNER='Always check the URL. Never enter your SSO password on any other site.'
BRANDING_LOGO_URL='/static/empty.png'
SITE_TITLE='uffd'
# Name and contact mail address are displayed to users in a few places (plain text only!)
ORGANISATION_NAME='Example Organisation'
......@@ -169,7 +148,6 @@ WELCOME_TEXT='See https://docs.example.com/ for further information.'
#TEMPLATES_AUTO_RELOAD=True
#SQLALCHEMY_ECHO=True
#FLASK_ENV=development
#LDAP_SERVICE_MOCK=True
# DO set in production
......
# pylint: skip-file
from flask_babel import gettext as _
from warnings import warn
from flask import request, current_app
import urllib.parse
# WebAuthn support is optional because fido2 has a pretty unstable
# interface and might be difficult to install with the correct version
try:
import fido2 as __fido2
if __fido2.__version__.startswith('0.5.'):
from fido2.client import ClientData
from fido2.server import Fido2Server, RelyingParty as __PublicKeyCredentialRpEntity
from fido2.ctap2 import AttestationObject, AuthenticatorData, AttestedCredentialData
from fido2 import cbor
cbor.encode = cbor.dumps
cbor.decode = lambda arg: cbor.loads(arg)[0]
class PublicKeyCredentialRpEntity(__PublicKeyCredentialRpEntity):
def __init__(self, name, id):
super().__init__(id, name)
elif __fido2.__version__.startswith('0.9.'):
from fido2.client import ClientData
from fido2.webauthn import PublicKeyCredentialRpEntity
from fido2.server import Fido2Server
from fido2.ctap2 import AttestationObject, AuthenticatorData, AttestedCredentialData
from fido2 import cbor
elif __fido2.__version__.startswith('1.'):
from fido2.webauthn import PublicKeyCredentialRpEntity, CollectedClientData as ClientData, AttestationObject, AuthenticatorData, AttestedCredentialData
from fido2.server import Fido2Server
from fido2 import cbor
else:
raise ImportError(f'Unsupported fido2 version: {__fido2.__version__}')
def get_webauthn_server():
hostname = urllib.parse.urlsplit(request.url).hostname
return Fido2Server(PublicKeyCredentialRpEntity(id=current_app.config.get('MFA_RP_ID', hostname),
name=current_app.config['MFA_RP_NAME']))
WEBAUTHN_SUPPORTED = True
except ImportError as err:
warn(_('2FA WebAuthn support disabled because import of the fido2 module failed (%s)')%err)
WEBAUTHN_SUPPORTED = False
from .views import bp as _bp
bp = [_bp]
{% extends 'base.html' %}
{% block body %}
<div class="row mt-2 justify-content-center">
<div class="col-lg-6 col-md-10" style="background: #f7f7f7; box-shadow: 0px 2px 2px rgba(0, 0, 0, 0.3); padding: 30px;">
<div class="text-center">
<img alt="branding logo" src="{{ config.get("BRANDING_LOGO_URL") }}" class="col-lg-8 col-md-12" >
</div>
<div class="col-12 mb-3">
<h2 class="text-center">{{_('Invite Link')}}</h2>
</div>
{% if not request.user %}
<p>{{_('Welcome to the %(org_name)s Single-Sign-On!', org_name=config.ORGANISATION_NAME)}}</p>
{% endif %}
{% if invite.roles and invite.allow_signup %}
<p>{{_('With this link you can register a new user account with the following roles or add the roles to an existing account:')}}</p>
{% elif invite.roles %}
<p>{{_('With this link you can add the following roles to an existing account:')}}</p>
{% elif invite.allow_signup %}
<p>{{_('With this link you can register a new user account.')}}</p>
{% endif %}
{% if invite.roles %}
<ul>
{% for role in invite.roles %}
<li>{{ role.name }}{% if role.description %}: {{ role.description }}{% endif %}</li>
{% endfor %}
</ul>
{% endif %}
{% if request.user %}
{% if invite.roles %}
<form method="POST" action="{{ url_for("invite.grant", token=invite.token) }}" class="mb-2">
<button type="submit" class="btn btn-primary btn-block">{{_('Add the roles to your account now')}}</button>
</form>
<a href="{{ url_for("session.logout", ref=url_for("session.login", ref=request.full_path)) }}" class="btn btn-secondary btn-block">{{_('Logout and switch to a different account')}}</a>
{% endif %}
{% if invite.allow_signup %}
<a href="{{ url_for("session.logout", ref=url_for("invite.signup_start", token=invite.token)) }}" class="btn btn-secondary btn-block">{{_('Logout to register a new account')}}</a>
{% endif %}
{% else %}
{% if invite.allow_signup %}
<a href="{{ url_for("invite.signup_start", token=invite.token) }}" class="btn btn-primary btn-block">{{_('Register a new account')}}</a>
{% endif %}
{% if invite.roles %}
<a href="{{ url_for("session.login", ref=request.full_path) }}" class="btn btn-primary btn-block">{{_('Login and add the roles to your account')}}</a>
{% endif %}
{% endif %}
</div>
</div>
{% endblock %}
from collections import UserString, UserList
from flask import current_app
class LazyConfigString(UserString):
def __init__(self, seq=None, key=None, default=None, error=True):
# pylint: disable=super-init-not-called
self.__seq = seq
self.__key = key
self.__default = default
self.__error = error
@property
def data(self):
if self.__seq is not None:
obj = self.__seq
elif self.__error:
obj = current_app.config[self.__key]
else:
obj = current_app.config.get(self.__key, self.__default)
return str(obj)
def __bytes__(self):
return self.data.encode()
def __get__(self, obj, owner=None):
return self.data
def lazyconfig_str(key, **kwargs):
return LazyConfigString(None, key, **kwargs)
class LazyConfigList(UserList):
def __init__(self, seq=None, key=None, default=None, error=True):
# pylint: disable=super-init-not-called
self.__seq = seq
self.__key = key
self.__default = default
self.__error = error
@property
def data(self):
if self.__seq is not None:
obj = self.__seq
elif self.__error:
obj = current_app.config[self.__key]
else:
obj = current_app.config.get(self.__key, self.__default)
return obj
def __get__(self, obj, owner=None):
return self.data
def lazyconfig_list(key, **kwargs):
return LazyConfigList(None, key, **kwargs)
import base64
import hashlib
from flask import current_app, request, abort, session
import ldap3
from ldap3.core.exceptions import LDAPBindError, LDAPPasswordIsMandatoryError, LDAPInvalidDnError
from uffd.ldapalchemy import LDAPMapper, LDAPCommitError # pylint: disable=unused-import
from uffd.ldapalchemy.model import Query
from uffd.ldapalchemy.core import encode_filter
def check_hashed(password_hash, password):
'''Return if password matches a LDAP-compatible password hash (only used for LDAP_SERVICE_MOCK)
:param password_hash: LDAP-compatible password hash (plain password or "{ssha512}...")
:type password_hash: bytes
:param password: Plain, (ideally) utf8-encoded password
:type password: bytes'''
algorithms = {
b'md5': 'MD5',
b'sha': 'SHA1',
b'sha256': 'SHA256',
b'sha384': 'SHA384',
b'sha512': 'SHA512'
}
if not password_hash.startswith(b'{'):
return password_hash == password
algorithm, data = password_hash[1:].split(b'}', 1)
data = base64.b64decode(data)
if algorithm in algorithms:
ctx = hashlib.new(algorithms[algorithm], password)
return data == ctx.digest()
if algorithm.startswith(b's') and algorithm[1:] in algorithms:
ctx = hashlib.new(algorithms[algorithm[1:]], password)
salt = data[ctx.digest_size:]
ctx.update(salt)
return data == ctx.digest() + salt
raise NotImplementedError()
class FlaskQuery(Query):
def get_or_404(self, dn):
res = self.get(dn)
if res is None:
abort(404)
return res
def first_or_404(self):
res = self.first()
if res is None:
abort(404)
return res
def test_user_bind(bind_dn, bind_pw):
try:
if current_app.config.get('LDAP_SERVICE_MOCK', False):
# Since we reuse the same conn and ldap3's mock only supports plain
# passwords for bind and rebind, we simulate the bind by retrieving
# and checking the password hash ourselves.
conn = ldap.get_connection()
conn.search(bind_dn, search_filter='(objectclass=*)', search_scope=ldap3.BASE,
attributes=ldap3.ALL_ATTRIBUTES)
if not conn.response:
return False
if not conn.response[0]['attributes'].get('userPassword'):
return False
return check_hashed(conn.response[0]['attributes']['userPassword'][0], bind_pw.encode())
server = ldap3.Server(current_app.config["LDAP_SERVICE_URL"])
conn = connect_and_bind_to_ldap(server, bind_dn, bind_pw)
if not conn:
return False
except (LDAPBindError, LDAPPasswordIsMandatoryError, LDAPInvalidDnError):
return False
conn.search(conn.user, encode_filter(current_app.config["LDAP_USER_SEARCH_FILTER"]))
lazy_entries = conn.entries
# Do not end the connection when using mock, as it will be reused afterwards
if not current_app.config.get('LDAP_SERVICE_MOCK', False):
conn.unbind()
return len(lazy_entries) == 1
def connect_and_bind_to_ldap(server, bind_dn, bind_pw):
# Using auto_bind cannot close the connection, so define the connection with extra steps
connection = ldap3.Connection(server, bind_dn, bind_pw)
if connection.closed:
connection.open(read_server_info=False)
if current_app.config["LDAP_SERVICE_USE_STARTTLS"]:
connection.start_tls(read_server_info=False)
if not connection.bind(read_server_info=True):
connection.unbind()
raise LDAPBindError
return connection
class FlaskLDAPMapper(LDAPMapper):
def __init__(self):
super().__init__()
class Model(self.Model):
query_class = FlaskQuery
self.Model = Model # pylint: disable=invalid-name
@property
def session(self):
if not hasattr(request, 'ldap_session'):
request.ldap_session = self.Session(self.get_connection)
return request.ldap_session
def get_connection(self):
if hasattr(request, 'ldap_connection'):
return request.ldap_connection
if current_app.config.get('LDAP_SERVICE_MOCK', False):
if not current_app.debug:
raise Exception('LDAP_SERVICE_MOCK cannot be enabled on production instances')
# Entries are stored in-memory in the mocked `Connection` object. To make
# changes persistent across requests we reuse the same `Connection` object
# for all calls to `service_conn()` and `user_conn()`.
if not hasattr(current_app, 'ldap_mock'):
server = ldap3.Server.from_definition('ldap_mock', 'tests/openldap_mock/ldap_server_info.json',
'tests/openldap_mock/ldap_server_schema.json')
current_app.ldap_mock = ldap3.Connection(server, client_strategy=ldap3.MOCK_SYNC)
current_app.ldap_mock.strategy.entries_from_json('tests/openldap_mock/ldap_server_entries.json')
current_app.ldap_mock.bind()
return current_app.ldap_mock
server = ldap3.Server(current_app.config["LDAP_SERVICE_URL"], get_info=ldap3.ALL)
if current_app.config['LDAP_SERVICE_USER_BIND']:
bind_dn = session['user_dn']
bind_pw = session['user_pw']
else:
bind_dn = current_app.config["LDAP_SERVICE_BIND_DN"]
bind_pw = current_app.config["LDAP_SERVICE_BIND_PASSWORD"]
request.ldap_connection = connect_and_bind_to_ldap(server, bind_dn, bind_pw)
return request.ldap_connection
ldap = FlaskLDAPMapper()
import ldap3
from .core import LDAPCommitError
from . import model, attribute, relationship
__all__ = ['LDAPMapper', 'LDAPCommitError']
class LDAPMapper:
def __init__(self, server=None, bind_dn=None, bind_password=None):
class Model(model.Model):
ldap_mapper = self
self.Model = Model # pylint: disable=invalid-name
self.Session = model.Session # pylint: disable=invalid-name
self.Attribute = attribute.Attribute # pylint: disable=invalid-name
self.Relationship = relationship.Relationship # pylint: disable=invalid-name
self.Backreference = relationship.Backreference # pylint: disable=invalid-name
if not hasattr(type(self), 'server'):
self.server = server
if not hasattr(type(self), 'bind_dn'):
self.bind_dn = bind_dn
if not hasattr(type(self), 'bind_password'):
self.bind_password = bind_password
if not hasattr(type(self), 'session'):
self.session = self.Session(self.get_connection)
def get_connection(self):
return ldap3.Connection(self.server, self.bind_dn, self.bind_password, auto_bind=True)
from collections.abc import MutableSequence
class AttributeList(MutableSequence):
def __init__(self, ldap_object, name, aliases):
self.__ldap_object = ldap_object
self.__name = name
self.__aliases = [name] + aliases
def __get(self):
return list(self.__ldap_object.getattr(self.__name))
def __set(self, values):
for name in self.__aliases:
self.__ldap_object.setattr(name, values)
def __repr__(self):
return repr(self.__get())
def __setitem__(self, key, value):
tmp = self.__get()
tmp[key] = value
self.__set(tmp)
def __delitem__(self, key):
tmp = self.__get()
del tmp[key]
self.__set(tmp)
def __len__(self):
return len(self.__get())
def __getitem__(self, key):
return self.__get()[key]
def insert(self, index, value):
tmp = self.__get()
tmp.insert(index, value)
self.__set(tmp)
class Attribute:
def __init__(self, name, aliases=None, multi=False, default=None):
self.name = name
self.aliases = aliases if aliases is not None else []
self.multi = multi
self.default = default
def add_hook(self, obj):
if obj.ldap_object.getattr(self.name) == []:
self.__set__(obj, self.default() if callable(self.default) else self.default)
def __set_name__(self, cls, name):
if self.default is not None:
cls.ldap_add_hooks = cls.ldap_add_hooks + (self.add_hook,)
def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.multi:
return AttributeList(obj.ldap_object, self.name, self.aliases)
return (obj.ldap_object.getattr(self.name) or [None])[0]
def __set__(self, obj, values):
if not self.multi:
values = [values]
for name in [self.name] + self.aliases:
obj.ldap_object.setattr(name, values)
from ldap3 import MODIFY_REPLACE, MODIFY_DELETE, MODIFY_ADD, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES
from ldap3.utils.conv import escape_filter_chars
def encode_filter(filter_params):
return '(&%s)'%(''.join(['(%s=%s)'%(attr, escape_filter_chars(value)) for attr, value in filter_params]))
def match_dn(dn, base):
return dn.endswith(base) # Probably good enougth for all valid dns
def make_cache_key(search_base, filter_params):
res = (search_base,)
for attr, value in sorted(filter_params):
res += ((attr, value),)
return res
class LDAPCommitError(Exception):
pass
class SessionState:
def __init__(self, objects=None, deleted_objects=None, references=None):
self.objects = objects or {}
self.deleted_objects = deleted_objects or {}
self.references = references or {} # {(attr_name, value): {srcobj, ...}, ...}
def copy(self):
objects = self.objects.copy()
deleted_objects = self.deleted_objects.copy()
references = {key: objs.copy() for key, objs in self.references.items()}
return SessionState(objects, deleted_objects, references)
def ref(self, obj, attr, values):
for value in values:
key = (attr, value)
if key not in self.references:
self.references[key] = {obj}
else:
self.references[key].add(obj)
def unref(self, obj, attr, values):
for value in values:
self.references.get((attr, value), set()).discard(obj)
class ObjectState:
def __init__(self, session=None, attributes=None, dn=None):
self.session = session
self.attributes = attributes or {}
self.dn = dn
def copy(self):
attributes = {name: values.copy() for name, values in self.attributes.items()}
return ObjectState(attributes=attributes, dn=self.dn, session=self.session)
class AddOperation:
def __init__(self, obj, dn, object_classes):
self.obj = obj
self.dn = dn
self.object_classes = object_classes
self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()}
def apply_object(self, obj_state):
obj_state.dn = self.dn
obj_state.attributes = {name: values.copy() for name, values in self.attributes.items()}
obj_state.attributes['objectClass'] = obj_state.attributes.get('objectClass', []) + list(self.object_classes)
def apply_session(self, session_state):
assert self.dn not in session_state.objects
session_state.objects[self.dn] = self.obj
for name, values in self.attributes.items():
session_state.ref(self.obj, name, values)
session_state.ref(self.obj, 'objectClass', self.object_classes)
def apply_ldap(self, conn):
success = conn.add(self.dn, self.object_classes, self.attributes)
if not success:
raise LDAPCommitError()
class DeleteOperation:
def __init__(self, obj):
self.dn = obj.state.dn
self.obj = obj
self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()}
def apply_object(self, obj_state):
obj_state.dn = None
def apply_session(self, session_state):
assert self.dn in session_state.objects
del session_state.objects[self.dn]
session_state.deleted_objects[self.dn] = self.obj
for name, values in self.attributes.items():
session_state.unref(self.obj, name, values)
def apply_ldap(self, conn):
success = conn.delete(self.dn)
if not success:
raise LDAPCommitError()
class ModifyOperation:
def __init__(self, obj, changes):
self.obj = obj
self.attributes = {name: values.copy() for name, values in obj.state.attributes.items()}
self.changes = changes
def apply_object(self, obj_state):
for attr, changes in self.changes.items():
for action, values in changes:
if action == MODIFY_REPLACE:
obj_state.attributes[attr] = values
elif action == MODIFY_ADD:
obj_state.attributes[attr] += values
elif action == MODIFY_DELETE:
for value in values:
if value in obj_state.attributes[attr]:
obj_state.attributes[attr].remove(value)
def apply_session(self, session_state):
for attr, changes in self.changes.items():
for action, values in changes:
if action == MODIFY_REPLACE:
session_state.unref(self.obj, attr, self.attributes.get(attr, []))
session_state.ref(self.obj, attr, values)
elif action == MODIFY_ADD:
session_state.ref(self.obj, attr, values)
elif action == MODIFY_DELETE:
session_state.unref(self.obj, attr, values)
def apply_ldap(self, conn):
success = conn.modify(self.obj.state.dn, self.changes)
if not success:
raise LDAPCommitError()
class Session:
def __init__(self, get_connection):
self.get_connection = get_connection
self.committed_state = SessionState()
self.state = SessionState()
self.changes = []
self.cached_searches = set()
def add(self, obj, dn, object_classes):
if self.state.objects.get(dn) == obj:
return
assert obj.state.session is None
oper = AddOperation(obj, dn, object_classes)
oper.apply_object(obj.state)
obj.state.session = self
oper.apply_session(self.state)
self.changes.append(oper)
def delete(self, obj):
if obj.state.dn not in self.state.objects:
return
assert obj.state.session == self
oper = DeleteOperation(obj)
oper.apply_object(obj.state)
obj.state.session = None
oper.apply_session(self.state)
self.changes.append(oper)
def record(self, oper):
assert oper.obj.state.session == self
self.changes.append(oper)
def commit(self):
conn = self.get_connection()
while self.changes:
oper = self.changes.pop(0)
try:
oper.apply_ldap(conn)
except Exception as err:
self.changes.insert(0, oper)
raise err
oper.apply_object(oper.obj.committed_state)
oper.apply_session(self.committed_state)
self.committed_state = self.state.copy()
def rollback(self):
for obj in self.state.objects.values():
obj.state = obj.committed_state.copy()
for obj in self.state.deleted_objects.values():
obj.state = obj.committed_state.copy()
self.state = self.committed_state.copy()
self.changes.clear()
def get(self, dn, filter_params):
if dn in self.state.objects:
obj = self.state.objects[dn]
return obj if obj.match(filter_params) else None
if dn in self.state.deleted_objects:
return None
conn = self.get_connection()
conn.search(dn, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
if not conn.response:
return None
assert len(conn.response) == 1
if conn.response[0]['dn'] != dn:
# To use DNs as cache keys, we assume each DN has a single unique string
# representation. This is not generally true: RDN attributes may be
# case insensitive or values may contain escape sequences.
# In this case, the provided DN differs from the canonical form the
# server returned. We cannot handle this consistently, so we report no
# match.
return None
obj = Object(self, conn.response[0])
self.state.objects[dn] = obj
self.committed_state.objects[dn] = obj
for attr, values in obj.state.attributes.items():
self.state.ref(obj, attr, values)
return obj
def filter(self, search_base, filter_params):
if not filter_params:
matches = self.state.objects.values()
else:
submatches = [self.state.references.get((attr, value), set()) for attr, value in filter_params]
matches = submatches.pop(0)
while submatches:
matches = matches.intersection(submatches.pop(0))
res = [obj for obj in matches if match_dn(obj.state.dn, search_base)]
cache_key = make_cache_key(search_base, filter_params)
if cache_key in self.cached_searches:
return res
conn = self.get_connection()
conn.search(search_base, encode_filter(filter_params), attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES])
for response in conn.response:
dn = response['dn']
if dn in self.state.objects or dn in self.state.deleted_objects:
continue
obj = Object(self, response)
self.state.objects[dn] = obj
self.committed_state.objects[dn] = obj
for attr, values in obj.state.attributes.items():
self.state.ref(obj, attr, values)
res.append(obj)
self.cached_searches.add(cache_key)
return res
class Object:
def __init__(self, session=None, response=None):
if response is None:
self.committed_state = ObjectState()
else:
assert session is not None
attrs = {attr: value if isinstance(value, list) else [value] for attr, value in response['attributes'].items()}
self.committed_state = ObjectState(session, attrs, response['dn'])
self.state = self.committed_state.copy()
@property
def dn(self):
return self.state.dn
@property
def session(self):
return self.state.session
def getattr(self, name):
return self.state.attributes.get(name, [])
def setattr(self, name, values):
oper = ModifyOperation(self, {name: [(MODIFY_REPLACE, values)]})
oper.apply_object(self.state)
if self.state.session:
oper.apply_session(self.state.session.state)
self.state.session.changes.append(oper)
def attr_append(self, name, value):
oper = ModifyOperation(self, {name: [(MODIFY_ADD, [value])]})
oper.apply_object(self.state)
if self.state.session:
oper.apply_session(self.state.session.state)
self.state.session.changes.append(oper)
def attr_remove(self, name, value):
oper = ModifyOperation(self, {name: [(MODIFY_DELETE, [value])]})
oper.apply_object(self.state)
if self.state.session:
oper.apply_session(self.state.session.state)
self.state.session.changes.append(oper)
def match(self, filter_params):
for attr, value in filter_params:
if value not in self.getattr(attr):
return False
return True
from collections.abc import MutableSet
from .model import add_to_session
class DBRelationshipSet(MutableSet):
def __init__(self, dbobj, relattr, ldapcls, mapcls):
self.__dbobj = dbobj
self.__relattr = relattr
self.__ldapcls = ldapcls
self.__mapcls = mapcls
def __get_dns(self):
return [mapobj.dn for mapobj in getattr(self.__dbobj, self.__relattr)]
def __repr__(self):
return repr(set(self))
def __contains__(self, value):
if value is None or not isinstance(value, self.__ldapcls):
return False
return value.ldap_object.dn in self.__get_dns()
def __iter__(self):
return iter(filter(lambda obj: obj is not None, [self.__ldapcls.query.get(dn) for dn in self.__get_dns()]))
def __len__(self):
return len(set(self))
def add(self, value):
if not isinstance(value, self.__ldapcls):
raise TypeError()
if value.ldap_object.session is None:
add_to_session(value, self.__ldapcls.ldap_mapper.session.ldap_session)
if value.ldap_object.dn not in self.__get_dns():
getattr(self.__dbobj, self.__relattr).append(self.__mapcls(dn=value.ldap_object.dn))
def discard(self, value):
if not isinstance(value, self.__ldapcls):
raise TypeError()
rel = getattr(self.__dbobj, self.__relattr)
for mapobj in list(rel):
if mapobj.dn == value.ldap_object.dn:
rel.remove(mapobj)
class DBRelationship:
def __init__(self, relattr, ldapcls, mapcls=None, backref=None, backattr=None):
self.relattr = relattr
self.ldapcls = ldapcls
self.mapcls = mapcls
self.backref = backref
self.backattr = backattr
def __set_name__(self, cls, name):
if self.backref:
setattr(self.ldapcls, self.backref, DBBackreference(cls, self.relattr, self.mapcls, self.backattr))
def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.mapcls is not None:
return DBRelationshipSet(obj, self.relattr, self.ldapcls, self.mapcls)
dn = getattr(obj, self.relattr)
if dn is not None:
return self.ldapcls.query.get(dn)
return None
def __set__(self, obj, values):
if self.mapcls is not None:
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
else:
if not isinstance(values, self.ldapcls):
raise TypeError()
setattr(obj, self.relattr, values.ldap_object.dn)
class DBBackreferenceSet(MutableSet):
def __init__(self, ldapobj, dbcls, relattr, mapcls, backattr):
self.__ldapobj = ldapobj
self.__dbcls = dbcls
self.__relattr = relattr
self.__mapcls = mapcls
self.__backattr = backattr
@property
def __dn(self):
return self.__ldapobj.ldap_object.dn
def __get(self):
if self.__mapcls is None:
return self.__dbcls.query.filter_by(**{self.__relattr: self.__dn}).all()
return {getattr(mapobj, self.__backattr) for mapobj in self.__mapcls.query.filter_by(dn=self.__dn)}
def __repr__(self):
return repr(self.__get())
def __contains__(self, value):
return value in self.__get()
def __iter__(self):
return iter(self.__get())
def __len__(self):
return len(self.__get())
def add(self, value):
assert self.__ldapobj.ldap_object.session is not None
if not isinstance(value, self.__dbcls):
raise TypeError()
if self.__mapcls is None:
setattr(value, self.__relattr, self.__dn)
else:
rel = getattr(value, self.__relattr)
if self.__dn not in {mapobj.dn for mapobj in rel}:
rel.append(self.__mapcls(dn=self.__dn))
def discard(self, value):
if not isinstance(value, self.__dbcls):
raise TypeError()
if self.__mapcls is None:
setattr(value, self.__relattr, None)
else:
rel = getattr(value, self.__relattr)
for mapobj in list(rel):
if mapobj.dn == self.__dn:
rel.remove(mapobj)
class DBBackreference:
def __init__(self, dbcls, relattr, mapcls=None, backattr=None):
self.dbcls = dbcls
self.relattr = relattr
self.mapcls = mapcls
self.backattr = backattr
def __get__(self, obj, objtype=None):
if obj is None:
return self
return DBBackreferenceSet(obj, self.dbcls, self.relattr, self.mapcls, self.backattr)
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
from collections.abc import Sequence
try:
# Added in v2.5
from ldap3.utils.dn import escape_rdn
except ImportError:
# From ldap3 source code, Copyright Giovanni Cannata, LGPL v3 license
def escape_rdn(rdn):
# '/' must be handled first or the escape slashes will be escaped!
for char in ['\\', ',', '+', '"', '<', '>', ';', '=', '\x00']:
rdn = rdn.replace(char, '\\' + char)
if rdn[0] == '#' or rdn[0] == ' ':
rdn = ''.join(('\\', rdn))
if rdn[-1] == ' ':
rdn = ''.join((rdn[:-1], '\\ '))
return rdn
from . import core
def add_to_session(obj, session):
if obj.ldap_object.session is None:
for func in obj.ldap_add_hooks:
func(obj)
session.add(obj.ldap_object, obj.dn, obj.ldap_object_classes)
class Session:
def __init__(self, get_connection):
self.ldap_session = core.Session(get_connection)
def add(self, obj):
add_to_session(obj, self.ldap_session)
def delete(self, obj):
self.ldap_session.delete(obj.ldap_object)
def commit(self):
self.ldap_session.commit()
def rollback(self):
self.ldap_session.rollback()
def make_modelobj(obj, model):
if obj is None:
return None
if not hasattr(obj, 'model'):
obj.model = model()
obj.model.ldap_object = obj
if not isinstance(obj.model, model):
return None
return obj.model
def make_modelobjs(objs, model):
modelobjs = []
for obj in objs:
modelobj = make_modelobj(obj, model)
if modelobj is not None:
modelobjs.append(modelobj)
return modelobjs
class Query(Sequence):
def __init__(self, model, filter_params=None):
self.__model = model
self.__filter_params = list(model.ldap_filter_params) + (filter_params or [])
@property
def __session(self):
return self.__model.ldap_mapper.session.ldap_session
def get(self, dn):
return make_modelobj(self.__session.get(dn, self.__filter_params), self.__model)
def all(self):
objs = self.__session.filter(self.__model.ldap_search_base, self.__filter_params)
objs = sorted(objs, key=lambda obj: obj.dn)
return make_modelobjs(objs, self.__model)
def first(self):
return (self.all() or [None])[0]
def one(self):
modelobjs = self.all()
if len(modelobjs) != 1:
raise Exception()
return modelobjs[0]
def one_or_none(self):
modelobjs = self.all()
if len(modelobjs) > 1:
raise Exception()
return (modelobjs or [None])[0]
def __contains__(self, value):
return value in self.all()
def __iter__(self):
return iter(self.all())
def __len__(self):
return len(self.all())
def __getitem__(self, index):
return self.all()[index]
def filter_by(self, **kwargs):
filter_params = [(getattr(self.__model, attr).name, value) for attr, value in kwargs.items()]
return type(self)(self.__model, self.__filter_params + filter_params)
class QueryWrapper:
def __get__(self, obj, objtype=None):
return objtype.query_class(objtype)
class Model:
# Overwritten by mapper
ldap_mapper = None
query_class = Query
query = QueryWrapper()
ldap_add_hooks = ()
# Overwritten by models
ldap_search_base = None
ldap_filter_params = ()
ldap_object_classes = ()
ldap_dn_base = None
ldap_dn_attribute = None
def __init__(self, **kwargs):
self.ldap_object = core.Object()
for key, value, in kwargs.items():
setattr(self, key, value)
@property
def dn(self):
if self.ldap_object.dn is not None:
return self.ldap_object.dn
if self.ldap_dn_base is None or self.ldap_dn_attribute is None:
return None
values = self.ldap_object.getattr(self.ldap_dn_attribute)
if not values:
return None
# escape_rdn can't handle empty strings
rdn = escape_rdn(values[0]) if values[0] else ''
return '%s=%s,%s'%(self.ldap_dn_attribute, rdn, self.ldap_dn_base)
def __repr__(self):
cls_name = '%s.%s'%(type(self).__module__, type(self).__name__)
if self.dn is not None:
return '<%s %s>'%(cls_name, self.dn)
return '<%s>'%cls_name
from collections.abc import MutableSet
from .model import make_modelobj, make_modelobjs, add_to_session
class UnboundObjectError(Exception):
pass
class RelationshipSet(MutableSet):
def __init__(self, ldap_object, name, model, destmodel):
self.__ldap_object = ldap_object
self.__name = name
self.__model = model # pylint: disable=unused-private-member
self.__destmodel = destmodel
def __modify_check(self, value):
if self.__ldap_object.session is None:
raise UnboundObjectError()
if not isinstance(value, self.__destmodel):
raise TypeError()
def __repr__(self):
return repr(set(self))
def __contains__(self, value):
if value is None or not isinstance(value, self.__destmodel):
return False
return value.ldap_object.dn in self.__ldap_object.getattr(self.__name)
def __iter__(self):
def get(dn):
return make_modelobj(self.__ldap_object.session.get(dn, self.__destmodel.ldap_filter_params), self.__destmodel)
dns = set(self.__ldap_object.getattr(self.__name))
return iter(filter(lambda obj: obj is not None, map(get, dns)))
def __len__(self):
return len(set(self))
def add(self, value):
self.__modify_check(value)
if value.ldap_object.session is None:
add_to_session(value, self.__ldap_object.session)
assert value.ldap_object.session == self.__ldap_object.session
self.__ldap_object.attr_append(self.__name, value.dn)
def discard(self, value):
self.__modify_check(value)
self.__ldap_object.attr_remove(self.__name, value.dn)
def update(self, values):
for value in values:
self.add(value)
class Relationship:
def __init__(self, name, destmodel, backref=None):
self.name = name
self.destmodel = destmodel
self.backref = backref
def __set_name__(self, cls, name):
if self.backref is not None:
setattr(self.destmodel, self.backref, Backreference(self.name, cls))
def __get__(self, obj, objtype=None):
if obj is None:
return self
return RelationshipSet(obj.ldap_object, self.name, type(obj), self.destmodel)
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
class BackreferenceSet(MutableSet):
def __init__(self, ldap_object, name, model, srcmodel):
self.__ldap_object = ldap_object
self.__name = name
self.__model = model # pylint: disable=unused-private-member
self.__srcmodel = srcmodel
def __modify_check(self, value):
if self.__ldap_object.session is None:
raise UnboundObjectError()
if not isinstance(value, self.__srcmodel):
raise TypeError()
def __get(self):
if self.__ldap_object.session is None:
return set()
filter_params = list(self.__srcmodel.ldap_filter_params) + [(self.__name, self.__ldap_object.dn)]
objs = self.__ldap_object.session.filter(self.__srcmodel.ldap_search_base, filter_params)
return set(make_modelobjs(objs, self.__srcmodel))
def __repr__(self):
return repr(self.__get())
def __contains__(self, value):
return value in self.__get()
def __iter__(self):
return iter(self.__get())
def __len__(self):
return len(self.__get())
def add(self, value):
self.__modify_check(value)
if value.ldap_object.session is None:
add_to_session(value, self.__ldap_object.session)
assert value.ldap_object.session == self.__ldap_object.session
if self.__ldap_object.dn not in value.ldap_object.getattr(self.__name):
value.ldap_object.attr_append(self.__name, self.__ldap_object.dn)
def discard(self, value):
self.__modify_check(value)
value.ldap_object.attr_remove(self.__name, self.__ldap_object.dn)
def update(self, values):
for value in values:
self.add(value)
class Backreference:
def __init__(self, name, srcmodel):
self.name = name
self.srcmodel = srcmodel
def __get__(self, obj, objtype=None):
if obj is None:
return self
return BackreferenceSet(obj.ldap_object, self.name, type(obj), self.srcmodel)
def __set__(self, obj, values):
tmp = self.__get__(obj)
tmp.clear()
for value in values:
tmp.add(value)
from .views import bp as bp_ui
bp = [bp_ui]
from uffd.ldap import ldap
from uffd.lazyconfig import lazyconfig_str, lazyconfig_list
class Mail(ldap.Model):
ldap_search_base = lazyconfig_str('LDAP_MAIL_SEARCH_BASE')
ldap_filter_params = lazyconfig_list('LDAP_MAIL_SEARCH_FILTER')
ldap_object_classes = lazyconfig_list('LDAP_MAIL_OBJECTCLASSES')
ldap_dn_attribute = lazyconfig_str('LDAP_MAIL_DN_ATTRIBUTE')
ldap_dn_base = lazyconfig_str('LDAP_MAIL_SEARCH_BASE')
uid = ldap.Attribute(lazyconfig_str('LDAP_MAIL_UID_ATTRIBUTE'))
receivers = ldap.Attribute(lazyconfig_str('LDAP_MAIL_RECEIVERS_ATTRIBUTE'), multi=True)
destinations = ldap.Attribute(lazyconfig_str('LDAP_MAIL_DESTINATIONS_ATTRIBUTE'), multi=True)
from .views import bp as _bp
bp = [_bp]
from warnings import warn
import urllib.parse
from flask import Blueprint, render_template, session, request, redirect, url_for, flash, current_app, abort
from flask_babel import gettext as _
from uffd.database import db
from uffd.ldap import ldap
from uffd.mfa.models import MFAMethod, TOTPMethod, WebauthnMethod, RecoveryCodeMethod
from uffd.session.views import login_required, login_required_pre_mfa, set_request_user
from uffd.user.models import User
from uffd.csrf import csrf_protect
from uffd.secure_redirect import secure_local_redirect
from uffd.ratelimit import Ratelimit, format_delay
bp = Blueprint('mfa', __name__, template_folder='templates', url_prefix='/mfa/')
mfa_ratelimit = Ratelimit('mfa', 1*60, 3)
@bp.route('/', methods=['GET'])
@login_required()
def setup():
return render_template('mfa/setup.html')
@bp.route('/setup/disable', methods=['GET'])
@login_required()
def disable():
return render_template('mfa/disable.html')
@bp.route('/setup/disable', methods=['POST'])
@login_required()
@csrf_protect(blueprint=bp)
def disable_confirm():
MFAMethod.query.filter_by(dn=request.user.dn).delete()
db.session.commit()
request.user.update_groups()
ldap.session.commit()
return redirect(url_for('mfa.setup'))
@bp.route('/admin/<int:uid>/disable')
@login_required()
@csrf_protect(blueprint=bp)
def admin_disable(uid):
# Group cannot be checked with login_required kwarg, because the config
# variable is not available when the decorator is processed
if not request.user.is_in_group(current_app.config['ACL_ADMIN_GROUP']):
flash('Access denied')
return redirect(url_for('index'))
user = User.query.filter_by(uid=uid).one()
MFAMethod.query.filter_by(dn=user.dn).delete()
db.session.commit()
user.update_groups()
ldap.session.commit()
flash(_('Two-factor authentication was reset'))
return redirect(url_for('user.show', uid=uid))
@bp.route('/setup/recovery', methods=['POST'])
@login_required()
@csrf_protect(blueprint=bp)
def setup_recovery():
for method in RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
db.session.delete(method)
methods = []
for _ in range(10):
method = RecoveryCodeMethod(request.user)
methods.append(method)
db.session.add(method)
db.session.commit()
return render_template('mfa/setup_recovery.html', methods=methods)
@bp.route('/setup/totp', methods=['GET'])
@login_required()
def setup_totp():
method = TOTPMethod(request.user)
session['mfa_totp_key'] = method.key
return render_template('mfa/setup_totp.html', method=method, name=request.values['name'])
@bp.route('/setup/totp', methods=['POST'])
@login_required()
@csrf_protect(blueprint=bp)
def setup_totp_finish():
if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
flash(_('Generate recovery codes first!'))
return redirect(url_for('mfa.setup'))
method = TOTPMethod(request.user, name=request.values['name'], key=session.pop('mfa_totp_key'))
if method.verify(request.form['code']):
db.session.add(method)
db.session.commit()
request.user.update_groups()
ldap.session.commit()
return redirect(url_for('mfa.setup'))
flash(_('Code is invalid'))
return redirect(url_for('mfa.setup_totp', name=request.values['name']))
@bp.route('/setup/totp/<int:id>/delete')
@login_required()
@csrf_protect(blueprint=bp)
def delete_totp(id): #pylint: disable=redefined-builtin
method = TOTPMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404()
db.session.delete(method)
db.session.commit()
request.user.update_groups()
ldap.session.commit()
return redirect(url_for('mfa.setup'))
# WebAuthn support is optional because fido2 has a pretty unstable
# interface (v0.5.0 on buster and current version are completely
# incompatible) and might be difficult to install with the correct version
try:
from fido2.client import ClientData
from fido2.server import Fido2Server, RelyingParty
from fido2.ctap2 import AttestationObject, AuthenticatorData
from fido2 import cbor
WEBAUTHN_SUPPORTED = True
except ImportError as err:
warn(_('2FA WebAuthn support disabled because import of the fido2 module failed (%s)')%err)
WEBAUTHN_SUPPORTED = False
bp.add_app_template_global(WEBAUTHN_SUPPORTED, name='webauthn_supported')
if WEBAUTHN_SUPPORTED:
def get_webauthn_server():
return Fido2Server(RelyingParty(current_app.config.get('MFA_RP_ID', urllib.parse.urlsplit(request.url).hostname), current_app.config['MFA_RP_NAME']))
@bp.route('/setup/webauthn/begin', methods=['POST'])
@login_required()
@csrf_protect(blueprint=bp)
def setup_webauthn_begin():
if not RecoveryCodeMethod.query.filter_by(dn=request.user.dn).all():
abort(403)
methods = WebauthnMethod.query.filter_by(dn=request.user.dn).all()
creds = [method.cred for method in methods]
server = get_webauthn_server()
registration_data, state = server.register_begin(
{
"id": request.user.dn.encode(),
"name": request.user.loginname,
"displayName": request.user.displayname,
},
creds,
user_verification='discouraged',
)
session["webauthn-state"] = state
return cbor.dumps(registration_data)
@bp.route('/setup/webauthn/complete', methods=['POST'])
@login_required()
@csrf_protect(blueprint=bp)
def setup_webauthn_complete():
server = get_webauthn_server()
data = cbor.loads(request.get_data())[0]
client_data = ClientData(data["clientDataJSON"])
att_obj = AttestationObject(data["attestationObject"])
auth_data = server.register_complete(session["webauthn-state"], client_data, att_obj)
method = WebauthnMethod(request.user, auth_data.credential_data, name=data['name'])
db.session.add(method)
db.session.commit()
request.user.update_groups()
ldap.session.commit()
return cbor.dumps({"status": "OK"})
@bp.route("/auth/webauthn/begin", methods=["POST"])
@login_required_pre_mfa(no_redirect=True)
def auth_webauthn_begin():
server = get_webauthn_server()
creds = [method.cred for method in request.user_pre_mfa.mfa_webauthn_methods]
if not creds:
abort(404)
auth_data, state = server.authenticate_begin(creds, user_verification='discouraged')
session["webauthn-state"] = state
return cbor.dumps(auth_data)
@bp.route("/auth/webauthn/complete", methods=["POST"])
@login_required_pre_mfa(no_redirect=True)
def auth_webauthn_complete():
server = get_webauthn_server()
creds = [method.cred for method in request.user_pre_mfa.mfa_webauthn_methods]
if not creds:
abort(404)
data = cbor.loads(request.get_data())[0]
credential_id = data["credentialId"]
client_data = ClientData(data["clientDataJSON"])
auth_data = AuthenticatorData(data["authenticatorData"])
signature = data["signature"]
# authenticate_complete() (as of python-fido2 v0.5.0, the version in Debian Buster)
# does not check signCount, although the spec recommends it
server.authenticate_complete(
session.pop("webauthn-state"),
creds,
credential_id,
client_data,
auth_data,
signature,
)
session['user_mfa'] = True
set_request_user()
return cbor.dumps({"status": "OK"})
@bp.route('/setup/webauthn/<int:id>/delete')
@login_required()
@csrf_protect(blueprint=bp)
def delete_webauthn(id): #pylint: disable=redefined-builtin
method = WebauthnMethod.query.filter_by(dn=request.user.dn, id=id).first_or_404()
db.session.delete(method)
db.session.commit()
request.user.update_groups()
ldap.session.commit()
return redirect(url_for('mfa.setup'))
@bp.route('/auth', methods=['GET'])
@login_required_pre_mfa()
def auth():
if not request.user_pre_mfa.mfa_enabled:
session['user_mfa'] = True
set_request_user()
if session.get('user_mfa'):
return secure_local_redirect(request.values.get('ref', url_for('index')))
return render_template('mfa/auth.html', ref=request.values.get('ref'))
@bp.route('/auth', methods=['POST'])
@login_required_pre_mfa()
def auth_finish():
delay = mfa_ratelimit.get_delay(request.user_pre_mfa.dn)
if delay:
flash(_('We received too many invalid attempts! Please wait at least %s.')%format_delay(delay))
return redirect(url_for('mfa.auth', ref=request.values.get('ref')))
for method in request.user_pre_mfa.mfa_totp_methods:
if method.verify(request.form['code']):
session['user_mfa'] = True
set_request_user()
return secure_local_redirect(request.values.get('ref', url_for('index')))
for method in request.user_pre_mfa.mfa_recovery_codes:
if method.verify(request.form['code']):
db.session.delete(method)
db.session.commit()
session['user_mfa'] = True
set_request_user()
if len(request.user_pre_mfa.mfa_recovery_codes) <= 1:
flash(_('You have exhausted your recovery codes. Please generate new ones now!'))
return redirect(url_for('mfa.setup'))
if len(request.user_pre_mfa.mfa_recovery_codes) <= 5:
flash(_('You only have a few recovery codes remaining. Make sure to generate new ones before they run out.'))
return redirect(url_for('mfa.setup'))
return secure_local_redirect(request.values.get('ref', url_for('index')))
mfa_ratelimit.log(request.user_pre_mfa.dn)
flash(_('Two-factor authentication failed'))
return redirect(url_for('mfa.auth', ref=request.values.get('ref')))