Phase 1 MVP - Complete implementation

- Auth: register, login, JWT with refresh tokens, blocklist
- Projects/Folders/Documents CRUD with soft deletes
- Tags CRUD and assignment
- FTS5 search with highlights and tag filtering
- ADR-001, ADR-002, ADR-003 compliant
- Security fixes applied (JWT_SECRET_KEY, exception handler, cookie secure)
- 25 tests passing
This commit is contained in:
Motoko
2026-03-30 15:17:27 +00:00
parent 33f19e02f8
commit 7f3e8a8f53
41 changed files with 2858 additions and 0 deletions

1
app/services/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Services package

233
app/services/auth.py Normal file
View File

@@ -0,0 +1,233 @@
import hashlib
import secrets
import uuid
from datetime import datetime, timedelta
import bcrypt
from jose import JWTError, jwt
from sqlalchemy import select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.agent import Agent
from app.models.refresh_token import JwtBlocklist, RefreshToken
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool:
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
def create_access_token(agent_id: str, role: str) -> tuple[str, str]:
"""Create JWT access token. Returns (token, jti)."""
jti = str(uuid.uuid4())
now = datetime.utcnow()
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"sub": agent_id,
"role": role,
"jti": jti,
"iat": now,
"exp": expire,
}
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return token, jti
def create_refresh_token() -> str:
"""Create opaque refresh token."""
return secrets.token_urlsafe(64)
def hash_token(token: str) -> str:
"""SHA-256 hash of a token."""
return hashlib.sha256(token.encode()).hexdigest()
def decode_token(token: str) -> dict | None:
"""Decode and validate JWT. Returns payload or None."""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
return payload
except JWTError:
return None
async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.username == username, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None:
result = await db.execute(
select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False)
)
return result.scalar_one_or_none()
async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent:
agent = Agent(
id=str(uuid.uuid4()),
username=username,
password_hash=hash_password(password),
role=role,
)
db.add(agent)
await db.flush()
return agent
async def save_refresh_token(
db: AsyncSession,
user_id: str,
token: str,
user_agent: str | None,
ip_address: str | None,
) -> RefreshToken:
"""Save a new refresh token with a new family."""
token_hash = hash_token(token)
family_id = str(uuid.uuid4())
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=user_id,
token_hash=token_hash,
token_family_id=family_id,
token_version=1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(rt)
await db.flush()
return rt
async def rotate_refresh_token(
db: AsyncSession,
old_token: str,
user_agent: str | None,
ip_address: str | None,
) -> tuple[RefreshToken, str] | None:
"""
Rotate a refresh token: revoke old, create new.
Returns (new_rt, new_token) or None if invalid.
"""
token_hash = hash_token(old_token)
# Find existing token
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
RefreshToken.is_global_logout == False,
)
)
old_rt: RefreshToken | None = result.scalar_one_or_none()
if not old_rt:
return None
# Check expiry
if old_rt.expires_at < datetime.utcnow():
return None
# Reuse detection: check if a higher version exists
reuse_check = await db.execute(
select(RefreshToken).where(
RefreshToken.token_family_id == old_rt.token_family_id,
RefreshToken.token_version > old_rt.token_version,
RefreshToken.revoked_at.is_not(None),
)
)
if reuse_check.scalar_one_or_none():
# Possible theft detected: revoke entire family
await db.execute(
update(RefreshToken)
.where(RefreshToken.token_family_id == old_rt.token_family_id)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
return None
# Revoke old token
old_rt.revoked_at = datetime.utcnow()
# Create new token in same family
new_token = create_refresh_token()
new_hash = hash_token(new_token)
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
id=str(uuid.uuid4()),
user_id=old_rt.user_id,
token_hash=new_hash,
token_family_id=old_rt.token_family_id,
token_version=old_rt.token_version + 1,
user_agent=user_agent,
ip_address=ip_address,
expires_at=expires_at,
)
db.add(new_rt)
await db.flush()
return new_rt, new_token
async def revoke_refresh_token(db: AsyncSession, token: str) -> bool:
"""Revoke a single refresh token."""
token_hash = hash_token(token)
result = await db.execute(
select(RefreshToken).where(
RefreshToken.token_hash == token_hash,
RefreshToken.revoked_at.is_(None),
)
)
rt: RefreshToken | None = result.scalar_one_or_none()
if not rt:
return False
rt.revoked_at = datetime.utcnow()
await db.flush()
return True
async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool:
"""Revoke all refresh tokens for a user (logout-all)."""
await db.execute(
update(RefreshToken)
.where(
RefreshToken.user_id == user_id,
RefreshToken.revoked_at.is_(None),
)
.values(is_global_logout=True, revoked_at=datetime.utcnow())
)
await db.flush()
return True
async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None:
"""Add a token JTI to the blocklist."""
entry = JwtBlocklist(token_id=jti, expires_at=expires_at)
db.add(entry)
await db.flush()
async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool:
"""Check if a token JTI is in the blocklist."""
result = await db.execute(
select(JwtBlocklist).where(JwtBlocklist.token_id == jti)
)
return result.scalar_one_or_none() is not None
async def cleanup_expired_blocklist(db: AsyncSession) -> None:
"""Remove expired entries from blocklist."""
await db.execute(
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
)
await db.flush()

