feat: Add role-based API tokens for Claudia Docs

- Add api_tokens table with role-based access (researcher, developer, viewer)
- Add POST /auth/token/generate endpoint for creating tokens
- Add GET /auth/tokens endpoint for listing user's tokens
- Add DELETE /auth/tokens/{token_id} endpoint for revoking tokens
- Add agent_type field to documents (research, development, general)
- Implement role-based access control for documents:
  - researcher: access to research and general documents
  - developer: access to development and general documents
  - viewer: read-only access
- Update document model and schemas with agent_type field
- Add comprehensive tests for API token functionality
- All existing tests pass (73 total)
This commit is contained in:
Motoko
2026-03-31 01:46:51 +00:00
parent 5beac2d673
commit 204badb964
10 changed files with 770 additions and 97 deletions

View File

@@ -0,0 +1,54 @@
"""Add api_tokens table and agent_type column
Revision ID: 002
Revises: 001
Create Date: 2026-03-31
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = '002'
down_revision: Union[str, None] = '001'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add agent_type column to documents table
op.add_column('documents', sa.Column('agent_type', sa.String(20), sa.TEXT(), nullable=True))
# Set default value for existing documents
op.execute("UPDATE documents SET agent_type = 'general' WHERE agent_type IS NULL")
# Make the column NOT NULL after setting defaults
op.alter_column('documents', 'agent_type', nullable=False)
# Add CHECK constraint for agent_type
op.execute("""
CREATE INDEX IF NOT EXISTS idx_documents_agent_type ON documents(agent_type)
""")
# Create api_tokens table
op.execute("""
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('researcher', 'developer', 'viewer')),
agent_id TEXT NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')),
last_used_at TIMESTAMP NULL
)
""")
# Create index for api_tokens
op.execute("CREATE INDEX IF NOT EXISTS idx_api_tokens_agent ON api_tokens(agent_id)")
def downgrade() -> None:
op.drop_table('api_tokens')
op.drop_column('documents', 'agent_type')

View File

@@ -192,7 +192,8 @@ def _create_schema(sync_conn):
model_source TEXT, model_source TEXT,
tiptap_content TEXT, tiptap_content TEXT,
outgoing_links TEXT DEFAULT '[]', outgoing_links TEXT DEFAULT '[]',
backlinks_count INTEGER NOT NULL DEFAULT 0 backlinks_count INTEGER NOT NULL DEFAULT 0,
agent_type TEXT DEFAULT 'general' CHECK (agent_type IN ('research', 'development', 'general'))
) )
""")) """))
@@ -244,6 +245,19 @@ def _create_schema(sync_conn):
) )
""")) """))
# API tokens table (role-based)
sync_conn.execute(text("""
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('researcher', 'developer', 'viewer')),
agent_id TEXT NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')),
last_used_at TIMESTAMP NULL
)
"""))
# Indexes # Indexes
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_projects_agent ON projects(agent_id)")) sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_projects_agent ON projects(agent_id)"))
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_project ON folders(project_id)")) sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_project ON folders(project_id)"))
@@ -254,6 +268,7 @@ def _create_schema(sync_conn):
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON document_tags(tag_id)")) sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON document_tags(tag_id)"))
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash)")) sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash)"))
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_family ON refresh_tokens(user_id, token_family_id)")) sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_family ON refresh_tokens(user_id, token_family_id)"))
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_api_tokens_agent ON api_tokens(agent_id)"))
# --- FTS5 virtual table --- # --- FTS5 virtual table ---
sync_conn.execute(text(""" sync_conn.execute(text("""

23
app/models/api_token.py Normal file
View File

@@ -0,0 +1,23 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
def generate_uuid() -> str:
return str(uuid.uuid4())
class ApiToken(Base):
__tablename__ = "api_tokens"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
name: Mapped[str] = mapped_column(String(255), nullable=False)
token_hash: Mapped[str] = mapped_column(Text, nullable=False) # SHA-256 hash of actual token
role: Mapped[str] = mapped_column(String(20), nullable=False) # researcher, developer, viewer
agent_id: Mapped[str] = mapped_column(String(36), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
last_used_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@@ -42,3 +42,5 @@ class Document(Base):
# Phase 3: Link tracking # Phase 3: Link tracking
outgoing_links: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON array of document IDs outgoing_links: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON array of document IDs
backlinks_count: Mapped[int] = mapped_column(default=0, nullable=False) # Cached count of incoming links backlinks_count: Mapped[int] = mapped_column(default=0, nullable=False) # Cached count of incoming links
# Role-based access
agent_type: Mapped[str] = mapped_column(String(20), nullable=False, default="general") # research, development, general

View File

@@ -3,9 +3,19 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_db from app.database import get_db
from app.models.agent import Agent from app.models.agent import Agent
from app.schemas.auth import AgentCreate, AgentLogin, AgentResponse, RefreshResponse, TokenResponse from app.schemas.auth import (
AgentCreate,
AgentLogin,
AgentResponse,
ApiTokenCreate,
ApiTokenGenerateResponse,
ApiTokenResponse,
RefreshResponse,
TokenResponse,
)
from app.services import auth as auth_service from app.services import auth as auth_service
router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
@@ -40,7 +50,7 @@ def _clear_refresh_cookie(response: Response):
async def get_current_agent(request: Request, db: AsyncSession) -> Agent: async def get_current_agent(request: Request, db: AsyncSession) -> Agent:
"""Get the current authenticated agent from request.""" """Get the current authenticated agent from request (JWT only)."""
auth_header = request.headers.get("authorization", "") auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "): if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
@@ -67,6 +77,46 @@ async def get_current_agent(request: Request, db: AsyncSession) -> Agent:
return agent return agent
async def get_current_agent_or_api_token(request: Request, db: AsyncSession) -> tuple[Agent | None, str | None]:
"""
Get the current authenticated agent or validate an API token.
Returns (agent, api_role) where agent is None for API tokens.
For JWT tokens: (agent, None)
For API tokens: (None, role)
Raises HTTPException if neither is valid.
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Not authenticated")
token = auth_header[7:]
# First try JWT
payload = auth_service.decode_token(token)
if payload:
jti = payload.get("jti")
if jti:
is_blocked = await auth_service.is_token_blocklisted(db, jti)
if is_blocked:
raise HTTPException(status_code=401, detail="Token revoked")
agent_id = payload.get("sub")
if agent_id:
agent = await auth_service.get_agent_by_id(db, agent_id)
if agent:
return agent, None
# Try API token
result = await auth_service.verify_api_token(db, token)
if result:
agent_id, role = result
agent = await auth_service.get_agent_by_id(db, agent_id)
if agent:
return agent, role
raise HTTPException(status_code=401, detail="Invalid token")
@router.post("/register", response_model=AgentResponse, status_code=201) @router.post("/register", response_model=AgentResponse, status_code=201)
async def register(payload: AgentCreate, db: AsyncSession = Depends(get_db)): async def register(payload: AgentCreate, db: AsyncSession = Depends(get_db)):
if settings.DISABLE_REGISTRATION: if settings.DISABLE_REGISTRATION:
@@ -168,3 +218,54 @@ async def logout_all(request: Request, response: Response, db: AsyncSession = De
_clear_refresh_cookie(response) _clear_refresh_cookie(response)
return Response(status_code=204) return Response(status_code=204)
# =============================================================================
# Role-based API Token Management
# =============================================================================
@router.post("/token/generate", response_model=ApiTokenGenerateResponse, status_code=201)
async def generate_token(
request: Request,
payload: ApiTokenCreate,
db: AsyncSession = Depends(get_db),
):
"""Generate a new API token with a specific role."""
# Only admin agents can create API tokens
agent = await get_current_agent(request, db)
if agent.role != "admin":
raise HTTPException(status_code=403, detail="Only admin agents can create API tokens")
raw_token, token_record = await auth_service.create_api_token(
db, payload.name, payload.role, agent.id
)
return ApiTokenGenerateResponse(
token=raw_token,
name=token_record.name,
role=token_record.role,
)
@router.get("/tokens", response_model=list[ApiTokenResponse])
async def list_tokens(
request: Request,
db: AsyncSession = Depends(get_db),
):
"""List all API tokens for the current agent."""
agent = await get_current_agent(request, db)
tokens = await auth_service.list_api_tokens(db, agent.id)
return [ApiTokenResponse.model_validate(t) for t in tokens]
@router.delete("/tokens/{token_id}", status_code=204)
async def revoke_token(
request: Request,
token_id: str,
db: AsyncSession = Depends(get_db),
):
"""Revoke an API token."""
agent = await get_current_agent(request, db)
success = await auth_service.revoke_api_token(db, token_id, agent.id)
if not success:
raise HTTPException(status_code=404, detail="Token not found")
return None

View File

@@ -11,7 +11,7 @@ from app.models.document import Document, ReasoningType
from app.models.folder import Folder from app.models.folder import Folder
from app.models.project import Project from app.models.project import Project
from app.models.tag import DocumentTag, Tag from app.models.tag import DocumentTag, Tag
from app.routers.auth import get_current_agent from app.routers.auth import get_current_agent, get_current_agent_or_api_token
from app.schemas.document import ( from app.schemas.document import (
DocumentBriefResponse, DocumentBriefResponse,
DocumentContentUpdate, DocumentContentUpdate,
@@ -178,14 +178,60 @@ async def document_to_response(db: AsyncSession, doc: Document) -> DocumentRespo
) )
def _can_access_document(api_role: str | None, doc_agent_type: str | None, require_write: bool = False) -> bool:
"""
Check if a role can access a document based on agent_type.
Rules:
- JWT tokens (api_role is None) have full access via project ownership check
- researcher: can access 'research' and 'general' documents
- developer: can access 'development' and 'general' documents
- viewer: can only read (handled elsewhere), not create/modify
- admin: full access (but admin is a JWT role, not API token role)
For write operations, viewer is denied.
"""
if api_role is None:
# JWT token - access is controlled by project ownership
return True
# API token role-based access
doc_type = doc_agent_type or 'general'
if require_write:
# Viewers cannot create/update/delete
if api_role == "viewer":
return False
# Researchers can only write to research and general
if api_role == "researcher":
return doc_type in ("research", "general")
# Developers can only write to development and general
if api_role == "developer":
return doc_type in ("development", "general")
else:
# Read access - viewers can read research, development, and general
if api_role == "viewer":
return doc_type in ("research", "development", "general")
# Researchers can only read research and general
if api_role == "researcher":
return doc_type in ("research", "general")
# Developers can only read development and general
if api_role == "developer":
return doc_type in ("development", "general")
return False
@router.get("/api/v1/projects/{project_id}/documents", response_model=DocumentListResponse) @router.get("/api/v1/projects/{project_id}/documents", response_model=DocumentListResponse)
async def list_documents( async def list_documents(
request: Request, request: Request,
project_id: str, project_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == project_id, Project.id == project_id,
@@ -195,6 +241,10 @@ async def list_documents(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
else:
# API tokens don't have project-level access control here
# Access is controlled at document level via agent_type
pass
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -206,6 +256,9 @@ async def list_documents(
responses = [] responses = []
for doc in docs: for doc in docs:
# Apply role-based filtering for API tokens
if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=False):
continue
tags = await get_document_tags(db, doc.id) tags = await get_document_tags(db, doc.id)
responses.append(DocumentBriefResponse( responses.append(DocumentBriefResponse(
id=doc.id, id=doc.id,
@@ -228,8 +281,10 @@ async def create_document(
payload: DocumentCreate, payload: DocumentCreate,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == project_id, Project.id == project_id,
@@ -240,6 +295,15 @@ async def create_document(
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Project not found") raise HTTPException(status_code=404, detail="Project not found")
# Determine agent_type for the document
doc_agent_type = payload.agent_type or "general"
if doc_agent_type not in ("research", "development", "general"):
raise HTTPException(status_code=400, detail="Invalid agent_type")
# Check role-based write access
if api_role is not None and not _can_access_document(api_role, doc_agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
folder_path = None folder_path = None
if payload.folder_id: if payload.folder_id:
folder_result = await db.execute( folder_result = await db.execute(
@@ -264,6 +328,7 @@ async def create_document(
project_id=project_id, project_id=project_id,
folder_id=payload.folder_id, folder_id=payload.folder_id,
path=path, path=path,
agent_type=doc_agent_type,
) )
db.add(doc) db.add(doc)
await db.flush() await db.flush()
@@ -276,7 +341,7 @@ async def get_document(
document_id: str, document_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -288,6 +353,8 @@ async def get_document(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -297,6 +364,10 @@ async def get_document(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
else:
# API tokens check role-based access
if not _can_access_document(api_role, doc.agent_type, require_write=False):
raise HTTPException(status_code=403, detail="Forbidden")
return await document_to_response(db, doc) return await document_to_response(db, doc)
@@ -308,7 +379,7 @@ async def update_document(
payload: DocumentUpdate, payload: DocumentUpdate,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -320,6 +391,8 @@ async def update_document(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -329,6 +402,10 @@ async def update_document(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
if payload.title is not None: if payload.title is not None:
doc.title = payload.title doc.title = payload.title
@@ -360,7 +437,7 @@ async def delete_document(
document_id: str, document_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -372,6 +449,8 @@ async def delete_document(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -381,6 +460,10 @@ async def delete_document(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
doc.is_deleted = True doc.is_deleted = True
doc.deleted_at = datetime.utcnow() doc.deleted_at = datetime.utcnow()
@@ -401,7 +484,7 @@ async def update_document_content(
Phase 2: Now supports both TipTap JSON and Markdown formats via the 'format' field. Phase 2: Now supports both TipTap JSON and Markdown formats via the 'format' field.
Also backward-compatible with legacy string content (treated as markdown). Also backward-compatible with legacy string content (treated as markdown).
""" """
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -413,6 +496,8 @@ async def update_document_content(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -422,6 +507,10 @@ async def update_document_content(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
# Determine actual format based on content type (backward compatibility) # Determine actual format based on content type (backward compatibility)
# If content is a string, treat as markdown regardless of format field # If content is a string, treat as markdown regardless of format field
@@ -465,7 +554,7 @@ async def restore_document(
document_id: str, document_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -477,6 +566,8 @@ async def restore_document(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -486,6 +577,10 @@ async def restore_document(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
doc.is_deleted = False doc.is_deleted = False
doc.deleted_at = None doc.deleted_at = None
@@ -501,7 +596,7 @@ async def assign_tags(
payload: DocumentTagsAssign, payload: DocumentTagsAssign,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -513,6 +608,8 @@ async def assign_tags(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -522,6 +619,10 @@ async def assign_tags(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
for tag_id in payload.tag_ids: for tag_id in payload.tag_ids:
tag_result = await db.execute( tag_result = await db.execute(
@@ -555,7 +656,7 @@ async def remove_tag(
tag_id: str, tag_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -567,6 +668,8 @@ async def remove_tag(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -576,6 +679,10 @@ async def remove_tag(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based write access
if not _can_access_document(api_role, doc.agent_type, require_write=True):
raise HTTPException(status_code=403, detail="Forbidden")
await db.execute( await db.execute(
delete(DocumentTag).where( delete(DocumentTag).where(
@@ -596,9 +703,9 @@ async def _get_doc_with_access(
document_id: str, document_id: str,
db: AsyncSession, db: AsyncSession,
require_write: bool = False, require_write: bool = False,
) -> tuple[Document, bool]: ) -> tuple[Document, str | None]:
"""Get document and check access. Returns (doc, has_access).""" """Get document and check access. Returns (doc, api_role)."""
agent = await get_current_agent(request, db) agent, api_role = await get_current_agent_or_api_token(request, db)
result = await db.execute( result = await db.execute(
select(Document).where( select(Document).where(
@@ -610,6 +717,8 @@ async def _get_doc_with_access(
if not doc: if not doc:
raise HTTPException(status_code=404, detail="Document not found") raise HTTPException(status_code=404, detail="Document not found")
# JWT tokens check project ownership
if api_role is None:
proj_result = await db.execute( proj_result = await db.execute(
select(Project).where( select(Project).where(
Project.id == doc.project_id, Project.id == doc.project_id,
@@ -619,8 +728,12 @@ async def _get_doc_with_access(
) )
if not proj_result.scalar_one_or_none(): if not proj_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="Forbidden") raise HTTPException(status_code=403, detail="Forbidden")
else:
# API tokens check role-based access
if not _can_access_document(api_role, doc.agent_type, require_write=require_write):
raise HTTPException(status_code=403, detail="Forbidden")
return doc, True return doc, api_role
@router.get("/api/v1/documents/{document_id}/reasoning") @router.get("/api/v1/documents/{document_id}/reasoning")

View File

@@ -30,3 +30,25 @@ class TokenResponse(BaseModel):
class RefreshResponse(BaseModel): class RefreshResponse(BaseModel):
access_token: str access_token: str
token_type: str = "bearer" token_type: str = "bearer"
class ApiTokenCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
role: str = Field(..., pattern="^(researcher|developer|viewer)$")
class ApiTokenResponse(BaseModel):
id: str
name: str
role: str
created_at: datetime
model_config = {"from_attributes": True}
class ApiTokenGenerateResponse(BaseModel):
token: str
name: str
role: str
model_config = {"from_attributes": True}

View File

@@ -16,6 +16,7 @@ class DocumentCreate(BaseModel):
title: str title: str
content: str = "" content: str = ""
folder_id: str | None = None folder_id: str | None = None
agent_type: str | None = "general" # research, development, general
class DocumentUpdate(BaseModel): class DocumentUpdate(BaseModel):

View File

@@ -231,3 +231,86 @@ async def cleanup_expired_blocklist(db: AsyncSession) -> None:
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')") text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
) )
await db.flush() await db.flush()
# =============================================================================
# API Token Management (Role-based)
# =============================================================================
def generate_api_token() -> str:
"""Generate a random API token."""
return secrets.token_urlsafe(48) # ~64 chars
async def create_api_token(
db: AsyncSession,
name: str,
role: str,
agent_id: str,
) -> tuple[str, "ApiToken"]:
"""Create a new API token. Returns (raw_token, token_record)."""
from app.models.api_token import ApiToken
raw_token = generate_api_token()
token_hash = hash_token(raw_token)
token_record = ApiToken(
id=str(uuid.uuid4()),
name=name,
token_hash=token_hash,
role=role,
agent_id=agent_id,
)
db.add(token_record)
await db.flush()
return raw_token, token_record
async def list_api_tokens(db: AsyncSession, agent_id: str) -> list["ApiToken"]:
"""List all API tokens for an agent (without the actual token)."""
from app.models.api_token import ApiToken
result = await db.execute(
select(ApiToken).where(ApiToken.agent_id == agent_id).order_by(ApiToken.created_at.desc())
)
return list(result.scalars().all())
async def revoke_api_token(db: AsyncSession, token_id: str, agent_id: str) -> bool:
"""Revoke an API token. Returns True if found and deleted."""
from app.models.api_token import ApiToken
result = await db.execute(
select(ApiToken).where(
ApiToken.id == token_id,
ApiToken.agent_id == agent_id,
)
)
token = result.scalar_one_or_none()
if not token:
return False
await db.delete(token)
await db.flush()
return True
async def verify_api_token(db: AsyncSession, raw_token: str) -> tuple[str, str] | None:
"""
Verify an API token. Returns (agent_id, role) if valid, None otherwise.
Also updates last_used_at.
"""
from app.models.api_token import ApiToken
token_hash = hash_token(raw_token)
result = await db.execute(
select(ApiToken).where(ApiToken.token_hash == token_hash)
)
token: ApiToken | None = result.scalar_one_or_none()
if not token:
return None
# Update last_used_at
token.last_used_at = datetime.utcnow()
await db.flush()
return token.agent_id, token.role

259
tests/test_api_tokens.py Normal file
View File

@@ -0,0 +1,259 @@
import pytest
import pytest_asyncio
from httpx import AsyncClient
import asyncio
import os
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
from app.main import app
from app.database import Base, get_db, async_engine
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="function")
async def db_session():
"""Create a fresh in-memory database for each test."""
async with async_engine.begin() as conn:
from app.database import _create_schema
await conn.run_sync(_create_schema)
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
async_session = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
yield session
await session.rollback()
async with async_engine.begin() as conn:
for table in reversed(Base.metadata.sorted_tables):
await conn.execute(table.delete())
@pytest_asyncio.fixture(scope="function")
async def client(db_session):
"""Async HTTP client for testing."""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest_asyncio.fixture(scope="function")
async def admin_user(client):
"""Create an admin user for testing."""
# Create admin directly in database with admin role
import uuid
import bcrypt
from sqlalchemy import text
password_hash = bcrypt.hashpw("adminpass123".encode(), bcrypt.gensalt()).decode()
async with async_engine.begin() as conn:
await conn.execute(
text("""
INSERT INTO agents (id, username, password_hash, role, is_deleted, created_at, updated_at)
VALUES (:id, :username, :password_hash, 'admin', 0, datetime('now'), datetime('now'))
"""),
{
"id": str(uuid.uuid4()),
"username": "testadmin",
"password_hash": password_hash
}
)
# Login as admin
login_resp = await client.post(
"/api/v1/auth/login",
json={"username": "testadmin", "password": "adminpass123"}
)
return login_resp.json()["access_token"]
from httpx import ASGITransport
@pytest.mark.asyncio
async def test_generate_api_token(client, admin_user):
"""Test creating an API token as admin."""
# Generate an API token with researcher role
response = await client.post(
"/api/v1/auth/token/generate",
json={"name": "research-token", "role": "researcher"},
headers={"Authorization": f"Bearer {admin_user}"}
)
assert response.status_code == 201, f"Expected 201 but got {response.status_code}: {response.json()}"
data = response.json()
assert data["name"] == "research-token"
assert data["role"] == "researcher"
assert "token" in data
assert len(data["token"]) > 20 # Token should be long
@pytest.mark.asyncio
async def test_generate_api_token_non_admin_forbidden(client):
"""Test that non-admin cannot create API tokens."""
# Register and login as regular agent
await client.post("/api/v1/auth/register", json={"username": "agent1", "password": "pass123"})
login_resp = await client.post(
"/api/v1/auth/login",
json={"username": "agent1", "password": "pass123"}
)
token = login_resp.json()["access_token"]
# Try to generate API token
response = await client.post(
"/api/v1/auth/token/generate",
json={"name": "test-token", "role": "researcher"},
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_list_api_tokens(client, admin_user):
"""Test listing API tokens."""
# Generate two tokens
await client.post(
"/api/v1/auth/token/generate",
json={"name": "token1", "role": "researcher"},
headers={"Authorization": f"Bearer {admin_user}"}
)
await client.post(
"/api/v1/auth/token/generate",
json={"name": "token2", "role": "developer"},
headers={"Authorization": f"Bearer {admin_user}"}
)
# List tokens
response = await client.get(
"/api/v1/auth/tokens",
headers={"Authorization": f"Bearer {admin_user}"}
)
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["name"] == "token2" # Most recent first
assert data[1]["name"] == "token1"
# Tokens should not include the actual token value
assert "token" not in data[0]
assert "token_hash" not in data[0]
@pytest.mark.asyncio
async def test_revoke_api_token(client, admin_user):
"""Test revoking an API token."""
# Generate a token
gen_resp = await client.post(
"/api/v1/auth/token/generate",
json={"name": "revoke-me", "role": "viewer"},
headers={"Authorization": f"Bearer {admin_user}"}
)
researcher_token = gen_resp.json()["token"]
# List tokens - should have 1
list_resp = await client.get(
"/api/v1/auth/tokens",
headers={"Authorization": f"Bearer {admin_user}"}
)
assert len(list_resp.json()) == 1
api_token_id = list_resp.json()[0]["id"]
# Revoke the token
response = await client.delete(
f"/api/v1/auth/tokens/{api_token_id}",
headers={"Authorization": f"Bearer {admin_user}"}
)
assert response.status_code == 204
# List tokens - should be empty
list_resp = await client.get(
"/api/v1/auth/tokens",
headers={"Authorization": f"Bearer {admin_user}"}
)
assert len(list_resp.json()) == 0
@pytest.mark.asyncio
async def test_revoke_nonexistent_token(client, admin_user):
"""Test revoking a non-existent token returns 404."""
response = await client.delete(
"/api/v1/auth/tokens/nonexistent-id",
headers={"Authorization": f"Bearer {admin_user}"}
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_api_token_with_invalid_role(client, admin_user):
"""Test that invalid role is rejected."""
response = await client.post(
"/api/v1/auth/token/generate",
json={"name": "bad-role", "role": "invalid_role"},
headers={"Authorization": f"Bearer {admin_user}"}
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_api_token_auth_flow(client, admin_user):
"""Test full flow: create project with JWT, then access with API token."""
# Generate researcher API token
gen_resp = await client.post(
"/api/v1/auth/token/generate",
json={"name": "research-token", "role": "researcher"},
headers={"Authorization": f"Bearer {admin_user}"}
)
researcher_token = gen_resp.json()["token"]
# Create a project
proj_resp = await client.post(
"/api/v1/projects",
json={"name": "Test Project"},
headers={"Authorization": f"Bearer {admin_user}"}
)
proj_id = proj_resp.json()["id"]
# Create a research document
doc_resp = await client.post(
f"/api/v1/projects/{proj_id}/documents",
json={"title": "Research Doc", "content": "Research content", "agent_type": "research"},
headers={"Authorization": f"Bearer {admin_user}"}
)
assert doc_resp.status_code == 201
doc_id = doc_resp.json()["id"]
# Access document with researcher token - should work (research doc)
get_resp = await client.get(
f"/api/v1/documents/{doc_id}",
headers={"Authorization": f"Bearer {researcher_token}"}
)
assert get_resp.status_code == 200
# Create a development document
dev_doc_resp = await client.post(
f"/api/v1/projects/{proj_id}/documents",
json={"title": "Dev Doc", "content": "Dev content", "agent_type": "development"},
headers={"Authorization": f"Bearer {admin_user}"}
)
assert dev_doc_resp.status_code == 201
dev_doc_id = dev_doc_resp.json()["id"]
# Access dev document with researcher token - should fail (read access denied)
get_resp = await client.get(
f"/api/v1/documents/{dev_doc_id}",
headers={"Authorization": f"Bearer {researcher_token}"}
)
assert get_resp.status_code == 403