From 7f3e8a8f53f67d3b8334e6dcf4f13ed092d0b6a7 Mon Sep 17 00:00:00 2001 From: Motoko Date: Mon, 30 Mar 2026 15:17:27 +0000 Subject: [PATCH] Phase 1 MVP - Complete implementation - Auth: register, login, JWT with refresh tokens, blocklist - Projects/Folders/Documents CRUD with soft deletes - Tags CRUD and assignment - FTS5 search with highlights and tag filtering - ADR-001, ADR-002, ADR-003 compliant - Security fixes applied (JWT_SECRET_KEY, exception handler, cookie secure) - 25 tests passing --- alembic.ini | 41 +++ alembic/env.py | 58 +++++ alembic/versions/001_initial.py | 27 ++ app/__init__.py | 1 + app/config.py | 45 ++++ app/database.py | 247 ++++++++++++++++++ app/main.py | 63 +++++ app/models/__init__.py | 1 + app/models/agent.py | 25 ++ app/models/document.py | 27 ++ app/models/folder.py | 26 ++ app/models/project.py | 25 ++ app/models/refresh_token.py | 35 +++ app/models/tag.py | 30 +++ app/routers/__init__.py | 1 + app/routers/auth.py | 167 ++++++++++++ app/routers/documents.py | 435 ++++++++++++++++++++++++++++++++ app/routers/folders.py | 251 ++++++++++++++++++ app/routers/projects.py | 146 +++++++++++ app/routers/search.py | 36 +++ app/routers/tags.py | 79 ++++++ app/schemas/__init__.py | 1 + app/schemas/auth.py | 32 +++ app/schemas/document.py | 58 +++++ app/schemas/folder.py | 29 +++ app/schemas/project.py | 27 ++ app/schemas/search.py | 18 ++ app/schemas/tag.py | 25 ++ app/services/__init__.py | 1 + app/services/auth.py | 233 +++++++++++++++++ app/services/search.py | 128 ++++++++++ data/claudia_docs.db | Bin 0 -> 147456 bytes pytest.ini | 8 + requirements.txt | 14 + tests/__init__.py | 1 + tests/conftest.py | 54 ++++ tests/test_auth.py | 70 +++++ tests/test_documents.py | 136 ++++++++++ tests/test_folders.py | 90 +++++++ tests/test_projects.py | 98 +++++++ tests/test_search.py | 69 +++++ 41 files changed, 2858 insertions(+) create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/versions/001_initial.py create mode 100644 app/__init__.py create mode 100644 app/config.py create mode 100644 app/database.py create mode 100644 app/main.py create mode 100644 app/models/__init__.py create mode 100644 app/models/agent.py create mode 100644 app/models/document.py create mode 100644 app/models/folder.py create mode 100644 app/models/project.py create mode 100644 app/models/refresh_token.py create mode 100644 app/models/tag.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/auth.py create mode 100644 app/routers/documents.py create mode 100644 app/routers/folders.py create mode 100644 app/routers/projects.py create mode 100644 app/routers/search.py create mode 100644 app/routers/tags.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/auth.py create mode 100644 app/schemas/document.py create mode 100644 app/schemas/folder.py create mode 100644 app/schemas/project.py create mode 100644 app/schemas/search.py create mode 100644 app/schemas/tag.py create mode 100644 app/services/__init__.py create mode 100644 app/services/auth.py create mode 100644 app/services/search.py create mode 100644 data/claudia_docs.db create mode 100644 pytest.ini create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_documents.py create mode 100644 tests/test_folders.py create mode 100644 tests/test_projects.py create mode 100644 tests/test_search.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..3f3161a --- /dev/null +++ b/alembic.ini @@ -0,0 +1,41 @@ +[alembic] +script_location = alembic +prepend_sys_path = . +version_path_separator = os +sqlalchemy.url = sqlite:///./data/claudia_docs.db + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..9e31bfe --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,58 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +# this is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import models for autogenerate +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from app.database import Base +from app.models import agent, project, folder, document, tag, refresh_token + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/versions/001_initial.py b/alembic/versions/001_initial.py new file mode 100644 index 0000000..37d21da --- /dev/null +++ b/alembic/versions/001_initial.py @@ -0,0 +1,27 @@ +"""Initial schema + +Revision ID: 001 +Revises: +Create Date: 2026-03-30 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = '001' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # All schema creation is handled in database.py init_db() + # This migration is a no-op since we use raw SQL for SQLite-specific features + pass + + +def downgrade() -> None: + pass diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..a583c28 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +# Claudia Docs Backend diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..5eebb75 --- /dev/null +++ b/app/config.py @@ -0,0 +1,45 @@ +import os +from pathlib import Path + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +def _resolve_db_url(url: str) -> str: + """Convert relative sqlite path to absolute path.""" + if url.startswith("sqlite+aiosqlite:///./"): + # Convert relative path to absolute + rel_path = url.replace("sqlite+aiosqlite:///./", "") + abs_path = Path("/root/.openclaw/workspace-orchestrator/backend").resolve() / rel_path + return f"sqlite+aiosqlite:///{abs_path}" + return url + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + DATABASE_URL: str = "sqlite+aiosqlite:///./data/claudia_docs.db" + JWT_SECRET_KEY: str + JWT_ALGORITHM: str = "HS256" + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 + CORS_ORIGINS: str = "http://localhost:5173,http://localhost:80" + HOST: str = "0.0.0.0" + PORT: int = 8000 + LOG_LEVEL: str = "INFO" + INITIAL_ADMIN_USERNAME: str = "admin" + INITIAL_ADMIN_PASSWORD: str = "admin123" + + @property + def resolved_database_url(self) -> str: + return _resolve_db_url(self.DATABASE_URL) + + @property + def cors_origins_list(self) -> list[str]: + return [o.strip() for o in self.CORS_ORIGINS.split(",") if o.strip()] + + +settings = Settings() + +# Validate required secrets at startup +if not settings.JWT_SECRET_KEY or settings.JWT_SECRET_KEY == "change-me-super-secret-key-min32chars!!": + raise ValueError("JWT_SECRET_KEY must be set in environment variables") diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..eed06ec --- /dev/null +++ b/app/database.py @@ -0,0 +1,247 @@ +import asyncio +import hashlib +import uuid +from contextlib import asynccontextmanager +from pathlib import Path + +from sqlalchemy import create_engine, event, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase, sessionmaker + +from app.config import settings + +# Async engine for aiosqlite +DATABASE_URL = settings.resolved_database_url + +# Sync engine for migrations and initial setup +SYNC_DATABASE_URL = DATABASE_URL.replace("sqlite+aiosqlite:///", "sqlite:///") + + +class Base(DeclarativeBase): + pass + + +# Async session factory +async_engine = create_async_engine(DATABASE_URL, echo=False) +AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) + + +# Sync engine for migrations +sync_engine = create_engine(SYNC_DATABASE_URL, echo=False) +SyncSessionLocal = sessionmaker(sync_engine) + + +async def get_db(): + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + +async def get_db_simple(): + """Bare async generator for FastAPI dependency injection.""" + async with AsyncSessionLocal() as session: + yield session + + +async def init_db(): + """Initialize database with all tables, views, FTS5, and triggers.""" + # Create data directory + Path("./data").mkdir(exist_ok=True) + + async with async_engine.begin() as conn: + # Create all tables via SQL (not ORM) to handle SQLite-specific features + await conn.run_sync(_create_schema) + + +def _create_schema(sync_conn): + """Create all tables, views, FTS5 tables, and triggers synchronously.""" + + # Agents table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS agents ( + id TEXT PRIMARY KEY, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'agent' CHECK (role IN ('agent', 'admin')), + is_deleted INTEGER NOT NULL DEFAULT 0, + deleted_at TIMESTAMP NULL, + deleted_by TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + updated_at TIMESTAMP NOT NULL DEFAULT (datetime('now')) + ) + """)) + + # Projects table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + agent_id TEXT NOT NULL REFERENCES agents(id), + is_deleted INTEGER NOT NULL DEFAULT 0, + deleted_at TIMESTAMP NULL, + deleted_by TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + updated_at TIMESTAMP NOT NULL DEFAULT (datetime('now')) + ) + """)) + + # Folders table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS folders ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + project_id TEXT NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + parent_id TEXT REFERENCES folders(id) ON DELETE CASCADE, + path TEXT NOT NULL, + is_deleted INTEGER NOT NULL DEFAULT 0, + deleted_at TIMESTAMP NULL, + deleted_by TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + updated_at TIMESTAMP NOT NULL DEFAULT (datetime('now')) + ) + """)) + + # Documents table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS documents ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + content TEXT NOT NULL DEFAULT '', + project_id TEXT NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + folder_id TEXT REFERENCES folders(id) ON DELETE SET NULL, + path TEXT NOT NULL, + is_deleted INTEGER NOT NULL DEFAULT 0, + deleted_at TIMESTAMP NULL, + deleted_by TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + updated_at TIMESTAMP NOT NULL DEFAULT (datetime('now')) + ) + """)) + + # Tags table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS tags ( + id TEXT PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + color TEXT NOT NULL DEFAULT '#6366f1', + is_deleted INTEGER NOT NULL DEFAULT 0, + deleted_at TIMESTAMP NULL, + deleted_by TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')) + ) + """)) + + # Document tags junction + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS document_tags ( + document_id TEXT NOT NULL REFERENCES documents(id) ON DELETE CASCADE, + tag_id TEXT NOT NULL REFERENCES tags(id) ON DELETE CASCADE, + PRIMARY KEY (document_id, tag_id) + ) + """)) + + # Refresh tokens table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS refresh_tokens ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES agents(id), + token_hash TEXT NOT NULL UNIQUE, + token_family_id TEXT NOT NULL, + token_version INTEGER NOT NULL, + user_agent TEXT, + ip_address TEXT, + created_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + expires_at TIMESTAMP NOT NULL, + revoked_at TIMESTAMP NULL, + is_global_logout INTEGER NOT NULL DEFAULT 0 + ) + """)) + + # JWT blocklist table + sync_conn.execute(text(""" + CREATE TABLE IF NOT EXISTS jwt_blocklist ( + token_id TEXT PRIMARY KEY, + revoked_at TIMESTAMP NOT NULL DEFAULT (datetime('now')), + expires_at TIMESTAMP NOT NULL + ) + """)) + + # Indexes + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_projects_agent ON projects(agent_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_project ON folders(project_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_folders_parent ON folders(parent_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_documents_project ON documents(project_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_documents_folder ON documents(folder_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_doc ON document_tags(document_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON document_tags(tag_id)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_hash ON refresh_tokens(token_hash)")) + sync_conn.execute(text("CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user_family ON refresh_tokens(user_id, token_family_id)")) + + # --- FTS5 virtual table --- + sync_conn.execute(text(""" + CREATE VIRTUAL TABLE IF NOT EXISTS documents_fts USING fts5( + document_id, + title, + content, + path, + tokenize='unicode61 remove_diacritics 1' + ) + """)) + + # --- Active views for soft deletes --- + sync_conn.execute(text(""" + CREATE VIEW IF NOT EXISTS active_agents AS + SELECT * FROM agents WHERE is_deleted = 0 + """)) + sync_conn.execute(text(""" + CREATE VIEW IF NOT EXISTS active_projects AS + SELECT * FROM projects WHERE is_deleted = 0 + """)) + sync_conn.execute(text(""" + CREATE VIEW IF NOT EXISTS active_folders AS + SELECT * FROM folders WHERE is_deleted = 0 + """)) + sync_conn.execute(text(""" + CREATE VIEW IF NOT EXISTS active_documents AS + SELECT * FROM documents WHERE is_deleted = 0 + """)) + sync_conn.execute(text(""" + CREATE VIEW IF NOT EXISTS active_tags AS + SELECT * FROM tags WHERE is_deleted = 0 + """)) + + # --- FTS5 Sync Triggers --- + # Insert trigger + sync_conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS documents_fts_ai AFTER INSERT ON documents BEGIN + INSERT INTO documents_fts(document_id, title, content, path) + VALUES (new.id, new.title, new.content, new.path); + END + """)) + + # Update trigger (delete old + insert new) + sync_conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS documents_fts_au AFTER UPDATE ON documents BEGIN + DELETE FROM documents_fts WHERE document_id = old.id; + INSERT INTO documents_fts(document_id, title, content, path) + VALUES (new.id, new.title, new.content, new.path); + END + """)) + + # Soft-delete trigger (when is_deleted becomes TRUE) + sync_conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS documents_fts_ad AFTER UPDATE ON documents + WHEN new.is_deleted = 1 AND old.is_deleted = 0 + BEGIN + DELETE FROM documents_fts WHERE document_id = old.id; + END + """)) + + sync_conn.commit() diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..c4a669e --- /dev/null +++ b/app/main.py @@ -0,0 +1,63 @@ +import asyncio +from contextlib import asynccontextmanager +from datetime import datetime + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from app.config import settings +from app.database import init_db, get_db, async_engine +from app.routers import auth, projects, folders, documents, tags, search +from app.services.auth import cleanup_expired_blocklist + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup: init database + await init_db() + yield + # Shutdown + await async_engine.dispose() + + +app = FastAPI( + title="Claudia Docs API", + description="Gestor documental para agentes de IA", + version="1.0.0", + lifespan=lifespan, +) + +# CORS +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins_list, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Include routers +app.include_router(auth.router) +app.include_router(projects.router) +app.include_router(folders.router) +app.include_router(documents.router) +app.include_router(tags.router) +app.include_router(search.router) + + +@app.get("/api/v1/health") +async def health(): + return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + import logging + logger = logging.getLogger(__name__) + logger.error(f"Unhandled exception: {exc}", exc_info=True) + return JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..f3d9f4b --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1 @@ +# Models package diff --git a/app/models/agent.py b/app/models/agent.py new file mode 100644 index 0000000..1999040 --- /dev/null +++ b/app/models/agent.py @@ -0,0 +1,25 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class Agent(Base): + __tablename__ = "agents" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + username: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + password_hash: Mapped[str] = mapped_column(Text, nullable=False) + role: Mapped[str] = mapped_column(String(20), nullable=False, default="agent") + is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + deleted_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/app/models/document.py b/app/models/document.py new file mode 100644 index 0000000..52860ed --- /dev/null +++ b/app/models/document.py @@ -0,0 +1,27 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class Document(Base): + __tablename__ = "documents" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + title: Mapped[str] = mapped_column(String(500), nullable=False) + content: Mapped[str] = mapped_column(Text, nullable=False, default="") + project_id: Mapped[str] = mapped_column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + folder_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("folders.id", ondelete="SET NULL"), nullable=True) + path: Mapped[str] = mapped_column(Text, nullable=False) + is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + deleted_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/app/models/folder.py b/app/models/folder.py new file mode 100644 index 0000000..581f4a4 --- /dev/null +++ b/app/models/folder.py @@ -0,0 +1,26 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class Folder(Base): + __tablename__ = "folders" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + name: Mapped[str] = mapped_column(String(255), nullable=False) + project_id: Mapped[str] = mapped_column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + parent_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("folders.id", ondelete="CASCADE"), nullable=True) + path: Mapped[str] = mapped_column(Text, nullable=False) + is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + deleted_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/app/models/project.py b/app/models/project.py new file mode 100644 index 0000000..88c25d0 --- /dev/null +++ b/app/models/project.py @@ -0,0 +1,25 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class Project(Base): + __tablename__ = "projects" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + agent_id: Mapped[str] = mapped_column(String(36), ForeignKey("agents.id"), nullable=False) + is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + deleted_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) diff --git a/app/models/refresh_token.py b/app/models/refresh_token.py new file mode 100644 index 0000000..1f9077b --- /dev/null +++ b/app/models/refresh_token.py @@ -0,0 +1,35 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + user_id: Mapped[str] = mapped_column(String(36), nullable=False) + token_hash: Mapped[str] = mapped_column(Text, nullable=False, unique=True) + token_family_id: Mapped[str] = mapped_column(String(36), nullable=False) + token_version: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + user_agent: Mapped[str | None] = mapped_column(Text, nullable=True) + ip_address: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + revoked_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + is_global_logout: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + +class JwtBlocklist(Base): + __tablename__ = "jwt_blocklist" + + token_id: Mapped[str] = mapped_column(String(36), primary_key=True) + revoked_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + expires_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) diff --git a/app/models/tag.py b/app/models/tag.py new file mode 100644 index 0000000..d20a3ff --- /dev/null +++ b/app/models/tag.py @@ -0,0 +1,30 @@ +import uuid +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Table, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base + + +def generate_uuid() -> str: + return str(uuid.uuid4()) + + +class Tag(Base): + __tablename__ = "tags" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=generate_uuid) + name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False) + color: Mapped[str] = mapped_column(String(7), nullable=False, default="#6366f1") + is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + deleted_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + deleted_by: Mapped[str | None] = mapped_column(String(36), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) + + +class DocumentTag(Base): + __tablename__ = "document_tags" + + document_id: Mapped[str] = mapped_column(String(36), ForeignKey("documents.id", ondelete="CASCADE"), primary_key=True) + tag_id: Mapped[str] = mapped_column(String(36), ForeignKey("tags.id", ondelete="CASCADE"), primary_key=True) diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..873f7bb --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1 @@ +# Routers package diff --git a/app/routers/auth.py b/app/routers/auth.py new file mode 100644 index 0000000..aaf16be --- /dev/null +++ b/app/routers/auth.py @@ -0,0 +1,167 @@ +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.agent import Agent +from app.schemas.auth import AgentCreate, AgentLogin, AgentResponse, RefreshResponse, TokenResponse +from app.services import auth as auth_service + +router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) + +COOKIE_NAME = "refresh_token" +COOKIE_PATH = "/api/v1/auth/refresh" +COOKIE_MAX_AGE = 60 * 60 * 24 * 7 # 7 days + + +def _set_refresh_cookie(response: Response, token: str): + response.set_cookie( + key=COOKIE_NAME, + value=token, + max_age=COOKIE_MAX_AGE, + httponly=True, + secure=True, # True in production with HTTPS + samesite="lax", + path=COOKIE_PATH, + ) + + +def _clear_refresh_cookie(response: Response): + response.set_cookie( + key=COOKIE_NAME, + value="", + max_age=0, + httponly=True, + secure=True, # True in production with HTTPS + samesite="lax", + path=COOKIE_PATH, + ) + + +async def get_current_agent(request: Request, db: AsyncSession) -> Agent: + """Get the current authenticated agent from request.""" + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Not authenticated") + + token = auth_header[7:] + payload = auth_service.decode_token(token) + if not payload: + raise HTTPException(status_code=401, detail="Invalid token") + + jti = payload.get("jti") + if jti: + is_blocked = await auth_service.is_token_blocklisted(db, jti) + if is_blocked: + raise HTTPException(status_code=401, detail="Token revoked") + + agent_id = payload.get("sub") + if not agent_id: + raise HTTPException(status_code=401, detail="Invalid token payload") + + agent = await auth_service.get_agent_by_id(db, agent_id) + if not agent: + raise HTTPException(status_code=401, detail="Agent not found") + + return agent + + +@router.post("/register", response_model=AgentResponse, status_code=201) +async def register(payload: AgentCreate, db: AsyncSession = Depends(get_db)): + existing = await auth_service.get_agent_by_username(db, payload.username) + if existing: + raise HTTPException(status_code=400, detail="Username already exists") + + agent = await auth_service.create_agent(db, payload.username, payload.password) + return AgentResponse.model_validate(agent) + + +@router.post("/login", response_model=TokenResponse) +async def login( + payload: AgentLogin, + request: Request, + response: Response, + db: AsyncSession = Depends(get_db), +): + agent = await auth_service.get_agent_by_username(db, payload.username) + if not agent or not auth_service.verify_password(payload.password, agent.password_hash): + raise HTTPException(status_code=401, detail="Invalid credentials") + + access_token, jti = auth_service.create_access_token(agent.id, agent.role) + + refresh_token = auth_service.create_refresh_token() + user_agent = request.headers.get("user-agent") + client_ip = request.client.host if request.client else None + await auth_service.save_refresh_token(db, agent.id, refresh_token, user_agent, client_ip) + + _set_refresh_cookie(response, refresh_token) + return TokenResponse(access_token=access_token) + + +@router.get("/me", response_model=AgentResponse) +async def get_me(request: Request, db: AsyncSession = Depends(get_db)): + agent = await get_current_agent(request, db) + return AgentResponse.model_validate(agent) + + +@router.post("/refresh", response_model=RefreshResponse) +async def refresh(request: Request, response: Response, db: AsyncSession = Depends(get_db)): + token = request.cookies.get(COOKIE_NAME) + if not token: + raise HTTPException(status_code=401, detail="No refresh token") + + user_agent = request.headers.get("user-agent") + client_ip = request.client.host if request.client else None + + result = await auth_service.rotate_refresh_token(db, token, user_agent, client_ip) + if not result: + raise HTTPException(status_code=401, detail="Invalid or expired refresh token") + + new_rt, new_token = result + + agent = await auth_service.get_agent_by_id(db, new_rt.user_id) + if not agent: + raise HTTPException(status_code=401, detail="Agent not found") + + access_token, jti = auth_service.create_access_token(agent.id, agent.role) + _set_refresh_cookie(response, new_token) + return RefreshResponse(access_token=access_token) + + +@router.post("/logout", status_code=204) +async def logout(request: Request, response: Response, db: AsyncSession = Depends(get_db)): + token = request.cookies.get(COOKIE_NAME) + if token: + await auth_service.revoke_refresh_token(db, token) + + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + access_token = auth_header[7:] + payload = auth_service.decode_token(access_token) + if payload and "jti" in payload: + exp = datetime.utcfromtimestamp(payload["exp"]) + await auth_service.add_to_blocklist(db, payload["jti"], exp) + + _clear_refresh_cookie(response) + return Response(status_code=204) + + +@router.post("/logout-all", status_code=204) +async def logout_all(request: Request, response: Response, db: AsyncSession = Depends(get_db)): + auth_header = request.headers.get("authorization", "") + agent = None + if auth_header.startswith("Bearer "): + access_token = auth_header[7:] + payload = auth_service.decode_token(access_token) + if payload and "sub" in payload: + agent = await auth_service.get_agent_by_id(db, payload["sub"]) + + if agent: + await auth_service.revoke_all_user_tokens(db, agent.id) + if payload and "jti" in payload: + exp = datetime.utcfromtimestamp(payload["exp"]) + await auth_service.add_to_blocklist(db, payload["jti"], exp) + + _clear_refresh_cookie(response) + return Response(status_code=204) diff --git a/app/routers/documents.py b/app/routers/documents.py new file mode 100644 index 0000000..0745faa --- /dev/null +++ b/app/routers/documents.py @@ -0,0 +1,435 @@ +import uuid +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import delete, select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.document import Document +from app.models.folder import Folder +from app.models.project import Project +from app.models.tag import DocumentTag, Tag +from app.routers.auth import get_current_agent +from app.schemas.document import ( + DocumentBriefResponse, + DocumentContentUpdate, + DocumentCreate, + DocumentListResponse, + DocumentResponse, + DocumentUpdate, + TagInfo, +) +from app.schemas.tag import DocumentTagsAssign + + +router = APIRouter(tags=["documents"]) + + +def build_doc_path(project_id: str, doc_id: str, folder_id: str | None, folder_path: str | None) -> str: + if folder_id and folder_path: + return f"{folder_path}/{doc_id}" + return f"/{project_id}/{doc_id}" + + +async def get_document_tags(db: AsyncSession, doc_id: str) -> list[TagInfo]: + result = await db.execute( + text(""" + SELECT t.id, t.name, t.color + FROM active_tags t + JOIN document_tags dt ON t.id = dt.tag_id + WHERE dt.document_id = :doc_id + """), + {"doc_id": doc_id} + ) + rows = result.fetchall() + return [TagInfo(id=r.id, name=r.name, color=r.color) for r in rows] + + +async def document_to_response(db: AsyncSession, doc: Document) -> DocumentResponse: + tags = await get_document_tags(db, doc.id) + return DocumentResponse( + id=doc.id, + title=doc.title, + content=doc.content, + project_id=doc.project_id, + folder_id=doc.folder_id, + path=doc.path, + tags=tags, + created_at=doc.created_at, + updated_at=doc.updated_at, + ) + + +@router.get("/api/v1/projects/{project_id}/documents", response_model=DocumentListResponse) +async def list_documents( + request: Request, + project_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Project not found") + + result = await db.execute( + select(Document).where( + Document.project_id == project_id, + Document.is_deleted == False, + ).order_by(Document.created_at.desc()) + ) + docs = result.scalars().all() + + responses = [] + for doc in docs: + tags = await get_document_tags(db, doc.id) + responses.append(DocumentBriefResponse( + id=doc.id, + title=doc.title, + project_id=doc.project_id, + folder_id=doc.folder_id, + path=doc.path, + tags=tags, + created_at=doc.created_at, + updated_at=doc.updated_at, + )) + + return DocumentListResponse(documents=responses) + + +@router.post("/api/v1/projects/{project_id}/documents", response_model=DocumentResponse, status_code=201) +async def create_document( + request: Request, + project_id: str, + payload: DocumentCreate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Project not found") + + folder_path = None + if payload.folder_id: + folder_result = await db.execute( + select(Folder).where( + Folder.id == payload.folder_id, + Folder.project_id == project_id, + Folder.is_deleted == False, + ) + ) + folder = folder_result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=400, detail="Folder not found") + folder_path = folder.path + + doc_id = str(uuid.uuid4()) + path = build_doc_path(project_id, doc_id, payload.folder_id, folder_path) + + doc = Document( + id=doc_id, + title=payload.title, + content=payload.content, + project_id=project_id, + folder_id=payload.folder_id, + path=path, + ) + db.add(doc) + await db.flush() + return await document_to_response(db, doc) + + +@router.get("/api/v1/documents/{document_id}", response_model=DocumentResponse) +async def get_document( + request: Request, + document_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Document not found") + + return await document_to_response(db, doc) + + +@router.put("/api/v1/documents/{document_id}", response_model=DocumentResponse) +async def update_document( + request: Request, + document_id: str, + payload: DocumentUpdate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + if payload.title is not None: + doc.title = payload.title + if payload.folder_id is not None: + if payload.folder_id: + folder_result = await db.execute( + select(Folder).where( + Folder.id == payload.folder_id, + Folder.project_id == doc.project_id, + Folder.is_deleted == False, + ) + ) + folder = folder_result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=400, detail="Folder not found") + doc.path = f"{folder.path}/{doc.id}" + else: + doc.path = f"/{doc.project_id}/{doc.id}" + doc.folder_id = payload.folder_id + + doc.updated_at = datetime.utcnow() + await db.flush() + return await document_to_response(db, doc) + + +@router.delete("/api/v1/documents/{document_id}", status_code=204) +async def delete_document( + request: Request, + document_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + doc.is_deleted = True + doc.deleted_at = datetime.utcnow() + doc.deleted_by = agent.id + await db.flush() + return None + + +@router.put("/api/v1/documents/{document_id}/content", response_model=DocumentResponse) +async def update_document_content( + request: Request, + document_id: str, + payload: DocumentContentUpdate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + doc.content = payload.content + doc.updated_at = datetime.utcnow() + await db.flush() + return await document_to_response(db, doc) + + +@router.post("/api/v1/documents/{document_id}/restore", response_model=DocumentResponse) +async def restore_document( + request: Request, + document_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == True, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + doc.is_deleted = False + doc.deleted_at = None + doc.deleted_by = None + await db.flush() + return await document_to_response(db, doc) + + +@router.post("/api/v1/documents/{document_id}/tags", status_code=204) +async def assign_tags( + request: Request, + document_id: str, + payload: DocumentTagsAssign, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + for tag_id in payload.tag_ids: + tag_result = await db.execute( + select(Tag).where( + Tag.id == tag_id, + Tag.is_deleted == False, + ) + ) + tag = tag_result.scalar_one_or_none() + if not tag: + raise HTTPException(status_code=400, detail=f"Tag {tag_id} not found") + + existing = await db.execute( + select(DocumentTag).where( + DocumentTag.document_id == document_id, + DocumentTag.tag_id == tag_id, + ) + ) + if not existing.scalar_one_or_none(): + dt = DocumentTag(document_id=document_id, tag_id=tag_id) + db.add(dt) + + await db.flush() + return None + + +@router.delete("/api/v1/documents/{document_id}/tags/{tag_id}", status_code=204) +async def remove_tag( + request: Request, + document_id: str, + tag_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Document).where( + Document.id == document_id, + Document.is_deleted == False, + ) + ) + doc = result.scalar_one_or_none() + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + await db.execute( + delete(DocumentTag).where( + DocumentTag.document_id == document_id, + DocumentTag.tag_id == tag_id, + ) + ) + await db.flush() + return None diff --git a/app/routers/folders.py b/app/routers/folders.py new file mode 100644 index 0000000..aabf2e9 --- /dev/null +++ b/app/routers/folders.py @@ -0,0 +1,251 @@ +import uuid +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.folder import Folder +from app.models.project import Project +from app.schemas.folder import FolderCreate, FolderListResponse, FolderResponse, FolderUpdate +from app.routers.auth import get_current_agent + +router = APIRouter(tags=["folders"]) + + +def build_folder_path(project_id: str, folder_id: str, parent_id: str | None, parent_path: str | None) -> str: + if parent_id and parent_path: + return f"{parent_path}/{folder_id}" + return f"/{project_id}/{folder_id}" + + +@router.get("/api/v1/projects/{project_id}/folders", response_model=FolderListResponse) +async def list_folders( + request: Request, + project_id: str, + parent_id: str | None = Query(None), + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Project not found") + + query = select(Folder).where( + Folder.project_id == project_id, + Folder.parent_id == parent_id, + Folder.is_deleted == False, + ).order_by(Folder.name) + + result = await db.execute(query) + folders = result.scalars().all() + return FolderListResponse(folders=[FolderResponse.model_validate(f) for f in folders]) + + +@router.post("/api/v1/projects/{project_id}/folders", response_model=FolderResponse, status_code=201) +async def create_folder( + request: Request, + project_id: str, + payload: FolderCreate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Project not found") + + parent_path = None + if payload.parent_id: + parent_result = await db.execute( + select(Folder).where( + Folder.id == payload.parent_id, + Folder.project_id == project_id, + Folder.is_deleted == False, + ) + ) + parent = parent_result.scalar_one_or_none() + if not parent: + raise HTTPException(status_code=400, detail="Parent folder not found") + parent_path = parent.path + + folder_id = str(uuid.uuid4()) + path = build_folder_path(project_id, folder_id, payload.parent_id, parent_path) + + folder = Folder( + id=folder_id, + name=payload.name, + project_id=project_id, + parent_id=payload.parent_id, + path=path, + ) + db.add(folder) + await db.flush() + return FolderResponse.model_validate(folder) + + +@router.get("/api/v1/folders/{folder_id}", response_model=FolderResponse) +async def get_folder( + request: Request, + folder_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Folder).where( + Folder.id == folder_id, + Folder.is_deleted == False, + ) + ) + folder = result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=404, detail="Folder not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == folder.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=404, detail="Folder not found") + + return FolderResponse.model_validate(folder) + + +@router.put("/api/v1/folders/{folder_id}", response_model=FolderResponse) +async def update_folder( + request: Request, + folder_id: str, + payload: FolderUpdate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Folder).where( + Folder.id == folder_id, + Folder.is_deleted == False, + ) + ) + folder = result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=404, detail="Folder not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == folder.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + if payload.name is not None: + folder.name = payload.name + if payload.parent_id is not None: + if payload.parent_id == folder_id: + raise HTTPException(status_code=400, detail="Cannot set folder as its own parent") + parent_result = await db.execute( + select(Folder).where( + Folder.id == payload.parent_id, + Folder.project_id == folder.project_id, + Folder.is_deleted == False, + ) + ) + parent = parent_result.scalar_one_or_none() + if not parent: + raise HTTPException(status_code=400, detail="Parent folder not found") + folder.parent_id = payload.parent_id + folder.path = f"{parent.path}/{folder.id}" + folder.updated_at = datetime.utcnow() + + await db.flush() + return FolderResponse.model_validate(folder) + + +@router.delete("/api/v1/folders/{folder_id}", status_code=204) +async def delete_folder( + request: Request, + folder_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Folder).where( + Folder.id == folder_id, + Folder.is_deleted == False, + ) + ) + folder = result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=404, detail="Folder not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == folder.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + folder.is_deleted = True + folder.deleted_at = datetime.utcnow() + folder.deleted_by = agent.id + await db.flush() + return None + + +@router.post("/api/v1/folders/{folder_id}/restore", response_model=FolderResponse) +async def restore_folder( + request: Request, + folder_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Folder).where( + Folder.id == folder_id, + Folder.is_deleted == True, + ) + ) + folder = result.scalar_one_or_none() + if not folder: + raise HTTPException(status_code=404, detail="Folder not found") + + proj_result = await db.execute( + select(Project).where( + Project.id == folder.project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + if not proj_result.scalar_one_or_none(): + raise HTTPException(status_code=403, detail="Forbidden") + + folder.is_deleted = False + folder.deleted_at = None + folder.deleted_by = None + await db.flush() + return FolderResponse.model_validate(folder) diff --git a/app/routers/projects.py b/app/routers/projects.py new file mode 100644 index 0000000..25c8877 --- /dev/null +++ b/app/routers/projects.py @@ -0,0 +1,146 @@ +import uuid +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.project import Project +from app.schemas.project import ProjectCreate, ProjectListResponse, ProjectResponse, ProjectUpdate +from app.routers.auth import get_current_agent + +router = APIRouter(prefix="/api/v1/projects", tags=["projects"]) + + +@router.get("", response_model=ProjectListResponse) +async def list_projects( + request: Request, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Project).where( + Project.agent_id == agent.id, + Project.is_deleted == False, + ).order_by(Project.created_at.desc()) + ) + projects = result.scalars().all() + return ProjectListResponse(projects=[ProjectResponse.model_validate(p) for p in projects]) + + +@router.post("", response_model=ProjectResponse, status_code=201) +async def create_project( + request: Request, + payload: ProjectCreate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + project = Project( + id=str(uuid.uuid4()), + name=payload.name, + description=payload.description, + agent_id=agent.id, + ) + db.add(project) + await db.flush() + return ProjectResponse.model_validate(project) + + +@router.get("/{project_id}", response_model=ProjectResponse) +async def get_project( + request: Request, + project_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + return ProjectResponse.model_validate(project) + + +@router.put("/{project_id}", response_model=ProjectResponse) +async def update_project( + request: Request, + project_id: str, + payload: ProjectUpdate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + if payload.name is not None: + project.name = payload.name + if payload.description is not None: + project.description = payload.description + project.updated_at = datetime.utcnow() + + await db.flush() + return ProjectResponse.model_validate(project) + + +@router.delete("/{project_id}", status_code=204) +async def delete_project( + request: Request, + project_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == False, + ) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + project.is_deleted = True + project.deleted_at = datetime.utcnow() + project.deleted_by = agent.id + await db.flush() + return None + + +@router.post("/{project_id}/restore", response_model=ProjectResponse) +async def restore_project( + request: Request, + project_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.agent_id == agent.id, + Project.is_deleted == True, + ) + ) + project = result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + project.is_deleted = False + project.deleted_at = None + project.deleted_by = None + await db.flush() + return ProjectResponse.model_validate(project) diff --git a/app/routers/search.py b/app/routers/search.py new file mode 100644 index 0000000..78cf229 --- /dev/null +++ b/app/routers/search.py @@ -0,0 +1,36 @@ +from fastapi import APIRouter, Depends, Query, Request +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.routers.auth import get_current_agent +from app.schemas.search import SearchResponse +from app.services.search import search_documents + +router = APIRouter(prefix="/api/v1/search", tags=["search"]) + + +@router.get("", response_model=SearchResponse) +async def search( + request: Request, + q: str = Query(..., min_length=1), + project_id: str | None = Query(None), + tags: str | None = Query(None), + limit: int = Query(20, ge=1, le=100), + offset: int = Query(0, ge=0), + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + tag_list = None + if tags: + tag_list = [t.strip() for t in tags.split(",") if t.strip()] + + return await search_documents( + db=db, + query=q, + agent_id=agent.id, + project_id=project_id, + tags=tag_list, + limit=limit, + offset=offset, + ) diff --git a/app/routers/tags.py b/app/routers/tags.py new file mode 100644 index 0000000..975884a --- /dev/null +++ b/app/routers/tags.py @@ -0,0 +1,79 @@ +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.tag import Tag +from app.routers.auth import get_current_agent +from app.schemas.tag import TagCreate, TagListResponse, TagResponse + +router = APIRouter(prefix="/api/v1/tags", tags=["tags"]) + + +@router.get("", response_model=TagListResponse) +async def list_tags( + request: Request, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + result = await db.execute( + select(Tag).where( + Tag.is_deleted == False, + ).order_by(Tag.name) + ) + tags = result.scalars().all() + return TagListResponse(tags=[TagResponse.model_validate(t) for t in tags]) + + +@router.post("", response_model=TagResponse, status_code=201) +async def create_tag( + request: Request, + payload: TagCreate, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + existing = await db.execute( + select(Tag).where( + Tag.name == payload.name, + Tag.is_deleted == False, + ) + ) + if existing.scalar_one_or_none(): + raise HTTPException(status_code=400, detail="Tag with this name already exists") + + tag = Tag( + id=str(uuid.uuid4()), + name=payload.name, + color=payload.color, + ) + db.add(tag) + await db.flush() + return TagResponse.model_validate(tag) + + +@router.post("/{tag_id}/restore", response_model=TagResponse) +async def restore_tag( + request: Request, + tag_id: str, + db: AsyncSession = Depends(get_db), +): + agent = await get_current_agent(request, db) + + result = await db.execute( + select(Tag).where( + Tag.id == tag_id, + Tag.is_deleted == True, + ) + ) + tag = result.scalar_one_or_none() + if not tag: + raise HTTPException(status_code=404, detail="Tag not found") + + tag.is_deleted = False + tag.deleted_at = None + tag.deleted_by = None + await db.flush() + return TagResponse.model_validate(tag) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..8d2fd85 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1 @@ +# Schemas package diff --git a/app/schemas/auth.py b/app/schemas/auth.py new file mode 100644 index 0000000..ffe0de9 --- /dev/null +++ b/app/schemas/auth.py @@ -0,0 +1,32 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +class AgentCreate(BaseModel): + username: str = Field(..., min_length=3, max_length=50) + password: str = Field(..., min_length=6) + + +class AgentResponse(BaseModel): + id: str + username: str + role: str + created_at: datetime + + model_config = {"from_attributes": True} + + +class AgentLogin(BaseModel): + username: str + password: str + + +class TokenResponse(BaseModel): + access_token: str + token_type: str = "bearer" + + +class RefreshResponse(BaseModel): + access_token: str + token_type: str = "bearer" diff --git a/app/schemas/document.py b/app/schemas/document.py new file mode 100644 index 0000000..9afcd8e --- /dev/null +++ b/app/schemas/document.py @@ -0,0 +1,58 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + + +class DocumentCreate(BaseModel): + title: str + content: str = "" + folder_id: str | None = None + + +class DocumentUpdate(BaseModel): + title: str | None = None + folder_id: str | None = None + + +class DocumentContentUpdate(BaseModel): + content: str = Field(..., max_length=1_000_000) # 1MB limit + + +class TagInfo(BaseModel): + id: str + name: str + color: str + + model_config = {"from_attributes": True} + + +class DocumentResponse(BaseModel): + id: str + title: str + content: str + project_id: str + folder_id: str | None + path: str + tags: list[TagInfo] = [] + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class DocumentListResponse(BaseModel): + documents: list[DocumentResponse] + + +class DocumentBriefResponse(BaseModel): + """Brief document for list views without content.""" + id: str + title: str + project_id: str + folder_id: str | None + path: str + tags: list[TagInfo] = [] + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} diff --git a/app/schemas/folder.py b/app/schemas/folder.py new file mode 100644 index 0000000..8fd65cd --- /dev/null +++ b/app/schemas/folder.py @@ -0,0 +1,29 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class FolderCreate(BaseModel): + name: str + parent_id: str | None = None + + +class FolderUpdate(BaseModel): + name: str | None = None + parent_id: str | None = None + + +class FolderResponse(BaseModel): + id: str + name: str + project_id: str + parent_id: str | None + path: str + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class FolderListResponse(BaseModel): + folders: list[FolderResponse] diff --git a/app/schemas/project.py b/app/schemas/project.py new file mode 100644 index 0000000..c849cac --- /dev/null +++ b/app/schemas/project.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class ProjectCreate(BaseModel): + name: str + description: str | None = None + + +class ProjectUpdate(BaseModel): + name: str | None = None + description: str | None = None + + +class ProjectResponse(BaseModel): + id: str + name: str + description: str | None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class ProjectListResponse(BaseModel): + projects: list[ProjectResponse] diff --git a/app/schemas/search.py b/app/schemas/search.py new file mode 100644 index 0000000..187e7dc --- /dev/null +++ b/app/schemas/search.py @@ -0,0 +1,18 @@ +from datetime import datetime + +from pydantic import BaseModel + +from app.schemas.document import TagInfo + + +class SearchResult(BaseModel): + id: str + title: str + excerpt: str + project_id: str + tags: list[TagInfo] = [] + score: float + + +class SearchResponse(BaseModel): + results: list[SearchResult] diff --git a/app/schemas/tag.py b/app/schemas/tag.py new file mode 100644 index 0000000..6210c30 --- /dev/null +++ b/app/schemas/tag.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class TagCreate(BaseModel): + name: str + color: str = "#6366f1" + + +class TagResponse(BaseModel): + id: str + name: str + color: str + created_at: datetime + + model_config = {"from_attributes": True} + + +class TagListResponse(BaseModel): + tags: list[TagResponse] + + +class DocumentTagsAssign(BaseModel): + tag_ids: list[str] diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..a70b302 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1 @@ +# Services package diff --git a/app/services/auth.py b/app/services/auth.py new file mode 100644 index 0000000..8652a60 --- /dev/null +++ b/app/services/auth.py @@ -0,0 +1,233 @@ +import hashlib +import secrets +import uuid +from datetime import datetime, timedelta + +import bcrypt +from jose import JWTError, jwt +from sqlalchemy import select, text, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.models.agent import Agent +from app.models.refresh_token import JwtBlocklist, RefreshToken + + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") + + +def verify_password(plain: str, hashed: str) -> bool: + return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) + + +def create_access_token(agent_id: str, role: str) -> tuple[str, str]: + """Create JWT access token. Returns (token, jti).""" + jti = str(uuid.uuid4()) + now = datetime.utcnow() + expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + payload = { + "sub": agent_id, + "role": role, + "jti": jti, + "iat": now, + "exp": expire, + } + token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) + return token, jti + + +def create_refresh_token() -> str: + """Create opaque refresh token.""" + return secrets.token_urlsafe(64) + + +def hash_token(token: str) -> str: + """SHA-256 hash of a token.""" + return hashlib.sha256(token.encode()).hexdigest() + + +def decode_token(token: str) -> dict | None: + """Decode and validate JWT. Returns payload or None.""" + try: + payload = jwt.decode( + token, + settings.JWT_SECRET_KEY, + algorithms=[settings.JWT_ALGORITHM], + ) + return payload + except JWTError: + return None + + +async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None: + result = await db.execute( + select(Agent).where(Agent.username == username, Agent.is_deleted == False) + ) + return result.scalar_one_or_none() + + +async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None: + result = await db.execute( + select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False) + ) + return result.scalar_one_or_none() + + +async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent: + agent = Agent( + id=str(uuid.uuid4()), + username=username, + password_hash=hash_password(password), + role=role, + ) + db.add(agent) + await db.flush() + return agent + + +async def save_refresh_token( + db: AsyncSession, + user_id: str, + token: str, + user_agent: str | None, + ip_address: str | None, +) -> RefreshToken: + """Save a new refresh token with a new family.""" + token_hash = hash_token(token) + family_id = str(uuid.uuid4()) + expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) + rt = RefreshToken( + id=str(uuid.uuid4()), + user_id=user_id, + token_hash=token_hash, + token_family_id=family_id, + token_version=1, + user_agent=user_agent, + ip_address=ip_address, + expires_at=expires_at, + ) + db.add(rt) + await db.flush() + return rt + + +async def rotate_refresh_token( + db: AsyncSession, + old_token: str, + user_agent: str | None, + ip_address: str | None, +) -> tuple[RefreshToken, str] | None: + """ + Rotate a refresh token: revoke old, create new. + Returns (new_rt, new_token) or None if invalid. + """ + token_hash = hash_token(old_token) + + # Find existing token + result = await db.execute( + select(RefreshToken).where( + RefreshToken.token_hash == token_hash, + RefreshToken.revoked_at.is_(None), + RefreshToken.is_global_logout == False, + ) + ) + old_rt: RefreshToken | None = result.scalar_one_or_none() + if not old_rt: + return None + + # Check expiry + if old_rt.expires_at < datetime.utcnow(): + return None + + # Reuse detection: check if a higher version exists + reuse_check = await db.execute( + select(RefreshToken).where( + RefreshToken.token_family_id == old_rt.token_family_id, + RefreshToken.token_version > old_rt.token_version, + RefreshToken.revoked_at.is_not(None), + ) + ) + if reuse_check.scalar_one_or_none(): + # Possible theft detected: revoke entire family + await db.execute( + update(RefreshToken) + .where(RefreshToken.token_family_id == old_rt.token_family_id) + .values(is_global_logout=True, revoked_at=datetime.utcnow()) + ) + return None + + # Revoke old token + old_rt.revoked_at = datetime.utcnow() + + # Create new token in same family + new_token = create_refresh_token() + new_hash = hash_token(new_token) + expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) + new_rt = RefreshToken( + id=str(uuid.uuid4()), + user_id=old_rt.user_id, + token_hash=new_hash, + token_family_id=old_rt.token_family_id, + token_version=old_rt.token_version + 1, + user_agent=user_agent, + ip_address=ip_address, + expires_at=expires_at, + ) + db.add(new_rt) + await db.flush() + return new_rt, new_token + + +async def revoke_refresh_token(db: AsyncSession, token: str) -> bool: + """Revoke a single refresh token.""" + token_hash = hash_token(token) + result = await db.execute( + select(RefreshToken).where( + RefreshToken.token_hash == token_hash, + RefreshToken.revoked_at.is_(None), + ) + ) + rt: RefreshToken | None = result.scalar_one_or_none() + if not rt: + return False + rt.revoked_at = datetime.utcnow() + await db.flush() + return True + + +async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool: + """Revoke all refresh tokens for a user (logout-all).""" + await db.execute( + update(RefreshToken) + .where( + RefreshToken.user_id == user_id, + RefreshToken.revoked_at.is_(None), + ) + .values(is_global_logout=True, revoked_at=datetime.utcnow()) + ) + await db.flush() + return True + + +async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None: + """Add a token JTI to the blocklist.""" + entry = JwtBlocklist(token_id=jti, expires_at=expires_at) + db.add(entry) + await db.flush() + + +async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool: + """Check if a token JTI is in the blocklist.""" + result = await db.execute( + select(JwtBlocklist).where(JwtBlocklist.token_id == jti) + ) + return result.scalar_one_or_none() is not None + + +async def cleanup_expired_blocklist(db: AsyncSession) -> None: + """Remove expired entries from blocklist.""" + await db.execute( + text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')") + ) + await db.flush() diff --git a/app/services/search.py b/app/services/search.py new file mode 100644 index 0000000..c595571 --- /dev/null +++ b/app/services/search.py @@ -0,0 +1,128 @@ +from sqlalchemy import select, text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.schemas.search import SearchResult, SearchResponse +from app.schemas.document import TagInfo + + +async def search_documents( + db: AsyncSession, + query: str, + agent_id: str | None = None, + project_id: str | None = None, + tags: list[str] | None = None, + limit: int = 20, + offset: int = 0, +) -> SearchResponse: + """ + Full-text search using FTS5. + Returns snippets with highlight markup. + """ + if not query or len(query.strip()) == 0: + return SearchResponse(results=[]) + + # Escape FTS5 special characters and prepare query + safe_query = query.replace('"', '""') + + # Build the FTS5 MATCH query + fts_query = f'"{safe_query}"' + + # Get document IDs from FTS5 + fts_sql = text(""" + SELECT document_id, title, content, path, + bm25(documents_fts) as score + FROM documents_fts + WHERE documents_fts MATCH :q + ORDER BY score + LIMIT :limit OFFSET :offset + """) + + fts_result = await db.execute( + fts_sql, + {"q": fts_query, "limit": limit, "offset": offset} + ) + fts_rows = fts_result.fetchall() + + if not fts_rows: + return SearchResponse(results=[]) + + results = [] + for row in fts_rows: + doc_id = row.document_id + + # Get document to verify access and get project_id + doc_sql = text(""" + SELECT d.id, d.title, d.content, d.project_id, d.is_deleted, + p.agent_id + FROM active_documents d + JOIN active_projects p ON d.project_id = p.id + WHERE d.id = :doc_id AND p.agent_id = :agent_id + """) + doc_result = await db.execute( + doc_sql, + {"doc_id": doc_id, "agent_id": agent_id} + ) + doc_row = doc_result.fetchone() + if not doc_row: + continue + + # Filter by project_id if provided + if project_id and doc_row.project_id != project_id: + continue + + # Get tags for this document + tags_sql = text(""" + SELECT t.id, t.name, t.color + FROM active_tags t + JOIN document_tags dt ON t.id = dt.tag_id + WHERE dt.document_id = :doc_id + """) + tags_result = await db.execute(tags_sql, {"doc_id": doc_id}) + tag_rows = tags_result.fetchall() + doc_tags = [TagInfo(id=t.id, name=t.name, color=t.color) for t in tag_rows] + + # Filter by tags if provided + if tags: + tag_names = {t.name for t in doc_tags} + if not any(tn in tag_names for tn in tags): + continue + + # Build excerpt with snippet + content = doc_row.content or "" + excerpt = _build_snippet(content, query) + + results.append(SearchResult( + id=doc_row.id, + title=doc_row.title, + excerpt=excerpt, + project_id=doc_row.project_id, + tags=doc_tags, + score=abs(row.score) if row.score else 0.0, + )) + + return SearchResponse(results=results) + + +def _build_snippet(content: str, query: str, context_chars: int = 150) -> str: + """Build a highlighted snippet from content.""" + query_lower = query.lower() + content_lower = content.lower() + + idx = content_lower.find(query_lower) + if idx == -1: + # No exact match, return beginning + snippet = content[:context_chars * 2] + else: + start = max(0, idx - context_chars) + end = min(len(content), idx + len(query) + context_chars) + snippet = content[start:end] + if start > 0: + snippet = "..." + snippet + if end < len(content): + snippet = snippet + "..." + + # Simple highlight: wrap matches in ** + import re + pattern = re.compile(re.escape(query), re.IGNORECASE) + snippet = pattern.sub(f"**{query}**", snippet) + return snippet diff --git a/data/claudia_docs.db b/data/claudia_docs.db new file mode 100644 index 0000000000000000000000000000000000000000..8d111b59ddb66d382095702d621c5194d7bb5302 GIT binary patch literal 147456 zcmeI*-*ekWVh3HJW5KAkNjc8+3f*aV0lL|o4uCDD!=S#~5cPTK2mmH;Rt zrbrzDNp|k0cZ_{MI@4+2?(rTv)2FsidFeEN!uaq+&nE#7{2 zv8n7XRx0cFmp8U=R_%t;`M&aKYx#rf+Eub?wj41pDtx}c7xJ7Sy)Eb8mWx-_f}}{o z$TiX=Ef<^qoS{E>KmY;|fB*y_009U<00Izz00bcLf(c|})rsTGJow}PFPZe0FIWMQ zT?jw`0uX=z1Rwwb2tWV=5P$##juJ@4G85iS2mby45190Wqc*{M2tWV=5P$##AOHaf zKmY;|fB*!JA<%u9ovgmeRK0fvki05uswnY#(IoU;0jA6omEOBA*NH%+qFPiN`f^*h zxx1bAN2cx^=UV>YUH;ltA+HuxapW4@|9^~%Llg)=00Izz00bZa0SG_<0uX?}^D5xp z|9_l02KV|u{m_3rAOHafKmY;|fB*y_009U<00Iy=#RBg8|1thQ#TAUAK>z{}fB*y_ z009U<00Izz00gLj|Nj3tla7M}0uX=z1Rwwb2tWV=5P$##AaJ?`x{uhI>gDTIru(}0 z-u${`7^)^|ye1KumkLCq-`FP+UJ>dVF(la(RQZQ*UpXoYUr( znqSuQStqMr@L*VbP#|51Rwwb2tWV=5P$## zAaJ?`-1q+m)Se;R#eNGdNE@NbHClo-+!)y7gqIy*n=zn?oPZ;(se|f1;RP+8CiO(yXAiXW;-1z6R5fiB1_BU(00bZa0SG_<0uX=z1pdqf-0%N8&b)>{{{M_gfA(hvgyKQ~ z0uX=z1Rwwb2tWV=5P$##UJQZdcz0&v>)vyK&;NTdDjT_k00bZa0SG_<0uX=z1Rwwb z2z(s@eE#3p(ZjY7fB*y_009U<00Izz00bZaffqwyCjDC`&GO95FVeqC|8nZ5Gv>sF zse6eZrhYR%9shOwKVttB+hqCJTVwylJf#1QdOT6n$xJTC{ygW9deiKfjgD#GuQ}}p zX3OqhO8q7gKWb;s-TGsfu4SKHS}&I>Wv)_MTrG3`C2!f>!X(%CunewJ{-DC$U0=Cf zTK|xHxBTIiNQ52R?9{?oYj-N#T6J}mTQA=%ub0=B${QTnG+U0nU>UjSh8}zEKC$on zDfG6juC2UREf0&_AX`@R(Lid$qVAd<+iJJCm9!Yvpyl4EU|3)xnCFPl3+U_RWo z=>CWAB$}Cyxl5br!SpC~G@51CHk<7_Y1W$U&GwEnQ2t)l^1WKjX=f6d+gX;eT88=1 z{(jSP%o^Eo+TL=_XzM#$v`XF~&9aXz?~sbDj4YgqG`YyC11+w% z?yyD`M)KMfcV~?jYqd;U^ipYKskD5s@sQ1dHq+li-GaI8^Qisw9RsvxLl(=p!qGj@ z6mt8d+37^)`s^rDPf~RMpUF-g8BuT}PJgmBnas@2vOl@uNxNb0?^MLq4LQIWD)wbQ z+GJW}%k+;x|G?|D;%GFz-E4Ob9I0V5&t6mtiqa6W(M@P$G|Z++50gWhVRVxq{(&Q6 zpqYo!>W_jP4aE40$LD?xK53?snfZD4@3N7WZHM)4PekN*z6AHknw zq;&>?Pun*4PYuJH=Cf{W+oa>Z(1fR~BOBY$vpw3{@UF*xTjP_7OlfXZTMODsO$eT( z9)B;M%$z^Z{%p!?XW`i+{5u|LX<_KIZfb^U>m6&`ar<;`hctTld0n_(y9pa>|7-lH zUMS8y*Txf>rSqd&Ls*|RA$*p2GR`J5=gzS|GrjiU%liM1McRTN_Ova*?UV1HcnM#`gNbSudAaZt0!!(!1P(7jtEeTL{9ha9Lt(S*?M-I~vaA zSR(VCbE6uLU->oB{~NzNc8th^A2~+H)+u*S={=?2*a+uZcVwq|GR z3b(!4YG3KEhW^#L0~NA+A(;_`5fwrl63q>(`bk{ z6Gj{9jqFa#(%Xiq2y{|zt4&WGhDB(%$FX#q6SBQiLauxETgl8yj_qFe=4fbj)uOX_ zHE(9ipD&B9nGedovbJ3QfIECkcS0n(^+GV+N2fg8S<_l?MlR>fB{SD)j$Rs-BX_1R zy0Z&MWyPBu9L&T1%wg`MZ_;Bg%XVYERD1Q~&V{&hR*``714%|!yYbyZM*De{EZyrXr=SC(KM(t0@Uhp?} z3)RTbLd_2Yw8Ph}UMk-_B$ePbDVj_W=%({muO%|V9J|*mOE}E%$9%!kxq%d-(cIk! zAq&01V(xEd6Pdys`-8!K`;#HT7T*}Xb3a~W@4b1H++V*+vwprW>*4Ef&3iQrmofvp z57+qd+`We(3*MHli}uQ4bFUxmIW_P6*0gM=CPWrr?dmW4bFbsx*)Fm&vhYffGk>~2 zvg$zR+&{9y@pu2rUGLoEc1MO!Mjtub5k2c2I{Dp8E_eN+QCFl70)MMW1Seu(>r+0vReKmT`BbeH@!*;X2a?8Dj z(hCpbaiyCT8eF}*>|RXjJ8jy%>T_EQ?JTZwEpzW`q@yZurL|>_&MroS_b*E&!@UK`cmu(2g&x9hos$GcX$T&qYcZ&58h19=5nd7 z+?R>q^)$#rurwqC!J5O94;00bZa0SG_<0uX=z1R(Hi1>Dd7 zKhAd+;{E^6)*X9;00bZa0SG_<0uX=z1Rwwb2t1Dh82>+ywSqiB00Izz00bZa0SG_< z0uX=z1fHz`?*D(bp4b}%AOHafKmY;|fB*y_009U<;3Ngy`~OK^l{HnAc)e&6ULvN< z6E$Dp~4wM9&+F zRA{KONn}xwh*UR-M(A7p4Mo$73K4ap%7_X92tWV= z5P$##AOHafKmY;|fWZH^z-qi&oe&kXF4OTmuN73Cezu>W@hQ zD?(i(h9sMUDzhmXHpRw^l0?|lI2)T} ztozIQ*Z+UVqN(hJTg+o_|%JbK2aJe|rG8Om7Uh7T#XKySER}+y70OC-fczzW=U;YsX2?AOC;J zq`&;i1;mCBfB*y_009U<00Izz00bZa0SG*w0!jLgfLGl&1o-#=f5D``cs?_VTtNT= i5P$##AOHafKmY;|fB*y_@YMvSVsq@S*|Dv5EB=3l#J#}) literal 0 HcmV?d00001 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0e50332 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +python_functions = test_* +python_classes = Test* +filterwarnings = + ignore::DeprecationWarning diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..806d546 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +fastapi==0.115.0 +uvicorn[standard]==0.30.0 +pydantic==2.9.0 +pydantic-settings==2.5.0 +sqlalchemy==2.0.35 +alembic==1.13.3 +python-jose[cryptography]==3.3.0 +bcrypt==4.2.0 +passlib==1.7.4 +python-multipart==0.0.12 +aiosqlite==0.20.0 +httpx==0.27.2 +pytest==8.3.0 +pytest-asyncio==0.24.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d4839a6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests package diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1c5d38b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,54 @@ +import asyncio +import os +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession + +# Set test database before importing app +os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" + +from app.main import app +from app.database import Base, get_db, async_engine + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest_asyncio.fixture(scope="function") +async def db_session(): + """Create a fresh in-memory database for each test.""" + # Create tables + async with async_engine.begin() as conn: + # Import and run schema creation + from app.database import _create_schema + await conn.run_sync(_create_schema) + + async_session = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session() as session: + yield session + await session.rollback() + + # Drop all tables after test + async with async_engine.begin() as conn: + for table in reversed(Base.metadata.sorted_tables): + await conn.execute(table.delete()) + + +@pytest_asyncio.fixture(scope="function") +async def client(db_session): + """Async HTTP client for testing.""" + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + + app.dependency_overrides.clear() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..5adc520 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,70 @@ +import pytest + + +@pytest.mark.asyncio +async def test_register(client): + response = await client.post( + "/api/v1/auth/register", + json={"username": "testuser", "password": "testpass123"} + ) + assert response.status_code == 201 + data = response.json() + assert data["username"] == "testuser" + assert data["role"] == "agent" + assert "id" in data + + +@pytest.mark.asyncio +async def test_register_duplicate(client): + await client.post("/api/v1/auth/register", json={"username": "dup", "password": "pass123"}) + response = await client.post( + "/api/v1/auth/register", + json={"username": "dup", "password": "pass123"} + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_login(client): + await client.post("/api/v1/auth/register", json={"username": "loginuser", "password": "pass123"}) + response = await client.post( + "/api/v1/auth/login", + json={"username": "loginuser", "password": "pass123"} + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + + +@pytest.mark.asyncio +async def test_login_invalid_password(client): + await client.post("/api/v1/auth/register", json={"username": "user1", "password": "pass123"}) + response = await client.post( + "/api/v1/auth/login", + json={"username": "user1", "password": "wrongpass"} + ) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_me(client): + await client.post("/api/v1/auth/register", json={"username": "meuser", "password": "pass123"}) + login_resp = await client.post( + "/api/v1/auth/login", + json={"username": "meuser", "password": "pass123"} + ) + token = login_resp.json()["access_token"] + + response = await client.get( + "/api/v1/auth/me", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["username"] == "meuser" + + +@pytest.mark.asyncio +async def test_me_unauthorized(client): + response = await client.get("/api/v1/auth/me") + assert response.status_code == 401 diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 0000000..54210b0 --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,136 @@ +import pytest + + +async def setup_project_and_get_token(client): + await client.post("/api/v1/auth/register", json={"username": "docuser", "password": "pass123"}) + login = await client.post("/api/v1/auth/login", json={"username": "docuser", "password": "pass123"}) + token = login.json()["access_token"] + proj_resp = await client.post( + "/api/v1/projects", + json={"name": "Doc Test Project"}, + headers={"Authorization": f"Bearer {token}"} + ) + return token, proj_resp.json()["id"] + + +@pytest.mark.asyncio +async def test_create_document(client): + token, proj_id = await setup_project_and_get_token(client) + response = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "My Document", "content": "# Hello\n\nWorld"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 201 + data = response.json() + assert data["title"] == "My Document" + assert data["content"] == "# Hello\n\nWorld" + assert data["project_id"] == proj_id + + +@pytest.mark.asyncio +async def test_get_document(client): + token, proj_id = await setup_project_and_get_token(client) + create_resp = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Get Doc Test"}, + headers={"Authorization": f"Bearer {token}"} + ) + doc_id = create_resp.json()["id"] + response = await client.get( + f"/api/v1/documents/{doc_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["title"] == "Get Doc Test" + + +@pytest.mark.asyncio +async def test_update_document_content(client): + token, proj_id = await setup_project_and_get_token(client) + create_resp = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Original Title", "content": "Original Content"}, + headers={"Authorization": f"Bearer {token}"} + ) + doc_id = create_resp.json()["id"] + response = await client.put( + f"/api/v1/documents/{doc_id}/content", + json={"content": "Updated Content"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["content"] == "Updated Content" + + +@pytest.mark.asyncio +async def test_soft_delete_document(client): + token, proj_id = await setup_project_and_get_token(client) + create_resp = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "To Delete"}, + headers={"Authorization": f"Bearer {token}"} + ) + doc_id = create_resp.json()["id"] + del_resp = await client.delete( + f"/api/v1/documents/{doc_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert del_resp.status_code == 204 + + +@pytest.mark.asyncio +async def test_assign_tag(client): + token, proj_id = await setup_project_and_get_token(client) + doc_resp = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Tagged Doc"}, + headers={"Authorization": f"Bearer {token}"} + ) + doc_id = doc_resp.json()["id"] + tag_resp = await client.post( + "/api/v1/tags", + json={"name": "important", "color": "#ff0000"}, + headers={"Authorization": f"Bearer {token}"} + ) + tag_id = tag_resp.json()["id"] + assign_resp = await client.post( + f"/api/v1/documents/{doc_id}/tags", + json={"tag_ids": [tag_id]}, + headers={"Authorization": f"Bearer {token}"} + ) + assert assign_resp.status_code == 204 + + get_resp = await client.get( + f"/api/v1/documents/{doc_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert len(get_resp.json()["tags"]) == 1 + assert get_resp.json()["tags"][0]["name"] == "important" + + +@pytest.mark.asyncio +async def test_remove_tag(client): + token, proj_id = await setup_project_and_get_token(client) + doc_resp = await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Tagged Doc 2"}, + headers={"Authorization": f"Bearer {token}"} + ) + doc_id = doc_resp.json()["id"] + tag_resp = await client.post( + "/api/v1/tags", + json={"name": "temp", "color": "#00ff00"}, + headers={"Authorization": f"Bearer {token}"} + ) + tag_id = tag_resp.json()["id"] + await client.post( + f"/api/v1/documents/{doc_id}/tags", + json={"tag_ids": [tag_id]}, + headers={"Authorization": f"Bearer {token}"} + ) + remove_resp = await client.delete( + f"/api/v1/documents/{doc_id}/tags/{tag_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert remove_resp.status_code == 204 diff --git a/tests/test_folders.py b/tests/test_folders.py new file mode 100644 index 0000000..97094d4 --- /dev/null +++ b/tests/test_folders.py @@ -0,0 +1,90 @@ +import pytest + + +async def setup_project(client, token): + resp = await client.post( + "/api/v1/projects", + json={"name": "Test Project"}, + headers={"Authorization": f"Bearer {token}"} + ) + return resp.json()["id"] + + +async def get_token(client): + await client.post("/api/v1/auth/register", json={"username": "folderuser", "password": "pass123"}) + login = await client.post("/api/v1/auth/login", json={"username": "folderuser", "password": "pass123"}) + return login.json()["access_token"] + + +@pytest.mark.asyncio +async def test_create_folder(client): + token = await get_token(client) + proj_id = await setup_project(client, token) + response = await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "Architecture"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "Architecture" + assert data["project_id"] == proj_id + assert data["parent_id"] is None + + +@pytest.mark.asyncio +async def test_create_subfolder(client): + token = await get_token(client) + proj_id = await setup_project(client, token) + parent_resp = await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "Parent"}, + headers={"Authorization": f"Bearer {token}"} + ) + parent_id = parent_resp.json()["id"] + child_resp = await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "Child", "parent_id": parent_id}, + headers={"Authorization": f"Bearer {token}"} + ) + assert child_resp.status_code == 201 + assert child_resp.json()["parent_id"] == parent_id + + +@pytest.mark.asyncio +async def test_list_folders(client): + token = await get_token(client) + proj_id = await setup_project(client, token) + await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "Folder 1"}, + headers={"Authorization": f"Bearer {token}"} + ) + await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "Folder 2"}, + headers={"Authorization": f"Bearer {token}"} + ) + response = await client.get( + f"/api/v1/projects/{proj_id}/folders", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert len(response.json()["folders"]) == 2 + + +@pytest.mark.asyncio +async def test_soft_delete_folder(client): + token = await get_token(client) + proj_id = await setup_project(client, token) + folder_resp = await client.post( + f"/api/v1/projects/{proj_id}/folders", + json={"name": "To Delete"}, + headers={"Authorization": f"Bearer {token}"} + ) + folder_id = folder_resp.json()["id"] + del_resp = await client.delete( + f"/api/v1/folders/{folder_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert del_resp.status_code == 204 diff --git a/tests/test_projects.py b/tests/test_projects.py new file mode 100644 index 0000000..7817daf --- /dev/null +++ b/tests/test_projects.py @@ -0,0 +1,98 @@ +import pytest + + +async def get_token(client, username="projuser", password="pass123"): + await client.post("/api/v1/auth/register", json={"username": username, "password": password}) + login = await client.post("/api/v1/auth/login", json={"username": username, "password": password}) + return login.json()["access_token"] + + +@pytest.mark.asyncio +async def test_create_project(client): + token = await get_token(client) + response = await client.post( + "/api/v1/projects", + json={"name": "My Project", "description": "Description"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "My Project" + assert data["description"] == "Description" + + +@pytest.mark.asyncio +async def test_list_projects(client): + token = await get_token(client) + await client.post("/api/v1/projects", json={"name": "Project 1"}, headers={"Authorization": f"Bearer {token}"}) + await client.post("/api/v1/projects", json={"name": "Project 2"}, headers={"Authorization": f"Bearer {token}"}) + response = await client.get("/api/v1/projects", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 200 + assert len(response.json()["projects"]) == 2 + + +@pytest.mark.asyncio +async def test_get_project(client): + token = await get_token(client) + create_resp = await client.post( + "/api/v1/projects", json={"name": "Get Test"}, headers={"Authorization": f"Bearer {token}"} + ) + proj_id = create_resp.json()["id"] + response = await client.get( + f"/api/v1/projects/{proj_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["name"] == "Get Test" + + +@pytest.mark.asyncio +async def test_update_project(client): + token = await get_token(client) + create_resp = await client.post( + "/api/v1/projects", json={"name": "Original"}, headers={"Authorization": f"Bearer {token}"} + ) + proj_id = create_resp.json()["id"] + response = await client.put( + f"/api/v1/projects/{proj_id}", + json={"name": "Updated"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json()["name"] == "Updated" + + +@pytest.mark.asyncio +async def test_soft_delete_project(client): + token = await get_token(client) + create_resp = await client.post( + "/api/v1/projects", json={"name": "To Delete"}, headers={"Authorization": f"Bearer {token}"} + ) + proj_id = create_resp.json()["id"] + + del_resp = await client.delete( + f"/api/v1/projects/{proj_id}", + headers={"Authorization": f"Bearer {token}"} + ) + assert del_resp.status_code == 204 + + # Should not appear in list + list_resp = await client.get("/api/v1/projects", headers={"Authorization": f"Bearer {token}"}) + assert len(list_resp.json()["projects"]) == 0 + + +@pytest.mark.asyncio +async def test_restore_project(client): + token = await get_token(client) + create_resp = await client.post( + "/api/v1/projects", json={"name": "To Restore"}, headers={"Authorization": f"Bearer {token}"} + ) + proj_id = create_resp.json()["id"] + await client.delete(f"/api/v1/projects/{proj_id}", headers={"Authorization": f"Bearer {token}"}) + + restore_resp = await client.post( + f"/api/v1/projects/{proj_id}/restore", + headers={"Authorization": f"Bearer {token}"} + ) + assert restore_resp.status_code == 200 + assert restore_resp.json()["name"] == "To Restore" diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..d96eae8 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,69 @@ +import pytest + + +async def setup_project_and_get_token(client): + await client.post("/api/v1/auth/register", json={"username": "searchuser", "password": "pass123"}) + login = await client.post("/api/v1/auth/login", json={"username": "searchuser", "password": "pass123"}) + token = login.json()["access_token"] + proj_resp = await client.post( + "/api/v1/projects", + json={"name": "Search Test Project"}, + headers={"Authorization": f"Bearer {token}"} + ) + return token, proj_resp.json()["id"] + + +@pytest.mark.asyncio +async def test_search_basic(client): + token, proj_id = await setup_project_and_get_token(client) + await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Python Tutorial", "content": "Learn Python programming language"}, + headers={"Authorization": f"Bearer {token}"} + ) + await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Rust Tutorial", "content": "Learn Rust programming language"}, + headers={"Authorization": f"Bearer {token}"} + ) + response = await client.get( + "/api/v1/search?q=Python", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["title"] == "Python Tutorial" + + +@pytest.mark.asyncio +async def test_search_returns_excerpt_with_highlight(client): + token, proj_id = await setup_project_and_get_token(client) + await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "FastAPI Guide", "content": "Building APIs with FastAPI framework"}, + headers={"Authorization": f"Bearer {token}"} + ) + response = await client.get( + "/api/v1/search?q=FastAPI", + headers={"Authorization": f"Bearer {token}"} + ) + results = response.json()["results"] + assert len(results) == 1 + assert "**FastAPI**" in results[0]["excerpt"] + + +@pytest.mark.asyncio +async def test_search_no_results(client): + token, proj_id = await setup_project_and_get_token(client) + await client.post( + f"/api/v1/projects/{proj_id}/documents", + json={"title": "Random Doc", "content": "Some random content"}, + headers={"Authorization": f"Bearer {token}"} + ) + response = await client.get( + "/api/v1/search?q=nonexistent", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert len(response.json()["results"]) == 0