Skip to content
Snippets Groups Projects
Commit 89f1ecdd authored by Julian's avatar Julian
Browse files

Bind OAuth2 state to sessions instead of users

Prerequisite for implementing missing OIDC features.
parent 08926d1f
Branches
No related tags found
No related merge requests found
...@@ -61,8 +61,6 @@ class TestFuzzy(MigrationTestCase): ...@@ -61,8 +61,6 @@ class TestFuzzy(MigrationTestCase):
service = Service(name='testservice', access_group=group) service = Service(name='testservice', access_group=group)
oauth2_client = OAuth2Client(service=service, client_id='testclient', client_secret='testsecret', redirect_uris=['http://localhost:1234/callback'], logout_uris=[OAuth2LogoutURI(method='GET', uri='http://localhost:1234/callback')]) oauth2_client = OAuth2Client(service=service, client_id='testclient', client_secret='testsecret', redirect_uris=['http://localhost:1234/callback'], logout_uris=[OAuth2LogoutURI(method='GET', uri='http://localhost:1234/callback')])
db.session.add_all([service, oauth2_client]) db.session.add_all([service, oauth2_client])
db.session.add(OAuth2Grant(user=user, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now()))
db.session.add(OAuth2Token(user=user, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now()))
session = Session( session = Session(
user=user, user=user,
secret='0919de9da3f7dc6c33ab849f44c20e8221b673ca701030de17488f3269fc5469f100e2ce56e5fd71305b23d8ecbb06d80d22004adcd3fefc5f5fcb80a436e31f2c2d9cc8fe8c59ae44871ae4524408d312474570280bf29d3ba145a4bd00010ca758eaa0795b180ec12978b42d13bf4c4f06f72103d44077022ce656610be855', secret='0919de9da3f7dc6c33ab849f44c20e8221b673ca701030de17488f3269fc5469f100e2ce56e5fd71305b23d8ecbb06d80d22004adcd3fefc5f5fcb80a436e31f2c2d9cc8fe8c59ae44871ae4524408d312474570280bf29d3ba145a4bd00010ca758eaa0795b180ec12978b42d13bf4c4f06f72103d44077022ce656610be855',
...@@ -71,6 +69,8 @@ class TestFuzzy(MigrationTestCase): ...@@ -71,6 +69,8 @@ class TestFuzzy(MigrationTestCase):
mfa_done=True, mfa_done=True,
) )
db.session.add(session) db.session.add(session)
db.session.add(OAuth2Grant(session=session, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now()))
db.session.add(OAuth2Token(session=session, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now()))
db.session.add(OAuth2DeviceLoginInitiation(client=oauth2_client, confirmations=[DeviceLoginConfirmation(session=session)])) db.session.add(OAuth2DeviceLoginInitiation(client=oauth2_client, confirmations=[DeviceLoginConfirmation(session=session)]))
db.session.add(PasswordToken(user=user)) db.session.add(PasswordToken(user=user))
db.session.commit() db.session.commit()
......
"""Migrate oauth2 state from user to session
Revision ID: e71e29cc605a
Revises: 99df71f0f4a0
Create Date: 2024-05-18 21:59:20.435912
"""
from alembic import op
import sqlalchemy as sa
revision = 'e71e29cc605a'
down_revision = '99df71f0f4a0'
branch_labels = None
depends_on = None
def upgrade():
op.drop_table('oauth2grant')
op.drop_table('oauth2token')
op.create_table('oauth2grant',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('expires', sa.DateTime(), nullable=False),
sa.Column('session_id', sa.Integer(), nullable=False),
sa.Column('client_db_id', sa.Integer(), nullable=False),
sa.Column('code', sa.String(length=255), nullable=False),
sa.Column('redirect_uri', sa.String(length=255), nullable=True),
sa.Column('nonce', sa.Text(), nullable=True),
sa.Column('_scopes', sa.Text(), nullable=False),
sa.Column('claims', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant'))
)
op.create_table('oauth2token',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('expires', sa.DateTime(), nullable=False),
sa.Column('session_id', sa.Integer(), nullable=False),
sa.Column('client_db_id', sa.Integer(), nullable=False),
sa.Column('token_type', sa.String(length=40), nullable=False),
sa.Column('access_token', sa.String(length=255), nullable=False),
sa.Column('refresh_token', sa.String(length=255), nullable=False),
sa.Column('_scopes', sa.Text(), nullable=False),
sa.Column('claims', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')),
sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')),
sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token'))
)
def downgrade():
# We don't drop and recreate the table here to improve fuzzy migration test coverage
meta = sa.MetaData(bind=op.get_bind())
session = sa.table('session',
sa.column('id', sa.Integer),
sa.column('user_id', sa.Integer()),
)
with op.batch_alter_table('oauth2token', schema=None) as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
oauth2token = sa.Table('oauth2token', meta,
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('expires', sa.DateTime(), nullable=False),
sa.Column('session_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('client_db_id', sa.Integer(), nullable=False),
sa.Column('token_type', sa.String(length=40), nullable=False),
sa.Column('access_token', sa.String(length=255), nullable=False),
sa.Column('refresh_token', sa.String(length=255), nullable=False),
sa.Column('_scopes', sa.Text(), nullable=False),
sa.Column('claims', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')),
sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')),
sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token'))
)
op.execute(oauth2token.update().values(user_id=sa.select([session.c.user_id]).where(oauth2token.c.session_id==session.c.id).as_scalar()))
op.execute(oauth2token.delete().where(oauth2token.c.user_id==None))
with op.batch_alter_table('oauth2token', copy_from=oauth2token) as batch_op:
batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer())
batch_op.create_foreign_key('fk_oauth2token_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
batch_op.drop_constraint(batch_op.f('fk_oauth2token_session_id_session'), type_='foreignkey')
batch_op.drop_column('session_id')
with op.batch_alter_table('oauth2grant', schema=None) as batch_op:
batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
oauth2grant = sa.Table('oauth2grant', meta,
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('expires', sa.DateTime(), nullable=False),
sa.Column('session_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('client_db_id', sa.Integer(), nullable=False),
sa.Column('code', sa.String(length=255), nullable=False),
sa.Column('redirect_uri', sa.String(length=255), nullable=True),
sa.Column('nonce', sa.Text(), nullable=True),
sa.Column('_scopes', sa.Text(), nullable=False),
sa.Column('claims', sa.Text(), nullable=True),
sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant'))
)
op.execute(oauth2grant.update().values(user_id=sa.select([session.c.user_id]).where(oauth2grant.c.session_id==session.c.id).as_scalar()))
op.execute(oauth2grant.delete().where(oauth2grant.c.user_id==None))
with op.batch_alter_table('oauth2grant', copy_from=oauth2grant) as batch_op:
batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer())
batch_op.create_foreign_key('fk_oauth2grant_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
batch_op.drop_constraint(batch_op.f('fk_oauth2grant_session_id_session'), type_='foreignkey')
batch_op.drop_column('session_id')
...@@ -72,8 +72,8 @@ class OAuth2Grant(db.Model): ...@@ -72,8 +72,8 @@ class OAuth2Grant(db.Model):
EXPIRES_IN = 100 EXPIRES_IN = 100
expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Grant.EXPIRES_IN)) expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Grant.EXPIRES_IN))
user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
user = relationship('User') session = relationship('Session')
client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
client = relationship('OAuth2Client') client = relationship('OAuth2Client')
...@@ -96,10 +96,7 @@ class OAuth2Grant(db.Model): ...@@ -96,10 +96,7 @@ class OAuth2Grant(db.Model):
@property @property
def service_user(self): def service_user(self):
service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) return ServiceUser.query.get((self.client.service_id, self.session.user_id))
if service_user is None:
raise Exception('ServiceUser lookup failed')
return service_user
@hybrid_property @hybrid_property
def expired(self): def expired(self):
...@@ -116,20 +113,23 @@ class OAuth2Grant(db.Model): ...@@ -116,20 +113,23 @@ class OAuth2Grant(db.Model):
grant = cls.query.filter_by(id=grant_id, expired=False).first() grant = cls.query.filter_by(id=grant_id, expired=False).first()
if not grant or not secrets.compare_digest(grant._code, grant_code): if not grant or not secrets.compare_digest(grant._code, grant_code):
return None return None
if grant.user.is_deactivated or not grant.client.access_allowed(grant.user): if grant.session.expired or grant.session.user.is_deactivated:
return None
if not grant.service_user or not grant.service_user.has_access:
return None return None
return grant return grant
def make_token(self, **kwargs): def make_token(self, **kwargs):
return OAuth2Token( return OAuth2Token(
user=self.user, session=self.session,
client=self.client, client=self.client,
scopes=self.scopes, scopes=self.scopes,
claims=self.claims, claims=self.claims,
**kwargs **kwargs
) )
@cleanup_task.delete_by_attribute('expired') # OAuth2Token objects are cleaned-up when the session expires and is
# auto-deleted (or the user manually revokes it).
class OAuth2Token(db.Model): class OAuth2Token(db.Model):
__tablename__ = 'oauth2token' __tablename__ = 'oauth2token'
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
...@@ -137,8 +137,8 @@ class OAuth2Token(db.Model): ...@@ -137,8 +137,8 @@ class OAuth2Token(db.Model):
EXPIRES_IN = 3600 EXPIRES_IN = 3600
expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Token.EXPIRES_IN)) expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Token.EXPIRES_IN))
user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
user = relationship('User') session = relationship('Session')
client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
client = relationship('OAuth2Client') client = relationship('OAuth2Client')
...@@ -163,10 +163,7 @@ class OAuth2Token(db.Model): ...@@ -163,10 +163,7 @@ class OAuth2Token(db.Model):
@property @property
def service_user(self): def service_user(self):
service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) return ServiceUser.query.get((self.client.service_id, self.session.user_id))
if service_user is None:
raise Exception('ServiceUser lookup failed')
return service_user
@hybrid_property @hybrid_property
def expired(self): def expired(self):
...@@ -181,7 +178,9 @@ class OAuth2Token(db.Model): ...@@ -181,7 +178,9 @@ class OAuth2Token(db.Model):
token = cls.query.filter_by(id=token_id, expired=False).first() token = cls.query.filter_by(id=token_id, expired=False).first()
if not token or not secrets.compare_digest(token._access_token, token_secret): if not token or not secrets.compare_digest(token._access_token, token_secret):
return None return None
if token.user.is_deactivated or not token.client.access_allowed(token.user): if token.session.expired or token.session.user.is_deactivated:
return None
if not token.service_user or not token.service_user.has_access:
return None return None
return token return token
......
...@@ -29,6 +29,8 @@ class Session(db.Model): ...@@ -29,6 +29,8 @@ class Session(db.Model):
user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
user = relationship('User', back_populates='sessions') user = relationship('User', back_populates='sessions')
oauth2_grants = relationship('OAuth2Grant', back_populates='session', cascade='all, delete-orphan')
oauth2_tokens = relationship('OAuth2Token', back_populates='session', cascade='all, delete-orphan')
created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
last_used = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) last_used = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
......
...@@ -254,8 +254,8 @@ def authorize_validate_request_oidc(grant): ...@@ -254,8 +254,8 @@ def authorize_validate_request_oidc(grant):
return grant, sub_value, prompt_values return grant, sub_value, prompt_values
def authorize_user(client): def authorize_user(client):
if request.user: if request.session:
return request.user return request.session
if 'devicelogin_started' in session: if 'devicelogin_started' in session:
del session['devicelogin_started'] del session['devicelogin_started']
...@@ -298,7 +298,7 @@ def authorize_user(client): ...@@ -298,7 +298,7 @@ def authorize_user(client):
) )
db.session.delete(initiation) db.session.delete(initiation)
db.session.commit() db.session.commit()
return confirmation.session.user return confirmation.session
raise LoginRequiredError( raise LoginRequiredError(
flash_message=_('You need to login to access this service'), flash_message=_('You need to login to access this service'),
...@@ -325,18 +325,18 @@ def authorize(): ...@@ -325,18 +325,18 @@ def authorize():
return render_template('oauth2/error.html', **err.params), 400 return render_template('oauth2/error.html', **err.params), 400
try: try:
user = authorize_user(grant.client) _session = authorize_user(grant.client)
if sub_value is not None and str(user.unix_uid) != sub_value: if sub_value is not None and str(_session.user.unix_uid) != sub_value:
# We only reach this point in OIDC requests with prompt=none, see # We only reach this point in OIDC requests with prompt=none, see
# authorize_validate_request_oidc. So this LoginRequiredError is # authorize_validate_request_oidc. So this LoginRequiredError is
# always returned as a redirect back to the client. # always returned as a redirect back to the client.
raise LoginRequiredError() raise LoginRequiredError()
if not grant.client.access_allowed(user): if not grant.client.access_allowed(_session.user):
raise AccessDeniedError(flash_message=_( raise AccessDeniedError(flash_message=_(
"You don't have the permission to access the service <b>%(service_name)s</b>.", "You don't have the permission to access the service <b>%(service_name)s</b>.",
service_name=grant.client.service.name service_name=grant.client.service.name
)) ))
grant.user = user grant.session = _session
except LoginRequiredError as err: except LoginRequiredError as err:
# We abuse LoginRequiredError to signal a redirect to the login page # We abuse LoginRequiredError to signal a redirect to the login page
if is_oidc and 'none' in prompt_values: if is_oidc and 'none' in prompt_values:
...@@ -350,9 +350,6 @@ def authorize(): ...@@ -350,9 +350,6 @@ def authorize():
return oauth2_redirect(**err.params) return oauth2_redirect(**err.params)
abort(403, description=err.flash_message) abort(403, description=err.flash_message)
session['oauth2-clients'] = session.get('oauth2-clients', [])
if grant.client.client_id not in session['oauth2-clients']:
session['oauth2-clients'].append(grant.client.client_id)
db.session.add(grant) db.session.add(grant)
db.session.commit() db.session.commit()
return oauth2_redirect(code=grant.code) return oauth2_redirect(code=grant.code)
...@@ -488,9 +485,11 @@ def userinfo(): ...@@ -488,9 +485,11 @@ def userinfo():
@bp.app_url_defaults @bp.app_url_defaults
def inject_logout_params(endpoint, values): def inject_logout_params(endpoint, values):
if endpoint != 'oauth2.logout' or not session.get('oauth2-clients'): if endpoint != 'oauth2.logout' or not request.session:
return return
values['client_ids'] = ','.join(session['oauth2-clients']) client_ids = set(token.client.client_id for token in request.session.oauth2_tokens)
if client_ids:
values['client_ids'] = ','.join(client_ids)
@bp.route('/oauth2/logout') @bp.route('/oauth2/logout')
def logout(): def logout():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment