import hashlib import secrets import uuid from datetime import datetime, timedelta import bcrypt from jose import JWTError, jwt from sqlalchemy import select, text, update from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.models.agent import Agent from app.models.refresh_token import JwtBlocklist, RefreshToken def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") def verify_password(plain: str, hashed: str) -> bool: return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) def create_access_token(agent_id: str, role: str) -> tuple[str, str]: """Create JWT access token. Returns (token, jti).""" jti = str(uuid.uuid4()) now = datetime.utcnow() expire = now + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) payload = { "sub": agent_id, "role": role, "jti": jti, "iat": now, "exp": expire, } token = jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) return token, jti def create_refresh_token() -> str: """Create opaque refresh token.""" return secrets.token_urlsafe(64) def hash_token(token: str) -> str: """SHA-256 hash of a token.""" return hashlib.sha256(token.encode()).hexdigest() def decode_token(token: str) -> dict | None: """Decode and validate JWT. Returns payload or None.""" try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM], ) return payload except JWTError: return None async def get_agent_by_username(db: AsyncSession, username: str) -> Agent | None: result = await db.execute( select(Agent).where(Agent.username == username, Agent.is_deleted == False) ) return result.scalar_one_or_none() async def get_agent_by_id(db: AsyncSession, agent_id: str) -> Agent | None: result = await db.execute( select(Agent).where(Agent.id == agent_id, Agent.is_deleted == False) ) return result.scalar_one_or_none() async def create_agent(db: AsyncSession, username: str, password: str, role: str = "agent") -> Agent: agent = Agent( id=str(uuid.uuid4()), username=username, password_hash=hash_password(password), role=role, ) db.add(agent) await db.flush() return agent async def save_refresh_token( db: AsyncSession, user_id: str, token: str, user_agent: str | None, ip_address: str | None, ) -> RefreshToken: """Save a new refresh token with a new family.""" token_hash = hash_token(token) family_id = str(uuid.uuid4()) expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) rt = RefreshToken( id=str(uuid.uuid4()), user_id=user_id, token_hash=token_hash, token_family_id=family_id, token_version=1, user_agent=user_agent, ip_address=ip_address, expires_at=expires_at, ) db.add(rt) await db.flush() return rt async def rotate_refresh_token( db: AsyncSession, old_token: str, user_agent: str | None, ip_address: str | None, ) -> tuple[RefreshToken, str] | None: """ Rotate a refresh token: revoke old, create new. Returns (new_rt, new_token) or None if invalid. """ token_hash = hash_token(old_token) # Find existing token result = await db.execute( select(RefreshToken).where( RefreshToken.token_hash == token_hash, RefreshToken.revoked_at.is_(None), RefreshToken.is_global_logout == False, ) ) old_rt: RefreshToken | None = result.scalar_one_or_none() if not old_rt: return None # Check expiry if old_rt.expires_at < datetime.utcnow(): return None # Reuse detection: check if a higher version exists reuse_check = await db.execute( select(RefreshToken).where( RefreshToken.token_family_id == old_rt.token_family_id, RefreshToken.token_version > old_rt.token_version, RefreshToken.revoked_at.is_not(None), ) ) if reuse_check.scalar_one_or_none(): # Possible theft detected: revoke entire family await db.execute( update(RefreshToken) .where(RefreshToken.token_family_id == old_rt.token_family_id) .values(is_global_logout=True, revoked_at=datetime.utcnow()) ) return None # Revoke old token old_rt.revoked_at = datetime.utcnow() # Create new token in same family new_token = create_refresh_token() new_hash = hash_token(new_token) expires_at = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) new_rt = RefreshToken( id=str(uuid.uuid4()), user_id=old_rt.user_id, token_hash=new_hash, token_family_id=old_rt.token_family_id, token_version=old_rt.token_version + 1, user_agent=user_agent, ip_address=ip_address, expires_at=expires_at, ) db.add(new_rt) await db.flush() return new_rt, new_token async def revoke_refresh_token(db: AsyncSession, token: str) -> bool: """Revoke a single refresh token.""" token_hash = hash_token(token) result = await db.execute( select(RefreshToken).where( RefreshToken.token_hash == token_hash, RefreshToken.revoked_at.is_(None), ) ) rt: RefreshToken | None = result.scalar_one_or_none() if not rt: return False rt.revoked_at = datetime.utcnow() await db.flush() return True async def revoke_all_user_tokens(db: AsyncSession, user_id: str) -> bool: """Revoke all refresh tokens for a user (logout-all).""" await db.execute( update(RefreshToken) .where( RefreshToken.user_id == user_id, RefreshToken.revoked_at.is_(None), ) .values(is_global_logout=True, revoked_at=datetime.utcnow()) ) await db.flush() return True async def add_to_blocklist(db: AsyncSession, jti: str, expires_at: datetime) -> None: """Add a token JTI to the blocklist.""" entry = JwtBlocklist(token_id=jti, expires_at=expires_at) db.add(entry) await db.flush() async def is_token_blocklisted(db: AsyncSession, jti: str) -> bool: """Check if a token JTI is in the blocklist.""" result = await db.execute( select(JwtBlocklist).where(JwtBlocklist.token_id == jti) ) return result.scalar_one_or_none() is not None async def cleanup_expired_blocklist(db: AsyncSession) -> None: """Remove expired entries from blocklist.""" await db.execute( text("DELETE FROM jwt_blocklist WHERE expires_at < datetime('now')") ) await db.flush() # ============================================================================= # API Token Management (Role-based) # ============================================================================= def generate_api_token() -> str: """Generate a random API token.""" return secrets.token_urlsafe(48) # ~64 chars async def create_api_token( db: AsyncSession, name: str, role: str, agent_id: str, ) -> tuple[str, "ApiToken"]: """Create a new API token. Returns (raw_token, token_record).""" from app.models.api_token import ApiToken raw_token = generate_api_token() token_hash = hash_token(raw_token) token_record = ApiToken( id=str(uuid.uuid4()), name=name, token_hash=token_hash, role=role, agent_id=agent_id, ) db.add(token_record) await db.flush() return raw_token, token_record async def list_api_tokens(db: AsyncSession, agent_id: str) -> list["ApiToken"]: """List all API tokens for an agent (without the actual token).""" from app.models.api_token import ApiToken result = await db.execute( select(ApiToken).where(ApiToken.agent_id == agent_id).order_by(ApiToken.created_at.desc()) ) return list(result.scalars().all()) async def revoke_api_token(db: AsyncSession, token_id: str, agent_id: str) -> bool: """Revoke an API token. Returns True if found and deleted.""" from app.models.api_token import ApiToken result = await db.execute( select(ApiToken).where( ApiToken.id == token_id, ApiToken.agent_id == agent_id, ) ) token = result.scalar_one_or_none() if not token: return False await db.delete(token) await db.flush() return True async def verify_api_token(db: AsyncSession, raw_token: str) -> tuple[str, str] | None: """ Verify an API token. Returns (agent_id, role) if valid, None otherwise. Also updates last_used_at. """ from app.models.api_token import ApiToken token_hash = hash_token(raw_token) result = await db.execute( select(ApiToken).where(ApiToken.token_hash == token_hash) ) token: ApiToken | None = result.scalar_one_or_none() if not token: return None # Update last_used_at token.last_used_at = datetime.utcnow() await db.flush() return token.agent_id, token.role