diff --git a/app/core/config.py b/app/core/config.py index d63b5c5..a0d9dd7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,6 +1,6 @@ """Application configuration settings.""" -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Optional @@ -18,9 +18,10 @@ class Settings(BaseSettings): # Security SECRET_KEY: Optional[str] = None - class Config: - env_file = ".env" - case_sensitive = True + model_config = SettingsConfigDict( + env_file=".env", + case_sensitive=True, + ) settings = Settings() diff --git a/app/deps.py b/app/deps.py new file mode 100644 index 0000000..a544338 --- /dev/null +++ b/app/deps.py @@ -0,0 +1,55 @@ +"""FastAPI dependencies for authentication and database access.""" +from typing import Annotated + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.orm import Session + +from app.db.database import get_db +from app.models.user import User +from app.services.auth_service import AuthService, decode_token + + +security = HTTPBearer() + + +async def get_current_user( + credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], + db: Annotated[Session, Depends(get_db)], +) -> User: + """Get the current authenticated user from JWT token. + + Args: + credentials: The HTTP Bearer credentials containing the JWT token. + db: Database session. + + Returns: + The authenticated User object. + + Raises: + HTTPException: If token is invalid or user not found. + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = decode_token(credentials.credentials) + user_id_str: str = payload.get("sub") + if user_id_str is None: + raise credentials_exception + user_id = int(user_id_str) + except (JWTError, ValueError): + raise credentials_exception + + user = AuthService.get_user_by_id(db, user_id) + if user is None: + raise credentials_exception + + return user + + +# Type alias for dependency injection +CurrentUser = Annotated[User, Depends(get_current_user)] diff --git a/app/main.py b/app/main.py index 7941e45..3849dc8 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,29 @@ """FastAPI application entry point.""" +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware + from app.core.config import settings from app.db.database import engine, Base +from app.routers.auth import router as auth_router + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager for startup/shutdown events.""" + # Startup: Create database tables + Base.metadata.create_all(bind=engine) + yield + # Shutdown: cleanup if needed + # Create FastAPI app app = FastAPI( title=settings.APP_NAME, version=settings.APP_VERSION, debug=settings.DEBUG, + lifespan=lifespan, ) # CORS middleware @@ -21,6 +35,9 @@ app.add_middleware( allow_headers=["*"], ) +# Include routers +app.include_router(auth_router) + @app.get("/") async def root(): diff --git a/app/models/__init__.py b/app/models/__init__.py index 53f3a7c..48da878 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1 +1,4 @@ """Models package.""" +from app.models.user import User + +__all__ = ["User"] diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 0000000..be49488 --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,32 @@ +"""User SQLAlchemy model.""" +from datetime import datetime +from typing import Optional + +from sqlalchemy import DateTime, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.sql import func + +from app.db.database import Base + + +class User(Base): + """User model for authentication and profile information.""" + + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) + email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + avatar_url: Mapped[Optional[str]] = mapped_column( + String(500), default="/static/default-avatar.png" + ) + bio: Mapped[Optional[str]] = mapped_column(Text, default="") + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..aea9060 --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1,4 @@ +"""Routers package.""" +from app.routers.auth import router as auth_router + +__all__ = ["auth_router"] diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..ee4532c --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,81 @@ +"""Authentication routes for SocialPhoto API.""" +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from app.db.database import get_db +from app.schemas.auth import Token, UserLogin, UserRegister +from app.services.auth_service import AuthService, create_access_token + + +router = APIRouter(prefix="/auth", tags=["Authentication"]) + + +@router.post("/register", response_model=Token, status_code=status.HTTP_201_CREATED) +async def register( + user_data: UserRegister, + db: Annotated[Session, Depends(get_db)], +) -> Token: + """Register a new user. + + Args: + user_data: User registration data (username, email, password). + db: Database session. + + Returns: + Token object with access token. + + Raises: + HTTPException: If username or email already exists. + """ + # Check if username exists + existing_user = AuthService.get_user_by_username(db, user_data.username) + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered", + ) + + # Check if email exists + existing_email = AuthService.get_user_by_email(db, user_data.email) + if existing_email: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) + + # Create user + user = AuthService.create_user(db, user_data) + + # Create access token + access_token = create_access_token(data={"sub": str(user.id)}) + return Token(access_token=access_token) + + +@router.post("/login", response_model=Token) +async def login( + user_data: UserLogin, + db: Annotated[Session, Depends(get_db)], +) -> Token: + """Login and get access token. + + Args: + user_data: User login data (username, password). + db: Database session. + + Returns: + Token object with access token. + + Raises: + HTTPException: If credentials are invalid. + """ + user = AuthService.authenticate_user(db, user_data.username, user_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + ) + + access_token = create_access_token(data={"sub": str(user.id)}) + return Token(access_token=access_token) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index 3a8b2f5..055a844 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -1 +1,9 @@ """Schemas package.""" +from app.schemas.auth import ( + Token, + UserLogin, + UserRegister, + UserResponse, +) + +__all__ = ["Token", "UserLogin", "UserRegister", "UserResponse"] diff --git a/app/schemas/auth.py b/app/schemas/auth.py new file mode 100644 index 0000000..146e293 --- /dev/null +++ b/app/schemas/auth.py @@ -0,0 +1,40 @@ +"""Authentication Pydantic schemas.""" +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, ConfigDict, EmailStr, Field + + +class UserRegister(BaseModel): + """Request model for user registration.""" + + username: str = Field(..., min_length=3, max_length=50) + email: EmailStr + password: str = Field(..., min_length=6) + + +class UserLogin(BaseModel): + """Request model for user login.""" + + username: str + password: str + + +class Token(BaseModel): + """Response model for JWT token.""" + + access_token: str + token_type: str = "bearer" + + +class UserResponse(BaseModel): + """Response model for user data.""" + + id: int + username: str + email: str + avatar_url: Optional[str] = "/static/default-avatar.png" + bio: Optional[str] = "" + created_at: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/app/services/__init__.py b/app/services/__init__.py index c7775ec..17a9d15 100644 --- a/app/services/__init__.py +++ b/app/services/__init__.py @@ -1 +1,4 @@ """Services package.""" +from app.services.auth_service import AuthService + +__all__ = ["AuthService"] diff --git a/app/services/auth_service.py b/app/services/auth_service.py new file mode 100644 index 0000000..9bcffeb --- /dev/null +++ b/app/services/auth_service.py @@ -0,0 +1,163 @@ +"""Authentication service with password hashing and JWT token management.""" +from datetime import datetime, timedelta, timezone +from typing import Optional + +from jose import JWTError, jwt +from passlib.context import CryptContext +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.models.user import User +from app.schemas.auth import Token, UserRegister + + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_HOURS = 24 + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash. + + Args: + plain_password: The plain text password. + hashed_password: The hashed password to verify against. + + Returns: + True if password matches, False otherwise. + """ + return pwd_context.verify(plain_password, hashed_password) + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt. + + Args: + password: The plain text password to hash. + + Returns: + The hashed password string. + """ + return pwd_context.hash(password) + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """Create a JWT access token. + + Args: + data: The payload data to encode in the token. + expires_delta: Optional custom expiration time delta. + + Returns: + The encoded JWT token string. + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) + to_encode.update({"exp": expire}) + secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production" + encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_token(token: str) -> dict: + """Decode and verify a JWT token. + + Args: + token: The JWT token string to decode. + + Returns: + The decoded token payload. + + Raises: + JWTError: If the token is invalid or expired. + """ + secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production" + payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) + return payload + + +class AuthService: + """Service class for authentication operations.""" + + @staticmethod + def get_user_by_username(db: Session, username: str) -> Optional[User]: + """Get a user by username. + + Args: + db: Database session. + username: The username to search for. + + Returns: + User object if found, None otherwise. + """ + return db.query(User).filter(User.username == username).first() + + @staticmethod + def get_user_by_email(db: Session, email: str) -> Optional[User]: + """Get a user by email. + + Args: + db: Database session. + email: The email to search for. + + Returns: + User object if found, None otherwise. + """ + return db.query(User).filter(User.email == email).first() + + @staticmethod + def get_user_by_id(db: Session, user_id: int) -> Optional[User]: + """Get a user by ID. + + Args: + db: Database session. + user_id: The user ID to search for. + + Returns: + User object if found, None otherwise. + """ + return db.query(User).filter(User.id == user_id).first() + + @staticmethod + def create_user(db: Session, user_data: UserRegister) -> User: + """Create a new user. + + Args: + db: Database session. + user_data: The user registration data. + + Returns: + The created User object. + """ + hashed_password = hash_password(user_data.password) + db_user = User( + username=user_data.username, + email=user_data.email, + password_hash=hashed_password, + ) + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user + + @staticmethod + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: + """Authenticate a user with username and password. + + Args: + db: Database session. + username: The username. + password: The plain text password. + + Returns: + User object if authentication successful, None otherwise. + """ + user = AuthService.get_user_by_username(db, username) + if not user: + return None + if not verify_password(password, user.password_hash): + return None + return user diff --git a/tests/conftest.py b/tests/conftest.py index 7fb5338..09fc9d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,8 @@ """Pytest configuration and fixtures.""" - import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import StaticPool from app.main import app @@ -22,8 +21,11 @@ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engin @pytest.fixture(scope="function") -def db_session(): +def db_session() -> Session: """Create a fresh database for each test.""" + # Import models to ensure they're registered with Base + from app.models.user import User # noqa: F401 + Base.metadata.create_all(bind=engine) db = TestingSessionLocal() try: @@ -34,7 +36,7 @@ def db_session(): @pytest.fixture(scope="function") -def client(db_session): +def client(db_session: Session) -> TestClient: """Create a test client with fresh database.""" def override_get_db():