feat: implement Instagram clone SocialPhoto API
- FastAPI backend with SQLite database - JWT authentication (register, login) - User profiles with follow/unfollow - Posts with image upload and likes/dislikes - Comments with likes - Global and personalized feed - Comprehensive pytest test suite (37 tests) TASK-ID: 758f4029-702
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""SocialPhoto - Instagram Clone API."""
|
||||
66
app/auth.py
Normal file
66
app/auth.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Authentication utilities for SocialPhoto."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
# Configuration
|
||||
SECRET_KEY = "your-secret-key-change-in-production" # TODO: Move to env
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_HOURS = 24
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token."""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""Decode and verify a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> int:
|
||||
"""Extract user ID from JWT token."""
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
user_id: int = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
)
|
||||
return int(user_id)
|
||||
120
app/database.py
Normal file
120
app/database.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Database module for SocialPhoto."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
DATABASE_PATH = Path(__file__).parent.parent / "socialphoto.db"
|
||||
|
||||
|
||||
def get_db_connection() -> sqlite3.Connection:
|
||||
"""Get a database connection with row factory."""
|
||||
conn = sqlite3.connect(DATABASE_PATH, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def get_db() -> sqlite3.Connection:
|
||||
"""Dependency for FastAPI routes."""
|
||||
conn = get_db_connection()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the database with all tables."""
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Users table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
avatar_url TEXT DEFAULT '/static/default-avatar.png',
|
||||
bio TEXT DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Posts table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS posts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
image_path TEXT NOT NULL,
|
||||
caption TEXT DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Comments table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS comments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
post_id INTEGER NOT NULL REFERENCES posts(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Likes table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS likes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
post_id INTEGER NOT NULL REFERENCES posts(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(post_id, user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Dislikes table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS dislikes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
post_id INTEGER NOT NULL REFERENCES posts(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(post_id, user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Follows table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS follows (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
follower_id INTEGER NOT NULL REFERENCES users(id),
|
||||
following_id INTEGER NOT NULL REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(follower_id, following_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Comment likes table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS comment_likes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
comment_id INTEGER NOT NULL REFERENCES comments(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(comment_id, user_id)
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def row_to_dict(row: sqlite3.Row) -> dict:
|
||||
"""Convert a sqlite Row to a dictionary."""
|
||||
return dict(row)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db()
|
||||
print("Database initialized successfully.")
|
||||
69
app/main.py
Normal file
69
app/main.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""SocialPhoto - Instagram Clone API.
|
||||
|
||||
A simple social media API for sharing images with likes, comments, and user follows.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.database import init_db
|
||||
from app.routes import auth, comments, feed, posts, users
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="SocialPhoto",
|
||||
description="Instagram Clone API - Share images with likes, comments, and follows",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount uploads directory
|
||||
UPLOAD_DIR = Path(__file__).parent.parent / "uploads"
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize database on startup."""
|
||||
init_db()
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(users.router)
|
||||
app.include_router(posts.router)
|
||||
app.include_router(comments.router)
|
||||
app.include_router(feed.router)
|
||||
|
||||
|
||||
@app.get("/", tags=["Root"])
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"name": "SocialPhoto",
|
||||
"version": "1.0.0",
|
||||
"description": "Instagram Clone API",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
109
app/models.py
Normal file
109
app/models.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Pydantic models for SocialPhoto API."""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
# Auth Models
|
||||
class UserRegister(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=6)
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""Request model for user login."""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""Response model for JWT token."""
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
# User Models
|
||||
class UserBase(BaseModel):
|
||||
"""Base user model."""
|
||||
username: str
|
||||
email: str
|
||||
avatar_url: Optional[str] = "/static/default-avatar.png"
|
||||
bio: Optional[str] = ""
|
||||
|
||||
|
||||
class UserResponse(UserBase):
|
||||
"""Response model for user data."""
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserStats(BaseModel):
|
||||
"""User statistics model."""
|
||||
posts_count: int
|
||||
followers_count: int
|
||||
following_count: int
|
||||
|
||||
|
||||
# Post Models
|
||||
class PostCreate(BaseModel):
|
||||
"""Request model for creating a post."""
|
||||
caption: Optional[str] = ""
|
||||
|
||||
|
||||
class PostResponse(BaseModel):
|
||||
"""Response model for post data."""
|
||||
id: int
|
||||
user_id: int
|
||||
username: str
|
||||
image_url: str
|
||||
caption: str
|
||||
likes_count: int
|
||||
dislikes_count: int
|
||||
comments_count: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PostDetail(PostResponse):
|
||||
"""Detailed post response with user info."""
|
||||
user: UserResponse
|
||||
|
||||
|
||||
# Comment Models
|
||||
class CommentCreate(BaseModel):
|
||||
"""Request model for creating a comment."""
|
||||
content: str = Field(..., min_length=1, max_length=500)
|
||||
|
||||
|
||||
class CommentResponse(BaseModel):
|
||||
"""Response model for comment data."""
|
||||
id: int
|
||||
post_id: int
|
||||
user_id: int
|
||||
username: str
|
||||
content: str
|
||||
likes_count: int
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# Feed Models
|
||||
class FeedResponse(BaseModel):
|
||||
"""Response model for feed."""
|
||||
posts: List[PostResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
# Error Models
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standard error response."""
|
||||
detail: str
|
||||
1
app/routes/__init__.py
Normal file
1
app/routes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Routes package for SocialPhoto."""
|
||||
81
app/routes/auth.py
Normal file
81
app/routes/auth.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Authentication routes for SocialPhoto."""
|
||||
import sqlite3
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.auth import create_access_token, hash_password, verify_password
|
||||
from app.database import get_db, row_to_dict
|
||||
from app.models import Token, UserLogin, UserRegister
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
@router.post("/register", response_model=Token, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserRegister,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> Token:
|
||||
"""Register a new user."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if username exists
|
||||
cursor.execute("SELECT id FROM users WHERE username = ?", (user_data.username,))
|
||||
if cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered",
|
||||
)
|
||||
|
||||
# Check if email exists
|
||||
cursor.execute("SELECT id FROM users WHERE email = ?", (user_data.email,))
|
||||
if cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
# Hash password and create user
|
||||
password_hash = hash_password(user_data.password)
|
||||
cursor.execute(
|
||||
"INSERT INTO users (username, email, password_hash) VALUES (?, ?, ?)",
|
||||
(user_data.username, user_data.email, password_hash),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Get the new user's ID
|
||||
cursor.execute("SELECT id FROM users WHERE username = ?", (user_data.username,))
|
||||
user = row_to_dict(cursor.fetchone())
|
||||
user_id = user["id"]
|
||||
|
||||
# Create access token
|
||||
access_token = create_access_token(data={"sub": str(user_id)})
|
||||
return Token(access_token=access_token)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
user_data: UserLogin,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> Token:
|
||||
"""Login and get access token."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Find user by username
|
||||
cursor.execute(
|
||||
"SELECT id, password_hash FROM users WHERE username = ?",
|
||||
(user_data.username,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row or not verify_password(user_data.password, row["password_hash"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
user = row_to_dict(row)
|
||||
access_token = create_access_token(data={"sub": user["id"]})
|
||||
return Token(access_token=access_token)
|
||||
79
app/routes/comments.py
Normal file
79
app/routes/comments.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Comment routes for SocialPhoto - comment-specific operations."""
|
||||
import sqlite3
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.auth import get_current_user_id
|
||||
from app.database import get_db, row_to_dict
|
||||
|
||||
router = APIRouter(prefix="/comments", tags=["Comments"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
@router.delete("/{comment_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_comment(
|
||||
comment_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a comment (only by owner)."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check comment exists and belongs to user
|
||||
cursor.execute(
|
||||
"SELECT user_id FROM comments WHERE id = ?",
|
||||
(comment_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Comment not found",
|
||||
)
|
||||
|
||||
comment = row_to_dict(row)
|
||||
if comment["user_id"] != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only delete your own comments",
|
||||
)
|
||||
|
||||
# Delete comment
|
||||
cursor.execute("DELETE FROM comments WHERE id = ?", (comment_id,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@router.post("/{comment_id}/like", status_code=status.HTTP_201_CREATED)
|
||||
async def like_comment(
|
||||
comment_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Like a comment."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check comment exists
|
||||
cursor.execute("SELECT id FROM comments WHERE id = ?", (comment_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Comment not found",
|
||||
)
|
||||
|
||||
# Add like
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO comment_likes (comment_id, user_id) VALUES (?, ?)",
|
||||
(comment_id, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You already liked this comment",
|
||||
)
|
||||
|
||||
return {"message": "Comment liked"}
|
||||
124
app/routes/feed.py
Normal file
124
app/routes/feed.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Feed routes for SocialPhoto."""
|
||||
import sqlite3
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.auth import get_current_user_id
|
||||
from app.database import get_db, row_to_dict
|
||||
from app.models import FeedResponse, PostResponse
|
||||
|
||||
router = APIRouter(prefix="/feed", tags=["Feed"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
@router.get("", response_model=FeedResponse)
|
||||
async def get_followed_feed(
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> FeedResponse:
|
||||
"""Get feed of posts from users you follow."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get posts from followed users
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
p.id, p.user_id, u.username, p.image_path, p.caption, p.created_at,
|
||||
(SELECT COUNT(*) FROM likes WHERE post_id = p.id) as likes_count,
|
||||
(SELECT COUNT(*) FROM dislikes WHERE post_id = p.id) as dislikes_count,
|
||||
(SELECT COUNT(*) FROM comments WHERE post_id = p.id) as comments_count
|
||||
FROM posts p
|
||||
JOIN users u ON p.user_id = u.id
|
||||
WHERE p.user_id IN (
|
||||
SELECT following_id FROM follows WHERE follower_id = ?
|
||||
)
|
||||
ORDER BY p.created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(user_id, limit, offset),
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
post = row_to_dict(row)
|
||||
posts.append(
|
||||
PostResponse(
|
||||
id=post["id"],
|
||||
user_id=post["user_id"],
|
||||
username=post["username"],
|
||||
image_url=f"/uploads/{post['image_path'].split('/')[-1]}",
|
||||
caption=post["caption"],
|
||||
likes_count=post["likes_count"],
|
||||
dislikes_count=post["dislikes_count"],
|
||||
comments_count=post["comments_count"],
|
||||
created_at=post["created_at"],
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) as total
|
||||
FROM posts p
|
||||
WHERE p.user_id IN (
|
||||
SELECT following_id FROM follows WHERE follower_id = ?
|
||||
)
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
return FeedResponse(posts=posts, total=total, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.get("/global", response_model=FeedResponse)
|
||||
async def get_global_feed(
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> FeedResponse:
|
||||
"""Get global feed of all posts."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
p.id, p.user_id, u.username, p.image_path, p.caption, p.created_at,
|
||||
(SELECT COUNT(*) FROM likes WHERE post_id = p.id) as likes_count,
|
||||
(SELECT COUNT(*) FROM dislikes WHERE post_id = p.id) as dislikes_count,
|
||||
(SELECT COUNT(*) FROM comments WHERE post_id = p.id) as comments_count
|
||||
FROM posts p
|
||||
JOIN users u ON p.user_id = u.id
|
||||
ORDER BY p.created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
post = row_to_dict(row)
|
||||
posts.append(
|
||||
PostResponse(
|
||||
id=post["id"],
|
||||
user_id=post["user_id"],
|
||||
username=post["username"],
|
||||
image_url=f"/uploads/{post['image_path'].split('/')[-1]}",
|
||||
caption=post["caption"],
|
||||
likes_count=post["likes_count"],
|
||||
dislikes_count=post["dislikes_count"],
|
||||
comments_count=post["comments_count"],
|
||||
created_at=post["created_at"],
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count
|
||||
cursor.execute("SELECT COUNT(*) as total FROM posts")
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
return FeedResponse(posts=posts, total=total, limit=limit, offset=offset)
|
||||
400
app/routes/posts.py
Normal file
400
app/routes/posts.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Post routes for SocialPhoto."""
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.auth import get_current_user_id
|
||||
from app.database import get_db, row_to_dict
|
||||
from app.models import CommentCreate, CommentResponse, PostResponse
|
||||
|
||||
router = APIRouter(prefix="/posts", tags=["Posts"])
|
||||
security = HTTPBearer()
|
||||
|
||||
# Configuration
|
||||
UPLOAD_DIR = Path(__file__).parent.parent.parent / "uploads"
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
|
||||
# Ensure upload directory exists
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _get_post_with_counts(conn: sqlite3.Connection, post_id: int) -> Optional[dict]:
|
||||
"""Get post data with like/dislike/comment counts."""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
p.id, p.user_id, u.username, p.image_path, p.caption, p.created_at,
|
||||
(SELECT COUNT(*) FROM likes WHERE post_id = p.id) as likes_count,
|
||||
(SELECT COUNT(*) FROM dislikes WHERE post_id = p.id) as dislikes_count,
|
||||
(SELECT COUNT(*) FROM comments WHERE post_id = p.id) as comments_count
|
||||
FROM posts p
|
||||
JOIN users u ON p.user_id = u.id
|
||||
WHERE p.id = ?
|
||||
""",
|
||||
(post_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
post = row_to_dict(row)
|
||||
post["image_url"] = f"/uploads/{post['image_path'].split('/')[-1]}"
|
||||
return post
|
||||
|
||||
|
||||
@router.post("", response_model=PostResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_post(
|
||||
caption: str = Form(""),
|
||||
image: UploadFile = File(...),
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> PostResponse:
|
||||
"""Create a new post with image."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
|
||||
# Validate file type
|
||||
file_ext = Path(image.filename).suffix.lower()
|
||||
if file_ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File type not allowed. Allowed: {', '.join(ALLOWED_EXTENSIONS)}",
|
||||
)
|
||||
|
||||
# Generate unique filename
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
file_path = UPLOAD_DIR / unique_filename
|
||||
|
||||
# Save file
|
||||
contents = await image.read()
|
||||
if len(contents) > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="File too large. Maximum size is 10MB",
|
||||
)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
# Insert into database
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO posts (user_id, image_path, caption) VALUES (?, ?, ?)",
|
||||
(user_id, str(file_path), caption),
|
||||
)
|
||||
conn.commit()
|
||||
post_id = cursor.lastrowid
|
||||
|
||||
# Get the created post
|
||||
post = _get_post_with_counts(conn, post_id)
|
||||
return PostResponse(**post)
|
||||
|
||||
|
||||
@router.get("", response_model=List[PostResponse])
|
||||
async def get_posts(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> List[PostResponse]:
|
||||
"""Get global feed of posts."""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
p.id, p.user_id, u.username, p.image_path, p.caption, p.created_at,
|
||||
(SELECT COUNT(*) FROM likes WHERE post_id = p.id) as likes_count,
|
||||
(SELECT COUNT(*) FROM dislikes WHERE post_id = p.id) as dislikes_count,
|
||||
(SELECT COUNT(*) FROM comments WHERE post_id = p.id) as comments_count
|
||||
FROM posts p
|
||||
JOIN users u ON p.user_id = u.id
|
||||
ORDER BY p.created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
post = row_to_dict(row)
|
||||
posts.append(
|
||||
PostResponse(
|
||||
id=post["id"],
|
||||
user_id=post["user_id"],
|
||||
username=post["username"],
|
||||
image_url=f"/uploads/{post['image_path'].split('/')[-1]}",
|
||||
caption=post["caption"],
|
||||
likes_count=post["likes_count"],
|
||||
dislikes_count=post["dislikes_count"],
|
||||
comments_count=post["comments_count"],
|
||||
created_at=post["created_at"],
|
||||
)
|
||||
)
|
||||
|
||||
return posts
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=PostResponse)
|
||||
async def get_post(
|
||||
post_id: int,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> PostResponse:
|
||||
"""Get a specific post."""
|
||||
post = _get_post_with_counts(conn, post_id)
|
||||
if not post:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
return PostResponse(**post)
|
||||
|
||||
|
||||
@router.delete("/{post_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_post(
|
||||
post_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a post (only by owner)."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check post exists and belongs to user
|
||||
cursor.execute("SELECT user_id, image_path FROM posts WHERE id = ?", (post_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
|
||||
post = row_to_dict(row)
|
||||
if post["user_id"] != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only delete your own posts",
|
||||
)
|
||||
|
||||
# Delete image file
|
||||
image_path = Path(post["image_path"])
|
||||
if image_path.exists():
|
||||
image_path.unlink()
|
||||
|
||||
# Delete post (cascade deletes comments, likes, dislikes)
|
||||
cursor.execute("DELETE FROM posts WHERE id = ?", (post_id,))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@router.post("/{post_id}/like", status_code=status.HTTP_201_CREATED)
|
||||
async def like_post(
|
||||
post_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Like a post."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check post exists
|
||||
cursor.execute("SELECT id FROM posts WHERE id = ?", (post_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
|
||||
# Remove any existing dislike first
|
||||
cursor.execute(
|
||||
"DELETE FROM dislikes WHERE post_id = ? AND user_id = ?",
|
||||
(post_id, user_id),
|
||||
)
|
||||
|
||||
# Add like
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO likes (post_id, user_id) VALUES (?, ?)",
|
||||
(post_id, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You already liked this post",
|
||||
)
|
||||
|
||||
return {"message": "Post liked"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/like", status_code=status.HTTP_200_OK)
|
||||
async def unlike_post(
|
||||
post_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Remove like from a post."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM likes WHERE post_id = ? AND user_id = ?",
|
||||
(post_id, user_id),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="You haven't liked this post",
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {"message": "Like removed"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/dislike", status_code=status.HTTP_201_CREATED)
|
||||
async def dislike_post(
|
||||
post_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Dislike a post."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check post exists
|
||||
cursor.execute("SELECT id FROM posts WHERE id = ?", (post_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
|
||||
# Remove any existing like first
|
||||
cursor.execute(
|
||||
"DELETE FROM likes WHERE post_id = ? AND user_id = ?",
|
||||
(post_id, user_id),
|
||||
)
|
||||
|
||||
# Add dislike
|
||||
try:
|
||||
cursor.execute(
|
||||
"INSERT INTO dislikes (post_id, user_id) VALUES (?, ?)",
|
||||
(post_id, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You already disliked this post",
|
||||
)
|
||||
|
||||
return {"message": "Post disliked"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/dislike", status_code=status.HTTP_200_OK)
|
||||
async def undislike_post(
|
||||
post_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Remove dislike from a post."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM dislikes WHERE post_id = ? AND user_id = ?",
|
||||
(post_id, user_id),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="You haven't disliked this post",
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {"message": "Dislike removed"}
|
||||
|
||||
|
||||
@router.get("/{post_id}/comments", response_model=list[CommentResponse])
|
||||
async def get_post_comments(
|
||||
post_id: int,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> list[CommentResponse]:
|
||||
"""Get all comments for a post."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check post exists
|
||||
cursor.execute("SELECT id FROM posts WHERE id = ?", (post_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.id, c.post_id, c.user_id, u.username, c.content, c.created_at,
|
||||
(SELECT COUNT(*) FROM comment_likes WHERE comment_id = c.id) as likes_count
|
||||
FROM comments c
|
||||
JOIN users u ON c.user_id = u.id
|
||||
WHERE c.post_id = ?
|
||||
ORDER BY c.created_at ASC
|
||||
""",
|
||||
(post_id,),
|
||||
)
|
||||
|
||||
comments = []
|
||||
for row in cursor.fetchall():
|
||||
comment = row_to_dict(row)
|
||||
comments.append(CommentResponse(**comment))
|
||||
|
||||
return comments
|
||||
|
||||
|
||||
@router.post("/{post_id}/comments", response_model=CommentResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_comment(
|
||||
post_id: int,
|
||||
comment_data: CommentCreate,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> CommentResponse:
|
||||
"""Create a comment on a post."""
|
||||
user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check post exists
|
||||
cursor.execute("SELECT id FROM posts WHERE id = ?", (post_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Post not found",
|
||||
)
|
||||
|
||||
# Create comment
|
||||
cursor.execute(
|
||||
"INSERT INTO comments (post_id, user_id, content) VALUES (?, ?, ?)",
|
||||
(post_id, user_id, comment_data.content),
|
||||
)
|
||||
conn.commit()
|
||||
comment_id = cursor.lastrowid
|
||||
|
||||
# Get the created comment with user info
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.id, c.post_id, c.user_id, u.username, c.content, c.created_at,
|
||||
(SELECT COUNT(*) FROM comment_likes WHERE comment_id = c.id) as likes_count
|
||||
FROM comments c
|
||||
JOIN users u ON c.user_id = u.id
|
||||
WHERE c.id = ?
|
||||
""",
|
||||
(comment_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
comment = row_to_dict(row)
|
||||
return CommentResponse(**comment)
|
||||
205
app/routes/users.py
Normal file
205
app/routes/users.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""User routes for SocialPhoto."""
|
||||
import sqlite3
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.auth import get_current_user_id
|
||||
from app.database import get_db, row_to_dict
|
||||
from app.models import PostResponse, UserResponse, UserStats
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["Users"])
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def _get_user_with_stats(conn: sqlite3.Connection, user_id: int) -> dict:
|
||||
"""Get user data with stats."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get user
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
user = row_to_dict(row)
|
||||
|
||||
# Get posts count
|
||||
cursor.execute("SELECT COUNT(*) as count FROM posts WHERE user_id = ?", (user_id,))
|
||||
posts_count = cursor.fetchone()["count"]
|
||||
|
||||
# Get followers count
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as count FROM follows WHERE following_id = ?",
|
||||
(user_id,),
|
||||
)
|
||||
followers_count = cursor.fetchone()["count"]
|
||||
|
||||
# Get following count
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) as count FROM follows WHERE follower_id = ?",
|
||||
(user_id,),
|
||||
)
|
||||
following_count = cursor.fetchone()["count"]
|
||||
|
||||
user["posts_count"] = posts_count
|
||||
user["followers_count"] = followers_count
|
||||
user["following_count"] = following_count
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> UserResponse:
|
||||
"""Get user profile."""
|
||||
user = _get_user_with_stats(conn, user_id)
|
||||
return UserResponse(**user)
|
||||
|
||||
|
||||
@router.get("/{user_id}/posts", response_model=List[PostResponse])
|
||||
async def get_user_posts(
|
||||
user_id: int,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> List[PostResponse]:
|
||||
"""Get all posts by a user."""
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check user exists
|
||||
cursor.execute("SELECT id FROM users WHERE id = ?", (user_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Get posts with counts
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
p.id, p.user_id, u.username, p.image_path, p.caption, p.created_at,
|
||||
(SELECT COUNT(*) FROM likes WHERE post_id = p.id) as likes_count,
|
||||
(SELECT COUNT(*) FROM dislikes WHERE post_id = p.id) as dislikes_count,
|
||||
(SELECT COUNT(*) FROM comments WHERE post_id = p.id) as comments_count
|
||||
FROM posts p
|
||||
JOIN users u ON p.user_id = u.id
|
||||
WHERE p.user_id = ?
|
||||
ORDER BY p.created_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
post = row_to_dict(row)
|
||||
posts.append(
|
||||
PostResponse(
|
||||
id=post["id"],
|
||||
user_id=post["user_id"],
|
||||
username=post["username"],
|
||||
image_url=f"/uploads/{post['image_path'].split('/')[-1]}",
|
||||
caption=post["caption"],
|
||||
likes_count=post["likes_count"],
|
||||
dislikes_count=post["dislikes_count"],
|
||||
comments_count=post["comments_count"],
|
||||
created_at=post["created_at"],
|
||||
)
|
||||
)
|
||||
|
||||
return posts
|
||||
|
||||
|
||||
@router.get("/{user_id}/stats", response_model=UserStats)
|
||||
async def get_user_stats(
|
||||
user_id: int,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> UserStats:
|
||||
"""Get user statistics."""
|
||||
user = _get_user_with_stats(conn, user_id)
|
||||
return UserStats(
|
||||
posts_count=user["posts_count"],
|
||||
followers_count=user["followers_count"],
|
||||
following_count=user["following_count"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{user_id}/follow", status_code=status.HTTP_201_CREATED)
|
||||
async def follow_user(
|
||||
user_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Follow a user."""
|
||||
current_user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check user exists
|
||||
cursor.execute("SELECT id FROM users WHERE id = ?", (user_id,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Cannot follow yourself
|
||||
if current_user_id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You cannot follow yourself",
|
||||
)
|
||||
|
||||
# Check if already following
|
||||
cursor.execute(
|
||||
"SELECT id FROM follows WHERE follower_id = ? AND following_id = ?",
|
||||
(current_user_id, user_id),
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You already follow this user",
|
||||
)
|
||||
|
||||
# Create follow
|
||||
cursor.execute(
|
||||
"INSERT INTO follows (follower_id, following_id) VALUES (?, ?)",
|
||||
(current_user_id, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {"message": "Successfully followed user"}
|
||||
|
||||
|
||||
@router.delete("/{user_id}/follow", status_code=status.HTTP_200_OK)
|
||||
async def unfollow_user(
|
||||
user_id: int,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Unfollow a user."""
|
||||
current_user_id = await get_current_user_id(credentials)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if following
|
||||
cursor.execute(
|
||||
"SELECT id FROM follows WHERE follower_id = ? AND following_id = ?",
|
||||
(current_user_id, user_id),
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="You are not following this user",
|
||||
)
|
||||
|
||||
# Delete follow
|
||||
cursor.execute(
|
||||
"DELETE FROM follows WHERE follower_id = ? AND following_id = ?",
|
||||
(current_user_id, user_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return {"message": "Successfully unfollowed user"}
|
||||
Reference in New Issue
Block a user