Source code for pydotorg.domains.users.repositories

"""User domain repositories for database access."""

from __future__ import annotations

from typing import TYPE_CHECKING

from advanced_alchemy.repository import SQLAlchemyAsyncRepository
from sqlalchemy import select

from pydotorg.domains.users.api_keys import APIKey
from pydotorg.domains.users.models import Membership, User, UserGroup

if TYPE_CHECKING:
    from uuid import UUID


[docs] class UserRepository(SQLAlchemyAsyncRepository[User]): """Repository for User database operations.""" model_type = User
[docs] async def get_by_email(self, email: str) -> User | None: """Get a user by email address. Args: email: The email address to search for. Returns: The user if found, None otherwise. """ statement = select(User).where(User.email == email) result = await self.session.execute(statement) return result.scalar_one_or_none()
[docs] async def get_by_username(self, username: str) -> User | None: """Get a user by username. Args: username: The username to search for. Returns: The user if found, None otherwise. """ statement = select(User).where(User.username == username) result = await self.session.execute(statement) return result.scalar_one_or_none()
[docs] async def exists_by_email(self, email: str) -> bool: """Check if a user exists by email. Args: email: The email address to check. Returns: True if a user with this email exists, False otherwise. """ user = await self.get_by_email(email) return user is not None
[docs] async def exists_by_username(self, username: str) -> bool: """Check if a user exists by username. Args: username: The username to check. Returns: True if a user with this username exists, False otherwise. """ user = await self.get_by_username(username) return user is not None
[docs] class MembershipRepository(SQLAlchemyAsyncRepository[Membership]): """Repository for Membership database operations.""" model_type = Membership
[docs] async def get_by_user_id(self, user_id: UUID) -> Membership | None: """Get a membership by user ID. Args: user_id: The user ID to search for. Returns: The membership if found, None otherwise. """ statement = select(Membership).where(Membership.user_id == user_id) result = await self.session.execute(statement) return result.scalar_one_or_none()
[docs] class UserGroupRepository(SQLAlchemyAsyncRepository[UserGroup]): """Repository for UserGroup database operations.""" model_type = UserGroup
[docs] async def list_approved(self, limit: int = 100, offset: int = 0) -> list[UserGroup]: """List approved user groups. Args: limit: Maximum number of groups to return. offset: Number of groups to skip. Returns: List of approved user groups. """ statement = select(UserGroup).where(UserGroup.approved.is_(True)).limit(limit).offset(offset) result = await self.session.execute(statement) return list(result.scalars().all())
[docs] async def list_trusted(self, limit: int = 100, offset: int = 0) -> list[UserGroup]: """List trusted user groups. Args: limit: Maximum number of groups to return. offset: Number of groups to skip. Returns: List of trusted user groups. """ statement = select(UserGroup).where(UserGroup.trusted.is_(True)).limit(limit).offset(offset) result = await self.session.execute(statement) return list(result.scalars().all())
[docs] class APIKeyRepository(SQLAlchemyAsyncRepository[APIKey]): """Repository for API key database operations.""" model_type = APIKey
[docs] async def get_by_hash(self, key_hash: str) -> APIKey | None: """Get an API key by its hash. Args: key_hash: The SHA-256 hash of the API key. Returns: The API key if found, None otherwise. """ statement = select(APIKey).where(APIKey.key_hash == key_hash) result = await self.session.execute(statement) return result.scalar_one_or_none()
[docs] async def get_by_prefix(self, key_prefix: str) -> APIKey | None: """Get an API key by its prefix. Args: key_prefix: The first 12 characters of the key. Returns: The API key if found, None otherwise. """ statement = select(APIKey).where(APIKey.key_prefix == key_prefix) result = await self.session.execute(statement) return result.scalar_one_or_none()
[docs] async def list_by_user(self, user_id: UUID) -> list[APIKey]: """List all API keys for a user. Args: user_id: The user ID. Returns: List of API keys. """ statement = select(APIKey).where(APIKey.user_id == user_id).order_by(APIKey.created_at.desc()) result = await self.session.execute(statement) return list(result.scalars().all())
[docs] async def list_active_by_user(self, user_id: UUID) -> list[APIKey]: """List active API keys for a user. Args: user_id: The user ID. Returns: List of active API keys. """ statement = ( select(APIKey) .where(APIKey.user_id == user_id, APIKey.is_active.is_(True)) .order_by(APIKey.created_at.desc()) ) result = await self.session.execute(statement) return list(result.scalars().all())
[docs] async def revoke_all_for_user(self, user_id: UUID) -> int: """Revoke all API keys for a user. Args: user_id: The user ID. Returns: Number of keys revoked. """ keys = await self.list_active_by_user(user_id) for key in keys: key.is_active = False return len(keys)