128
app/services/search.py Normal file
View File

@@ -0,0 +1,128 @@
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas.search import SearchResult, SearchResponse
from app.schemas.document import TagInfo
async def search_documents(
db: AsyncSession,
query: str,
agent_id: str | None = None,
project_id: str | None = None,
tags: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> SearchResponse:
"""
Full-text search using FTS5.
Returns snippets with highlight markup.
"""
if not query or len(query.strip()) == 0:
return SearchResponse(results=[])
# Escape FTS5 special characters and prepare query
safe_query = query.replace('"', '""')
# Build the FTS5 MATCH query
fts_query = f'"{safe_query}"'
# Get document IDs from FTS5
fts_sql = text("""
SELECT document_id, title, content, path,
bm25(documents_fts) as score
FROM documents_fts
WHERE documents_fts MATCH :q
ORDER BY score
LIMIT :limit OFFSET :offset
""")
fts_result = await db.execute(
fts_sql,
{"q": fts_query, "limit": limit, "offset": offset}
)
fts_rows = fts_result.fetchall()
if not fts_rows:
return SearchResponse(results=[])
results = []
for row in fts_rows:
doc_id = row.document_id
# Get document to verify access and get project_id
doc_sql = text("""
SELECT d.id, d.title, d.content, d.project_id, d.is_deleted,
p.agent_id
FROM active_documents d
JOIN active_projects p ON d.project_id = p.id
WHERE d.id = :doc_id AND p.agent_id = :agent_id
""")
doc_result = await db.execute(
doc_sql,
{"doc_id": doc_id, "agent_id": agent_id}
)
doc_row = doc_result.fetchone()
if not doc_row:
continue
# Filter by project_id if provided
if project_id and doc_row.project_id != project_id:
continue
# Get tags for this document
tags_sql = text("""
SELECT t.id, t.name, t.color
FROM active_tags t
JOIN document_tags dt ON t.id = dt.tag_id
WHERE dt.document_id = :doc_id
""")
tags_result = await db.execute(tags_sql, {"doc_id": doc_id})
tag_rows = tags_result.fetchall()
doc_tags = [TagInfo(id=t.id, name=t.name, color=t.color) for t in tag_rows]
# Filter by tags if provided
if tags:
tag_names = {t.name for t in doc_tags}
if not any(tn in tag_names for tn in tags):
continue
# Build excerpt with snippet
content = doc_row.content or ""
excerpt = _build_snippet(content, query)
results.append(SearchResult(
id=doc_row.id,
title=doc_row.title,
excerpt=excerpt,
project_id=doc_row.project_id,
tags=doc_tags,
score=abs(row.score) if row.score else 0.0,
))
return SearchResponse(results=results)
def _build_snippet(content: str, query: str, context_chars: int = 150) -> str:
"""Build a highlighted snippet from content."""
query_lower = query.lower()
content_lower = content.lower()
idx = content_lower.find(query_lower)
if idx == -1:
# No exact match, return beginning
snippet = content[:context_chars * 2]
else:
start = max(0, idx - context_chars)
end = min(len(content), idx + len(query) + context_chars)
snippet = content[start:end]
if start > 0:
snippet = "..." + snippet
if end < len(content):
snippet = snippet + "..."
# Simple highlight: wrap matches in **
import re
pattern = re.compile(re.escape(query), re.IGNORECASE)
snippet = pattern.sub(f"**{query}**", snippet)
return snippet