diff --git a/tests/migrations/test_fuzzy.py b/tests/migrations/test_fuzzy.py
index 310889b99dc9c2c88b4a2a01b4d2e880d145ebfe..597e9b33c5a2614b452739ad74d998d29db4fa46 100644
--- a/tests/migrations/test_fuzzy.py
+++ b/tests/migrations/test_fuzzy.py
@@ -61,8 +61,6 @@ class TestFuzzy(MigrationTestCase):
 		service = Service(name='testservice', access_group=group)
 		oauth2_client = OAuth2Client(service=service, client_id='testclient', client_secret='testsecret', redirect_uris=['http://localhost:1234/callback'], logout_uris=[OAuth2LogoutURI(method='GET', uri='http://localhost:1234/callback')])
 		db.session.add_all([service, oauth2_client])
-		db.session.add(OAuth2Grant(user=user, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now()))
-		db.session.add(OAuth2Token(user=user, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now()))
 		session = Session(
 			user=user,
 			secret='0919de9da3f7dc6c33ab849f44c20e8221b673ca701030de17488f3269fc5469f100e2ce56e5fd71305b23d8ecbb06d80d22004adcd3fefc5f5fcb80a436e31f2c2d9cc8fe8c59ae44871ae4524408d312474570280bf29d3ba145a4bd00010ca758eaa0795b180ec12978b42d13bf4c4f06f72103d44077022ce656610be855',
@@ -71,6 +69,8 @@ class TestFuzzy(MigrationTestCase):
 			mfa_done=True,
 		)
 		db.session.add(session)
+		db.session.add(OAuth2Grant(session=session, client=oauth2_client, _code='testcode', redirect_uri='http://example.com/callback', expires=datetime.datetime.now()))
+		db.session.add(OAuth2Token(session=session, client=oauth2_client, token_type='Bearer', _access_token='testcode', _refresh_token='testcode', expires=datetime.datetime.now()))
 		db.session.add(OAuth2DeviceLoginInitiation(client=oauth2_client, confirmations=[DeviceLoginConfirmation(session=session)]))
 		db.session.add(PasswordToken(user=user))
 		db.session.commit()
