diff --git a/tests/test_user.py b/tests/test_user.py index 8973d16c6928d67d8b097a062a1f7253661ea3ce..d641eb2200328770130c775975717880c0eb1ebd 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -2,6 +2,7 @@ import datetime import unittest from flask import url_for, session +import sqlalchemy # These imports are required, because otherwise we get circular imports?! from uffd import user @@ -37,6 +38,66 @@ class TestUserModel(UffdTestCase): self.assertFalse(user_.has_permission(['uffd_admin', ['users', 'notagroup']])) self.assertTrue(admin.has_permission(['uffd_admin', ['users', 'notagroup']])) + def test_unix_uid_generation(self): + self.app.config['USER_MIN_UID'] = 10000 + self.app.config['USER_MAX_UID'] = 18999 + self.app.config['USER_SERVICE_MIN_UID'] = 19000 + self.app.config['USER_SERVICE_MAX_UID'] =19999 + User.query.delete() + db.session.commit() + user0 = User(loginname='user0', displayname='user0', mail='user0@example.com') + user1 = User(loginname='user1', displayname='user1', mail='user1@example.com') + user2 = User(loginname='user2', displayname='user2', mail='user2@example.com') + db.session.add_all([user0, user1, user2]) + db.session.commit() + self.assertEqual(user0.unix_uid, 10000) + self.assertEqual(user1.unix_uid, 10001) + self.assertEqual(user2.unix_uid, 10002) + db.session.delete(user1) + db.session.commit() + user3 = User(loginname='user3', displayname='user3', mail='user3@example.com') + db.session.add(user3) + db.session.commit() + self.assertEqual(user3.unix_uid, 10003) + service0 = User(loginname='service0', displayname='service0', mail='service0@example.com', is_service_user=True) + service1 = User(loginname='service1', displayname='service1', mail='service1@example.com', is_service_user=True) + db.session.add_all([service0, service1]) + db.session.commit() + self.assertEqual(service0.unix_uid, 19000) + self.assertEqual(service1.unix_uid, 19001) + + def test_unix_uid_generation_overlapping(self): + self.app.config['USER_MIN_UID'] = 10000 + self.app.config['USER_MAX_UID'] = 19999 + self.app.config['USER_SERVICE_MIN_UID'] = 10000 + self.app.config['USER_SERVICE_MAX_UID'] = 19999 + User.query.delete() + db.session.commit() + user0 = User(loginname='user0', displayname='user0', mail='user0@example.com') + service0 = User(loginname='service0', displayname='service0', mail='service0@example.com', is_service_user=True) + user1 = User(loginname='user1', displayname='user1', mail='user1@example.com') + db.session.add_all([user0, service0, user1]) + db.session.commit() + self.assertEqual(user0.unix_uid, 10000) + self.assertEqual(service0.unix_uid, 10001) + self.assertEqual(user1.unix_uid, 10002) + + def test_unix_uid_generation_overflow(self): + self.app.config['USER_MIN_UID'] = 10000 + self.app.config['USER_MAX_UID'] = 10001 + User.query.delete() + db.session.commit() + user0 = User(loginname='user0', displayname='user0', mail='user0@example.com') + user1 = User(loginname='user1', displayname='user1', mail='user1@example.com') + db.session.add_all([user0, user1]) + db.session.commit() + self.assertEqual(user0.unix_uid, 10000) + self.assertEqual(user1.unix_uid, 10001) + with self.assertRaises(sqlalchemy.exc.IntegrityError): + user2 = User(loginname='user2', displayname='user2', mail='user2@example.com') + db.session.add(user2) + db.session.commit() + class TestUserViews(UffdTestCase): def setUp(self): super().setUp() @@ -446,6 +507,44 @@ class TestUserCLI(UffdTestCase): result = self.app.test_cli_runner().invoke(args=['user', 'delete', 'doesnotexist']) self.assertEqual(result.exit_code, 1) +class TestGroupModel(UffdTestCase): + def test_unix_gid_generation(self): + self.app.config['GROUP_MIN_GID'] = 20000 + self.app.config['GROUP_MAX_GID'] = 49999 + Group.query.delete() + db.session.commit() + group0 = Group(name='group0', description='group0') + group1 = Group(name='group1', description='group1') + group2 = Group(name='group2', description='group2') + db.session.add_all([group0, group1, group2]) + db.session.commit() + self.assertEqual(group0.unix_gid, 20000) + self.assertEqual(group1.unix_gid, 20001) + self.assertEqual(group2.unix_gid, 20002) + db.session.delete(group1) + db.session.commit() + group3 = Group(name='group3', description='group3') + db.session.add(group3) + db.session.commit() + self.assertEqual(group3.unix_gid, 20003) + + def test_unix_gid_generation(self): + self.app.config['GROUP_MIN_GID'] = 20000 + self.app.config['GROUP_MAX_GID'] = 20001 + Group.query.delete() + db.session.commit() + group0 = Group(name='group0', description='group0') + group1 = Group(name='group1', description='group1') + db.session.add_all([group0, group1]) + db.session.commit() + self.assertEqual(group0.unix_gid, 20000) + self.assertEqual(group1.unix_gid, 20001) + db.session.commit() + with self.assertRaises(sqlalchemy.exc.IntegrityError): + group2 = Group(name='group2', description='group2') + db.session.add(group2) + db.session.commit() + class TestGroupViews(UffdTestCase): def setUp(self): super().setUp() diff --git a/uffd/default_config.cfg b/uffd/default_config.cfg index 8eadca78f51e0df218abf0932ddf418725801e3b..718e4933f674ad670546d804131a71c18adf8fb2 100644 --- a/uffd/default_config.cfg +++ b/uffd/default_config.cfg @@ -1,8 +1,11 @@ USER_GID=20001 + +# Service and non-service users must either have the same UID range or must not overlap USER_MIN_UID=10000 USER_MAX_UID=18999 USER_SERVICE_MIN_UID=19000 USER_SERVICE_MAX_UID=19999 + GROUP_MIN_GID=20000 GROUP_MAX_GID=49999 diff --git a/uffd/user/models.py b/uffd/user/models.py index eee54cd0405f1d4f43d555e91237b774487a8cff..900ce743865941dd29ded81a3d0d0c93882e51c7 100644 --- a/uffd/user/models.py +++ b/uffd/user/models.py @@ -5,27 +5,10 @@ from flask import current_app, escape from flask_babel import lazy_gettext from sqlalchemy import Column, Integer, String, ForeignKey, Boolean, Text from sqlalchemy.orm import relationship -from sqlalchemy.sql.expression import func from uffd.database import db from uffd.password_hash import PasswordHashAttribute, LowEntropyPasswordHash -def get_next_unix_uid(context): - is_service_user = bool(context.get_current_parameters().get('is_service_user', False)) - if is_service_user: - min_uid = current_app.config['USER_SERVICE_MIN_UID'] - max_uid = current_app.config['USER_SERVICE_MAX_UID'] - else: - min_uid = current_app.config['USER_MIN_UID'] - max_uid = current_app.config['USER_MAX_UID'] - next_uid = max(min_uid, - db.session.query(func.max(User.unix_uid + 1))\ - .filter(User.is_service_user==is_service_user)\ - .scalar() or 0) - if next_uid > max_uid: - raise Exception('No free uid found') - return next_uid - # pylint: disable=E1101 user_groups = db.Table('user_groups', Column('user_id', Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), primary_key=True), @@ -49,7 +32,8 @@ class User(db.Model): __tablename__ = 'user' id = Column(Integer(), primary_key=True, autoincrement=True) - unix_uid = Column(Integer(), unique=True, nullable=False, default=get_next_unix_uid) + # Default is set in event handler below + unix_uid = Column(Integer(), unique=True, nullable=False) loginname = Column(String(32), unique=True, nullable=False) displayname = Column(String(128), nullable=False) mail = Column(String(128), nullable=False) @@ -120,17 +104,41 @@ class User(db.Model): def update_groups(self): pass -def get_next_unix_gid(): - next_gid = max(current_app.config['GROUP_MIN_GID'], - db.session.query(func.max(Group.unix_gid + 1)).scalar() or 0) - if next_gid > current_app.config['GROUP_MAX_GID']: - raise Exception('No free gid found') - return next_gid +def next_id_expr(column, min_value, max_value): + # db.func.max(column) + 1: highest used value in range + 1, NULL if no values in range + # db.func.min(..., max_value): clip to range + # db.func.coalesce(..., min_value): if NULL use min_value + # if range is exhausted, evaluates to max_value that violates the UNIQUE constraint + return db.select([db.func.coalesce(db.func.min(db.func.max(column) + 1, max_value), min_value)])\ + .where(column >= min_value)\ + .where(column <= max_value) + +# Emulates the behaviour of Column.default. We cannot use a static SQL +# expression like we do for Group.unix_gid, because we need context +# information. We also cannot set Column.default to a callable, because +# SQLAlchemy always treats the return value as a literal value and does +# not allow SQL expressions. +@db.event.listens_for(User, 'before_insert') +def set_default_unix_uid(mapper, connect, target): + # pylint: disable=unused-argument + if target.unix_uid is not None: + return + if target.is_service_user: + min_uid = current_app.config['USER_SERVICE_MIN_UID'] + max_uid = current_app.config['USER_SERVICE_MAX_UID'] + else: + min_uid = current_app.config['USER_MIN_UID'] + max_uid = current_app.config['USER_MAX_UID'] + target.unix_uid = next_id_expr(User.unix_uid, min_uid, max_uid) + +group_table = db.table('group', db.column('unix_gid')) +min_gid = db.bindparam('min_gid', unique=True, callable_=lambda: current_app.config['GROUP_MIN_GID'], type_=db.Integer) +max_gid = db.bindparam('max_gid', unique=True, callable_=lambda: current_app.config['GROUP_MAX_GID'], type_=db.Integer) class Group(db.Model): __tablename__ = 'group' id = Column(Integer(), primary_key=True, autoincrement=True) - unix_gid = Column(Integer(), unique=True, nullable=False, default=get_next_unix_gid) + unix_gid = Column(Integer(), unique=True, nullable=False, default=next_id_expr(group_table.c.unix_gid, min_gid, max_gid)) name = Column(String(32), unique=True, nullable=False) description = Column(String(128), nullable=False, default='') members = relationship('User', secondary='user_groups', back_populates='groups')