diff --git a/uffd/database.py b/uffd/database.py index 68800c741860a3a3fb03eabba3da873506aaecb1..83d5446d958898b1429f0119e66cfba91642b44a 100644 --- a/uffd/database.py +++ b/uffd/database.py @@ -1,6 +1,8 @@ from collections import OrderedDict from sqlalchemy import MetaData, event +from sqlalchemy.types import TypeDecorator, Text +from sqlalchemy.ext.mutable import MutableList from flask_sqlalchemy import SQLAlchemy from flask.json import JSONEncoder @@ -40,3 +42,24 @@ class SQLAlchemyJSON(JSONEncoder): result[key] = getattr(o, key) return result return JSONEncoder.default(self, o) + +class CommaSeparatedList(TypeDecorator): + # For some reason TypeDecorator.process_literal_param and + # TypeEngine.python_type are abstract but not actually required + # pylint: disable=abstract-method + + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + for item in value: + if ',' in item: + raise ValueError('Items of comma-separated list must not contain commas') + return ','.join(value) + + def process_result_value(self, value, dialect): + if value is None: + return None + return MutableList(value.split(',')) diff --git a/uffd/oauth2/models.py b/uffd/oauth2/models.py index 374707514f18c1b5780bab44d53bb0e5420d9992..bae0f7ceec35aebf2067edcd13bed5ddd4d5ddd7 100644 --- a/uffd/oauth2/models.py +++ b/uffd/oauth2/models.py @@ -2,11 +2,11 @@ import datetime from flask import current_app from flask_babel import get_locale, gettext as _ -from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.ext.hybrid import hybrid_property -from uffd.database import db +from uffd.database import db, CommaSeparatedList from uffd.tasks import cleanup_task from uffd.session.models import DeviceLoginInitiation, DeviceLoginType @@ -64,18 +64,7 @@ class OAuth2Grant(db.Model): code = Column(String(255), index=True, nullable=False) redirect_uri = Column(String(255), nullable=False) expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=100)) - - _scopes = Column(Text, nullable=False, default='') - @property - def scopes(self): - if self._scopes: - return self._scopes.split() - return [] - - def delete(self): - db.session.delete(self) - db.session.commit() - return self + scopes = Column('_scopes', CommaSeparatedList(), nullable=False, default=tuple()) @hybrid_property def expired(self): @@ -106,18 +95,7 @@ class OAuth2Token(db.Model): access_token = Column(String(255), unique=True, nullable=False) refresh_token = Column(String(255), unique=True, nullable=False) expires = Column(DateTime, nullable=False) - - _scopes = Column(Text, nullable=False, default='') - @property - def scopes(self): - if self._scopes: - return self._scopes.split() - return [] - - def delete(self): - db.session.delete(self) - db.session.commit() - return self + scopes = Column('_scopes', CommaSeparatedList(), nullable=False, default=tuple()) @hybrid_property def expired(self): diff --git a/uffd/oauth2/views.py b/uffd/oauth2/views.py index db5b69783f6dbc6fd4351d575ff0dfff7b5f2197..b52e3600f600f463595f4bf4738aa8bed1ac2228 100644 --- a/uffd/oauth2/views.py +++ b/uffd/oauth2/views.py @@ -66,7 +66,7 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator): def save_authorization_code(self, client_id, code, oauthreq, *args, **kwargs): grant = OAuth2Grant(user=oauthreq.user, client_id=client_id, code=code['code'], - redirect_uri=oauthreq.redirect_uri, _scopes=' '.join(oauthreq.scopes)) + redirect_uri=oauthreq.redirect_uri, scopes=oauthreq.scopes) db.session.add(grant) db.session.commit() # Oauthlib does not really provide a way to customize grant code generation. @@ -95,7 +95,6 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator): db.session.commit() def save_bearer_token(self, token_data, oauthreq, *args, **kwargs): - OAuth2Token.query.filter_by(client_id=oauthreq.client.client_id, user=oauthreq.user).delete() tok = OAuth2Token( user=oauthreq.user, client_id=oauthreq.client.client_id, @@ -103,7 +102,7 @@ class UffdRequestValidator(oauthlib.oauth2.RequestValidator): access_token=token_data['access_token'], refresh_token=token_data['refresh_token'], expires_in_seconds=token_data['expires_in'], - _scopes=' '.join(oauthreq.scopes) + scopes=oauthreq.scopes ) db.session.add(tok) db.session.commit()