From 524524394aa091e7a39bd58d0996d4147cbf0d9e Mon Sep 17 00:00:00 2001 From: Pdzly <34943569+Pdzly@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:56:43 +0100 Subject: [PATCH 1/4] feat: add authentication backend with JWT, email verification, and password reset Implement full auth system including user registration, login with Argon2id password hashing, JWT access/refresh tokens, email verification flow, password reset flow, and per-endpoint rate limiting for auth routes. --- backend/alembic.ini | 2 +- .../versions/20250125_add_auth_tables.py | 192 ++++++++++++ backend/app/api/routes.py | 7 +- backend/app/api/subroutes/auth.py | 286 ++++++++++++++++++ backend/app/containers.py | 33 +- backend/app/db/models.py | 150 ++++++++- backend/app/exceptions.py | 91 ++++++ backend/app/services/email_service.py | 192 ++++++++++++ backend/pyproject.toml | 3 + backend/uv.lock | 46 +++ docker-compose.yml | 2 +- 11 files changed, 997 insertions(+), 7 deletions(-) create mode 100644 backend/alembic/versions/20250125_add_auth_tables.py create mode 100644 backend/app/api/subroutes/auth.py create mode 100644 backend/app/services/email_service.py diff --git a/backend/alembic.ini b/backend/alembic.ini index 5d74f27..fddd89f 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -84,7 +84,7 @@ path_separator = os # database URL. This is consumed by the user-maintained env.py script only. # other means of configuring database URLs may be customized within the env.py # file. -sqlalchemy.url = postgresql://postgres:postgres@localhost/devbin +sqlalchemy.url = postgresql://postgres:postgres@localhost:5433/devbin [post_write_hooks] diff --git a/backend/alembic/versions/20250125_add_auth_tables.py b/backend/alembic/versions/20250125_add_auth_tables.py new file mode 100644 index 0000000..b39bb04 --- /dev/null +++ b/backend/alembic/versions/20250125_add_auth_tables.py @@ -0,0 +1,192 @@ +"""Add authentication tables + +Revision ID: add_auth_tables +Revises: 20251219_203917_hash_existing_tokens +Create Date: 2025-01-25 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_auth_tables" +down_revision: str | None = "4e57d32ab2ac" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # Create users table + op.create_table( + "users", + sa.Column( + "id", + sa.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column("username", sa.String(50), unique=True, nullable=False), + sa.Column("email", sa.String(255), unique=True, nullable=False), + sa.Column("password_hash", sa.String(255), nullable=False), + sa.Column( + "is_verified", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column( + "is_active", sa.Boolean(), nullable=False, server_default=sa.text("true") + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("last_login_at", sa.TIMESTAMP(timezone=True), nullable=True), + ) + + # Create indexes for users table + op.create_index("idx_users_email", "users", ["email"]) + op.create_index("idx_users_username", "users", ["username"]) + op.create_index("idx_users_created_at", "users", ["created_at"]) + + # Create email_verification_tokens table + op.create_table( + "email_verification_tokens", + sa.Column( + "id", + sa.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "user_id", + sa.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("token_hash", sa.String(255), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("used_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + + # Create indexes for email_verification_tokens table + op.create_index( + "idx_email_verification_tokens_user_id", + "email_verification_tokens", + ["user_id"], + ) + op.create_index( + "idx_email_verification_tokens_expires_at", + "email_verification_tokens", + ["expires_at"], + ) + + # Create password_reset_tokens table + op.create_table( + "password_reset_tokens", + sa.Column( + "id", + sa.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "user_id", + sa.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("token_hash", sa.String(255), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("used_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + + # Create indexes for password_reset_tokens table + op.create_index( + "idx_password_reset_tokens_user_id", "password_reset_tokens", ["user_id"] + ) + op.create_index( + "idx_password_reset_tokens_expires_at", "password_reset_tokens", ["expires_at"] + ) + + # Create refresh_tokens table + op.create_table( + "refresh_tokens", + sa.Column( + "id", + sa.UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()"), + ), + sa.Column( + "user_id", + sa.UUID(as_uuid=True), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("token_hash", sa.String(255), nullable=False), + sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("user_agent", sa.String(512), nullable=True), + sa.Column("ip_address", sa.String(45), nullable=True), # IPv6 max length + ) + + # Create indexes for refresh_tokens table + op.create_index("idx_refresh_tokens_user_id", "refresh_tokens", ["user_id"]) + op.create_index("idx_refresh_tokens_expires_at", "refresh_tokens", ["expires_at"]) + op.create_index("idx_refresh_tokens_revoked_at", "refresh_tokens", ["revoked_at"]) + + +def downgrade() -> None: + # Drop refresh_tokens table and indexes + op.drop_index("idx_refresh_tokens_revoked_at", table_name="refresh_tokens") + op.drop_index("idx_refresh_tokens_expires_at", table_name="refresh_tokens") + op.drop_index("idx_refresh_tokens_user_id", table_name="refresh_tokens") + op.drop_table("refresh_tokens") + + # Drop password_reset_tokens table and indexes + op.drop_index( + "idx_password_reset_tokens_expires_at", table_name="password_reset_tokens" + ) + op.drop_index( + "idx_password_reset_tokens_user_id", table_name="password_reset_tokens" + ) + op.drop_table("password_reset_tokens") + + # Drop email_verification_tokens table and indexes + op.drop_index( + "idx_email_verification_tokens_expires_at", + table_name="email_verification_tokens", + ) + op.drop_index( + "idx_email_verification_tokens_user_id", table_name="email_verification_tokens" + ) + op.drop_table("email_verification_tokens") + + # Drop users table and indexes + op.drop_index("idx_users_created_at", table_name="users") + op.drop_index("idx_users_username", table_name="users") + op.drop_index("idx_users_email", table_name="users") + op.drop_table("users") diff --git a/backend/app/api/routes.py b/backend/app/api/routes.py index 7c7a0da..849e227 100644 --- a/backend/app/api/routes.py +++ b/backend/app/api/routes.py @@ -8,6 +8,7 @@ from starlette.requests import Request from starlette.responses import Response +from app.api.subroutes.auth import auth_route from app.api.subroutes.pastes import pastes_route from app.config import config from app.containers import Container @@ -24,7 +25,9 @@ metrics_security = HTTPBearer(auto_error=False, description="Metrics access token") -def verify_metrics_token(credentials: HTTPAuthorizationCredentials | None = Depends(metrics_security)) -> None: +def verify_metrics_token( + credentials: HTTPAuthorizationCredentials | None = Depends(metrics_security), +) -> None: """ Verify Bearer token for metrics endpoint. @@ -86,4 +89,6 @@ async def metrics(_: None = Depends(verify_metrics_token)): ) + router.include_router(pastes_route) +router.include_router(auth_route) diff --git a/backend/app/api/subroutes/auth.py b/backend/app/api/subroutes/auth.py new file mode 100644 index 0000000..4c84b09 --- /dev/null +++ b/backend/app/api/subroutes/auth.py @@ -0,0 +1,286 @@ +"""Authentication API routes.""" + +import logging + +from dependency_injector.wiring import Provide, inject +from fastapi import APIRouter, Depends +from starlette.requests import Request + +from app.api.dto.auth_dto import ( + ForgotPasswordRequest, + LoginRequest, + MessageResponse, + RefreshTokenRequest, + RegisterRequest, + ResendVerificationRequest, + ResetPasswordRequest, + TokenResponse, + UserResponse, + VerifyEmailRequest, +) +from app.api.dto.Error import ErrorResponse +from app.config import config +from app.containers import Container +from app.db.models import UserEntity +from app.dependencies.auth import get_current_user +from app.ratelimit import create_limit_resolver, limiter +from app.services.auth_service import AuthService + +logger = logging.getLogger(__name__) + +auth_route = APIRouter(prefix="/auth", tags=["Authentication"]) + + +@auth_route.post( + "/register", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + 409: {"model": ErrorResponse, "description": "User already exists"}, + 400: {"model": ErrorResponse, "description": "Validation error"}, + }, + summary="Register a new user", + description="Create a new user account. A verification email will be sent.", +) +@limiter.limit(create_limit_resolver(config, "auth_register")) +@inject +async def register( + request: Request, + body: RegisterRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Register a new user account.""" + await auth_service.register( + username=body.username, + email=body.email, + password=body.password, + ) + return MessageResponse( + message="Registration successful. Please check your email to verify your account." + ) + + +@auth_route.post( + "/login", + response_model=TokenResponse, + responses={ + 200: {"model": TokenResponse}, + 401: {"model": ErrorResponse, "description": "Invalid credentials"}, + 403: { + "model": ErrorResponse, + "description": "Email not verified or account inactive", + }, + }, + summary="Login", + description="Authenticate with username/email and password to receive access and refresh tokens.", +) +@limiter.limit(create_limit_resolver(config, "auth_login")) +@inject +async def login( + request: Request, + body: LoginRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> TokenResponse: + """Authenticate user and return tokens.""" + user_agent = request.headers.get("user-agent") + ip_address = str(request.state.user_metadata.ip) + + access_token, refresh_token, expires_in = await auth_service.login( + username=body.username, + password=body.password, + user_agent=user_agent, + ip_address=ip_address, + ) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=expires_in, + ) + + +@auth_route.post( + "/refresh", + response_model=TokenResponse, + responses={ + 200: {"model": TokenResponse}, + 401: { + "model": ErrorResponse, + "description": "Invalid or expired refresh token", + }, + }, + summary="Refresh tokens", + description="Exchange a refresh token for new access and refresh tokens.", +) +@limiter.limit(create_limit_resolver(config, "auth_refresh")) +@inject +async def refresh_tokens( + request: Request, + body: RefreshTokenRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> TokenResponse: + """Refresh access token using refresh token.""" + user_agent = request.headers.get("user-agent") + ip_address = str(request.state.user_metadata.ip) + + access_token, refresh_token, expires_in = await auth_service.refresh_tokens( + refresh_token=body.refresh_token, + user_agent=user_agent, + ip_address=ip_address, + ) + + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=expires_in, + ) + + +@auth_route.post( + "/verify-email", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + 401: {"model": ErrorResponse, "description": "Invalid or expired token"}, + }, + summary="Verify email", + description="Verify email address using the token sent via email.", +) +@limiter.limit(create_limit_resolver(config, "auth_verify_email")) +@inject +async def verify_email( + request: Request, + body: VerifyEmailRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Verify user's email address.""" + await auth_service.verify_email(token=body.token) + return MessageResponse(message="Email verified successfully. You can now login.") + + +@auth_route.post( + "/resend-verification", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + }, + summary="Resend verification email", + description="Resend the email verification link. Rate limited to prevent abuse.", +) +@limiter.limit(create_limit_resolver(config, "auth_resend_verification")) +@inject +async def resend_verification( + request: Request, + body: ResendVerificationRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Resend email verification link.""" + await auth_service.resend_verification_email(email=body.email) + # Always return success to prevent email enumeration + return MessageResponse( + message="If an unverified account exists with this email, a verification link has been sent." + ) + + +@auth_route.post( + "/forgot-password", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + }, + summary="Forgot password", + description="Request a password reset email. Rate limited to prevent abuse.", +) +@limiter.limit(create_limit_resolver(config, "auth_forgot_password")) +@inject +async def forgot_password( + request: Request, + body: ForgotPasswordRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Request password reset email.""" + await auth_service.forgot_password(email=body.email) + # Always return success to prevent email enumeration + return MessageResponse( + message="If an account exists with this email, a password reset link has been sent." + ) + + +@auth_route.post( + "/reset-password", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + 401: {"model": ErrorResponse, "description": "Invalid or expired token"}, + }, + summary="Reset password", + description="Reset password using the token sent via email.", +) +@limiter.limit(create_limit_resolver(config, "auth_reset_password")) +@inject +async def reset_password( + request: Request, + body: ResetPasswordRequest, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Reset user's password.""" + await auth_service.reset_password( + token=body.token, + new_password=body.new_password, + ) + return MessageResponse( + message="Password reset successful. You can now login with your new password." + ) + + +@auth_route.get( + "/me", + response_model=UserResponse, + responses={ + 200: {"model": UserResponse}, + 401: {"model": ErrorResponse, "description": "Not authenticated"}, + }, + summary="Get current user", + description="Get the currently authenticated user's profile.", +) +@limiter.limit(create_limit_resolver(config, "auth_me")) +async def get_me( + request: Request, + user: UserEntity = Depends(get_current_user), +) -> UserResponse: + """Get current authenticated user's profile.""" + return UserResponse( + id=user.id, + username=user.username, + email=user.email, + is_verified=user.is_verified, + created_at=user.created_at, + last_login_at=user.last_login_at, + ) + + +@auth_route.post( + "/logout", + response_model=MessageResponse, + responses={ + 200: {"model": MessageResponse}, + 401: {"model": ErrorResponse, "description": "Not authenticated"}, + }, + summary="Logout", + description="Logout by revoking the current refresh token or all tokens.", +) +@limiter.limit(create_limit_resolver(config, "auth_logout")) +@inject +async def logout( + request: Request, + user: UserEntity = Depends(get_current_user), + body: RefreshTokenRequest | None = None, + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> MessageResponse: + """Logout user by revoking refresh tokens.""" + refresh_token = body.refresh_token if body else None + await auth_service.logout( + user_id=user.id, + refresh_token=refresh_token, + ) + return MessageResponse(message="Logged out successfully.") diff --git a/backend/app/containers.py b/backend/app/containers.py index 6f19e5e..5e63324 100644 --- a/backend/app/containers.py +++ b/backend/app/containers.py @@ -19,7 +19,9 @@ @asynccontextmanager -async def _engine_resource(db_url: str, echo: bool = False) -> AsyncIterator[AsyncEngine]: +async def _engine_resource( + db_url: str, echo: bool = False +) -> AsyncIterator[AsyncEngine]: engine = create_async_engine(db_url, echo=echo, future=True) try: yield engine @@ -154,8 +156,10 @@ class Container(containers.DeclarativeContainer): modules=[ "app.api.routes", "app.api.subroutes.pastes", + "app.api.subroutes.auth", "app.services", "app.dependencies.db", + "app.dependencies.auth", ] ) @@ -213,4 +217,29 @@ class Container(containers.DeclarativeContainer): distributed_lock, ) - paste_service = providers.Factory(PasteService, session_factory, cleanup_service, storage_client) + paste_service = providers.Factory( + PasteService, session_factory, cleanup_service, storage_client + ) + + # Auth services + from app.services.auth_service import AuthService + from app.services.email_service import EmailService + from app.services.jwt_service import JWTService + + jwt_service = providers.Singleton( + JWTService, + config=config, + ) + + email_service = providers.Singleton( + EmailService, + config=config, + ) + + auth_service = providers.Factory( + AuthService, + session_factory=session_factory, + jwt_service=jwt_service, + email_service=email_service, + config=config, + ) diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 69cd62f..9398f3b 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -1,4 +1,16 @@ -from sqlalchemy import TIMESTAMP, UUID, Boolean, Column, Index, Integer, String, func, text +from sqlalchemy import ( + TIMESTAMP, + UUID, + Boolean, + Column, + ForeignKey, + Index, + Integer, + String, + func, + text, +) +from sqlalchemy.orm import relationship from app.db.base import Base @@ -18,7 +30,9 @@ class PasteEntity(Base): content_path = Column(String, nullable=False) content_language = Column(String, nullable=False, server_default="plain_text") expires_at = Column(TIMESTAMP(timezone=True), nullable=True) - created_at = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + created_at = Column( + TIMESTAMP(timezone=True), nullable=False, server_default=func.now() + ) content_size = Column(Integer, nullable=False) is_compressed = Column(Boolean, nullable=False, server_default="false") @@ -38,3 +52,135 @@ def __repr__(self): def __str__(self): return self.title + + +# ───────────────────────────────────────────────────────────────────────────── +# Authentication Models +# ───────────────────────────────────────────────────────────────────────────── + + +class UserEntity(Base): + """User account for authentication.""" + + __tablename__ = "users" + __table_args__ = ( + Index("idx_users_email", "email"), + Index("idx_users_username", "username"), + Index("idx_users_created_at", "created_at"), + ) + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=UUID_DEFAULT) + username = Column(String(50), unique=True, nullable=False) + email = Column(String(255), unique=True, nullable=False) + password_hash = Column(String(255), nullable=False) + + is_verified = Column(Boolean, nullable=False, server_default="false") + is_active = Column(Boolean, nullable=False, server_default="true") + + created_at = Column( + TIMESTAMP(timezone=True), nullable=False, server_default=func.now() + ) + updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now()) + last_login_at = Column(TIMESTAMP(timezone=True)) + + # Relationships + email_verification_tokens = relationship( + "EmailVerificationTokenEntity", + back_populates="user", + cascade="all, delete-orphan", + ) + password_reset_tokens = relationship( + "PasswordResetTokenEntity", back_populates="user", cascade="all, delete-orphan" + ) + refresh_tokens = relationship( + "RefreshTokenEntity", back_populates="user", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"" + + +class EmailVerificationTokenEntity(Base): + """Token for email verification.""" + + __tablename__ = "email_verification_tokens" + __table_args__ = ( + Index("idx_email_verification_tokens_user_id", "user_id"), + Index("idx_email_verification_tokens_expires_at", "expires_at"), + ) + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=UUID_DEFAULT) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + token_hash = Column(String(255), nullable=False) + expires_at = Column(TIMESTAMP(timezone=True), nullable=False) + used_at = Column(TIMESTAMP(timezone=True)) + created_at = Column( + TIMESTAMP(timezone=True), nullable=False, server_default=func.now() + ) + + # Relationship + user = relationship("UserEntity", back_populates="email_verification_tokens") + + def __repr__(self): + return f"" + + +class PasswordResetTokenEntity(Base): + """Token for password reset.""" + + __tablename__ = "password_reset_tokens" + __table_args__ = ( + Index("idx_password_reset_tokens_user_id", "user_id"), + Index("idx_password_reset_tokens_expires_at", "expires_at"), + ) + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=UUID_DEFAULT) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + token_hash = Column(String(255), nullable=False) + expires_at = Column(TIMESTAMP(timezone=True), nullable=False) + used_at = Column(TIMESTAMP(timezone=True)) + created_at = Column( + TIMESTAMP(timezone=True), nullable=False, server_default=func.now() + ) + + # Relationship + user = relationship("UserEntity", back_populates="password_reset_tokens") + + def __repr__(self): + return f"" + + +class RefreshTokenEntity(Base): + """JWT refresh token for session management.""" + + __tablename__ = "refresh_tokens" + __table_args__ = ( + Index("idx_refresh_tokens_user_id", "user_id"), + Index("idx_refresh_tokens_expires_at", "expires_at"), + Index("idx_refresh_tokens_revoked_at", "revoked_at"), + ) + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=UUID_DEFAULT) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + token_hash = Column(String(255), nullable=False) + expires_at = Column(TIMESTAMP(timezone=True), nullable=False) + revoked_at = Column(TIMESTAMP(timezone=True)) + created_at = Column( + TIMESTAMP(timezone=True), nullable=False, server_default=func.now() + ) + + # Session tracking + user_agent = Column(String(512)) + ip_address = Column(String(45)) # IPv6 max length + + # Relationship + user = relationship("UserEntity", back_populates="refresh_tokens") + + def __repr__(self): + return f"" diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index 5756aff..ffef0ab 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -120,3 +120,94 @@ def __init__(self, message: str, operation: str = "unknown"): status_code=500, ) self.operation = operation + + +# ───────────────────────────────────────────────────────────────────────────── +# Authentication Exceptions +# ───────────────────────────────────────────────────────────────────────────── + + +class UserNotFoundError(DevBinException): + """Raised when a user does not exist.""" + + def __init__(self, identifier: str = "user"): + super().__init__( + message=f"User '{identifier}' not found", + status_code=404, + ) + self.identifier = identifier + + +class UserAlreadyExistsError(DevBinException): + """Raised when attempting to create a user that already exists.""" + + def __init__(self, field: str = "username or email"): + super().__init__( + message=f"User with this {field} already exists", + status_code=409, + ) + self.field = field + + +class EmailNotVerifiedError(DevBinException): + """Raised when user attempts to login without verified email.""" + + def __init__(self): + super().__init__( + message="Email address has not been verified", + status_code=403, + ) + + +class InvalidCredentialsError(DevBinException): + """Raised when login credentials are invalid.""" + + def __init__(self): + super().__init__( + message="Invalid username or password", + status_code=401, + ) + + +class TokenExpiredError(DevBinException): + """Raised when a token has expired.""" + + def __init__(self, token_type: str = "token"): + super().__init__( + message=f"The {token_type} has expired", + status_code=401, + ) + self.token_type = token_type + + +class InvalidJWTError(DevBinException): + """Raised when JWT token is invalid or malformed.""" + + def __init__(self, reason: str = "Invalid token"): + super().__init__( + message=reason, + status_code=401, + ) + self.reason = reason + + +class PasswordValidationError(DevBinException): + """Raised when password does not meet requirements.""" + + def __init__(self, requirements: list[str]): + message = "Password does not meet requirements: " + ", ".join(requirements) + super().__init__( + message=message, + status_code=400, + ) + self.requirements = requirements + + +class UserInactiveError(DevBinException): + """Raised when user account is deactivated.""" + + def __init__(self): + super().__init__( + message="User account is inactive", + status_code=403, + ) diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py new file mode 100644 index 0000000..73e82be --- /dev/null +++ b/backend/app/services/email_service.py @@ -0,0 +1,192 @@ +"""Email service for sending authentication emails.""" + +import logging +from email.message import EmailMessage + +import aiosmtplib + +from app.config import Config + + +class EmailService: + """Service for sending authentication-related emails.""" + + def __init__(self, config: Config): + self.config = config + self.logger = logging.getLogger(self.__class__.__name__) + + def _is_configured(self) -> bool: + """Check if SMTP is properly configured.""" + return bool(self.config.SMTP_HOST and self.config.SMTP_PORT) + + async def _send_email(self, to: str, subject: str, html_content: str) -> bool: + """ + Send an email using SMTP. + + Args: + to: Recipient email address + subject: Email subject + html_content: HTML email body + + Returns: + True if email was sent successfully, False otherwise + """ + if not self._is_configured(): + self.logger.warning( + "SMTP not configured. Email to %s with subject '%s' not sent. " + "Set APP_SMTP_HOST and APP_SMTP_PORT to enable email.", + to, + subject, + ) + # In development, log the email content for testing + if self.config.ENVIRONMENT == "dev": + self.logger.info("Email content (dev mode):\n%s", html_content) + return False + + message = EmailMessage() + message["From"] = ( + f"{self.config.SMTP_FROM_NAME} <{self.config.SMTP_FROM_EMAIL}>" + ) + message["To"] = to + message["Subject"] = subject + message.set_content(html_content, subtype="html") + + try: + await aiosmtplib.send( + message, + hostname=self.config.SMTP_HOST, + port=self.config.SMTP_PORT, + username=self.config.SMTP_USERNAME or None, + password=self.config.SMTP_PASSWORD or None, + start_tls=self.config.SMTP_USE_TLS, + ) + self.logger.info("Email sent to %s: %s", to, subject) + return True + + except aiosmtplib.SMTPException as e: + self.logger.error("Failed to send email to %s: %s", to, str(e)) + return False + + def _build_verification_url(self, token: str) -> str: + """Build the email verification URL.""" + base = self.config.FRONTEND_URL.rstrip("/") + path = self.config.EMAIL_VERIFY_PATH.lstrip("/") + return f"{base}/{path}?token={token}" + + def _build_password_reset_url(self, token: str) -> str: + """Build the password reset URL.""" + base = self.config.FRONTEND_URL.rstrip("/") + path = self.config.PASSWORD_RESET_PATH.lstrip("/") + return f"{base}/{path}?token={token}" + + async def send_verification_email(self, to: str, username: str, token: str) -> bool: + """ + Send email verification email. + + Args: + to: Recipient email address + username: User's username + token: Verification token + + Returns: + True if email was sent successfully + """ + verification_url = self._build_verification_url(token) + + html_content = f""" + + + + + Verify Your Email - DevBin + + +

