import datetime import json import secrets import base64 from sqlalchemy import Column, Integer, String, DateTime, Text, ForeignKey, Boolean from sqlalchemy.orm import relationship from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.associationproxy import association_proxy import jwt from uffd.database import db, CommaSeparatedList from uffd.tasks import cleanup_task from uffd.password_hash import PasswordHashAttribute, HighEntropyPasswordHash from uffd.utils import token_urlfriendly from .session import DeviceLoginInitiation, DeviceLoginType from .service import ServiceUser # pyjwt v1.7.x compat (Buster/Bullseye) if not hasattr(jwt, 'get_algorithm_by_name'): jwt.get_algorithm_by_name = lambda name: jwt.algorithms.get_default_algorithms()[name] class OAuth2Client(db.Model): __tablename__ = 'oauth2client' # Inconsistently named "db_id" instead of "id" because of the naming conflict # with "client_id" in the OAuth2 standard db_id = Column(Integer, primary_key=True, autoincrement=True) service_id = Column(Integer, ForeignKey('service.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) service = relationship('Service', back_populates='oauth2_clients') client_id = Column(String(40), unique=True, nullable=False) _client_secret = Column('client_secret', Text(), nullable=False) client_secret = PasswordHashAttribute('_client_secret', HighEntropyPasswordHash) _redirect_uris = relationship('OAuth2RedirectURI', cascade='all, delete-orphan') redirect_uris = association_proxy('_redirect_uris', 'uri') logout_uris = relationship('OAuth2LogoutURI', cascade='all, delete-orphan') @property def default_redirect_uri(self): return self.redirect_uris[0] if len(self.redirect_uris) == 1 else None def access_allowed(self, user): service_user = ServiceUser.query.get((self.service_id, user.id)) return service_user and service_user.has_access @property def logout_uris_json(self): return json.dumps([[item.method, item.uri] for item in self.logout_uris]) class OAuth2RedirectURI(db.Model): __tablename__ = 'oauth2redirect_uri' id = Column(Integer, primary_key=True, autoincrement=True) client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) uri = Column(String(255), nullable=False) def __init__(self, uri): self.uri = uri class OAuth2LogoutURI(db.Model): __tablename__ = 'oauth2logout_uri' id = Column(Integer, primary_key=True, autoincrement=True) client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) method = Column(String(40), nullable=False, default='GET') uri = Column(String(255), nullable=False) @cleanup_task.delete_by_attribute('expired') class OAuth2Grant(db.Model): __tablename__ = 'oauth2grant' id = Column(Integer, primary_key=True, autoincrement=True) EXPIRES_IN = 100 expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Grant.EXPIRES_IN)) user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) user = relationship('User') client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client = relationship('OAuth2Client') _code = Column('code', String(255), nullable=False, default=token_urlfriendly) code = property(lambda self: f'{self.id}-{self._code}') redirect_uri = Column(String(255), nullable=True) nonce = Column(Text(), nullable=True) scopes = Column('_scopes', CommaSeparatedList(), nullable=False, default=tuple()) _claims = Column('claims', Text(), nullable=True) @property def claims(self): return json.loads(self._claims) if self._claims is not None else None @claims.setter def claims(self, value): self._claims = json.dumps(value) if value is not None else None @property def service_user(self): service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) if service_user is None: raise Exception('ServiceUser lookup failed') return service_user @hybrid_property def expired(self): if self.expires is None: return False return self.expires < datetime.datetime.utcnow() @classmethod def get_by_authorization_code(cls, code): # pylint: disable=protected-access if '-' not in code: return None grant_id, grant_code = code.split('-', 2) grant = cls.query.filter_by(id=grant_id, expired=False).first() if not grant or not secrets.compare_digest(grant._code, grant_code): return None if grant.user.is_deactivated or not grant.client.access_allowed(grant.user): return None return grant def make_token(self, **kwargs): return OAuth2Token( user=self.user, client=self.client, scopes=self.scopes, claims=self.claims, **kwargs ) @cleanup_task.delete_by_attribute('expired') class OAuth2Token(db.Model): __tablename__ = 'oauth2token' id = Column(Integer, primary_key=True, autoincrement=True) EXPIRES_IN = 3600 expires = Column(DateTime, nullable=False, default=lambda: datetime.datetime.utcnow() + datetime.timedelta(seconds=OAuth2Token.EXPIRES_IN)) user_id = Column(Integer(), ForeignKey('user.id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) user = relationship('User') client_db_id = Column(Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE'), nullable=False) client = relationship('OAuth2Client') # currently only bearer is supported token_type = Column(String(40), nullable=False, default='bearer') _access_token = Column('access_token', String(255), unique=True, nullable=False, default=token_urlfriendly) access_token = property(lambda self: f'{self.id}-{self._access_token}') _refresh_token = Column('refresh_token', String(255), unique=True, nullable=False, default=token_urlfriendly) refresh_token = property(lambda self: f'{self.id}-{self._refresh_token}') scopes = Column('_scopes', CommaSeparatedList(), nullable=False, default=tuple()) _claims = Column('claims', Text(), nullable=True) @property def claims(self): return json.loads(self._claims) if self._claims is not None else None @claims.setter def claims(self, value): self._claims = json.dumps(value) if value is not None else None @property def service_user(self): service_user = ServiceUser.query.get((self.client.service_id, self.user.id)) if service_user is None: raise Exception('ServiceUser lookup failed') return service_user @hybrid_property def expired(self): return self.expires < datetime.datetime.utcnow() @classmethod def get_by_access_token(cls, access_token): # pylint: disable=protected-access if '-' not in access_token: return None token_id, token_secret = access_token.split('-', 2) token = cls.query.filter_by(id=token_id, expired=False).first() if not token or not secrets.compare_digest(token._access_token, token_secret): return None if token.user.is_deactivated or not token.client.access_allowed(token.user): return None return token class OAuth2DeviceLoginInitiation(DeviceLoginInitiation): __mapper_args__ = { 'polymorphic_identity': DeviceLoginType.OAUTH2 } client_db_id = Column('oauth2_client_db_id', Integer, ForeignKey('oauth2client.db_id', onupdate='CASCADE', ondelete='CASCADE')) client = relationship('OAuth2Client') @property def description(self): return self.client.service.name class OAuth2Key(db.Model): __tablename__ = 'oauth2_key' id = Column(String(64), primary_key=True, default=token_urlfriendly) created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) active = Column(Boolean(create_constraint=False), default=True, nullable=False) algorithm = Column(String(32), nullable=False) private_key_jwk = Column(Text(), nullable=False) public_key_jwk = Column(Text(), nullable=False) def __init__(self, **kwargs): if kwargs.get('algorithm') and kwargs.get('private_key') \ and not kwargs.get('private_key_jwk') \ and not kwargs.get('public_key_jwk'): algorithm = jwt.get_algorithm_by_name(kwargs['algorithm']) private_key = kwargs.pop('private_key') kwargs['private_key_jwk'] = algorithm.to_jwk(private_key) kwargs['public_key_jwk'] = algorithm.to_jwk(private_key.public_key()) super().__init__(**kwargs) @property def private_key(self): # pylint: disable=protected-access,import-outside-toplevel # cryptography performs expensive checks when loading RSA private keys. # Since we only load keys we generated ourselves with help of cryptography, # these checks are unnecessary. import cryptography.hazmat.backends.openssl cryptography.hazmat.backends.openssl.backend._rsa_skip_check_key = True res = jwt.get_algorithm_by_name(self.algorithm).from_jwk(self.private_key_jwk) cryptography.hazmat.backends.openssl.backend._rsa_skip_check_key = False return res @property def public_key(self): return jwt.get_algorithm_by_name(self.algorithm).from_jwk(self.public_key_jwk) @property def public_key_jwks_dict(self): res = json.loads(self.public_key_jwk) res['kid'] = self.id res['alg'] = self.algorithm res['use'] = 'sig' # RFC7517 4.3 "The "use" and "key_ops" JWK members SHOULD NOT be used together [...]" res.pop('key_ops', None) return res def encode_jwt(self, payload): if not self.active: raise jwt.exceptions.InvalidKeyError(f'Key {self.id} not active') return jwt.encode(payload, key=self.private_key, algorithm=self.algorithm, headers={'kid': self.id}) # Hash algorithm for at_hash/c_hash from OpenID Connect Core 1.0 section 3.1.3.6 def oidc_hash(self, value): # pylint: disable=import-outside-toplevel from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend # Only required for Buster hash_alg = jwt.get_algorithm_by_name(self.algorithm).hash_alg digest = hashes.Hash(hash_alg(), backend=default_backend()) digest.update(value) return base64.urlsafe_b64encode( digest.finalize()[:hash_alg.digest_size // 2] ).decode('ascii').rstrip('=') @classmethod def get_preferred_key(cls, algorithm='RS256'): return cls.query.filter_by(active=True, algorithm=algorithm).order_by(OAuth2Key.created.desc()).first() @classmethod def get_available_algorithms(cls): return ['RS256'] @classmethod def decode_jwt(cls, data, algorithms=('RS256',), **kwargs): headers = jwt.get_unverified_header(data) if 'kid' not in headers: raise jwt.exceptions.InvalidKeyError('JWT without kid') kid = headers['kid'] key = cls.query.get(kid) if not key: raise jwt.exceptions.InvalidKeyError(f'Key {kid} not found') if not key.active: raise jwt.exceptions.InvalidKeyError(f'Key {kid} not active') return jwt.decode(data, key=key.public_key, algorithms=algorithms, **kwargs) @classmethod def generate_rsa_key(cls, public_exponent=65537, key_size=3072): # pylint: disable=import-outside-toplevel from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend # Only required for Buster return cls(algorithm='RS256', private_key=rsa.generate_private_key(public_exponent=public_exponent, key_size=key_size, backend=default_backend()))