2a9ebe24: Implement auth endpoints (register/login/JWT) with SQLAlchemy

This commit is contained in:
OpenClaw Agent
2026-04-16 12:51:02 +00:00
parent ef5b32143a
commit 135d4111bb
12 changed files with 417 additions and 8 deletions

View File

@@ -1,6 +1,6 @@
"""Application configuration settings."""
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional
@@ -18,9 +18,10 @@ class Settings(BaseSettings):
# Security
SECRET_KEY: Optional[str] = None
class Config:
env_file = ".env"
case_sensitive = True
model_config = SettingsConfigDict(
env_file=".env",
case_sensitive=True,
)
settings = Settings()

55
app/deps.py Normal file
View File

@@ -0,0 +1,55 @@
"""FastAPI dependencies for authentication and database access."""
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.models.user import User
from app.services.auth_service import AuthService, decode_token
security = HTTPBearer()
async def get_current_user(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
db: Annotated[Session, Depends(get_db)],
) -> User:
"""Get the current authenticated user from JWT token.
Args:
credentials: The HTTP Bearer credentials containing the JWT token.
db: Database session.
Returns:
The authenticated User object.
Raises:
HTTPException: If token is invalid or user not found.
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_token(credentials.credentials)
user_id_str: str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
user = AuthService.get_user_by_id(db, user_id)
if user is None:
raise credentials_exception
return user
# Type alias for dependency injection
CurrentUser = Annotated[User, Depends(get_current_user)]

View File

@@ -1,15 +1,29 @@
"""FastAPI application entry point."""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings
from app.db.database import engine, Base
from app.routers.auth import router as auth_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup/shutdown events."""
# Startup: Create database tables
Base.metadata.create_all(bind=engine)
yield
# Shutdown: cleanup if needed
# Create FastAPI app
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
debug=settings.DEBUG,
lifespan=lifespan,
)
# CORS middleware
@@ -21,6 +35,9 @@ app.add_middleware(
allow_headers=["*"],
)
# Include routers
app.include_router(auth_router)
@app.get("/")
async def root():

View File

@@ -1 +1,4 @@
"""Models package."""
from app.models.user import User
__all__ = ["User"]

32
app/models/user.py Normal file
View File

@@ -0,0 +1,32 @@
"""User SQLAlchemy model."""
from datetime import datetime
from typing import Optional
from sqlalchemy import DateTime, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from app.db.database import Base
class User(Base):
"""User model for authentication and profile information."""
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
avatar_url: Mapped[Optional[str]] = mapped_column(
String(500), default="/static/default-avatar.png"
)
bio: Mapped[Optional[str]] = mapped_column(Text, default="")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
def __repr__(self) -> str:
return f"<User(id={self.id}, username='{self.username}')>"

4
app/routers/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
"""Routers package."""
from app.routers.auth import router as auth_router
__all__ = ["auth_router"]

81
app/routers/auth.py Normal file
View File

