Phase 1 MVP - Complete implementation
- Auth: register, login, JWT with refresh tokens, blocklist - Projects/Folders/Documents CRUD with soft deletes - Tags CRUD and assignment - FTS5 search with highlights and tag filtering - ADR-001, ADR-002, ADR-003 compliant - Security fixes applied (JWT_SECRET_KEY, exception handler, cookie secure) - 25 tests passing
This commit is contained in:
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Services package
|
||||
233
app/services/auth.py
Normal file
233
app/services/auth.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.models.agent import Agent
|
||||
from app.models.refresh_token import JwtBlocklist, RefreshToken
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
|
||||
|
||||
|
||||
def create_access_token(agent_id: str, role: str) -> tuple[str, str]:
|
||||
"""Create JWT access token. Returns (token, jti)."""
|
||||
jti = str(uuid.uuid4())
|
||||
now = datetime.utcnow()
|
||||
expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
payload = {
|
||||
"sub": agent_id,
|
||||
"role": role,
|
||||
"jti": jti,
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, jti
|
||||
|
||||
|
||||
def create_refresh_token() -> str:
|
||||
"""Create opaque refresh token."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""SHA-256 hash of a token."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict | None:
|
||||
"""Decode and validate JWT. Returns payload or None."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None:
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.username == username, Agent.is_deleted == False)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None:
|
||||
result = await db.execute(
|
||||
select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent:
|
||||
agent = Agent(
|
||||
id=str(uuid.uuid4()),
|
||||
username=username,
|
||||
password_hash=hash_password(password),
|
||||
role=role,
|
||||
)
|
||||
db.add(agent)
|
||||
await db.flush()
|
||||
return agent
|
||||
|
||||
|
||||
async def save_refresh_token(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
token: str,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
) -> RefreshToken:
|
||||
"""Save a new refresh token with a new family."""
|
||||
token_hash = hash_token(token)
|
||||
family_id = str(uuid.uuid4())
|
||||
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
rt = RefreshToken(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
token_family_id=family_id,
|
||||
token_version=1,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.flush()
|
||||
return rt
|
||||
|
||||
|
||||
async def rotate_refresh_token(
|
||||
db: AsyncSession,
|
||||
old_token: str,
|
||||
user_agent: str | None,
|
||||
ip_address: str | None,
|
||||
) -> tuple[RefreshToken, str] | None:
|
||||
"""
|
||||
Rotate a refresh token: revoke old, create new.
|
||||
Returns (new_rt, new_token) or None if invalid.
|
||||
"""
|
||||
token_hash = hash_token(old_token)
|
||||
|
||||
# Find existing token
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
RefreshToken.is_global_logout == False,
|
||||
)
|
||||
)
|
||||
old_rt: RefreshToken | None = result.scalar_one_or_none()
|
||||
if not old_rt:
|
||||
return None
|
||||
|
||||
# Check expiry
|
||||
if old_rt.expires_at < datetime.utcnow():
|
||||
return None
|
||||
|
||||
# Reuse detection: check if a higher version exists
|
||||
reuse_check = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_family_id == old_rt.token_family_id,
|
||||
RefreshToken.token_version > old_rt.token_version,
|
||||
RefreshToken.revoked_at.is_not(None),
|
||||
)
|
||||
)
|
||||
if reuse_check.scalar_one_or_none():
|
||||
# Possible theft detected: revoke entire family
|
||||
await db.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.token_family_id == old_rt.token_family_id)
|
||||
.values(is_global_logout=True, revoked_at=datetime.utcnow())
|
||||
)
|
||||
return None
|
||||
|
||||
# Revoke old token
|
||||
old_rt.revoked_at = datetime.utcnow()
|
||||
|
||||
# Create new token in same family
|
||||
new_token = create_refresh_token()
|
||||
new_hash = hash_token(new_token)
|
||||
expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=old_rt.user_id,
|
||||
token_hash=new_hash,
|
||||
token_family_id=old_rt.token_family_id,
|
||||
token_version=old_rt.token_version + 1,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.flush()
|
||||
return new_rt, new_token
|
||||
|
||||
|
||||
async def revoke_refresh_token(db: AsyncSession, token: str) -> bool:
|
||||
"""Revoke a single refresh token."""
|
||||
token_hash = hash_token(token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.token_hash == token_hash,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
)
|
||||
rt: RefreshToken | None = result.scalar_one_or_none()
|
||||
if not rt:
|
||||
return False
|
||||
rt.revoked_at = datetime.utcnow()
|
||||
await db.flush()
|
||||
return True
|
||||
|
||||
|
||||
async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool:
|
||||
"""Revoke all refresh tokens for a user (logout-all)."""
|
||||
await db.execute(
|
||||
update(RefreshToken)
|
||||
.where(
|
||||
RefreshToken.user_id == user_id,
|
||||
RefreshToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(is_global_logout=True, revoked_at=datetime.utcnow())
|
||||
)
|
||||
await db.flush()
|
||||
return True
|
||||
|
||||
|
||||
async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None:
|
||||
"""Add a token JTI to the blocklist."""
|
||||
entry = JwtBlocklist(token_id=jti, expires_at=expires_at)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
|
||||
|
||||
async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool:
|
||||
"""Check if a token JTI is in the blocklist."""
|
||||
result = await db.execute(
|
||||
select(JwtBlocklist).where(JwtBlocklist.token_id == jti)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
|
||||
async def cleanup_expired_blocklist(db: AsyncSession) -> None:
|
||||
"""Remove expired entries from blocklist."""
|
||||
await db.execute(
|
||||
text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')")
|
||||
)
|
||||
await db.flush()
|
||||
128
app/services/search.py
Normal file
128
app/services/search.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas.search import SearchResult, SearchResponse
|
||||
from app.schemas.document import TagInfo
|
||||
|
||||
|
||||
async def search_documents(
|
||||
db: AsyncSession,
|
||||
query: str,
|
||||
agent_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Full-text search using FTS5.
|
||||
Returns snippets with highlight markup.
|
||||
"""
|
||||
if not query or len(query.strip()) == 0:
|
||||
return SearchResponse(results=[])
|
||||
|
||||
# Escape FTS5 special characters and prepare query
|
||||
safe_query = query.replace('"', '""')
|
||||
|
||||
# Build the FTS5 MATCH query
|
||||
fts_query = f'"{safe_query}"'
|
||||
|
||||
# Get document IDs from FTS5
|
||||
fts_sql = text("""
|
||||
SELECT document_id, title, content, path,
|
||||
bm25(documents_fts) as score
|
||||
FROM documents_fts
|
||||
WHERE documents_fts MATCH :q
|
||||
ORDER BY score
|
||||
LIMIT :limit OFFSET :offset
|
||||
""")
|
||||
|
||||
fts_result = await db.execute(
|
||||
fts_sql,
|
||||
{"q": fts_query, "limit": limit, "offset": offset}
|
||||
)
|
||||
fts_rows = fts_result.fetchall()
|
||||
|
||||
if not fts_rows:
|
||||
return SearchResponse(results=[])
|
||||
|
||||
results = []
|
||||
for row in fts_rows:
|
||||
doc_id = row.document_id
|
||||
|
||||
# Get document to verify access and get project_id
|
||||
doc_sql = text("""
|
||||
SELECT d.id, d.title, d.content, d.project_id, d.is_deleted,
|
||||
p.agent_id
|
||||
FROM active_documents d
|
||||
JOIN active_projects p ON d.project_id = p.id
|
||||
WHERE d.id = :doc_id AND p.agent_id = :agent_id
|
||||
""")
|
||||
doc_result = await db.execute(
|
||||
doc_sql,
|
||||
{"doc_id": doc_id, "agent_id": agent_id}
|
||||
)
|
||||
doc_row = doc_result.fetchone()
|
||||
if not doc_row:
|
||||
continue
|
||||
|
||||
# Filter by project_id if provided
|
||||
if project_id and doc_row.project_id != project_id:
|
||||
continue
|
||||
|
||||
# Get tags for this document
|
||||
tags_sql = text("""
|
||||
SELECT t.id, t.name, t.color
|
||||
FROM active_tags t
|
||||
JOIN document_tags dt ON t.id = dt.tag_id
|
||||
WHERE dt.document_id = :doc_id
|
||||
""")
|
||||
tags_result = await db.execute(tags_sql, {"doc_id": doc_id})
|
||||
tag_rows = tags_result.fetchall()
|
||||
doc_tags = [TagInfo(id=t.id, name=t.name, color=t.color) for t in tag_rows]
|
||||
|
||||
# Filter by tags if provided
|
||||
if tags:
|
||||
tag_names = {t.name for t in doc_tags}
|
||||
if not any(tn in tag_names for tn in tags):
|
||||
continue
|
||||
|
||||
# Build excerpt with snippet
|
||||
content = doc_row.content or ""
|
||||
excerpt = _build_snippet(content, query)
|
||||
|
||||
results.append(SearchResult(
|
||||
id=doc_row.id,
|
||||
title=doc_row.title,
|
||||
excerpt=excerpt,
|
||||
project_id=doc_row.project_id,
|
||||
tags=doc_tags,
|
||||
score=abs(row.score) if row.score else 0.0,
|
||||
))
|
||||
|
||||
return SearchResponse(results=results)
|
||||
|
||||
|
||||
def _build_snippet(content: str, query: str, context_chars: int = 150) -> str:
|
||||
"""Build a highlighted snippet from content."""
|
||||
query_lower = query.lower()
|
||||
content_lower = content.lower()
|
||||
|
||||
idx = content_lower.find(query_lower)
|
||||
if idx == -1:
|
||||
# No exact match, return beginning
|
||||
snippet = content[:context_chars * 2]
|
||||
else:
|
||||
start = max(0, idx - context_chars)
|
||||
end = min(len(content), idx + len(query) + context_chars)
|
||||
snippet = content[start:end]
|
||||
if start > 0:
|
||||
snippet = "..." + snippet
|
||||
if end < len(content):
|
||||
snippet = snippet + "..."
|
||||
|
||||
# Simple highlight: wrap matches in **
|
||||
import re
|
||||
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||
snippet = pattern.sub(f"**{query}**", snippet)
|
||||
return snippet
|
||||
Reference in New Issue
Block a user