Welcome to DevBin, {username}!

+

+ Thank you for registering. Please verify your email address by clicking the button below: +

+

+ + Verify Email Address + +

+

+ Or copy and paste this link into your browser:
+ {verification_url} +

+

+ This link will expire in {self.config.EMAIL_VERIFICATION_EXPIRE_HOURS} hours. + If you didn't create an account, you can safely ignore this email. +

+
+

+ DevBin - Share code snippets easily +

+ + +""" + + return await self._send_email( + to=to, + subject="Verify Your Email - DevBin", + html_content=html_content, + ) + + async def send_password_reset_email( + self, to: str, username: str, token: str + ) -> bool: + """ + Send password reset email. + + Args: + to: Recipient email address + username: User's username + token: Password reset token + + Returns: + True if email was sent successfully + """ + reset_url = self._build_password_reset_url(token) + + html_content = f""" + + + + + Reset Your Password - DevBin + + +

Password Reset Request

+

+ Hi {username}, we received a request to reset your password. + Click the button below to create a new password: +

+

+ + Reset Password + +

+

+ Or copy and paste this link into your browser:
+ {reset_url} +

+

+ This link will expire in {self.config.PASSWORD_RESET_EXPIRE_HOURS} hour(s). + If you didn't request a password reset, you can safely ignore this email. + Your password will remain unchanged. +

