2a9ebe24: Implement auth endpoints (register/login/JWT) with SQLAlchemy
This commit is contained in:
@@ -1 +1,4 @@
|
||||
"""Services package."""
|
||||
from app.services.auth_service import AuthService
|
||||
|
||||
__all__ = ["AuthService"]
|
||||
|
||||
163
app/services/auth_service.py
Normal file
163
app/services/auth_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Authentication service with password hashing and JWT token management."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import Token, UserRegister
|
||||
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash.
|
||||
|
||||
Args:
|
||||
plain_password: The plain text password.
|
||||
hashed_password: The hashed password to verify against.
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise.
|
||||
"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt.
|
||||
|
||||
Args:
|
||||
password: The plain text password to hash.
|
||||
|
||||
Returns:
|
||||
The hashed password string.
|
||||
"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: The payload data to encode in the token.
|
||||
expires_delta: Optional custom expiration time delta.
|
||||
|
||||
Returns:
|
||||
The encoded JWT token string.
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
|
||||
to_encode.update({"exp": expire})
|
||||
secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production"
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: The JWT token string to decode.
|
||||
|
||||
Returns:
|
||||
The decoded token payload.
|
||||
|
||||
Raises:
|
||||
JWTError: If the token is invalid or expired.
|
||||
"""
|
||||
secret_key = settings.SECRET_KEY or "fallback-secret-key-change-in-production"
|
||||
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service class for authentication operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""Get a user by username.
|
||||
|
||||
Args:
|
||||
db: Database session.
|
||||
username: The username to search for.
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise.
|
||||
"""
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""Get a user by email.
|
||||
|
||||
Args:
|
||||
db: Database session.
|
||||
email: The email to search for.
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise.
|
||||
"""
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(db: Session, user_id: int) -> Optional[User]:
|
||||
"""Get a user by ID.
|
||||
|
||||
Args:
|
||||
db: Database session.
|
||||
user_id: The user ID to search for.
|
||||
|
||||
Returns:
|
||||
User object if found, None otherwise.
|
||||
"""
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserRegister) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
db: Database session.
|
||||
user_data: The user registration data.
|
||||
|
||||
Returns:
|
||||
The created User object.
|
||||
"""
|
||||
hashed_password = hash_password(user_data.password)
|
||||
db_user = User(
|
||||
username=user_data.username,
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
|
||||
"""Authenticate a user with username and password.
|
||||
|
||||
Args:
|
||||
db: Database session.
|
||||
username: The username.
|
||||
password: The plain text password.
|
||||
|
||||
Returns:
|
||||
User object if authentication successful, None otherwise.
|
||||
"""
|
||||
user = AuthService.get_user_by_username(db, username)
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
Reference in New Issue
Block a user