Files
claudia-docs-api/app/services/auth.py
Motoko 204badb964 feat: Add role-based API tokens for Claudia Docs
- Add api_tokens table with role-based access (researcher, developer, viewer)
- Add POST /auth/token/generate endpoint for creating tokens
- Add GET /auth/tokens endpoint for listing user's tokens
- Add DELETE /auth/tokens/{token_id} endpoint for revoking tokens
- Add agent_type field to documents (research, development, general)
- Implement role-based access control for documents:
  - researcher: access to research and general documents
  - developer: access to development and general documents
  - viewer: read-only access
- Update document model and schemas with agent_type field
- Add comprehensive tests for API token functionality
- All existing tests pass (73 total)
2026-03-31 01:46:51 +00:00

317 lines
9.1 KiB
Python

import hashlib
import secrets
import uuid
from datetime import datetime, timedelta
import bcrypt
from jose import JWTError, jwt
from sqlalchemy import select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.agent import Agent
from app.models.refresh_token import JwtBlocklist, RefreshToken
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
def create_access_token(agent_id: str, role: str) -> tuple[str, str]:
"""Create JWT access token. Returns (token, jti)."""
jti = str(uuid.uuid4())
now = datetime.utcnow()
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": agent_id,
"role": role,
"jti": jti,
"iat": now,
"exp": expire,
}
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return token, jti
def create_refresh_token() -> str:
"""Create opaque refresh token."""
return secrets.token_urlsafe(64)
def hash_token(token: str) -> str:
"""SHA-256 hash of a token."""
return hashlib.sha256(token.encode()).hexdigest()
def decode_token(token: str) -> dict | None:
"""Decode and validate JWT. Returns payload or None."""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
return payload
except JWTError:
return None
async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.username == username, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent:
agent = Agent(
id=str(uuid.uuid4()),
username=username,
password_hash=hash_password(password),
role=role,
)
db.add(agent)
await db.flush()
return agent
async def save_refresh_token(
db: AsyncSession,
user_id: str,
token: str,
user_agent: str | None,
ip_address: str | None,
) -> RefreshToken:
"""Save a new refresh token with a new family."""
token_hash = hash_token(token)
family_id = str(uuid.uuid4())
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=user_id,
token_hash=token_hash,
token_family_id=family_id,
token_version=1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(rt)
await db.flush()
return rt
async def rotate_refresh_token(
db: AsyncSession,
old_token: str,
user_agent: str | None,
ip_address: str | None,
) -> tuple[RefreshToken, str] | None:
"""
Rotate a refresh token: revoke old, create new.
Returns (new_rt, new_token) or None if invalid.
"""
token_hash = hash_token(old_token)
# Find existing token
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
RefreshToken.is_global_logout == False,
)
)
old_rt: RefreshToken | None = result.scalar_one_or_none()
if not old_rt:
return None
# Check expiry
if old_rt.expires_at < datetime.utcnow():
return None
# Reuse detection: check if a higher version exists
reuse_check = await db.execute(
select(RefreshToken).where(
RefreshToken.token_family_id == old_rt.token_family_id,
RefreshToken.token_version > old_rt.token_version,
RefreshToken.revoked_at.is_not(None),
)
)
if reuse_check.scalar_one_or_none():
# Possible theft detected: revoke entire family
await db.execute(
update(RefreshToken)
.where(RefreshToken.token_family_id == old_rt.token_family_id)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
return None
# Revoke old token
old_rt.revoked_at = datetime.utcnow()
# Create new token in same family
new_token = create_refresh_token()
new_hash = hash_token(new_token)
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=old_rt.user_id,
token_hash=new_hash,
token_family_id=old_rt.token_family_id,
token_version=old_rt.token_version + 1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(new_rt)
await db.flush()
return new_rt, new_token
async def revoke_refresh_token(db: AsyncSession, token: str) -> bool:
"""Revoke a single refresh token."""
token_hash = hash_token(token)
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
)
rt: RefreshToken | None = result.scalar_one_or_none()
if not rt:
return False
rt.revoked_at = datetime.utcnow()
await db.flush()
return True
async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool:
"""Revoke all refresh tokens for a user (logout-all)."""
await db.execute(
update(RefreshToken)
.where(
RefreshToken.user_id == user_id,
RefreshToken.revoked_at.is_(None),
)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
await db.flush()
return True
async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None:
"""Add a token JTI to the blocklist."""
entry = JwtBlocklist(token_id=jti, expires_at=expires_at)
db.add(entry)
await db.flush()
async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool:
"""Check if a token JTI is in the blocklist."""
result = await db.execute(
select(JwtBlocklist).where(JwtBlocklist.token_id == jti)
)
return result.scalar_one_or_none() is not None
async def cleanup_expired_blocklist(db: AsyncSession) -> None:
"""Remove expired entries from blocklist."""
await db.execute(
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
)
await db.flush()
# =============================================================================
# API Token Management (Role-based)
# =============================================================================
def generate_api_token() -> str:
"""Generate a random API token."""
return secrets.token_urlsafe(48) # ~64 chars
async def create_api_token(
db: AsyncSession,
name: str,
role: str,
agent_id: str,
) -> tuple[str, "ApiToken"]:
"""Create a new API token. Returns (raw_token, token_record)."""
from app.models.api_token import ApiToken
raw_token = generate_api_token()
token_hash = hash_token(raw_token)
token_record = ApiToken(
id=str(uuid.uuid4()),
name=name,
token_hash=token_hash,
role=role,
agent_id=agent_id,
)
db.add(token_record)
await db.flush()
return raw_token, token_record
async def list_api_tokens(db: AsyncSession, agent_id: str) -> list["ApiToken"]:
"""List all API tokens for an agent (without the actual token)."""
from app.models.api_token import ApiToken
result = await db.execute(
select(ApiToken).where(ApiToken.agent_id == agent_id).order_by(ApiToken.created_at.desc())
)
return list(result.scalars().all())
async def revoke_api_token(db: AsyncSession, token_id: str, agent_id: str) -> bool:
"""Revoke an API token. Returns True if found and deleted."""
from app.models.api_token import ApiToken
result = await db.execute(
select(ApiToken).where(
ApiToken.id == token_id,
ApiToken.agent_id == agent_id,
)
)
token = result.scalar_one_or_none()
if not token:
return False
await db.delete(token)
await db.flush()
return True
async def verify_api_token(db: AsyncSession, raw_token: str) -> tuple[str, str] | None:
"""
Verify an API token. Returns (agent_id, role) if valid, None otherwise.
Also updates last_used_at.
"""
from app.models.api_token import ApiToken
token_hash = hash_token(raw_token)
result = await db.execute(
select(ApiToken).where(ApiToken.token_hash == token_hash)
)
token: ApiToken | None = result.scalar_one_or_none()
if not token:
return None
# Update last_used_at
token.last_used_at = datetime.utcnow()
await db.flush()
return token.agent_id, token.role