+
+

+ DevBin - Share code snippets easily +

+ + +""" + + return await self._send_email( + to=to, + subject="Reset Your Password - DevBin", + html_content=html_content, + ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 10d3deb..40883cd 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -22,6 +22,9 @@ dependencies = [ "aiocache>=0.12.3", "fastapi-cors>=0.0.6", "argon2-cffi>=23.1.0", + "PyJWT>=2.8.0", + "aiosmtplib>=3.0.0", + "email-validator>=2.0.0", ] [project.optional-dependencies] diff --git a/backend/uv.lock b/backend/uv.lock index d686c9e..09bcbb1 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -163,6 +163,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosmtplib" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/ad/240a7ce4e50713b111dff8b781a898d8d4770e5d6ad4899103f84c86005c/aiosmtplib-5.1.0.tar.gz", hash = "sha256:2504a23b2b63c9de6bc4ea719559a38996dba68f73f6af4eb97be20ee4c5e6c4", size = 66176, upload-time = "2026-01-25T01:51:11.408Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/37/82/70f2c452acd7ed18c558c8ace9a8cf4fdcc70eae9a41749b5bdc53eb6f45/aiosmtplib-5.1.0-py3-none-any.whl", hash = "sha256:368029440645b486b69db7029208a7a78c6691b90d24a5332ddba35d9109d55b", size = 27778, upload-time = "2026-01-25T01:51:10.026Z" }, +] + [[package]] name = "alembic" version = "1.17.2" @@ -298,15 +307,18 @@ source = { editable = "." } dependencies = [ { name = "aiocache" }, { name = "aiofiles" }, + { name = "aiosmtplib" }, { name = "alembic" }, { name = "argon2-cffi" }, { name = "asyncpg" }, { name = "dependency-injector" }, + { name = "email-validator" }, { name = "fastapi" }, { name = "fastapi-cors" }, { name = "orjson" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-dotenv" }, { name = "slowapi" }, { name = "sqlalchemy" }, @@ -364,10 +376,12 @@ requires-dist = [ { name = "aiocache", extras = ["redis"], marker = "extra == 'production'", specifier = ">=0.12.3" }, { name = "aiocache", extras = ["redis"], marker = "extra == 'redis'", specifier = ">=0.12.3" }, { name = "aiofiles", specifier = ">=25.1.0" }, + { name = "aiosmtplib", specifier = ">=3.0.0" }, { name = "alembic", specifier = ">=1.17.2" }, { name = "argon2-cffi", specifier = ">=23.1.0" }, { name = "asyncpg", specifier = ">=0.29" }, { name = "dependency-injector", specifier = ">=4.41" }, + { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.124" }, { name = "fastapi-cors", specifier = ">=0.0.6" }, { name = "orjson", specifier = ">=3.11.5" }, @@ -376,6 +390,7 @@ requires-dist = [ { name = "psycopg2-binary", marker = "extra == 'migrations'" }, { name = "pydantic", specifier = ">=2.7" }, { name = "pydantic-settings", specifier = ">=2.12.0" }, + { name = "pyjwt", specifier = ">=2.8.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "redis", marker = "extra == 'production'", specifier = ">=5.0.0" }, { name = "redis", marker = "extra == 'redis'", specifier = ">=5.0.0" }, @@ -594,6 +609,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, ] +[[package]] +name = "dnspython" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/8b/57666417c0f90f08bcafa776861060426765fdb422eb10212086fb811d26/dnspython-2.8.0.tar.gz", hash = "sha256:181d3c6996452cb1189c4046c61599b84a5a86e099562ffde77d26984ff26d0f", size = 368251, upload-time = "2025-09-07T18:58:00.022Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, +] + +[[package]] +name = "email-validator" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/22/900cb125c76b7aaa450ce02fd727f452243f2e91a61af068b40adba60ea9/email_validator-2.3.0.tar.gz", hash = "sha256:9fc05c37f2f6cf439ff414f8fc46d917929974a82244c20eb10231ba60c54426", size = 51238, upload-time = "2025-08-26T13:09:06.831Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604, upload-time = "2025-08-26T13:09:05.858Z" }, +] + [[package]] name = "environs" version = "14.5.0" @@ -1290,6 +1327,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + [[package]] name = "pytest" version = "9.0.2" diff --git a/docker-compose.yml b/docker-compose.yml index dcefd50..51c9ff0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,7 +46,7 @@ services: - .env image: postgres:16 ports: - - "5432:5432" + - "5433:5432" volumes: - postgres_data:/var/lib/postgresql/data healthcheck: From 307d0692565f0d6349633d1a1c3d8478f20a398e Mon Sep 17 00:00:00 2001 From: Pdzly <34943569+Pdzly@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:56:57 +0100 Subject: [PATCH 2/4] fix: stabilize auth with token invalidation, security hardening, and differentiated paste rate limits - Invalidate old reset tokens in forgot_password (was select-only, no-op) - Invalidate old verification tokens in resend_verification_email - Fix TokenExpiredError to use actual token type instead of hardcoded "access token" - Narrow get_optional_current_user exception catch to auth-specific errors - Fix stale get_exempt_key reference in paste routes via module-level access - Add max_length=512 to LoginRequest fields to prevent Argon2 DoS - Broaden _verify_password to catch InvalidHashError and HashingError - Deduplicate password validation logic into shared function - Add differentiated rate limits: authenticated users get 20/min vs 4/min anonymous for paste creation, using slowapi key-based limit resolution --- backend/.env.example | 63 +++ backend/app/api/dto/auth_dto.py | 161 ++++++++ backend/app/api/subroutes/pastes.py | 29 +- backend/app/config.py | 298 ++++++++++++-- backend/app/dependencies/auth.py | 148 +++++++ backend/app/ratelimit.py | 66 +++- backend/app/services/auth_service.py | 558 +++++++++++++++++++++++++++ backend/app/services/jwt_service.py | 159 ++++++++ 8 files changed, 1447 insertions(+), 35 deletions(-) create mode 100644 backend/app/api/dto/auth_dto.py create mode 100644 backend/app/dependencies/auth.py create mode 100644 backend/app/services/auth_service.py create mode 100644 backend/app/services/jwt_service.py diff --git a/backend/.env.example b/backend/.env.example index 5b8acf2..d35c349 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -103,6 +103,7 @@ APP_RATELIMIT_DEFAULT=60/minute # APP_RATELIMIT_GET_PASTE=10/minute # APP_RATELIMIT_GET_PASTE_LEGACY=10/minute # APP_RATELIMIT_CREATE_PASTE=4/minute +# APP_RATELIMIT_CREATE_PASTE_AUTHENTICATED=20/minute # APP_RATELIMIT_EDIT_PASTE=4/minute # APP_RATELIMIT_DELETE_PASTE=4/minute @@ -116,3 +117,65 @@ APP_RATELIMIT_DEFAULT=60/minute # Production deployments should set this to a strong random token # APP_METRICS_TOKEN=your_secure_random_token_here # Example generation: openssl rand -hex 32 + +# ============================================================================= +# Authentication Configuration +# ============================================================================= + +# JWT Configuration (REQUIRED for auth) +# IMPORTANT: Change this in production! Must be at least 32 characters. +# Generate with: openssl rand -hex 32 +APP_JWT_SECRET_KEY=CHANGE_ME_IN_PRODUCTION_32_CHARS_MIN +# Access token lifetime (default: 15 minutes) +APP_JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15 +# Refresh token lifetime (default: 7 days) +APP_JWT_REFRESH_TOKEN_EXPIRE_DAYS=7 +# JWT algorithm (default: HS256) +APP_JWT_ALGORITHM=HS256 + +# SMTP Configuration (REQUIRED for email verification and password reset) +# Leave empty to disable email sending (emails logged instead in dev) +APP_SMTP_HOST=smtp.example.com +APP_SMTP_PORT=587 +APP_SMTP_USERNAME=your_smtp_username +APP_SMTP_PASSWORD=your_smtp_password +APP_SMTP_FROM_EMAIL=noreply@devbin.dev +APP_SMTP_USE_TLS=true + +# Frontend URL (REQUIRED for email links) +# Used for verification and password reset links in emails +APP_FRONTEND_URL=http://localhost:3000 +# Path for email verification links (default: /verify-email) +APP_EMAIL_VERIFY_PATH=/verify-email +# Path for password reset links (default: /reset-password) +APP_PASSWORD_RESET_PATH=/reset-password + +# Email Token Expiry +# Verification token lifetime in hours (default: 24) +APP_EMAIL_VERIFICATION_EXPIRE_HOURS=24 +# Password reset token lifetime in hours (default: 1) +APP_PASSWORD_RESET_EXPIRE_HOURS=1 + +# Auth Rate Limits (optional, these are defaults) +# Format: / +# APP_RATELIMIT_AUTH_REGISTER=5/hour +# APP_RATELIMIT_AUTH_LOGIN=10/minute +# APP_RATELIMIT_AUTH_REFRESH=20/minute +# APP_RATELIMIT_AUTH_VERIFY_EMAIL=10/minute +# APP_RATELIMIT_AUTH_RESEND_VERIFICATION=3/hour +# APP_RATELIMIT_AUTH_FORGOT_PASSWORD=3/hour +# APP_RATELIMIT_AUTH_RESET_PASSWORD=5/hour +# APP_RATELIMIT_AUTH_ME=60/minute +# APP_RATELIMIT_AUTH_LOGOUT=20/minute + +# Password Requirements (optional, these are defaults) +# Minimum password length +# APP_PASSWORD_MIN_LENGTH=8 +# Require at least one uppercase letter +# APP_PASSWORD_REQUIRE_UPPERCASE=true +# Require at least one lowercase letter +# APP_PASSWORD_REQUIRE_LOWERCASE=true +# Require at least one digit +# APP_PASSWORD_REQUIRE_DIGIT=true +# Require at least one special character +# APP_PASSWORD_REQUIRE_SPECIAL=false diff --git a/backend/app/api/dto/auth_dto.py b/backend/app/api/dto/auth_dto.py new file mode 100644 index 0000000..4238feb --- /dev/null +++ b/backend/app/api/dto/auth_dto.py @@ -0,0 +1,161 @@ +"""Authentication DTOs for request/response models.""" + +import re +from datetime import datetime + +from pydantic import UUID4, BaseModel, EmailStr, Field, field_validator + +from app.config import config + + +def validate_password_requirements(v: str) -> str: + """Validate password meets configured requirements.""" + errors = [] + + if config.PASSWORD_REQUIRE_UPPERCASE and not re.search(r"[A-Z]", v): + errors.append("at least one uppercase letter") + + if config.PASSWORD_REQUIRE_LOWERCASE and not re.search(r"[a-z]", v): + errors.append("at least one lowercase letter") + + if config.PASSWORD_REQUIRE_DIGIT and not re.search(r"\d", v): + errors.append("at least one digit") + + if config.PASSWORD_REQUIRE_SPECIAL and not re.search( + r"[!@#$%^&*(),.?\":{}|<>]", v + ): + errors.append("at least one special character") + + if errors: + raise ValueError(f"Password must contain {', '.join(errors)}") + + return v + + +class RegisterRequest(BaseModel): + """Request model for user registration.""" + + username: str = Field( + min_length=3, + max_length=50, + description="Username (3-50 characters, alphanumeric and underscores)", + examples=["john_doe"], + ) + email: EmailStr = Field( + description="Valid email address", + examples=["john@example.com"], + ) + password: str = Field( + min_length=config.PASSWORD_MIN_LENGTH, + max_length=128, + description="Password meeting security requirements", + ) + + @field_validator("username") + @classmethod + def validate_username(cls, v: str) -> str: + """Validate username format.""" + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", v): + raise ValueError( + "Username must start with a letter and contain only letters, numbers, and underscores" + ) + return v.lower() + + @field_validator("password") + @classmethod + def validate_password(cls, v: str) -> str: + """Validate password meets requirements.""" + return validate_password_requirements(v) + + +class LoginRequest(BaseModel): + """Request model for user login.""" + + username: str = Field( + min_length=1, + max_length=512, + description="Username or email", + examples=["john_doe"], + ) + password: str = Field( + min_length=1, + max_length=512, + description="User password", + ) + + +class TokenResponse(BaseModel): + """Response model for authentication tokens.""" + + access_token: str = Field(description="JWT access token") + refresh_token: str = Field( + description="JWT refresh token for obtaining new access tokens" + ) + token_type: str = Field(default="Bearer", description="Token type") + expires_in: int = Field(description="Access token expiration time in seconds") + + +class RefreshTokenRequest(BaseModel): + """Request model for token refresh.""" + + refresh_token: str = Field(description="Refresh token to exchange for new tokens") + + +class VerifyEmailRequest(BaseModel): + """Request model for email verification.""" + + token: str = Field(description="Email verification token") + + +class ForgotPasswordRequest(BaseModel): + """Request model for forgot password.""" + + email: EmailStr = Field( + description="Email address associated with the account", + examples=["john@example.com"], + ) + + +class ResetPasswordRequest(BaseModel): + """Request model for password reset.""" + + token: str = Field(description="Password reset token") + new_password: str = Field( + min_length=config.PASSWORD_MIN_LENGTH, + max_length=128, + description="New password meeting security requirements", + ) + + @field_validator("new_password") + @classmethod + def validate_password(cls, v: str) -> str: + """Validate password meets requirements.""" + return validate_password_requirements(v) + + +class ResendVerificationRequest(BaseModel): + """Request model for resending verification email.""" + + email: EmailStr = Field( + description="Email address to resend verification to", + examples=["john@example.com"], + ) + + +class UserResponse(BaseModel): + """Response model for user profile.""" + + id: UUID4 = Field(description="User UUID") + username: str = Field(description="Username") + email: EmailStr = Field(description="Email address") + is_verified: bool = Field(description="Whether email is verified") + created_at: datetime = Field(description="Account creation timestamp") + last_login_at: datetime | None = Field(description="Last login timestamp") + + model_config = {"from_attributes": True} + + +class MessageResponse(BaseModel): + """Generic message response.""" + + message: str = Field(description="Response message") diff --git a/backend/app/api/subroutes/pastes.py b/backend/app/api/subroutes/pastes.py index bcc0810..e288480 100644 --- a/backend/app/api/subroutes/pastes.py +++ b/backend/app/api/subroutes/pastes.py @@ -9,6 +9,7 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse, Response +import app.ratelimit as ratelimit from app.api.dto.Error import ErrorResponse from app.api.dto.paste_dto import ( CreatePaste, @@ -19,8 +20,9 @@ ) from app.config import config from app.containers import Container +from app.dependencies.auth import get_optional_current_user from app.exceptions import PasteNotFoundError -from app.ratelimit import create_limit_resolver, get_exempt_key, limiter +from app.ratelimit import create_auth_aware_key_func, create_auth_aware_limit_resolver, create_limit_resolver, limiter from app.services.paste_service import PasteService from app.utils.LRUMemoryCache import LRUMemoryCache from app.utils.metrics import cache_operations @@ -43,6 +45,15 @@ def set_cache(cache_instance: "RedisCache | LRUMemoryCache"): cache = cache_instance +async def _resolve_optional_user( + request: Request, + user=Depends(get_optional_current_user), +): + """Resolve optional user and attach to request.state for rate limiter access.""" + request.state.current_user = user + return user + + edit_token_key_header = APIKeyHeader(name="Authorization", scheme_name="Edit Token") delete_token_key_header = APIKeyHeader(name="Authorization", scheme_name="Delete Token") @@ -53,7 +64,7 @@ def set_cache(cache_instance: "RedisCache | LRUMemoryCache"): summary="Get legacy Hastebin-format paste", description="Retrieve a paste stored in legacy Hastebin format by its ID.", ) -@limiter.limit(create_limit_resolver(config, "get_paste_legacy"), key_func=get_exempt_key) +@limiter.limit(create_limit_resolver(config, "get_paste_legacy"), key_func=lambda r: ratelimit.get_exempt_key(r)) @inject async def get_legacy_paste( request: Request, @@ -96,7 +107,7 @@ async def get_legacy_paste( summary="Get paste by UUID", description="Retrieve a paste by its UUID identifier.", ) -@limiter.limit(create_limit_resolver(config, "get_paste"), key_func=get_exempt_key) +@limiter.limit(create_limit_resolver(config, "get_paste"), key_func=lambda r: ratelimit.get_exempt_key(r)) @inject async def get_paste_by_uuid( request: Request, @@ -140,7 +151,7 @@ async def get_paste_by_uuid( summary="Get raw paste content", description="Retrieve only the raw text content of a paste. Useful for curl/wget users.", ) -@limiter.limit(create_limit_resolver(config, "get_paste"), key_func=get_exempt_key) +@limiter.limit(create_limit_resolver(config, "get_paste"), key_func=lambda r: ratelimit.get_exempt_key(r)) @inject async def get_paste_raw( request: Request, @@ -189,12 +200,16 @@ async def get_paste_raw( summary="Create a new paste", description="Create a new paste with the provided content and metadata.", ) -@limiter.limit(create_limit_resolver(config, "create_paste"), key_func=get_exempt_key) +@limiter.limit( + create_auth_aware_limit_resolver(config, "create_paste", "create_paste_authenticated"), + key_func=create_auth_aware_key_func(config), +) @inject async def create_paste( request: Request, create_paste_body: CreatePaste, paste_service: PasteService = Depends(Provide[Container.paste_service]), + _current_user=Depends(_resolve_optional_user), ): """Create a new paste and return edit/delete tokens.""" return await paste_service.create_paste(create_paste_body, request.state.user_metadata) @@ -206,7 +221,7 @@ async def create_paste( summary="Edit an existing paste", description="Update a paste's content or metadata. Requires a valid edit token.", ) -@limiter.limit(create_limit_resolver(config, "edit_paste"), key_func=get_exempt_key) +@limiter.limit(create_limit_resolver(config, "edit_paste"), key_func=lambda r: ratelimit.get_exempt_key(r)) @inject async def edit_paste( request: Request, @@ -230,7 +245,7 @@ async def edit_paste( summary="Delete a paste", description="Permanently delete a paste. Requires a valid delete token.", ) -@limiter.limit(create_limit_resolver(config, "delete_paste"), key_func=get_exempt_key) +@limiter.limit(create_limit_resolver(config, "delete_paste"), key_func=lambda r: ratelimit.get_exempt_key(r)) @inject async def delete_paste( request: Request, diff --git a/backend/app/config.py b/backend/app/config.py index 8d1bd44..bab98c6 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -6,7 +6,12 @@ from pydantic import AfterValidator, Field, ValidationError, field_validator from pydantic_settings import BaseSettings -from app.utils.ip import TrustedHost, parse_ip_or_network, resolve_hostname, validate_ip_address +from app.utils.ip import ( + TrustedHost, + parse_ip_or_network, + resolve_hostname, + validate_ip_address, +) # Rate limit format validation RATE_LIMIT_PATTERN = re.compile(r"^\d+/(second|minute|hour|day)$") @@ -15,7 +20,9 @@ def validate_rate_limit(value: str) -> str: """Validate rate limit format (e.g., '10/minute', '100/hour').""" if not RATE_LIMIT_PATTERN.match(value): - raise ValueError(f"Invalid rate limit format: '{value}'. Expected format: '/'") + raise ValueError( + f"Invalid rate limit format: '{value}'. Expected format: '/'" + ) return value @@ -28,11 +35,15 @@ def validate_rate_limit(value: str) -> str: class Config(BaseSettings): # Environment ENVIRONMENT: Literal["dev", "staging", "prod"] = Field( - default="dev", validation_alias="APP_ENVIRONMENT", description="Application environment (dev, staging, prod)" + default="dev", + validation_alias="APP_ENVIRONMENT", + description="Application environment (dev, staging, prod)", ) PORT: int = Field(default=8000, validation_alias="APP_PORT") - HOST: str = Field(default="0.0.0.0", validation_alias="APP_HOST") # noqa: S104 - Bind to all interfaces for container deployment + HOST: str = Field( + default="0.0.0.0", validation_alias="APP_HOST" + ) # noqa: S104 - Bind to all interfaces for container deployment # DB DATABASE_URL: str = Field( @@ -42,8 +53,12 @@ class Config(BaseSettings): SQLALCHEMY_ECHO: bool = Field(default=False, validation_alias="APP_SQLALCHEMY_ECHO") # Paste - MAX_CONTENT_LENGTH: int = Field(default=10000, validation_alias="APP_MAX_CONTENT_LENGTH") - BASE_FOLDER_PATH: str = Field(default="./files", validation_alias="APP_BASE_FOLDER_PATH") + MAX_CONTENT_LENGTH: int = Field( + default=10000, validation_alias="APP_MAX_CONTENT_LENGTH" + ) + BASE_FOLDER_PATH: str = Field( + default="./files", validation_alias="APP_BASE_FOLDER_PATH" + ) WORKERS: int | Literal[True] = Field(default=1, validation_alias="APP_WORKERS") METRICS_TOKEN: str | None = Field( default=None, @@ -67,18 +82,32 @@ class Config(BaseSettings): # Cache backend configuration CACHE_TYPE: Literal["memory", "redis"] = Field( - default="memory", validation_alias="APP_CACHE_TYPE", description="Cache backend type (memory, redis)" + default="memory", + validation_alias="APP_CACHE_TYPE", + description="Cache backend type (memory, redis)", + ) + REDIS_HOST: str = Field( + default="localhost", + validation_alias="APP_REDIS_HOST", + description="Redis server host", + ) + REDIS_PORT: int = Field( + default=6379, validation_alias="APP_REDIS_PORT", description="Redis server port" + ) + REDIS_DB: int = Field( + default=0, validation_alias="APP_REDIS_DB", description="Redis database number" ) - REDIS_HOST: str = Field(default="localhost", validation_alias="APP_REDIS_HOST", description="Redis server host") - REDIS_PORT: int = Field(default=6379, validation_alias="APP_REDIS_PORT", description="Redis server port") - REDIS_DB: int = Field(default=0, validation_alias="APP_REDIS_DB", description="Redis database number") REDIS_PASSWORD: str | None = Field( - default=None, validation_alias="APP_REDIS_PASSWORD", description="Redis password (optional)" + default=None, + validation_alias="APP_REDIS_PASSWORD", + description="Redis password (optional)", ) # Lock backend configuration LOCK_TYPE: Literal["file", "redis"] = Field( - default="file", validation_alias="APP_LOCK_TYPE", description="Lock backend type (file, redis)" + default="file", + validation_alias="APP_LOCK_TYPE", + description="Lock backend type (file, redis)", ) # Rate limiting configuration @@ -117,7 +146,12 @@ class Config(BaseSettings): RATELIMIT_CREATE_PASTE: RateLimit = Field( default="4/minute", validation_alias="APP_RATELIMIT_CREATE_PASTE", - description="Rate limit for POST /p/", + description="Rate limit for POST /p/ (anonymous users)", + ) + RATELIMIT_CREATE_PASTE_AUTHENTICATED: RateLimit = Field( + default="20/minute", + validation_alias="APP_RATELIMIT_CREATE_PASTE_AUTHENTICATED", + description="Rate limit for POST /p/ (authenticated users)", ) RATELIMIT_EDIT_PASTE: RateLimit = Field( default="4/minute", @@ -165,29 +199,57 @@ class Config(BaseSettings): description="Minimum content size in bytes to trigger compression (2KB+ shows 30-40% compression ratio)", ) COMPRESSION_LEVEL: int = Field( - default=6, validation_alias="APP_COMPRESSION_LEVEL", description="Gzip compression level (1-9, 6=balanced)" + default=6, + validation_alias="APP_COMPRESSION_LEVEL", + description="Gzip compression level (1-9, 6=balanced)", ) # Storage settings STORAGE_TYPE: Literal["local", "s3", "minio"] = Field( - default="local", validation_alias="APP_STORAGE_TYPE", description="Storage backend type (local, s3, minio)" + default="local", + validation_alias="APP_STORAGE_TYPE", + description="Storage backend type (local, s3, minio)", + ) + S3_BUCKET_NAME: str = Field( + default="", validation_alias="APP_S3_BUCKET_NAME", description="S3 bucket name" + ) + S3_REGION: str = Field( + default="us-east-1", + validation_alias="APP_S3_REGION", + description="AWS region for S3", + ) + S3_ACCESS_KEY: str = Field( + default="", validation_alias="APP_S3_ACCESS_KEY", description="S3 access key ID" + ) + S3_SECRET_KEY: str = Field( + default="", + validation_alias="APP_S3_SECRET_KEY", + description="S3 secret access key", ) - S3_BUCKET_NAME: str = Field(default="", validation_alias="APP_S3_BUCKET_NAME", description="S3 bucket name") - S3_REGION: str = Field(default="us-east-1", validation_alias="APP_S3_REGION", description="AWS region for S3") - S3_ACCESS_KEY: str = Field(default="", validation_alias="APP_S3_ACCESS_KEY", description="S3 access key ID") - S3_SECRET_KEY: str = Field(default="", validation_alias="APP_S3_SECRET_KEY", description="S3 secret access key") S3_ENDPOINT_URL: str | None = Field( default=None, validation_alias="APP_S3_ENDPOINT_URL", description="Custom S3 endpoint URL (for S3-compatible services)", ) MINIO_ENDPOINT: str = Field( - default="", validation_alias="APP_MINIO_ENDPOINT", description="MinIO server endpoint (e.g., 'minio:9000')" + default="", + validation_alias="APP_MINIO_ENDPOINT", + description="MinIO server endpoint (e.g., 'minio:9000')", + ) + MINIO_ACCESS_KEY: str = Field( + default="", + validation_alias="APP_MINIO_ACCESS_KEY", + description="MinIO access key", + ) + MINIO_SECRET_KEY: str = Field( + default="", + validation_alias="APP_MINIO_SECRET_KEY", + description="MinIO secret key", ) - MINIO_ACCESS_KEY: str = Field(default="", validation_alias="APP_MINIO_ACCESS_KEY", description="MinIO access key") - MINIO_SECRET_KEY: str = Field(default="", validation_alias="APP_MINIO_SECRET_KEY", description="MinIO secret key") MINIO_SECURE: bool = Field( - default=True, validation_alias="APP_MINIO_SECURE", description="Use HTTPS for MinIO connection" + default=True, + validation_alias="APP_MINIO_SECURE", + description="Use HTTPS for MinIO connection", ) KEEP_DELETED_PASTES_TIME_HOURS: int = Field( @@ -205,6 +267,172 @@ class Config(BaseSettings): RELOAD: bool = Field(default=False, validation_alias="APP_RELOAD") DEBUG: bool = Field(default=False, validation_alias="APP_DEBUG") + # ───────────────────────────────────────────────────────────────────────────── + # Authentication Configuration + # ───────────────────────────────────────────────────────────────────────────── + + # JWT Settings + JWT_SECRET_KEY: str = Field( + default="CHANGE_ME_IN_PRODUCTION_32_CHARS_MIN", + validation_alias="APP_JWT_SECRET_KEY", + description="Secret key for JWT token signing (minimum 32 characters)", + ) + JWT_ALGORITHM: str = Field( + default="HS256", + validation_alias="APP_JWT_ALGORITHM", + description="JWT signing algorithm", + ) + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = Field( + default=15, + validation_alias="APP_JWT_ACCESS_TOKEN_EXPIRE_MINUTES", + description="Access token expiration time in minutes", + ) + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = Field( + default=7, + validation_alias="APP_JWT_REFRESH_TOKEN_EXPIRE_DAYS", + description="Refresh token expiration time in days", + ) + + # SMTP Settings for Email + SMTP_HOST: str = Field( + default="", + validation_alias="APP_SMTP_HOST", + description="SMTP server hostname", + ) + SMTP_PORT: int = Field( + default=587, + validation_alias="APP_SMTP_PORT", + description="SMTP server port (587 for TLS, 465 for SSL)", + ) + SMTP_USERNAME: str = Field( + default="", + validation_alias="APP_SMTP_USERNAME", + description="SMTP authentication username", + ) + SMTP_PASSWORD: str = Field( + default="", + validation_alias="APP_SMTP_PASSWORD", + description="SMTP authentication password", + ) + SMTP_FROM_EMAIL: str = Field( + default="noreply@devbin.dev", + validation_alias="APP_SMTP_FROM_EMAIL", + description="Email address for outgoing emails", + ) + SMTP_FROM_NAME: str = Field( + default="DevBin", + validation_alias="APP_SMTP_FROM_NAME", + description="Display name for outgoing emails", + ) + SMTP_USE_TLS: bool = Field( + default=True, + validation_alias="APP_SMTP_USE_TLS", + description="Use STARTTLS for SMTP connection", + ) + + # Frontend URL for email links + FRONTEND_URL: str = Field( + default="http://localhost:3000", + validation_alias="APP_FRONTEND_URL", + description="Frontend URL for email verification and password reset links", + ) + EMAIL_VERIFY_PATH: str = Field( + default="/auth/verify-email", + validation_alias="APP_EMAIL_VERIFY_PATH", + description="Frontend path for email verification", + ) + PASSWORD_RESET_PATH: str = Field( + default="/auth/reset-password", + validation_alias="APP_PASSWORD_RESET_PATH", + description="Frontend path for password reset", + ) + + # Email Token Expiration + EMAIL_VERIFICATION_EXPIRE_HOURS: int = Field( + default=24, + validation_alias="APP_EMAIL_VERIFICATION_EXPIRE_HOURS", + description="Email verification token expiration in hours", + ) + PASSWORD_RESET_EXPIRE_HOURS: int = Field( + default=1, + validation_alias="APP_PASSWORD_RESET_EXPIRE_HOURS", + description="Password reset token expiration in hours", + ) + + # Auth Rate Limits + RATELIMIT_AUTH_REGISTER: RateLimit = Field( + default="5/hour", + validation_alias="APP_RATELIMIT_AUTH_REGISTER", + description="Rate limit for registration endpoint", + ) + RATELIMIT_AUTH_LOGIN: RateLimit = Field( + default="10/minute", + validation_alias="APP_RATELIMIT_AUTH_LOGIN", + description="Rate limit for login endpoint", + ) + RATELIMIT_AUTH_REFRESH: RateLimit = Field( + default="20/minute", + validation_alias="APP_RATELIMIT_AUTH_REFRESH", + description="Rate limit for token refresh endpoint", + ) + RATELIMIT_AUTH_VERIFY_EMAIL: RateLimit = Field( + default="10/minute", + validation_alias="APP_RATELIMIT_AUTH_VERIFY_EMAIL", + description="Rate limit for email verification endpoint", + ) + RATELIMIT_AUTH_RESEND_VERIFICATION: RateLimit = Field( + default="3/hour", + validation_alias="APP_RATELIMIT_AUTH_RESEND_VERIFICATION", + description="Rate limit for resend verification endpoint", + ) + RATELIMIT_AUTH_FORGOT_PASSWORD: RateLimit = Field( + default="3/hour", + validation_alias="APP_RATELIMIT_AUTH_FORGOT_PASSWORD", + description="Rate limit for forgot password endpoint", + ) + RATELIMIT_AUTH_RESET_PASSWORD: RateLimit = Field( + default="5/hour", + validation_alias="APP_RATELIMIT_AUTH_RESET_PASSWORD", + description="Rate limit for password reset endpoint", + ) + RATELIMIT_AUTH_ME: RateLimit = Field( + default="60/minute", + validation_alias="APP_RATELIMIT_AUTH_ME", + description="Rate limit for user profile endpoint", + ) + RATELIMIT_AUTH_LOGOUT: RateLimit = Field( + default="20/minute", + validation_alias="APP_RATELIMIT_AUTH_LOGOUT", + description="Rate limit for logout endpoint", + ) + + # Password Requirements + PASSWORD_MIN_LENGTH: int = Field( + default=8, + validation_alias="APP_PASSWORD_MIN_LENGTH", + description="Minimum password length", + ) + PASSWORD_REQUIRE_UPPERCASE: bool = Field( + default=True, + validation_alias="APP_PASSWORD_REQUIRE_UPPERCASE", + description="Require at least one uppercase letter in password", + ) + PASSWORD_REQUIRE_LOWERCASE: bool = Field( + default=True, + validation_alias="APP_PASSWORD_REQUIRE_LOWERCASE", + description="Require at least one lowercase letter in password", + ) + PASSWORD_REQUIRE_DIGIT: bool = Field( + default=True, + validation_alias="APP_PASSWORD_REQUIRE_DIGIT", + description="Require at least one digit in password", + ) + PASSWORD_REQUIRE_SPECIAL: bool = Field( + default=False, + validation_alias="APP_PASSWORD_REQUIRE_SPECIAL", + description="Require at least one special character in password", + ) + ENFORCE_HTTPS: bool = Field( default=False, validation_alias="APP_ENFORCE_HTTPS", @@ -279,7 +507,9 @@ def validate_cors_domains(cls, domains: list[str], info) -> list[str]: def validate_compression_level(cls, level: int) -> int: """Validate compression level is in valid range.""" if not 1 <= level <= 9: - logging.warning("Invalid compression level %d, must be 1-9. Using default 6.", level) + logging.warning( + "Invalid compression level %d, must be 1-9. Using default 6.", level + ) return 6 return level @@ -287,7 +517,10 @@ def validate_compression_level(cls, level: int) -> int: def validate_compression_threshold(cls, threshold: int) -> int: """Validate compression threshold is reasonable.""" if threshold < 0: - logging.warning("Invalid compression threshold %d, must be >= 0. Using default 512.", threshold) + logging.warning( + "Invalid compression threshold %d, must be >= 0. Using default 512.", + threshold, + ) return 512 return threshold @@ -296,7 +529,9 @@ def model_post_init(self, __context): if self.ENVIRONMENT == "prod": # Production security validations if self.DEBUG: - logging.error("PRODUCTION ERROR: DEBUG mode is enabled in production. Set APP_DEBUG=false") + logging.error( + "PRODUCTION ERROR: DEBUG mode is enabled in production. Set APP_DEBUG=false" + ) raise ValueError("DEBUG must be False in production") if "*" in self.CORS_DOMAINS: @@ -325,6 +560,17 @@ def model_post_init(self, __context): "Set APP_METRICS_TOKEN to enable authentication for /metrics" ) + # Validate JWT secret key in production + if ( + self.JWT_SECRET_KEY == "CHANGE_ME_IN_PRODUCTION_32_CHARS_MIN" + or len(self.JWT_SECRET_KEY) < 32 + ): + logging.error( + "PRODUCTION ERROR: JWT_SECRET_KEY must be changed and at least 32 characters. " + "Set APP_JWT_SECRET_KEY to a secure random value" + ) + raise ValueError("JWT_SECRET_KEY must be changed in production") + logging.info("Production environment validated successfully") diff --git a/backend/app/dependencies/auth.py b/backend/app/dependencies/auth.py new file mode 100644 index 0000000..8126421 --- /dev/null +++ b/backend/app/dependencies/auth.py @@ -0,0 +1,148 @@ +"""Authentication dependencies for FastAPI routes.""" + +from uuid import UUID + +from dependency_injector.wiring import Provide, inject +from fastapi import Depends +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.containers import Container +from app.db.models import UserEntity +from app.exceptions import ( + EmailNotVerifiedError, + InvalidJWTError, + TokenExpiredError, + UserInactiveError, + UserNotFoundError, +) +from app.services.auth_service import AuthService +from app.services.jwt_service import JWTService, TokenType + +# HTTPBearer scheme for OpenAPI documentation +auth_scheme = HTTPBearer( + auto_error=True, + description="JWT Bearer token authentication", +) + +# Optional version that doesn't raise on missing token +auth_scheme_optional = HTTPBearer( + auto_error=False, + description="Optional JWT Bearer token authentication", +) + + +@inject +async def get_current_user_id( + credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), + jwt_service: JWTService = Depends(Provide[Container.jwt_service]), +) -> UUID: + """ + Extract and validate user ID from JWT access token. + + This is a lightweight dependency that only validates the token + without hitting the database. + + Args: + credentials: Bearer token from Authorization header + jwt_service: JWT service for token validation + + Returns: + User UUID from the token + + Raises: + InvalidJWTError: If token is invalid + TokenExpiredError: If token has expired + """ + token = credentials.credentials + payload = jwt_service.decode_token(token, expected_type=TokenType.ACCESS) + return payload.user_id + + +@inject +async def get_current_user( + user_id: UUID = Depends(get_current_user_id), + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> UserEntity: + """ + Get the current authenticated user from the database. + + Args: + user_id: User ID extracted from token + auth_service: Auth service for user lookup + + Returns: + User entity + + Raises: + UserNotFoundError: If user doesn't exist + UserInactiveError: If user account is deactivated + """ + user = await auth_service.get_user_by_id(user_id) + + if not user: + raise UserNotFoundError() + + if not user.is_active: + raise UserInactiveError() + + return user + + +async def get_current_active_verified_user( + user: UserEntity = Depends(get_current_user), +) -> UserEntity: + """ + Get the current authenticated user, ensuring email is verified. + + Args: + user: Current user from get_current_user + + Returns: + Verified user entity + + Raises: + EmailNotVerifiedError: If email is not verified + """ + if not user.is_verified: + raise EmailNotVerifiedError() + + return user + + +@inject +async def get_optional_current_user( + credentials: HTTPAuthorizationCredentials | None = Depends(auth_scheme_optional), + jwt_service: JWTService = Depends(Provide[Container.jwt_service]), + auth_service: AuthService = Depends(Provide[Container.auth_service]), +) -> UserEntity | None: + """ + Optionally get the current authenticated user. + + Returns None if no token is provided instead of raising an error. + Useful for routes that have different behavior for authenticated + vs anonymous users. + + Args: + credentials: Optional Bearer token + jwt_service: JWT service for token validation + auth_service: Auth service for user lookup + + Returns: + User entity or None if not authenticated + """ + if not credentials: + return None + + try: + token = credentials.credentials + payload = jwt_service.decode_token(token, expected_type=TokenType.ACCESS) + user = await auth_service.get_user_by_id(payload.user_id) + + if user and user.is_active: + return user + + return None + + except (InvalidJWTError, TokenExpiredError, UserNotFoundError, UserInactiveError): + # Invalid/expired token or inactive user treated as not authenticated + return None diff --git a/backend/app/ratelimit.py b/backend/app/ratelimit.py index ba593bb..78c4edb 100644 --- a/backend/app/ratelimit.py +++ b/backend/app/ratelimit.py @@ -15,8 +15,19 @@ "get_paste", "get_paste_legacy", "create_paste", + "create_paste_authenticated", "edit_paste", "delete_paste", + # Auth endpoints + "auth_register", + "auth_login", + "auth_refresh", + "auth_verify_email", + "auth_resend_verification", + "auth_forgot_password", + "auth_reset_password", + "auth_me", + "auth_logout", ] @@ -42,7 +53,11 @@ def create_exempt_key_func(config: Config) -> Callable[[Request], str]: def get_exempt_key(request: Request) -> str: auth_header = request.headers.get("Authorization") - if auth_header and config.RATELIMIT_BYPASS_TOKENS and auth_header in config.RATELIMIT_BYPASS_TOKENS: + if ( + auth_header + and config.RATELIMIT_BYPASS_TOKENS + and auth_header in config.RATELIMIT_BYPASS_TOKENS + ): # Return unique key for each request = effectively unlimited return str(uuid4()) return get_ip_address(request) @@ -57,8 +72,19 @@ def create_limit_resolver(config: Config, limit_name: LimitName) -> Callable[[], "get_paste": config.RATELIMIT_GET_PASTE, "get_paste_legacy": config.RATELIMIT_GET_PASTE_LEGACY, "create_paste": config.RATELIMIT_CREATE_PASTE, + "create_paste_authenticated": config.RATELIMIT_CREATE_PASTE_AUTHENTICATED, "edit_paste": config.RATELIMIT_EDIT_PASTE, "delete_paste": config.RATELIMIT_DELETE_PASTE, + # Auth endpoints + "auth_register": config.RATELIMIT_AUTH_REGISTER, + "auth_login": config.RATELIMIT_AUTH_LOGIN, + "auth_refresh": config.RATELIMIT_AUTH_REFRESH, + "auth_verify_email": config.RATELIMIT_AUTH_VERIFY_EMAIL, + "auth_resend_verification": config.RATELIMIT_AUTH_RESEND_VERIFICATION, + "auth_forgot_password": config.RATELIMIT_AUTH_FORGOT_PASSWORD, + "auth_reset_password": config.RATELIMIT_AUTH_RESET_PASSWORD, + "auth_me": config.RATELIMIT_AUTH_ME, + "auth_logout": config.RATELIMIT_AUTH_LOGOUT, } limit_value = limit_map.get(limit_name, config.RATELIMIT_DEFAULT) @@ -69,6 +95,38 @@ def resolver() -> str: return resolver +def create_auth_aware_limit_resolver( + config: Config, anon_name: LimitName, auth_name: LimitName +) -> Callable[[str], str]: + """Create a limit callable that returns different rate limits based on the key prefix. + + slowapi calls this with the result of key_func(request) when the callable has a `key` parameter. + Keys prefixed with "auth:" get the authenticated limit, others get the anonymous limit. + """ + anon_limit = create_limit_resolver(config, anon_name) + auth_limit = create_limit_resolver(config, auth_name) + + def resolver(key: str) -> str: + if key.startswith("auth:"): + return auth_limit() + return anon_limit() + + return resolver + + +def create_auth_aware_key_func(config: Config) -> Callable[[Request], str]: + """Create a key function that returns 'auth:{user_id}' for authenticated users and IP for anonymous.""" + exempt_key = create_exempt_key_func(config) + + def key_func(request: Request) -> str: + user = getattr(getattr(request, "state", None), "current_user", None) + if user is not None: + return f"auth:{user.id}" + return exempt_key(request) + + return key_func + + def _build_redis_uri(config: Config) -> str: """Build Redis URI from config.""" if config.REDIS_PASSWORD: @@ -93,7 +151,11 @@ def create_rate_limiter(config: Config) -> Limiter | NoOpLimiter: try: storage_uri = _build_redis_uri(config) limiter = Limiter(key_func=get_ip_address, storage_uri=storage_uri) - logger.info("Rate limiter using Redis backend: %s:%d", config.REDIS_HOST, config.REDIS_PORT) + logger.info( + "Rate limiter using Redis backend: %s:%d", + config.REDIS_HOST, + config.REDIS_PORT, + ) return limiter except Exception as e: logger.warning("Redis rate limiter failed, falling back to memory: %s", e) diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py new file mode 100644 index 0000000..4dd97e9 --- /dev/null +++ b/backend/app/services/auth_service.py @@ -0,0 +1,558 @@ +"""Authentication service for user management and auth flows.""" + +import hashlib +import logging +import secrets +from datetime import UTC, datetime, timedelta +from uuid import UUID + +from argon2 import PasswordHasher +from argon2.exceptions import HashingError, InvalidHashError, VerifyMismatchError +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import sessionmaker + +from app.config import Config +from app.db.models import ( + EmailVerificationTokenEntity, + PasswordResetTokenEntity, + RefreshTokenEntity, + UserEntity, +) +from app.exceptions import ( + EmailNotVerifiedError, + InvalidCredentialsError, + InvalidJWTError, + TokenExpiredError, + UserAlreadyExistsError, + UserInactiveError, + UserNotFoundError, +) +from app.services.email_service import EmailService +from app.services.jwt_service import JWTService, TokenType + + +class AuthService: + """Service for handling authentication operations.""" + + def __init__( + self, + session_factory: sessionmaker[AsyncSession], + jwt_service: JWTService, + email_service: EmailService, + config: Config, + ): + self.session_factory = session_factory + self.jwt_service = jwt_service + self.email_service = email_service + self.config = config + self.logger = logging.getLogger(self.__class__.__name__) + # Argon2id with OWASP recommended parameters + self.password_hasher = PasswordHasher( + time_cost=2, + memory_cost=19456, # 19 MiB + parallelism=1, + ) + + def _hash_token(self, token: str) -> str: + """Hash a token for secure storage using SHA-256.""" + return hashlib.sha256(token.encode()).hexdigest() + + def _generate_token(self) -> str: + """Generate a cryptographically secure random token.""" + return secrets.token_urlsafe(32) + + def _hash_password(self, password: str) -> str: + """Hash a password using Argon2id.""" + return self.password_hasher.hash(password) + + def _verify_password(self, password: str, password_hash: str) -> bool: + """Verify a password against its hash.""" + try: + self.password_hasher.verify(password_hash, password) + return True + except (VerifyMismatchError, InvalidHashError, HashingError): + return False + + async def register( + self, + username: str, + email: str, + password: str, + ) -> UserEntity: + """ + Register a new user. + + Args: + username: Unique username + email: Unique email address + password: Plain text password + + Returns: + Created user entity + + Raises: + UserAlreadyExistsError: If username or email already exists + """ + async with self.session_factory() as session: + # Check for existing user + existing = await session.execute( + select(UserEntity).where( + (UserEntity.username == username.lower()) + | (UserEntity.email == email.lower()) + ) + ) + if existing.scalar_one_or_none(): + raise UserAlreadyExistsError() + + # Create user + user = UserEntity( + username=username.lower(), + email=email.lower(), + password_hash=self._hash_password(password), + is_verified=False, + is_active=True, + ) + session.add(user) + await session.flush() + + # Create verification token + token = self._generate_token() + verification_token = EmailVerificationTokenEntity( + user_id=user.id, + token_hash=self._hash_token(token), + expires_at=datetime.now(UTC) + + timedelta(hours=self.config.EMAIL_VERIFICATION_EXPIRE_HOURS), + ) + session.add(verification_token) + await session.commit() + + # Send verification email (don't wait for result) + await self.email_service.send_verification_email( + to=user.email, + username=user.username, + token=token, + ) + + self.logger.info("User registered: %s", user.username) + return user + + async def verify_email(self, token: str) -> UserEntity: + """ + Verify a user's email address. + + Args: + token: Email verification token + + Returns: + Verified user entity + + Raises: + TokenExpiredError: If token has expired + InvalidJWTError: If token is invalid + """ + token_hash = self._hash_token(token) + + async with self.session_factory() as session: + result = await session.execute( + select(EmailVerificationTokenEntity).where( + EmailVerificationTokenEntity.token_hash == token_hash, + EmailVerificationTokenEntity.used_at.is_(None), + ) + ) + verification = result.scalar_one_or_none() + + if not verification: + raise InvalidJWTError("Invalid or already used verification token") + + if verification.expires_at < datetime.now(UTC): + raise TokenExpiredError("verification token") + + # Mark token as used + verification.used_at = datetime.now(UTC) + + # Mark user as verified + user_result = await session.execute( + select(UserEntity).where(UserEntity.id == verification.user_id) + ) + user = user_result.scalar_one_or_none() + + if not user: + raise UserNotFoundError() + + user.is_verified = True + await session.commit() + + self.logger.info("Email verified for user: %s", user.username) + return user + + async def login( + self, + username: str, + password: str, + user_agent: str | None = None, + ip_address: str | None = None, + ) -> tuple[str, str, int]: + """ + Authenticate a user and return tokens. + + Args: + username: Username or email + password: Plain text password + user_agent: Optional user agent for session tracking + ip_address: Optional IP address for session tracking + + Returns: + Tuple of (access_token, refresh_token, expires_in_seconds) + + Raises: + InvalidCredentialsError: If credentials are invalid + EmailNotVerifiedError: If email is not verified + UserInactiveError: If user account is deactivated + """ + async with self.session_factory() as session: + # Find user by username or email + result = await session.execute( + select(UserEntity).where( + (UserEntity.username == username.lower()) + | (UserEntity.email == username.lower()) + ) + ) + user = result.scalar_one_or_none() + + if not user or not self._verify_password(password, user.password_hash): + raise InvalidCredentialsError() + + if not user.is_active: + raise UserInactiveError() + + if not user.is_verified: + raise EmailNotVerifiedError() + + # Create tokens + access_result = self.jwt_service.create_access_token(user.id) + refresh_result = self.jwt_service.create_refresh_token(user.id) + + # Store refresh token + refresh_token_entity = RefreshTokenEntity( + user_id=user.id, + token_hash=self._hash_token(refresh_result.token), + expires_at=refresh_result.expires_at, + user_agent=user_agent[:512] if user_agent else None, + ip_address=ip_address, + ) + session.add(refresh_token_entity) + + # Update last login + user.last_login_at = datetime.now(UTC) + await session.commit() + + self.logger.info("User logged in: %s", user.username) + + expires_in = int( + (access_result.expires_at - datetime.now(UTC)).total_seconds() + ) + return access_result.token, refresh_result.token, expires_in + + async def refresh_tokens( + self, + refresh_token: str, + user_agent: str | None = None, + ip_address: str | None = None, + ) -> tuple[str, str, int]: + """ + Refresh access token using refresh token. + + Implements token rotation - old refresh token is revoked. + + Args: + refresh_token: Current refresh token + user_agent: Optional user agent for session tracking + ip_address: Optional IP address for session tracking + + Returns: + Tuple of (new_access_token, new_refresh_token, expires_in_seconds) + + Raises: + InvalidJWTError: If refresh token is invalid + TokenExpiredError: If refresh token has expired + """ + # Decode and validate the refresh token + payload = self.jwt_service.decode_token(refresh_token, TokenType.REFRESH) + token_hash = self._hash_token(refresh_token) + + async with self.session_factory() as session: + # Find the refresh token in DB + result = await session.execute( + select(RefreshTokenEntity).where( + RefreshTokenEntity.token_hash == token_hash, + RefreshTokenEntity.revoked_at.is_(None), + ) + ) + stored_token = result.scalar_one_or_none() + + if not stored_token: + raise InvalidJWTError("Refresh token not found or already revoked") + + if stored_token.expires_at < datetime.now(UTC): + raise TokenExpiredError("refresh token") + + # Verify user exists and is active + user_result = await session.execute( + select(UserEntity).where(UserEntity.id == payload.user_id) + ) + user = user_result.scalar_one_or_none() + + if not user or not user.is_active: + raise InvalidJWTError("User not found or inactive") + + # Revoke old token (rotation) + stored_token.revoked_at = datetime.now(UTC) + + # Create new tokens + access_result = self.jwt_service.create_access_token(user.id) + new_refresh_result = self.jwt_service.create_refresh_token(user.id) + + # Store new refresh token + new_refresh_token = RefreshTokenEntity( + user_id=user.id, + token_hash=self._hash_token(new_refresh_result.token), + expires_at=new_refresh_result.expires_at, + user_agent=user_agent[:512] if user_agent else None, + ip_address=ip_address, + ) + session.add(new_refresh_token) + await session.commit() + + self.logger.debug("Tokens refreshed for user: %s", user.username) + + expires_in = int( + (access_result.expires_at - datetime.now(UTC)).total_seconds() + ) + return access_result.token, new_refresh_result.token, expires_in + + async def forgot_password(self, email: str) -> bool: + """ + Initiate password reset flow. + + Always returns True to prevent email enumeration. + + Args: + email: Email address + + Returns: + True (always, to prevent enumeration) + """ + async with self.session_factory() as session: + result = await session.execute( + select(UserEntity).where(UserEntity.email == email.lower()) + ) + user = result.scalar_one_or_none() + + if not user: + # Don't reveal if email exists + self.logger.debug("Password reset requested for unknown email") + return True + + # Invalidate existing reset tokens + await session.execute( + update(PasswordResetTokenEntity) + .where( + PasswordResetTokenEntity.user_id == user.id, + PasswordResetTokenEntity.used_at.is_(None), + ) + .values(used_at=datetime.now(UTC)) + ) + + # Create new reset token + token = self._generate_token() + reset_token = PasswordResetTokenEntity( + user_id=user.id, + token_hash=self._hash_token(token), + expires_at=datetime.now(UTC) + + timedelta(hours=self.config.PASSWORD_RESET_EXPIRE_HOURS), + ) + session.add(reset_token) + await session.commit() + + # Send reset email + await self.email_service.send_password_reset_email( + to=user.email, + username=user.username, + token=token, + ) + + self.logger.info("Password reset email sent to: %s", user.email) + return True + + async def reset_password(self, token: str, new_password: str) -> UserEntity: + """ + Reset user's password. + + Args: + token: Password reset token + new_password: New plain text password + + Returns: + Updated user entity + + Raises: + TokenExpiredError: If token has expired + InvalidJWTError: If token is invalid + """ + token_hash = self._hash_token(token) + + async with self.session_factory() as session: + result = await session.execute( + select(PasswordResetTokenEntity).where( + PasswordResetTokenEntity.token_hash == token_hash, + PasswordResetTokenEntity.used_at.is_(None), + ) + ) + reset_token = result.scalar_one_or_none() + + if not reset_token: + raise InvalidJWTError("Invalid or already used reset token") + + if reset_token.expires_at < datetime.now(UTC): + raise TokenExpiredError("password reset token") + + # Mark token as used + reset_token.used_at = datetime.now(UTC) + + # Update password + user_result = await session.execute( + select(UserEntity).where(UserEntity.id == reset_token.user_id) + ) + user = user_result.scalar_one_or_none() + + if not user: + raise UserNotFoundError() + + user.password_hash = self._hash_password(new_password) + + # Revoke all refresh tokens (force re-login) + refresh_result = await session.execute( + select(RefreshTokenEntity).where( + RefreshTokenEntity.user_id == user.id, + RefreshTokenEntity.revoked_at.is_(None), + ) + ) + for refresh_token_entity in refresh_result.scalars(): + refresh_token_entity.revoked_at = datetime.now(UTC) + + await session.commit() + + self.logger.info("Password reset for user: %s", user.username) + return user + + async def logout(self, user_id: UUID, refresh_token: str | None = None) -> bool: + """ + Logout user by revoking refresh tokens. + + Args: + user_id: User's UUID + refresh_token: Optional specific refresh token to revoke. + If None, revokes all user's refresh tokens. + + Returns: + True if tokens were revoked + """ + async with self.session_factory() as session: + if refresh_token: + # Revoke specific token + token_hash = self._hash_token(refresh_token) + result = await session.execute( + select(RefreshTokenEntity).where( + RefreshTokenEntity.user_id == user_id, + RefreshTokenEntity.token_hash == token_hash, + RefreshTokenEntity.revoked_at.is_(None), + ) + ) + token_entity = result.scalar_one_or_none() + if token_entity: + token_entity.revoked_at = datetime.now(UTC) + else: + # Revoke all tokens + result = await session.execute( + select(RefreshTokenEntity).where( + RefreshTokenEntity.user_id == user_id, + RefreshTokenEntity.revoked_at.is_(None), + ) + ) + for token_entity in result.scalars(): + token_entity.revoked_at = datetime.now(UTC) + + await session.commit() + self.logger.debug("Logged out user: %s", user_id) + return True + + async def resend_verification_email(self, email: str) -> bool: + """ + Resend verification email. + + Args: + email: Email address + + Returns: + True (always, to prevent enumeration) + """ + async with self.session_factory() as session: + result = await session.execute( + select(UserEntity).where( + UserEntity.email == email.lower(), + UserEntity.is_verified == False, # noqa: E712 + ) + ) + user = result.scalar_one_or_none() + + if not user: + # Don't reveal if email exists or is already verified + return True + + # Invalidate existing verification tokens + await session.execute( + update(EmailVerificationTokenEntity) + .where( + EmailVerificationTokenEntity.user_id == user.id, + EmailVerificationTokenEntity.used_at.is_(None), + ) + .values(used_at=datetime.now(UTC)) + ) + + # Create new verification token + token = self._generate_token() + verification_token = EmailVerificationTokenEntity( + user_id=user.id, + token_hash=self._hash_token(token), + expires_at=datetime.now(UTC) + + timedelta(hours=self.config.EMAIL_VERIFICATION_EXPIRE_HOURS), + ) + session.add(verification_token) + await session.commit() + + # Send verification email + await self.email_service.send_verification_email( + to=user.email, + username=user.username, + token=token, + ) + + self.logger.info("Verification email resent to: %s", user.email) + return True + + async def get_user_by_id(self, user_id: UUID) -> UserEntity | None: + """ + Get user by ID. + + Args: + user_id: User's UUID + + Returns: + User entity or None if not found + """ + async with self.session_factory() as session: + result = await session.execute( + select(UserEntity).where(UserEntity.id == user_id) + ) + return result.scalar_one_or_none() diff --git a/backend/app/services/jwt_service.py b/backend/app/services/jwt_service.py new file mode 100644 index 0000000..dbe7348 --- /dev/null +++ b/backend/app/services/jwt_service.py @@ -0,0 +1,159 @@ +"""JWT token service for authentication.""" + +import logging +import uuid +from datetime import UTC, datetime, timedelta +from enum import Enum +from typing import NamedTuple + +import jwt + +from app.config import Config +from app.exceptions import InvalidJWTError, TokenExpiredError + + +class TokenType(str, Enum): + """JWT token types.""" + + ACCESS = "access" + REFRESH = "refresh" + + +class TokenPayload(NamedTuple): + """Decoded JWT token payload.""" + + user_id: uuid.UUID + token_type: TokenType + jti: str # JWT ID for refresh token tracking + exp: datetime + iat: datetime + + +class TokenResult(NamedTuple): + """Result of token creation.""" + + token: str + expires_at: datetime + jti: str + + +class JWTService: + """Service for creating and validating JWT tokens.""" + + def __init__(self, config: Config): + self.config = config + self.secret_key = config.JWT_SECRET_KEY + self.algorithm = config.JWT_ALGORITHM + self.access_token_expire_minutes = config.JWT_ACCESS_TOKEN_EXPIRE_MINUTES + self.refresh_token_expire_days = config.JWT_REFRESH_TOKEN_EXPIRE_DAYS + self.logger = logging.getLogger(self.__class__.__name__) + + def create_access_token(self, user_id: uuid.UUID) -> TokenResult: + """ + Create a new access token. + + Args: + user_id: The user's UUID + + Returns: + TokenResult with the token string and expiration time + """ + now = datetime.now(UTC) + expires_at = now + timedelta(minutes=self.access_token_expire_minutes) + jti = str(uuid.uuid4()) + + payload = { + "sub": str(user_id), + "type": TokenType.ACCESS.value, + "jti": jti, + "iat": now, + "exp": expires_at, + } + + token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) + self.logger.debug("Created access token for user %s", user_id) + + return TokenResult(token=token, expires_at=expires_at, jti=jti) + + def create_refresh_token( + self, user_id: uuid.UUID, jti: str | None = None + ) -> TokenResult: + """ + Create a new refresh token. + + Args: + user_id: The user's UUID + jti: Optional JWT ID (generated if not provided) + + Returns: + TokenResult with the token string and expiration time + """ + now = datetime.now(UTC) + expires_at = now + timedelta(days=self.refresh_token_expire_days) + token_jti = jti or str(uuid.uuid4()) + + payload = { + "sub": str(user_id), + "type": TokenType.REFRESH.value, + "jti": token_jti, + "iat": now, + "exp": expires_at, + } + + token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) + self.logger.debug("Created refresh token for user %s", user_id) + + return TokenResult(token=token, expires_at=expires_at, jti=token_jti) + + def decode_token( + self, token: str, expected_type: TokenType | None = None + ) -> TokenPayload: + """ + Decode and validate a JWT token. + + Args: + token: The JWT token string + expected_type: Expected token type (access or refresh) + + Returns: + TokenPayload with decoded token information + + Raises: + TokenExpiredError: If the token has expired + InvalidJWTError: If the token is invalid or malformed + """ + try: + payload = jwt.decode( + token, + self.secret_key, + algorithms=[self.algorithm], + options={"require": ["sub", "type", "jti", "exp", "iat"]}, + ) + + token_type = TokenType(payload["type"]) + + if expected_type and token_type != expected_type: + raise InvalidJWTError( + f"Expected {expected_type.value} token, got {token_type.value}" + ) + + return TokenPayload( + user_id=uuid.UUID(payload["sub"]), + token_type=token_type, + jti=payload["jti"], + exp=datetime.fromtimestamp(payload["exp"], tz=UTC), + iat=datetime.fromtimestamp(payload["iat"], tz=UTC), + ) + + except jwt.ExpiredSignatureError as e: + self.logger.debug("Token expired") + token_type_name = expected_type.value if expected_type else "token" + raise TokenExpiredError(token_type_name) from e + + except jwt.InvalidTokenError as e: + self.logger.warning("Invalid token: %s", str(e)) + raise InvalidJWTError(f"Invalid token: {e}") from e + + except (KeyError, ValueError) as e: + self.logger.warning("Malformed token payload: %s", str(e)) + raise InvalidJWTError("Malformed token payload") from e From 0f2e9044b79bbac038081785f03298f131704e35 Mon Sep 17 00:00:00 2001 From: Pdzly <34943569+Pdzly@users.noreply.github.com> Date: Sun, 22 Mar 2026 13:42:25 +0100 Subject: [PATCH 3/4] docs: add auth endpoints to API table in README --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index e65a186..3c71b6d 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,15 @@ Full reference: [`docs/configuration.md`](docs/configuration.md) | `POST /pastes` | Create paste | | `GET /pastes/{id}` | Get paste | | `DELETE /pastes/{id}` | Delete paste | +| `POST /auth/register` | Register a new user | +| `POST /auth/login` | Authenticate and get tokens | +| `POST /auth/refresh` | Refresh access token | +| `POST /auth/verify-email` | Verify email address | +| `POST /auth/resend-verification` | Resend verification email | +| `POST /auth/forgot-password` | Request password reset | +| `POST /auth/reset-password` | Reset password | +| `GET /auth/me` | Get current user profile | +| `POST /auth/logout` | Logout and revoke tokens | Interactive docs at `/docs` when running. From 0857063ac799431fc5fd569b44ecc2a5678a083c Mon Sep 17 00:00:00 2001 From: Pdzly <34943569+Pdzly@users.noreply.github.com> Date: Sat, 4 Apr 2026 14:07:50 +0200 Subject: [PATCH 4/4] feat: link pastes to user accounts with ownership tracking Add user_id FK to pastes table so authenticated users' pastes are associated with their account. Adds GET /pastes/me endpoint to retrieve a user's own pastes. --- README.md | 3 +- .../20260404_add_user_id_to_pastes.py | 38 +++++++++++++++++ backend/app/api/dto/paste_dto.py | 4 ++ backend/app/api/subroutes/pastes.py | 26 +++++++++++- backend/app/db/models.py | 8 ++++ backend/app/services/paste_service.py | 42 ++++++++++++++++++- 6 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 backend/alembic/versions/20260404_add_user_id_to_pastes.py diff --git a/README.md b/README.md index 3c71b6d..00c11f8 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,9 @@ Full reference: [`docs/configuration.md`](docs/configuration.md) |----------|-------------| | `GET /health` | Health check | | `GET /metrics` | Prometheus metrics | -| `POST /pastes` | Create paste | +| `POST /pastes` | Create paste (linked to account if authenticated) | | `GET /pastes/{id}` | Get paste | +| `GET /pastes/me` | Get authenticated user's pastes | | `DELETE /pastes/{id}` | Delete paste | | `POST /auth/register` | Register a new user | | `POST /auth/login` | Authenticate and get tokens | diff --git a/backend/alembic/versions/20260404_add_user_id_to_pastes.py b/backend/alembic/versions/20260404_add_user_id_to_pastes.py new file mode 100644 index 0000000..de17a23 --- /dev/null +++ b/backend/alembic/versions/20260404_add_user_id_to_pastes.py @@ -0,0 +1,38 @@ +"""Add user_id to pastes + +Revision ID: add_user_id_to_pastes +Revises: add_auth_tables +Create Date: 2026-04-04 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_user_id_to_pastes" +down_revision: str | None = "add_auth_tables" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column("pastes", sa.Column("user_id", sa.UUID(as_uuid=True), nullable=True)) + op.create_foreign_key( + "fk_pastes_user_id", + "pastes", + "users", + ["user_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index("idx_pastes_user_id", "pastes", ["user_id"]) + + +def downgrade() -> None: + op.drop_index("idx_pastes_user_id", table_name="pastes") + op.drop_constraint("fk_pastes_user_id", "pastes", type_="foreignkey") + op.drop_column("pastes", "user_id") diff --git a/backend/app/api/dto/paste_dto.py b/backend/app/api/dto/paste_dto.py index 32aeab4..337c86e 100644 --- a/backend/app/api/dto/paste_dto.py +++ b/backend/app/api/dto/paste_dto.py @@ -82,6 +82,10 @@ class PasteResponse(BaseModel): id: UUID4 = Field( description="The unique identifier of the paste", ) + user_id: UUID4 | None = Field( + None, + description="The ID of the user who created the paste (null if anonymous)", + ) title: str = Field( description="The title of the paste", ) diff --git a/backend/app/api/subroutes/pastes.py b/backend/app/api/subroutes/pastes.py index e288480..f9920c6 100644 --- a/backend/app/api/subroutes/pastes.py +++ b/backend/app/api/subroutes/pastes.py @@ -20,7 +20,7 @@ ) from app.config import config from app.containers import Container -from app.dependencies.auth import get_optional_current_user +from app.dependencies.auth import get_current_user, get_optional_current_user from app.exceptions import PasteNotFoundError from app.ratelimit import create_auth_aware_key_func, create_auth_aware_limit_resolver, create_limit_resolver, limiter from app.services.paste_service import PasteService @@ -58,6 +58,23 @@ async def _resolve_optional_user( delete_token_key_header = APIKeyHeader(name="Authorization", scheme_name="Delete Token") +@pastes_route.get( + "/me", + response_model=list[PasteResponse], + summary="Get pastes for the authenticated user", + description="Retrieve all pastes created by the currently authenticated user.", +) +@limiter.limit(create_limit_resolver(config, "get_paste"), key_func=lambda r: ratelimit.get_exempt_key(r)) +@inject +async def get_user_pastes( + request: Request, + paste_service: PasteService = Depends(Provide[Container.paste_service]), + current_user=Depends(get_current_user), +): + """Get all pastes belonging to the authenticated user.""" + return await paste_service.get_user_pastes(current_user.id) + + @pastes_route.get( "/legacy/{paste_id}", responses={404: {"model": ErrorResponse}, 200: {"model": LegacyPasteResponse}}, @@ -212,7 +229,12 @@ async def create_paste( _current_user=Depends(_resolve_optional_user), ): """Create a new paste and return edit/delete tokens.""" - return await paste_service.create_paste(create_paste_body, request.state.user_metadata) + user = request.state.current_user + return await paste_service.create_paste( + create_paste_body, + request.state.user_metadata, + user_id=user.id if user else None, + ) @pastes_route.put( diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 9398f3b..7d1747e 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -23,6 +23,7 @@ class PasteEntity(Base): Index("idx_pastes_expires_at", "expires_at"), Index("idx_pastes_deleted_at", "deleted_at"), Index("idx_pastes_created_at", "created_at"), + Index("idx_pastes_user_id", "user_id"), ) id = Column(UUID(as_uuid=True), primary_key=True, server_default=UUID_DEFAULT) @@ -47,6 +48,12 @@ class PasteEntity(Base): delete_token = Column(String) deleted_at = Column(TIMESTAMP(timezone=True), nullable=True) + user_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + + user = relationship("UserEntity", back_populates="pastes") + def __repr__(self): return f"" @@ -95,6 +102,7 @@ class UserEntity(Base): refresh_tokens = relationship( "RefreshTokenEntity", back_populates="user", cascade="all, delete-orphan" ) + pastes = relationship("PasteEntity", back_populates="user") def __repr__(self): return f"" diff --git a/backend/app/services/paste_service.py b/backend/app/services/paste_service.py index 01358df..6ed8a12 100644 --- a/backend/app/services/paste_service.py +++ b/backend/app/services/paste_service.py @@ -240,6 +240,7 @@ async def get_paste_by_id(self, paste_id: UUID4) -> PasteResponse | None: paste_operations.labels(operation="get", status="success").inc() return PasteResponse( id=result.id, + user_id=result.user_id, title=result.title, content=content, content_language=PasteContentLanguage(result.content_language), @@ -326,6 +327,7 @@ async def edit_paste(self, paste_id: UUID4, edit_paste: EditPaste, edit_token: s paste_operations.labels(operation="edit", status="success").inc() return PasteResponse( id=result.id, + user_id=result.user_id, title=result.title, content=content, content_language=PasteContentLanguage(result.content_language), @@ -380,7 +382,7 @@ async def delete_paste(self, paste_id: UUID4, delete_token: str) -> bool: counter.dec() return True - async def create_paste(self, paste: CreatePaste, user_data: UserMetaData) -> PasteResponse: + async def create_paste(self, paste: CreatePaste, user_data: UserMetaData, user_id: uuid.UUID | None = None) -> PasteResponse: if not self.verify_storage_limit(): paste_operations.labels(operation="create", status="storage_limit").inc() raise HTTPException( @@ -426,6 +428,7 @@ async def create_paste(self, paste: CreatePaste, user_data: UserMetaData) -> Pas original_size=original_size, edit_token=edit_token_hashed, delete_token=delete_token_hashed, + user_id=user_id, ) session.add(entity) await session.commit() @@ -442,6 +445,7 @@ async def create_paste(self, paste: CreatePaste, user_data: UserMetaData) -> Pas return CreatePasteResponse( id=entity.id, + user_id=entity.user_id, title=entity.title, content=paste.content, content_language=PasteContentLanguage(entity.content_language), @@ -460,3 +464,39 @@ async def create_paste(self, paste: CreatePaste, user_data: UserMetaData) -> Pas detail="Failed to create paste", headers={"Retry-After": "60"}, ) from exc + + async def get_user_pastes(self, user_id: uuid.UUID) -> list[PasteResponse]: + async with self.session_maker() as session: + stmt = ( + select(PasteEntity) + .where( + PasteEntity.user_id == user_id, + PasteEntity.deleted_at.is_(None), + or_( + PasteEntity.expires_at > datetime.now(tz=UTC), + PasteEntity.expires_at.is_(None), + ), + ) + .order_by(PasteEntity.created_at.desc()) + ) + results = (await session.execute(stmt)).scalars().all() + + pastes = [] + for result in results: + content = await self._read_content( + result.content_path, + is_compressed=result.is_compressed, + ) + pastes.append( + PasteResponse( + id=result.id, + user_id=result.user_id, + title=result.title, + content=content, + content_language=PasteContentLanguage(result.content_language), + created_at=result.created_at, + expires_at=result.expires_at, + last_updated_at=result.last_updated_at, + ) + ) + return pastes