@@ -0,0 +1,81 @@
"""Authentication routes for SocialPhoto API."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.db.database import get_db
from app.schemas.auth import Token, UserLogin, UserRegister
from app.services.auth_service import AuthService, create_access_token
router = APIRouter(prefix="/auth", tags=["Authentication"])
@router.post("/register", response_model=Token, status_code=status.HTTP_201_CREATED)
async def register(
user_data: UserRegister,
db: Annotated[Session, Depends(get_db)],
) -> Token:
"""Register a new user.
Args:
user_data: User registration data (username, email, password).
db: Database session.
Returns:
Token object with access token.
Raises:
HTTPException: If username or email already exists.
"""
# Check if username exists
existing_user = AuthService.get_user_by_username(db, user_data.username)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered",
)
# Check if email exists
existing_email = AuthService.get_user_by_email(db, user_data.email)
if existing_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
# Create user
user = AuthService.create_user(db, user_data)
# Create access token
access_token = create_access_token(data={"sub": str(user.id)})
return Token(access_token=access_token)
@router.post("/login", response_model=Token)
async def login(
user_data: UserLogin,
db: Annotated[Session, Depends(get_db)],
) -> Token:
"""Login and get access token.
Args:
user_data: User login data (username, password).
db: Database session.
Returns:
Token object with access token.
Raises:
HTTPException: If credentials are invalid.
"""
user = AuthService.authenticate_user(db, user_data.username, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(data={"sub": str(user.id)})
return Token(access_token=access_token)

View File

@@ -1 +1,9 @@
"""Schemas package."""
from app.schemas.auth import (
Token,
UserLogin,
UserRegister,
UserResponse,
)
__all__ = ["Token", "UserLogin", "UserRegister", "UserResponse"]

40
app/schemas/auth.py Normal file
View File

@@ -0,0 +1,40 @@
"""Authentication Pydantic schemas."""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, ConfigDict, EmailStr, Field
class UserRegister(BaseModel):
"""Request model for user registration."""
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
password: str = Field(..., min_length=6)
class UserLogin(BaseModel):
"""Request model for user login."""
username: str
password: str
class Token(BaseModel):
"""Response model for JWT token."""
access_token: str
token_type: str = "bearer"
class UserResponse(BaseModel):
"""Response model for user data."""
id: int
username: str
email: str
avatar_url: Optional[str] = "/static/default-avatar.png"
bio: Optional[str] = ""
created_at: datetime
model_config = ConfigDict(from_attributes=True)

View File

@@ -1 +1,4 @@
"""Services package."""
from app.services.auth_service import AuthService
__all__ = ["AuthService"]

View File

@@ -0,0 +1,163 @@
"""Authentication service with password hashing and JWT token management."""
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from app.core.config import settings
from app.models.user import User
from app.schemas.auth import Token, UserRegister
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash.
Args:
plain_password: The plain text password.
hashed_password: The hashed password to verify against.
Returns:
True if password matches, False otherwise.
"""
return pwd_context.verify(plain_password, hashed_password)
def hash_password(password: str) -> str:
"""Hash a password using bcrypt.
Args:
password: The plain text password to hash.
Returns:
The hashed password string.
"""
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token.
Args:
data: The payload data to encode in the token.
expires_delta: Optional custom expiration time delta.
Returns:
The encoded JWT token string.
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
to_encode.update({"exp": expire})
secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production"
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> dict:
"""Decode and verify a JWT token.
Args:
token: The JWT token string to decode.
Returns:
The decoded token payload.
Raises:
JWTError: If the token is invalid or expired.
"""
secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production"
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
return payload
class AuthService:
"""Service class for authentication operations."""
@staticmethod
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""Get a user by username.
Args:
db: Database session.
username: The username to search for.
Returns:
User object if found, None otherwise.
"""
return db.query(User).filter(User.username == username).first()
@staticmethod
def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""Get a user by email.
Args:
db: Database session.
email: The email to search for.
Returns:
User object if found, None otherwise.
"""
return db.query(User).filter(User.email == email).first()
@staticmethod
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
"""Get a user by ID.
Args:
db: Database session.
user_id: The user ID to search for.
Returns:
User object if found, None otherwise.
"""
return db.query(User).filter(User.id == user_id).first()
@staticmethod
def create_user(db: Session, user_data: UserRegister) -> User:
"""Create a new user.
Args:
db: Database session.
user_data: The user registration data.
Returns:
The created User object.
"""
hashed_password = hash_password(user_data.password)
db_user = User(
username=user_data.username,
email=user_data.email,
password_hash=hashed_password,
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
@staticmethod
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
"""Authenticate a user with username and password.
Args:
db: Database session.
username: The username.
password: The plain text password.
Returns:
User object if authentication successful, None otherwise.
"""
user = AuthService.get_user_by_username(db, username)
if not user:
return None
if not verify_password(password, user.password_hash):
return None
return user

View File

@@ -1,9 +1,8 @@
"""Pytest configuration and fixtures."""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool
from app.main import app
@@ -22,8 +21,11 @@ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engin
@pytest.fixture(scope="function")
def db_session():
def db_session() -> Session:
"""Create a fresh database for each test."""
# Import models to ensure they're registered with Base
from app.models.user import User # noqa: F401
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
try:
@@ -34,7 +36,7 @@ def db_session():
@pytest.fixture(scope="function")
def client(db_session):
def client(db_session: Session) -> TestClient:
"""Create a test client with fresh database."""
def override_get_db():