Files
claudia-docs-api/app/services/auth.py
Motoko 7f3e8a8f53 Phase 1 MVP - Complete implementation
- Auth: register, login, JWT with refresh tokens, blocklist
- Projects/Folders/Documents CRUD with soft deletes
- Tags CRUD and assignment
- FTS5 search with highlights and tag filtering
- ADR-001, ADR-002, ADR-003 compliant
- Security fixes applied (JWT_SECRET_KEY, exception handler, cookie secure)
- 25 tests passing
2026-03-30 15:17:27 +00:00

234 lines
6.8 KiB
Python

import hashlib
import secrets
import uuid
from datetime import datetime, timedelta
import bcrypt
from jose import JWTError, jwt
from sqlalchemy import select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.agent import Agent
from app.models.refresh_token import JwtBlocklist, RefreshToken
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
def create_access_token(agent_id: str, role: str) -> tuple[str, str]:
"""Create JWT access token. Returns (token, jti)."""
jti = str(uuid.uuid4())
now = datetime.utcnow()
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": agent_id,
"role": role,
"jti": jti,
"iat": now,
"exp": expire,
}
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return token, jti
def create_refresh_token() -> str:
"""Create opaque refresh token."""
return secrets.token_urlsafe(64)
def hash_token(token: str) -> str:
"""SHA-256 hash of a token."""
return hashlib.sha256(token.encode()).hexdigest()
def decode_token(token: str) -> dict | None:
"""Decode and validate JWT. Returns payload or None."""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
return payload
except JWTError:
return None
async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.username == username, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent:
agent = Agent(
id=str(uuid.uuid4()),
username=username,
password_hash=hash_password(password),
role=role,
)
db.add(agent)
await db.flush()
return agent
async def save_refresh_token(
db: AsyncSession,
user_id: str,
token: str,
user_agent: str | None,
ip_address: str | None,
) -> RefreshToken:
"""Save a new refresh token with a new family."""
token_hash = hash_token(token)
family_id = str(uuid.uuid4())
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=user_id,
token_hash=token_hash,
token_family_id=family_id,
token_version=1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(rt)
await db.flush()
return rt
async def rotate_refresh_token(
db: AsyncSession,
old_token: str,
user_agent: str | None,
ip_address: str | None,
) -> tuple[RefreshToken, str] | None:
"""
Rotate a refresh token: revoke old, create new.
Returns (new_rt, new_token) or None if invalid.
"""
token_hash = hash_token(old_token)
# Find existing token
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
RefreshToken.is_global_logout == False,
)
)
old_rt: RefreshToken | None = result.scalar_one_or_none()
if not old_rt:
return None
# Check expiry
if old_rt.expires_at < datetime.utcnow():
return None
# Reuse detection: check if a higher version exists
reuse_check = await db.execute(
select(RefreshToken).where(
RefreshToken.token_family_id == old_rt.token_family_id,
RefreshToken.token_version > old_rt.token_version,
RefreshToken.revoked_at.is_not(None),
)
)
if reuse_check.scalar_one_or_none():
# Possible theft detected: revoke entire family
await db.execute(
update(RefreshToken)
.where(RefreshToken.token_family_id == old_rt.token_family_id)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
return None
# Revoke old token
old_rt.revoked_at = datetime.utcnow()
# Create new token in same family
new_token = create_refresh_token()
new_hash = hash_token(new_token)
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=old_rt.user_id,
token_hash=new_hash,
token_family_id=old_rt.token_family_id,
token_version=old_rt.token_version + 1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(new_rt)
await db.flush()
return new_rt, new_token
async def revoke_refresh_token(db: AsyncSession, token: str) -> bool:
"""Revoke a single refresh token."""
token_hash = hash_token(token)
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
)
rt: RefreshToken | None = result.scalar_one_or_none()
if not rt:
return False
rt.revoked_at = datetime.utcnow()
await db.flush()
return True
async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool:
"""Revoke all refresh tokens for a user (logout-all)."""
await db.execute(
update(RefreshToken)
.where(
RefreshToken.user_id == user_id,
RefreshToken.revoked_at.is_(None),
)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
await db.flush()
return True
async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None:
"""Add a token JTI to the blocklist."""
entry = JwtBlocklist(token_id=jti, expires_at=expires_at)
db.add(entry)
await db.flush()
async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool:
"""Check if a token JTI is in the blocklist."""
result = await db.execute(
select(JwtBlocklist).where(JwtBlocklist.token_id == jti)
)
return result.scalar_one_or_none() is not None
async def cleanup_expired_blocklist(db: AsyncSession) -> None:
"""Remove expired entries from blocklist."""
await db.execute(
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
)
await db.flush()