feat: Add role-based API tokens for Claudia Docs
- Add api_tokens table with role-based access (researcher, developer, viewer)
- Add POST /auth/token/generate endpoint for creating tokens
- Add GET /auth/tokens endpoint for listing user's tokens
- Add DELETE /auth/tokens/{token_id} endpoint for revoking tokens
- Add agent_type field to documents (research, development, general)
- Implement role-based access control for documents:
- researcher: access to research and general documents
- developer: access to development and general documents
- viewer: read-only access
- Update document model and schemas with agent_type field
- Add comprehensive tests for API token functionality
- All existing tests pass (73 total)
This commit is contained in:
54
alembic/versions/002_api_tokens.py
Normal file
54
alembic/versions/002_api_tokens.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Add api_tokens table and agent_type column
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-31
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = '002'
|
||||||
|
down_revision: Union[str, None] = '001'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add agent_type column to documents table
|
||||||
|
op.add_column('documents', sa.Column('agent_type', sa.String(20), sa.TEXT(), nullable=True))
|
||||||
|
|
||||||
|
# Set default value for existing documents
|
||||||
|
op.execute("UPDATE documents SET agent_type = 'general' WHERE agent_type IS NULL")
|
||||||
|
|
||||||
|
# Make the column NOT NULL after setting defaults
|
||||||
|
op.alter_column('documents', 'agent_type', nullable=False)
|
||||||
|
|
||||||
|
# Add CHECK constraint for agent_type
|
||||||
|
op.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_documents_agent_type ON documents(agent_type)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create api_tokens table
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
token_hash TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL CHECK (role IN ('researcher', 'developer', 'viewer')),
|
||||||
|
agent_id TEXT NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')),
|
||||||
|
last_used_at TIMESTAMP NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create index for api_tokens
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS idx_api_tokens_agent ON api_tokens(agent_id)")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('api_tokens')
|
||||||
|
op.drop_column('documents', 'agent_type')
|
||||||
@@ -192,7 +192,8 @@ def _create_schema(sync_conn):
|
|||||||
model_source TEXT,
|
model_source TEXT,
|
||||||
tiptap_content TEXT,
|
tiptap_content TEXT,
|
||||||
outgoing_links TEXT DEFAULT '[]',
|
outgoing_links TEXT DEFAULT '[]',
|
||||||
backlinks_count INTEGER NOT NULL DEFAULT 0
|
backlinks_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
agent_type TEXT DEFAULT 'general' CHECK (agent_type IN ('research', 'development', 'general'))
|
||||||
)
|
)
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
@@ -244,6 +245,19 @@ def _create_schema(sync_conn):
|
|||||||
)
|
)
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
|
# API tokens table (role-based)
|
||||||
|
sync_conn.execute(text("""
|
||||||
|
CREATE TABLE IF NOT EXISTS api_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
token_hash TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL CHECK (role IN ('researcher', 'developer', 'viewer')),
|
||||||
|
agent_id TEXT NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')),
|
||||||
|
last_used_at TIMESTAMP NULL
|
||||||
|
)
|
||||||
|
"""))
|
||||||
|
|
||||||
# Indexes
|
# Indexes
|
||||||
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_projects_agent ON projects(agent_id)"))
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_projects_agent ON projects(agent_id)"))
|
||||||
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_project ON folders(project_id)"))
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_project ON folders(project_id)"))
|
||||||
@@ -254,6 +268,7 @@ def _create_schema(sync_conn):
|
|||||||
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON document_tags(tag_id)"))
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON document_tags(tag_id)"))
|
||||||
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash)"))
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash)"))
|
||||||
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_family ON refresh_tokens(user_id, token_family_id)"))
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_family ON refresh_tokens(user_id, token_family_id)"))
|
||||||
|
sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_api_tokens_agent ON api_tokens(agent_id)"))
|
||||||
|
|
||||||
# --- FTS5 virtual table ---
|
# --- FTS5 virtual table ---
|
||||||
sync_conn.execute(text("""
|
sync_conn.execute(text("""
|
||||||
|
|||||||
23
app/models/api_token.py
Normal file
23
app/models/api_token.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
def generate_uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class ApiToken(Base):
|
||||||
|
__tablename__ = "api_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
token_hash: Mapped[str] = mapped_column(Text, nullable=False) # SHA-256 hash of actual token
|
||||||
|
role: Mapped[str] = mapped_column(String(20), nullable=False) # researcher, developer, viewer
|
||||||
|
agent_id: Mapped[str] = mapped_column(String(36), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
|
last_used_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
@@ -42,3 +42,5 @@ class Document(Base):
|
|||||||
# Phase 3: Link tracking
|
# Phase 3: Link tracking
|
||||||
outgoing_links: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON array of document IDs
|
outgoing_links: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON array of document IDs
|
||||||
backlinks_count: Mapped[int] = mapped_column(default=0, nullable=False) # Cached count of incoming links
|
backlinks_count: Mapped[int] = mapped_column(default=0, nullable=False) # Cached count of incoming links
|
||||||
|
# Role-based access
|
||||||
|
agent_type: Mapped[str] = mapped_column(String(20), nullable=False, default="general") # research, development, general
|
||||||
|
|||||||
@@ -3,9 +3,19 @@ from datetime import datetime
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.agent import Agent
|
from app.models.agent import Agent
|
||||||
from app.schemas.auth import AgentCreate, AgentLogin, AgentResponse, RefreshResponse, TokenResponse
|
from app.schemas.auth import (
|
||||||
|
AgentCreate,
|
||||||
|
AgentLogin,
|
||||||
|
AgentResponse,
|
||||||
|
ApiTokenCreate,
|
||||||
|
ApiTokenGenerateResponse,
|
||||||
|
ApiTokenResponse,
|
||||||
|
RefreshResponse,
|
||||||
|
TokenResponse,
|
||||||
|
)
|
||||||
from app.services import auth as auth_service
|
from app.services import auth as auth_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
@@ -40,7 +50,7 @@ def _clear_refresh_cookie(response: Response):
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_agent(request: Request, db: AsyncSession) -> Agent:
|
async def get_current_agent(request: Request, db: AsyncSession) -> Agent:
|
||||||
"""Get the current authenticated agent from request."""
|
"""Get the current authenticated agent from request (JWT only)."""
|
||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
if not auth_header.startswith("Bearer "):
|
if not auth_header.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
@@ -67,6 +77,46 @@ async def get_current_agent(request: Request, db: AsyncSession) -> Agent:
|
|||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_agent_or_api_token(request: Request, db: AsyncSession) -> tuple[Agent | None, str | None]:
|
||||||
|
"""
|
||||||
|
Get the current authenticated agent or validate an API token.
|
||||||
|
Returns (agent, api_role) where agent is None for API tokens.
|
||||||
|
For JWT tokens: (agent, None)
|
||||||
|
For API tokens: (None, role)
|
||||||
|
Raises HTTPException if neither is valid.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
|
token = auth_header[7:]
|
||||||
|
|
||||||
|
# First try JWT
|
||||||
|
payload = auth_service.decode_token(token)
|
||||||
|
if payload:
|
||||||
|
jti = payload.get("jti")
|
||||||
|
if jti:
|
||||||
|
is_blocked = await auth_service.is_token_blocklisted(db, jti)
|
||||||
|
if is_blocked:
|
||||||
|
raise HTTPException(status_code=401, detail="Token revoked")
|
||||||
|
|
||||||
|
agent_id = payload.get("sub")
|
||||||
|
if agent_id:
|
||||||
|
agent = await auth_service.get_agent_by_id(db, agent_id)
|
||||||
|
if agent:
|
||||||
|
return agent, None
|
||||||
|
|
||||||
|
# Try API token
|
||||||
|
result = await auth_service.verify_api_token(db, token)
|
||||||
|
if result:
|
||||||
|
agent_id, role = result
|
||||||
|
agent = await auth_service.get_agent_by_id(db, agent_id)
|
||||||
|
if agent:
|
||||||
|
return agent, role
|
||||||
|
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AgentResponse, status_code=201)
|
@router.post("/register", response_model=AgentResponse, status_code=201)
|
||||||
async def register(payload: AgentCreate, db: AsyncSession = Depends(get_db)):
|
async def register(payload: AgentCreate, db: AsyncSession = Depends(get_db)):
|
||||||
if settings.DISABLE_REGISTRATION:
|
if settings.DISABLE_REGISTRATION:
|
||||||
@@ -168,3 +218,54 @@ async def logout_all(request: Request, response: Response, db: AsyncSession = De
|
|||||||
|
|
||||||
_clear_refresh_cookie(response)
|
_clear_refresh_cookie(response)
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Role-based API Token Management
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.post("/token/generate", response_model=ApiTokenGenerateResponse, status_code=201)
|
||||||
|
async def generate_token(
|
||||||
|
request: Request,
|
||||||
|
payload: ApiTokenCreate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Generate a new API token with a specific role."""
|
||||||
|
# Only admin agents can create API tokens
|
||||||
|
agent = await get_current_agent(request, db)
|
||||||
|
if agent.role != "admin":
|
||||||
|
raise HTTPException(status_code=403, detail="Only admin agents can create API tokens")
|
||||||
|
|
||||||
|
raw_token, token_record = await auth_service.create_api_token(
|
||||||
|
db, payload.name, payload.role, agent.id
|
||||||
|
)
|
||||||
|
return ApiTokenGenerateResponse(
|
||||||
|
token=raw_token,
|
||||||
|
name=token_record.name,
|
||||||
|
role=token_record.role,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tokens", response_model=list[ApiTokenResponse])
|
||||||
|
async def list_tokens(
|
||||||
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""List all API tokens for the current agent."""
|
||||||
|
agent = await get_current_agent(request, db)
|
||||||
|
tokens = await auth_service.list_api_tokens(db, agent.id)
|
||||||
|
return [ApiTokenResponse.model_validate(t) for t in tokens]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/tokens/{token_id}", status_code=204)
|
||||||
|
async def revoke_token(
|
||||||
|
request: Request,
|
||||||
|
token_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Revoke an API token."""
|
||||||
|
agent = await get_current_agent(request, db)
|
||||||
|
success = await auth_service.revoke_api_token(db, token_id, agent.id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="Token not found")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from app.models.document import Document, ReasoningType
|
|||||||
from app.models.folder import Folder
|
from app.models.folder import Folder
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
from app.models.tag import DocumentTag, Tag
|
from app.models.tag import DocumentTag, Tag
|
||||||
from app.routers.auth import get_current_agent
|
from app.routers.auth import get_current_agent, get_current_agent_or_api_token
|
||||||
from app.schemas.document import (
|
from app.schemas.document import (
|
||||||
DocumentBriefResponse,
|
DocumentBriefResponse,
|
||||||
DocumentContentUpdate,
|
DocumentContentUpdate,
|
||||||
@@ -178,23 +178,73 @@ async def document_to_response(db: AsyncSession, doc: Document) -> DocumentRespo
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_access_document(api_role: str | None, doc_agent_type: str | None, require_write: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a role can access a document based on agent_type.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- JWT tokens (api_role is None) have full access via project ownership check
|
||||||
|
- researcher: can access 'research' and 'general' documents
|
||||||
|
- developer: can access 'development' and 'general' documents
|
||||||
|
- viewer: can only read (handled elsewhere), not create/modify
|
||||||
|
- admin: full access (but admin is a JWT role, not API token role)
|
||||||
|
|
||||||
|
For write operations, viewer is denied.
|
||||||
|
"""
|
||||||
|
if api_role is None:
|
||||||
|
# JWT token - access is controlled by project ownership
|
||||||
|
return True
|
||||||
|
|
||||||
|
# API token role-based access
|
||||||
|
doc_type = doc_agent_type or 'general'
|
||||||
|
|
||||||
|
if require_write:
|
||||||
|
# Viewers cannot create/update/delete
|
||||||
|
if api_role == "viewer":
|
||||||
|
return False
|
||||||
|
# Researchers can only write to research and general
|
||||||
|
if api_role == "researcher":
|
||||||
|
return doc_type in ("research", "general")
|
||||||
|
# Developers can only write to development and general
|
||||||
|
if api_role == "developer":
|
||||||
|
return doc_type in ("development", "general")
|
||||||
|
else:
|
||||||
|
# Read access - viewers can read research, development, and general
|
||||||
|
if api_role == "viewer":
|
||||||
|
return doc_type in ("research", "development", "general")
|
||||||
|
# Researchers can only read research and general
|
||||||
|
if api_role == "researcher":
|
||||||
|
return doc_type in ("research", "general")
|
||||||
|
# Developers can only read development and general
|
||||||
|
if api_role == "developer":
|
||||||
|
return doc_type in ("development", "general")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/v1/projects/{project_id}/documents", response_model=DocumentListResponse)
|
@router.get("/api/v1/projects/{project_id}/documents", response_model=DocumentListResponse)
|
||||||
async def list_documents(
|
async def list_documents(
|
||||||
request: Request,
|
request: Request,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
raise HTTPException(status_code=404, detail="Project not found")
|
else:
|
||||||
|
# API tokens don't have project-level access control here
|
||||||
|
# Access is controlled at document level via agent_type
|
||||||
|
pass
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -206,6 +256,9 @@ async def list_documents(
|
|||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
# Apply role-based filtering for API tokens
|
||||||
|
if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=False):
|
||||||
|
continue
|
||||||
tags = await get_document_tags(db, doc.id)
|
tags = await get_document_tags(db, doc.id)
|
||||||
responses.append(DocumentBriefResponse(
|
responses.append(DocumentBriefResponse(
|
||||||
id=doc.id,
|
id=doc.id,
|
||||||
@@ -228,17 +281,28 @@ async def create_document(
|
|||||||
payload: DocumentCreate,
|
payload: DocumentCreate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
raise HTTPException(status_code=404, detail="Project not found")
|
|
||||||
|
# Determine agent_type for the document
|
||||||
|
doc_agent_type = payload.agent_type or "general"
|
||||||
|
if doc_agent_type not in ("research", "development", "general"):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid agent_type")
|
||||||
|
|
||||||
|
# Check role-based write access
|
||||||
|
if api_role is not None and not _can_access_document(api_role, doc_agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
folder_path = None
|
folder_path = None
|
||||||
if payload.folder_id:
|
if payload.folder_id:
|
||||||
@@ -264,6 +328,7 @@ async def create_document(
|
|||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
folder_id=payload.folder_id,
|
folder_id=payload.folder_id,
|
||||||
path=path,
|
path=path,
|
||||||
|
agent_type=doc_agent_type,
|
||||||
)
|
)
|
||||||
db.add(doc)
|
db.add(doc)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
@@ -276,7 +341,7 @@ async def get_document(
|
|||||||
document_id: str,
|
document_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -288,15 +353,21 @@ async def get_document(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
else:
|
||||||
|
# API tokens check role-based access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=False):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
return await document_to_response(db, doc)
|
return await document_to_response(db, doc)
|
||||||
|
|
||||||
@@ -308,7 +379,7 @@ async def update_document(
|
|||||||
payload: DocumentUpdate,
|
payload: DocumentUpdate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -320,15 +391,21 @@ async def update_document(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
if payload.title is not None:
|
if payload.title is not None:
|
||||||
doc.title = payload.title
|
doc.title = payload.title
|
||||||
@@ -360,7 +437,7 @@ async def delete_document(
|
|||||||
document_id: str,
|
document_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -372,15 +449,21 @@ async def delete_document(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
doc.is_deleted = True
|
doc.is_deleted = True
|
||||||
doc.deleted_at = datetime.utcnow()
|
doc.deleted_at = datetime.utcnow()
|
||||||
@@ -401,7 +484,7 @@ async def update_document_content(
|
|||||||
Phase 2: Now supports both TipTap JSON and Markdown formats via the 'format' field.
|
Phase 2: Now supports both TipTap JSON and Markdown formats via the 'format' field.
|
||||||
Also backward-compatible with legacy string content (treated as markdown).
|
Also backward-compatible with legacy string content (treated as markdown).
|
||||||
"""
|
"""
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -413,15 +496,21 @@ async def update_document_content(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
# Determine actual format based on content type (backward compatibility)
|
# Determine actual format based on content type (backward compatibility)
|
||||||
# If content is a string, treat as markdown regardless of format field
|
# If content is a string, treat as markdown regardless of format field
|
||||||
@@ -465,7 +554,7 @@ async def restore_document(
|
|||||||
document_id: str,
|
document_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -477,15 +566,21 @@ async def restore_document(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
doc.is_deleted = False
|
doc.is_deleted = False
|
||||||
doc.deleted_at = None
|
doc.deleted_at = None
|
||||||
@@ -501,7 +596,7 @@ async def assign_tags(
|
|||||||
payload: DocumentTagsAssign,
|
payload: DocumentTagsAssign,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -513,15 +608,21 @@ async def assign_tags(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
for tag_id in payload.tag_ids:
|
for tag_id in payload.tag_ids:
|
||||||
tag_result = await db.execute(
|
tag_result = await db.execute(
|
||||||
@@ -555,7 +656,7 @@ async def remove_tag(
|
|||||||
tag_id: str,
|
tag_id: str,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -567,15 +668,21 @@ async def remove_tag(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based write access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=True):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
await db.execute(
|
await db.execute(
|
||||||
delete(DocumentTag).where(
|
delete(DocumentTag).where(
|
||||||
@@ -596,9 +703,9 @@ async def _get_doc_with_access(
|
|||||||
document_id: str,
|
document_id: str,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
require_write: bool = False,
|
require_write: bool = False,
|
||||||
) -> tuple[Document, bool]:
|
) -> tuple[Document, str | None]:
|
||||||
"""Get document and check access. Returns (doc, has_access)."""
|
"""Get document and check access. Returns (doc, api_role)."""
|
||||||
agent = await get_current_agent(request, db)
|
agent, api_role = await get_current_agent_or_api_token(request, db)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
@@ -610,17 +717,23 @@ async def _get_doc_with_access(
|
|||||||
if not doc:
|
if not doc:
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
proj_result = await db.execute(
|
# JWT tokens check project ownership
|
||||||
select(Project).where(
|
if api_role is None:
|
||||||
Project.id == doc.project_id,
|
proj_result = await db.execute(
|
||||||
Project.agent_id == agent.id,
|
select(Project).where(
|
||||||
Project.is_deleted == False,
|
Project.id == doc.project_id,
|
||||||
|
Project.agent_id == agent.id,
|
||||||
|
Project.is_deleted == False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
if not proj_result.scalar_one_or_none():
|
||||||
if not proj_result.scalar_one_or_none():
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
raise HTTPException(status_code=403, detail="Forbidden")
|
else:
|
||||||
|
# API tokens check role-based access
|
||||||
|
if not _can_access_document(api_role, doc.agent_type, require_write=require_write):
|
||||||
|
raise HTTPException(status_code=403, detail="Forbidden")
|
||||||
|
|
||||||
return doc, True
|
return doc, api_role
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/v1/documents/{document_id}/reasoning")
|
@router.get("/api/v1/documents/{document_id}/reasoning")
|
||||||
|
|||||||
@@ -30,3 +30,25 @@ class TokenResponse(BaseModel):
|
|||||||
class RefreshResponse(BaseModel):
|
class RefreshResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class ApiTokenCreate(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
|
role: str = Field(..., pattern="^(researcher|developer|viewer)$")
|
||||||
|
|
||||||
|
|
||||||
|
class ApiTokenResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class ApiTokenGenerateResponse(BaseModel):
|
||||||
|
token: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class DocumentCreate(BaseModel):
|
|||||||
title: str
|
title: str
|
||||||
content: str = ""
|
content: str = ""
|
||||||
folder_id: str | None = None
|
folder_id: str | None = None
|
||||||
|
agent_type: str | None = "general" # research, development, general
|
||||||
|
|
||||||
|
|
||||||
class DocumentUpdate(BaseModel):
|
class DocumentUpdate(BaseModel):
|
||||||
|
|||||||
@@ -231,3 +231,86 @@ async def cleanup_expired_blocklist(db: AsyncSession) -> None:
|
|||||||
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
|
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
|
||||||
)
|
)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Token Management (Role-based)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def generate_api_token() -> str:
|
||||||
|
"""Generate a random API token."""
|
||||||
|
return secrets.token_urlsafe(48) # ~64 chars
|
||||||
|
|
||||||
|
|
||||||
|
async def create_api_token(
|
||||||
|
db: AsyncSession,
|
||||||
|
name: str,
|
||||||
|
role: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> tuple[str, "ApiToken"]:
|
||||||
|
"""Create a new API token. Returns (raw_token, token_record)."""
|
||||||
|
from app.models.api_token import ApiToken
|
||||||
|
|
||||||
|
raw_token = generate_api_token()
|
||||||
|
token_hash = hash_token(raw_token)
|
||||||
|
|
||||||
|
token_record = ApiToken(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name=name,
|
||||||
|
token_hash=token_hash,
|
||||||
|
role=role,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
db.add(token_record)
|
||||||
|
await db.flush()
|
||||||
|
return raw_token, token_record
|
||||||
|
|
||||||
|
|
||||||
|
async def list_api_tokens(db: AsyncSession, agent_id: str) -> list["ApiToken"]:
|
||||||
|
"""List all API tokens for an agent (without the actual token)."""
|
||||||
|
from app.models.api_token import ApiToken
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(ApiToken).where(ApiToken.agent_id == agent_id).order_by(ApiToken.created_at.desc())
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_api_token(db: AsyncSession, token_id: str, agent_id: str) -> bool:
|
||||||
|
"""Revoke an API token. Returns True if found and deleted."""
|
||||||
|
from app.models.api_token import ApiToken
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(ApiToken).where(
|
||||||
|
ApiToken.id == token_id,
|
||||||
|
ApiToken.agent_id == agent_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token = result.scalar_one_or_none()
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await db.delete(token)
|
||||||
|
await db.flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_token(db: AsyncSession, raw_token: str) -> tuple[str, str] | None:
|
||||||
|
"""
|
||||||
|
Verify an API token. Returns (agent_id, role) if valid, None otherwise.
|
||||||
|
Also updates last_used_at.
|
||||||
|
"""
|
||||||
|
from app.models.api_token import ApiToken
|
||||||
|
|
||||||
|
token_hash = hash_token(raw_token)
|
||||||
|
result = await db.execute(
|
||||||
|
select(ApiToken).where(ApiToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
token: ApiToken | None = result.scalar_one_or_none()
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update last_used_at
|
||||||
|
token.last_used_at = datetime.utcnow()
|
||||||
|
await db.flush()
|
||||||
|
return token.agent_id, token.role
|
||||||
|
|||||||
259
tests/test_api_tokens.py
Normal file
259
tests/test_api_tokens.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.database import Base, get_db, async_engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def db_session():
|
||||||
|
"""Create a fresh in-memory database for each test."""
|
||||||
|
async with async_engine.begin() as conn:
|
||||||
|
from app.database import _create_schema
|
||||||
|
await conn.run_sync(_create_schema)
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||||
|
async_session = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
async with async_engine.begin() as conn:
|
||||||
|
for table in reversed(Base.metadata.sorted_tables):
|
||||||
|
await conn.execute(table.delete())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def client(db_session):
|
||||||
|
"""Async HTTP client for testing."""
|
||||||
|
async def override_get_db():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
|
||||||
|
yield ac
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def admin_user(client):
|
||||||
|
"""Create an admin user for testing."""
|
||||||
|
# Create admin directly in database with admin role
|
||||||
|
import uuid
|
||||||
|
import bcrypt
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
password_hash = bcrypt.hashpw("adminpass123".encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
async with async_engine.begin() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
text("""
|
||||||
|
INSERT INTO agents (id, username, password_hash, role, is_deleted, created_at, updated_at)
|
||||||
|
VALUES (:id, :username, :password_hash, 'admin', 0, datetime('now'), datetime('now'))
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"username": "testadmin",
|
||||||
|
"password_hash": password_hash
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Login as admin
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"username": "testadmin", "password": "adminpass123"}
|
||||||
|
)
|
||||||
|
return login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
from httpx import ASGITransport
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_api_token(client, admin_user):
|
||||||
|
"""Test creating an API token as admin."""
|
||||||
|
# Generate an API token with researcher role
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "research-token", "role": "researcher"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 201, f"Expected 201 but got {response.status_code}: {response.json()}"
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "research-token"
|
||||||
|
assert data["role"] == "researcher"
|
||||||
|
assert "token" in data
|
||||||
|
assert len(data["token"]) > 20 # Token should be long
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_api_token_non_admin_forbidden(client):
|
||||||
|
"""Test that non-admin cannot create API tokens."""
|
||||||
|
# Register and login as regular agent
|
||||||
|
await client.post("/api/v1/auth/register", json={"username": "agent1", "password": "pass123"})
|
||||||
|
login_resp = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"username": "agent1", "password": "pass123"}
|
||||||
|
)
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
# Try to generate API token
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "test-token", "role": "researcher"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_api_tokens(client, admin_user):
|
||||||
|
"""Test listing API tokens."""
|
||||||
|
# Generate two tokens
|
||||||
|
await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "token1", "role": "researcher"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "token2", "role": "developer"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# List tokens
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/auth/tokens",
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["name"] == "token2" # Most recent first
|
||||||
|
assert data[1]["name"] == "token1"
|
||||||
|
# Tokens should not include the actual token value
|
||||||
|
assert "token" not in data[0]
|
||||||
|
assert "token_hash" not in data[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_api_token(client, admin_user):
|
||||||
|
"""Test revoking an API token."""
|
||||||
|
# Generate a token
|
||||||
|
gen_resp = await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "revoke-me", "role": "viewer"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
researcher_token = gen_resp.json()["token"]
|
||||||
|
|
||||||
|
# List tokens - should have 1
|
||||||
|
list_resp = await client.get(
|
||||||
|
"/api/v1/auth/tokens",
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert len(list_resp.json()) == 1
|
||||||
|
api_token_id = list_resp.json()[0]["id"]
|
||||||
|
|
||||||
|
# Revoke the token
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/auth/tokens/{api_token_id}",
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# List tokens - should be empty
|
||||||
|
list_resp = await client.get(
|
||||||
|
"/api/v1/auth/tokens",
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert len(list_resp.json()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_nonexistent_token(client, admin_user):
|
||||||
|
"""Test revoking a non-existent token returns 404."""
|
||||||
|
response = await client.delete(
|
||||||
|
"/api/v1/auth/tokens/nonexistent-id",
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_token_with_invalid_role(client, admin_user):
|
||||||
|
"""Test that invalid role is rejected."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "bad-role", "role": "invalid_role"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_token_auth_flow(client, admin_user):
|
||||||
|
"""Test full flow: create project with JWT, then access with API token."""
|
||||||
|
# Generate researcher API token
|
||||||
|
gen_resp = await client.post(
|
||||||
|
"/api/v1/auth/token/generate",
|
||||||
|
json={"name": "research-token", "role": "researcher"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
researcher_token = gen_resp.json()["token"]
|
||||||
|
|
||||||
|
# Create a project
|
||||||
|
proj_resp = await client.post(
|
||||||
|
"/api/v1/projects",
|
||||||
|
json={"name": "Test Project"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
proj_id = proj_resp.json()["id"]
|
||||||
|
|
||||||
|
# Create a research document
|
||||||
|
doc_resp = await client.post(
|
||||||
|
f"/api/v1/projects/{proj_id}/documents",
|
||||||
|
json={"title": "Research Doc", "content": "Research content", "agent_type": "research"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert doc_resp.status_code == 201
|
||||||
|
doc_id = doc_resp.json()["id"]
|
||||||
|
|
||||||
|
# Access document with researcher token - should work (research doc)
|
||||||
|
get_resp = await client.get(
|
||||||
|
f"/api/v1/documents/{doc_id}",
|
||||||
|
headers={"Authorization": f"Bearer {researcher_token}"}
|
||||||
|
)
|
||||||
|
assert get_resp.status_code == 200
|
||||||
|
|
||||||
|
# Create a development document
|
||||||
|
dev_doc_resp = await client.post(
|
||||||
|
f"/api/v1/projects/{proj_id}/documents",
|
||||||
|
json={"title": "Dev Doc", "content": "Dev content", "agent_type": "development"},
|
||||||
|
headers={"Authorization": f"Bearer {admin_user}"}
|
||||||
|
)
|
||||||
|
assert dev_doc_resp.status_code == 201
|
||||||
|
dev_doc_id = dev_doc_resp.json()["id"]
|
||||||
|
|
||||||
|
# Access dev document with researcher token - should fail (read access denied)
|
||||||
|
get_resp = await client.get(
|
||||||
|
f"/api/v1/documents/{dev_doc_id}",
|
||||||
|
headers={"Authorization": f"Bearer {researcher_token}"}
|
||||||
|
)
|
||||||
|
assert get_resp.status_code == 403
|
||||||
Reference in New Issue
Block a user