diff --git a/tests/migrations/test_fuzzy.py b/tests/migrations/test_fuzzy.py index 310889b99dc9c2c88b4a2a01b4d2e880d145ebfe..597e9b33c5a2614b452739ad74d998d29db4fa46 100644 --- a/tests/migrations/test_fuzzy.py +++ b/tests/migrations/test_fuzzy.py @@ -61,8 +61,6 @@ class TestFuzzy(MigrationTestCase): service = Service(name='testservice', access_group=group) oauth2_client = OAuth2Client(service=service, client_id='testclient', client_secret='testsecret', redirect_uris=['http://localhost:1234/callback'], logout_uris=[OAuth2LogoutURI(method='GET', uri='http://localhost:1234/callback')]) db.session.add_all([service, oauth2_client]) - db.session.add(OAuth2Grant(user=user, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now())) - db.session.add(OAuth2Token(user=user, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now())) session = Session( user=user, secret='0919de9da3f7dc6c33ab849f44c20e8221b673ca701030de17488f3269fc5469f100e2ce56e5fd71305b23d8ecbb06d80d22004adcd3fefc5f5fcb80a436e31f2c2d9cc8fe8c59ae44871ae4524408d312474570280bf29d3ba145a4bd00010ca758eaa0795b180ec12978b42d13bf4c4f06f72103d44077022ce656610be855', @@ -71,6 +69,8 @@ class TestFuzzy(MigrationTestCase): mfa_done=True, ) db.session.add(session) + db.session.add(OAuth2Grant(session=session, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now())) + db.session.add(OAuth2Token(session=session, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now())) db.session.add(OAuth2DeviceLoginInitiation(client=oauth2_client, confirmations=[DeviceLoginConfirmation(session=session)])) db.session.add(PasswordToken(user=user)) db.session.commit() diff --git a/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py b/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py new file mode 100644 index 0000000000000000000000000000000000000000..f70dadce200c24a46d10a3203c37fc6b4f85e9d1 --- /dev/null +++ b/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py @@ -0,0 +1,108 @@ +"""Migrate oauth2 state from user to session + +Revision ID: e71e29cc605a +Revises: 99df71f0f4a0 +Create Date: 2024-05-18 21:59:20.435912 + +""" +from alembic import op +import sqlalchemy as sa + +revision = 'e71e29cc605a' +down_revision = '99df71f0f4a0' +branch_labels = None +depends_on = None + +def upgrade(): + op.drop_table('oauth2grant') + op.drop_table('oauth2token') + op.create_table('oauth2grant', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('expires', sa.DateTime(), nullable=False), + sa.Column('session_id', sa.Integer(), nullable=False), + sa.Column('client_db_id', sa.Integer(), nullable=False), + sa.Column('code', sa.String(length=255), nullable=False), + sa.Column('redirect_uri', sa.String(length=255), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('_scopes', sa.Text(), nullable=False), + sa.Column('claims', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant')) + ) + op.create_table('oauth2token', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('expires', sa.DateTime(), nullable=False), + sa.Column('session_id', sa.Integer(), nullable=False), + sa.Column('client_db_id', sa.Integer(), nullable=False), + sa.Column('token_type', sa.String(length=40), nullable=False), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=False), + sa.Column('_scopes', sa.Text(), nullable=False), + sa.Column('claims', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')), + sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')), + sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token')) + ) + +def downgrade(): + # We don't drop and recreate the table here to improve fuzzy migration test coverage + meta = sa.MetaData(bind=op.get_bind()) + session = sa.table('session', + sa.column('id', sa.Integer), + sa.column('user_id', sa.Integer()), + ) + + with op.batch_alter_table('oauth2token', schema=None) as batch_op: + batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True)) + oauth2token = sa.Table('oauth2token', meta, + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('expires', sa.DateTime(), nullable=False), + sa.Column('session_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('client_db_id', sa.Integer(), nullable=False), + sa.Column('token_type', sa.String(length=40), nullable=False), + sa.Column('access_token', sa.String(length=255), nullable=False), + sa.Column('refresh_token', sa.String(length=255), nullable=False), + sa.Column('_scopes', sa.Text(), nullable=False), + sa.Column('claims', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')), + sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')), + sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token')) + ) + op.execute(oauth2token.update().values(user_id=sa.select([session.c.user_id]).where(oauth2token.c.session_id==session.c.id).as_scalar())) + op.execute(oauth2token.delete().where(oauth2token.c.user_id==None)) + with op.batch_alter_table('oauth2token', copy_from=oauth2token) as batch_op: + batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer()) + batch_op.create_foreign_key('fk_oauth2token_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE') + batch_op.drop_constraint(batch_op.f('fk_oauth2token_session_id_session'), type_='foreignkey') + batch_op.drop_column('session_id') + + with op.batch_alter_table('oauth2grant', schema=None) as batch_op: + batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True)) + oauth2grant = sa.Table('oauth2grant', meta, + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('expires', sa.DateTime(), nullable=False), + sa.Column('session_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('client_db_id', sa.Integer(), nullable=False), + sa.Column('code', sa.String(length=255), nullable=False), + sa.Column('redirect_uri', sa.String(length=255), nullable=True), + sa.Column('nonce', sa.Text(), nullable=True), + sa.Column('_scopes', sa.Text(), nullable=False), + sa.Column('claims', sa.Text(), nullable=True), + sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant')) + ) + op.execute(oauth2grant.update().values(user_id=sa.select([session.c.user_id]).where(oauth2grant.c.session_id==session.c.id).as_scalar())) + op.execute(oauth2grant.delete().where(oauth2grant.c.user_id==None)) + with op.batch_alter_table('oauth2grant', copy_from=oauth2grant) as batch_op: + batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer()) + batch_op.create_foreign_key('fk_oauth2grant_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE') + batch_op.drop_constraint(batch_op.f('fk_oauth2grant_session_id_session'), type_='foreignkey') + batch_op.drop_column('session_id') diff --git a/uffd/models/oauth2.py b/uffd/models/oauth2.py index e8fd0cdd0a79cf5669d13c9fcf6f3d2c065c228a..bdd503872384168982aa823aa85795752381bc94 100644 --- a/uffd/models/oauth2.py +++ b/uffd/models/oauth2.py @@ -72,8 +72,8 @@ class OAuth2Grant(db.Model): EXPIRES_IN = 100 expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Grant.EXPIRES_IN)) - user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) - user = relationship('User') + session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) + session = relationship('Session') client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client = relationship('OAuth2Client') @@ -96,10 +96,7 @@ class OAuth2Grant(db.Model): @property def service_user(self): - service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) - if service_user is None: - raise Exception('ServiceUser lookup failed') - return service_user + return ServiceUser.query.get((self.client.service_id, self.session.user_id)) @hybrid_property def expired(self): @@ -116,20 +113,23 @@ class OAuth2Grant(db.Model): grant = cls.query.filter_by(id=grant_id, expired=False).first() if not grant or not secrets.compare_digest(grant._code, grant_code): return None - if grant.user.is_deactivated or not grant.client.access_allowed(grant.user): + if grant.session.expired or grant.session.user.is_deactivated: + return None + if not grant.service_user or not grant.service_user.has_access: return None return grant def make_token(self, **kwargs): return OAuth2Token( - user=self.user, + session=self.session, client=self.client, scopes=self.scopes, claims=self.claims, **kwargs ) -@cleanup_task.delete_by_attribute('expired') +# OAuth2Token objects are cleaned-up when the session expires and is +# auto-deleted (or the user manually revokes it). class OAuth2Token(db.Model): __tablename__ = 'oauth2token' id = Column(Integer, primary_key=True, autoincrement=True) @@ -137,8 +137,8 @@ class OAuth2Token(db.Model): EXPIRES_IN = 3600 expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Token.EXPIRES_IN)) - user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) - user = relationship('User') + session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) + session = relationship('Session') client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client = relationship('OAuth2Client') @@ -163,10 +163,7 @@ class OAuth2Token(db.Model): @property def service_user(self): - service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) - if service_user is None: - raise Exception('ServiceUser lookup failed') - return service_user + return ServiceUser.query.get((self.client.service_id, self.session.user_id)) @hybrid_property def expired(self): @@ -181,7 +178,9 @@ class OAuth2Token(db.Model): token = cls.query.filter_by(id=token_id, expired=False).first() if not token or not secrets.compare_digest(token._access_token, token_secret): return None - if token.user.is_deactivated or not token.client.access_allowed(token.user): + if token.session.expired or token.session.user.is_deactivated: + return None + if not token.service_user or not token.service_user.has_access: return None return token diff --git a/uffd/models/session.py b/uffd/models/session.py index b29ba3e245153a530961e1e3433ed5938318f990..368dbaaefece27c27bbe8d8c34dcb1aa3f2019de 100644 --- a/uffd/models/session.py +++ b/uffd/models/session.py @@ -29,6 +29,8 @@ class Session(db.Model): user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) user = relationship('User', back_populates='sessions') + oauth2_grants = relationship('OAuth2Grant', back_populates='session', cascade='all, delete-orphan') + oauth2_tokens = relationship('OAuth2Token', back_populates='session', cascade='all, delete-orphan') created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) last_used = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) diff --git a/uffd/views/oauth2.py b/uffd/views/oauth2.py index f9a0b7b802cb07ad34af4b56635e49a85ad36392..8da8a68d593463ef36f0f351a2ea92e02f2ee0ae 100644 --- a/uffd/views/oauth2.py +++ b/uffd/views/oauth2.py @@ -254,8 +254,8 @@ def authorize_validate_request_oidc(grant): return grant, sub_value, prompt_values def authorize_user(client): - if request.user: - return request.user + if request.session: + return request.session if 'devicelogin_started' in session: del session['devicelogin_started'] @@ -298,7 +298,7 @@ def authorize_user(client): ) db.session.delete(initiation) db.session.commit() - return confirmation.session.user + return confirmation.session raise LoginRequiredError( flash_message=_('You need to login to access this service'), @@ -325,18 +325,18 @@ def authorize(): return render_template('oauth2/error.html', **err.params), 400 try: - user = authorize_user(grant.client) - if sub_value is not None and str(user.unix_uid) != sub_value: + _session = authorize_user(grant.client) + if sub_value is not None and str(_session.user.unix_uid) != sub_value: # We only reach this point in OIDC requests with prompt=none, see # authorize_validate_request_oidc. So this LoginRequiredError is # always returned as a redirect back to the client. raise LoginRequiredError() - if not grant.client.access_allowed(user): + if not grant.client.access_allowed(_session.user): raise AccessDeniedError(flash_message=_( "You don't have the permission to access the service <b>%(service_name)s</b>.", service_name=grant.client.service.name )) - grant.user = user + grant.session = _session except LoginRequiredError as err: # We abuse LoginRequiredError to signal a redirect to the login page if is_oidc and 'none' in prompt_values: @@ -350,9 +350,6 @@ def authorize(): return oauth2_redirect(**err.params) abort(403, description=err.flash_message) - session['oauth2-clients'] = session.get('oauth2-clients', []) - if grant.client.client_id not in session['oauth2-clients']: - session['oauth2-clients'].append(grant.client.client_id) db.session.add(grant) db.session.commit() return oauth2_redirect(code=grant.code) @@ -488,9 +485,11 @@ def userinfo(): @bp.app_url_defaults def inject_logout_params(endpoint, values): - if endpoint != 'oauth2.logout' or not session.get('oauth2-clients'): + if endpoint != 'oauth2.logout' or not request.session: return - values['client_ids'] = ','.join(session['oauth2-clients']) + client_ids = set(token.client.client_id for token in request.session.oauth2_tokens) + if client_ids: + values['client_ids'] = ','.join(client_ids) @bp.route('/oauth2/logout') def logout():