diff --git a/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py b/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70dadce200c24a46d10a3203c37fc6b4f85e9d1
--- /dev/null
+++ b/uffd/migrations/versions/e71e29cc605a_migrate_oauth2_state_from_user_to_session.py
@@ -0,0 +1,108 @@
+"""Migrate oauth2 state from user to session
+
+Revision ID: e71e29cc605a
+Revises: 99df71f0f4a0
+Create Date: 2024-05-18 21:59:20.435912
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+revision = 'e71e29cc605a'
+down_revision = '99df71f0f4a0'
+branch_labels = None
+depends_on = None
+
+def upgrade():
+	op.drop_table('oauth2grant')
+	op.drop_table('oauth2token')
+	op.create_table('oauth2grant',
+		sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+		sa.Column('expires', sa.DateTime(), nullable=False),
+		sa.Column('session_id', sa.Integer(), nullable=False),
+		sa.Column('client_db_id', sa.Integer(), nullable=False),
+		sa.Column('code', sa.String(length=255), nullable=False),
+		sa.Column('redirect_uri', sa.String(length=255), nullable=True),
+		sa.Column('nonce', sa.Text(), nullable=True),
+		sa.Column('_scopes', sa.Text(), nullable=False),
+		sa.Column('claims', sa.Text(), nullable=True),
+		sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant'))
+	)
+	op.create_table('oauth2token',
+		sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+		sa.Column('expires', sa.DateTime(), nullable=False),
+		sa.Column('session_id', sa.Integer(), nullable=False),
+		sa.Column('client_db_id', sa.Integer(), nullable=False),
+		sa.Column('token_type', sa.String(length=40), nullable=False),
+		sa.Column('access_token', sa.String(length=255), nullable=False),
+		sa.Column('refresh_token', sa.String(length=255), nullable=False),
+		sa.Column('_scopes', sa.Text(), nullable=False),
+		sa.Column('claims', sa.Text(), nullable=True),
+		sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')),
+		sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')),
+		sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token'))
+	)
+
+def downgrade():
+	# We don't drop and recreate the table here to improve fuzzy migration test coverage
+	meta = sa.MetaData(bind=op.get_bind())
+	session = sa.table('session',
+		sa.column('id', sa.Integer),
+		sa.column('user_id', sa.Integer()),
+	)
+
+	with op.batch_alter_table('oauth2token', schema=None) as batch_op:
+		batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
+	oauth2token = sa.Table('oauth2token', meta,
+		sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+		sa.Column('expires', sa.DateTime(), nullable=False),
+		sa.Column('session_id', sa.Integer(), nullable=False),
+		sa.Column('user_id', sa.Integer(), nullable=True),
+		sa.Column('client_db_id', sa.Integer(), nullable=False),
+		sa.Column('token_type', sa.String(length=40), nullable=False),
+		sa.Column('access_token', sa.String(length=255), nullable=False),
+		sa.Column('refresh_token', sa.String(length=255), nullable=False),
+		sa.Column('_scopes', sa.Text(), nullable=False),
+		sa.Column('claims', sa.Text(), nullable=True),
+		sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2token_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2token_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2token')),
+		sa.UniqueConstraint('access_token', name=op.f('uq_oauth2token_access_token')),
+		sa.UniqueConstraint('refresh_token', name=op.f('uq_oauth2token_refresh_token'))
+	)
+	op.execute(oauth2token.update().values(user_id=sa.select([session.c.user_id]).where(oauth2token.c.session_id==session.c.id).as_scalar()))
+	op.execute(oauth2token.delete().where(oauth2token.c.user_id==None))
+	with op.batch_alter_table('oauth2token', copy_from=oauth2token) as batch_op:
+		batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer())
+		batch_op.create_foreign_key('fk_oauth2token_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
+		batch_op.drop_constraint(batch_op.f('fk_oauth2token_session_id_session'), type_='foreignkey')
+		batch_op.drop_column('session_id')
+
+	with op.batch_alter_table('oauth2grant', schema=None) as batch_op:
+		batch_op.add_column(sa.Column('user_id', sa.INTEGER(), nullable=True))
+	oauth2grant = sa.Table('oauth2grant', meta,
+		sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
+		sa.Column('expires', sa.DateTime(), nullable=False),
+		sa.Column('session_id', sa.Integer(), nullable=False),
+		sa.Column('user_id', sa.Integer(), nullable=True),
+		sa.Column('client_db_id', sa.Integer(), nullable=False),
+		sa.Column('code', sa.String(length=255), nullable=False),
+		sa.Column('redirect_uri', sa.String(length=255), nullable=True),
+		sa.Column('nonce', sa.Text(), nullable=True),
+		sa.Column('_scopes', sa.Text(), nullable=False),
+		sa.Column('claims', sa.Text(), nullable=True),
+		sa.ForeignKeyConstraint(['client_db_id'], ['oauth2client.db_id'], name=op.f('fk_oauth2grant_client_db_id_oauth2client'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.ForeignKeyConstraint(['session_id'], ['session.id'], name=op.f('fk_oauth2grant_session_id_session'), onupdate='CASCADE', ondelete='CASCADE'),
+		sa.PrimaryKeyConstraint('id', name=op.f('pk_oauth2grant'))
+	)
+	op.execute(oauth2grant.update().values(user_id=sa.select([session.c.user_id]).where(oauth2grant.c.session_id==session.c.id).as_scalar()))
+	op.execute(oauth2grant.delete().where(oauth2grant.c.user_id==None))
+	with op.batch_alter_table('oauth2grant', copy_from=oauth2grant) as batch_op:
+		batch_op.alter_column('user_id', nullable=False, existing_type=sa.Integer())
+		batch_op.create_foreign_key('fk_oauth2grant_user_id_user', 'user', ['user_id'], ['id'], onupdate='CASCADE', ondelete='CASCADE')
+		batch_op.drop_constraint(batch_op.f('fk_oauth2grant_session_id_session'), type_='foreignkey')
+		batch_op.drop_column('session_id')
diff --git a/uffd/models/oauth2.py b/uffd/models/oauth2.py
index e8fd0cdd0a79cf5669d13c9fcf6f3d2c065c228a..bdd503872384168982aa823aa85795752381bc94 100644
--- a/uffd/models/oauth2.py
+++ b/uffd/models/oauth2.py
@@ -72,8 +72,8 @@ class OAuth2Grant(db.Model):
 	EXPIRES_IN = 100
 	expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Grant.EXPIRES_IN))
 
-	user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
-	user = relationship('User')
+	session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
+	session = relationship('Session')
 
 	client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
 	client = relationship('OAuth2Client')
@@ -96,10 +96,7 @@ class OAuth2Grant(db.Model):
 
 	@property
 	def service_user(self):
-		service_user = ServiceUser.query.get((self.client.service_id, self.user.id))
-		if service_user is None:
-			raise Exception('ServiceUser lookup failed')
-		return service_user
+		return ServiceUser.query.get((self.client.service_id, self.session.user_id))
 
 	@hybrid_property
 	def expired(self):
@@ -116,20 +113,23 @@ class OAuth2Grant(db.Model):
 		grant = cls.query.filter_by(id=grant_id, expired=False).first()
 		if not grant or not secrets.compare_digest(grant._code, grant_code):
 			return None
-		if grant.user.is_deactivated or not grant.client.access_allowed(grant.user):
+		if grant.session.expired or grant.session.user.is_deactivated:
+			return None
+		if not grant.service_user or not grant.service_user.has_access:
 			return None
 		return grant
 
 	def make_token(self, **kwargs):
 		return OAuth2Token(
-			user=self.user,
+			session=self.session,
 			client=self.client,
 			scopes=self.scopes,
 			claims=self.claims,
 			**kwargs
 		)
 
-@cleanup_task.delete_by_attribute('expired')
+# OAuth2Token objects are cleaned-up when the session expires and is
+# auto-deleted (or the user manually revokes it).
 class OAuth2Token(db.Model):
 	__tablename__ = 'oauth2token'
 	id = Column(Integer, primary_key=True, autoincrement=True)
@@ -137,8 +137,8 @@ class OAuth2Token(db.Model):
 	EXPIRES_IN = 3600
 	expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Token.EXPIRES_IN))
 
-	user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
-	user = relationship('User')
+	session_id = Column(Integer(), ForeignKey('session.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
+	session = relationship('Session')
 
 	client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
 	client = relationship('OAuth2Client')
@@ -163,10 +163,7 @@ class OAuth2Token(db.Model):
 
 	@property
 	def service_user(self):
-		service_user = ServiceUser.query.get((self.client.service_id, self.user.id))
-		if service_user is None:
-			raise Exception('ServiceUser lookup failed')
-		return service_user
+		return ServiceUser.query.get((self.client.service_id, self.session.user_id))
 
 	@hybrid_property
 	def expired(self):
@@ -181,7 +178,9 @@ class OAuth2Token(db.Model):
 		token = cls.query.filter_by(id=token_id, expired=False).first()
 		if not token or not secrets.compare_digest(token._access_token, token_secret):
 			return None
-		if token.user.is_deactivated or not token.client.access_allowed(token.user):
+		if token.session.expired or token.session.user.is_deactivated:
+			return None
+		if not token.service_user or not token.service_user.has_access:
 			return None
 		return token
 
diff --git a/uffd/models/session.py b/uffd/models/session.py
index b29ba3e245153a530961e1e3433ed5938318f990..368dbaaefece27c27bbe8d8c34dcb1aa3f2019de 100644
--- a/uffd/models/session.py
+++ b/uffd/models/session.py
@@ -29,6 +29,8 @@ class Session(db.Model):
 
 	user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False)
 	user = relationship('User', back_populates='sessions')
+	oauth2_grants = relationship('OAuth2Grant', back_populates='session', cascade='all, delete-orphan')
+	oauth2_tokens = relationship('OAuth2Token', back_populates='session', cascade='all, delete-orphan')
 
 	created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
 	last_used = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
diff --git a/uffd/views/oauth2.py b/uffd/views/oauth2.py
index f9a0b7b802cb07ad34af4b56635e49a85ad36392..8da8a68d593463ef36f0f351a2ea92e02f2ee0ae 100644
--- a/uffd/views/oauth2.py
+++ b/uffd/views/oauth2.py
@@ -254,8 +254,8 @@ def authorize_validate_request_oidc(grant):
 	return grant, sub_value, prompt_values
 
 def authorize_user(client):
-	if request.user:
-		return request.user
+	if request.session:
+		return request.session
 
 	if 'devicelogin_started' in session:
 		del session['devicelogin_started']
@@ -298,7 +298,7 @@ def authorize_user(client):
 			)
 		db.session.delete(initiation)
 		db.session.commit()
-		return confirmation.session.user
+		return confirmation.session
 
 	raise LoginRequiredError(
 		flash_message=_('You need to login to access this service'),
@@ -325,18 +325,18 @@ def authorize():
 		return render_template('oauth2/error.html', **err.params), 400
 
 	try:
-		user = authorize_user(grant.client)
-		if sub_value is not None and str(user.unix_uid) != sub_value:
+		_session = authorize_user(grant.client)
+		if sub_value is not None and str(_session.user.unix_uid) != sub_value:
 			# We only reach this point in OIDC requests with prompt=none, see
 			# authorize_validate_request_oidc. So this LoginRequiredError is
 			# always returned as a redirect back to the client.
 			raise LoginRequiredError()
-		if not grant.client.access_allowed(user):
+		if not grant.client.access_allowed(_session.user):
 			raise AccessDeniedError(flash_message=_(
 				"You don't have the permission to access the service <b>%(service_name)s</b>.",
 				service_name=grant.client.service.name
 			))
-		grant.user = user
+		grant.session = _session
 	except LoginRequiredError as err:
 		# We abuse LoginRequiredError to signal a redirect to the login page
 		if is_oidc and 'none' in prompt_values:
@@ -350,9 +350,6 @@ def authorize():
 			return oauth2_redirect(**err.params)
 		abort(403, description=err.flash_message)
 
-	session['oauth2-clients'] = session.get('oauth2-clients', [])
-	if grant.client.client_id not in session['oauth2-clients']:
-		session['oauth2-clients'].append(grant.client.client_id)
 	db.session.add(grant)
 	db.session.commit()
 	return oauth2_redirect(code=grant.code)
@@ -488,9 +485,11 @@ def userinfo():
 
 @bp.app_url_defaults
 def inject_logout_params(endpoint, values):
-	if endpoint != 'oauth2.logout' or not session.get('oauth2-clients'):
+	if endpoint != 'oauth2.logout' or not request.session:
 		return
-	values['client_ids'] = ','.join(session['oauth2-clients'])
+	client_ids = set(token.client.client_id for token in request.session.oauth2_tokens)
+	if client_ids:
+		values['client_ids'] = ','.join(client_ids)
 
 @bp.route('/oauth2/logout')
 def logout():