diff --git a/app/routers/documents.py b/app/routers/documents.py index 4f4631d..d0d438c 100644 --- a/app/routers/documents.py +++ b/app/routers/documents.py @@ -230,21 +230,20 @@ async def list_documents( ): agent, api_role = await get_current_agent_or_api_token(request, db) - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=404, detail="Project not found") - else: - # API tokens don't have project-level access control here - # Access is controlled at document level via agent_type - pass + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") result = await db.execute( select(Document).where( @@ -283,17 +282,20 @@ async def create_document( ): agent, api_role = await get_current_agent_or_api_token(request, db) - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=404, detail="Project not found") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") # Determine agent_type for the document doc_agent_type = payload.agent_type or "general" @@ -353,21 +355,24 @@ async def get_document( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=404, detail="Document not found") - else: - # API tokens check role-based access - if not _can_access_document(api_role, doc.agent_type, require_write=False): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=404, detail="Document not found") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=False): + raise HTTPException(status_code=403, detail="Forbidden") return await document_to_response(db, doc) @@ -391,21 +396,24 @@ async def update_document( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") if payload.title is not None: doc.title = payload.title @@ -449,21 +457,24 @@ async def delete_document( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") doc.is_deleted = True doc.deleted_at = datetime.utcnow() @@ -496,21 +507,24 @@ async def update_document_content( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") # Determine actual format based on content type (backward compatibility) # If content is a string, treat as markdown regardless of format field @@ -566,21 +580,24 @@ async def restore_document( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") doc.is_deleted = False doc.deleted_at = None @@ -608,21 +625,24 @@ async def assign_tags( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") for tag_id in payload.tag_ids: tag_result = await db.execute( @@ -668,21 +688,24 @@ async def remove_tag( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based write access - if not _can_access_document(api_role, doc.agent_type, require_write=True): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based write access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=True): + raise HTTPException(status_code=403, detail="Forbidden") await db.execute( delete(DocumentTag).where( @@ -717,21 +740,24 @@ async def _get_doc_with_access( if not doc: raise HTTPException(status_code=404, detail="Document not found") - # JWT tokens check project ownership - if api_role is None: - proj_result = await db.execute( - select(Project).where( - Project.id == doc.project_id, - Project.agent_id == agent.id, - Project.is_deleted == False, - ) + # Check project ownership + proj_result = await db.execute( + select(Project).where( + Project.id == doc.project_id, + Project.is_deleted == False, ) - if not proj_result.scalar_one_or_none(): - raise HTTPException(status_code=403, detail="Forbidden") - else: - # API tokens check role-based access - if not _can_access_document(api_role, doc.agent_type, require_write=require_write): - raise HTTPException(status_code=403, detail="Forbidden") + ) + project = proj_result.scalar_one_or_none() + if not project: + raise HTTPException(status_code=403, detail="Forbidden") + + # For API tokens, verify project belongs to the token owner + if api_role is not None and project.agent_id != agent.id: + raise HTTPException(status_code=403, detail="Forbidden") + + # Check role-based access for API tokens + if api_role is not None and not _can_access_document(api_role, doc.agent_type, require_write=require_write): + raise HTTPException(status_code=403, detail="Forbidden") return doc, api_role diff --git a/tests/test_api_tokens.py b/tests/test_api_tokens.py index 622d57f..662f881 100644 --- a/tests/test_api_tokens.py +++ b/tests/test_api_tokens.py @@ -257,3 +257,88 @@ async def test_api_token_auth_flow(client, admin_user): headers={"Authorization": f"Bearer {researcher_token}"} ) assert get_resp.status_code == 403 + + +@pytest.mark.asyncio +async def test_api_token_cannot_access_other_user_project(client, admin_user): + """Test that API token can only access projects belonging to the token owner.""" + # Create another user + import uuid + import bcrypt + from sqlalchemy import text + + password_hash = bcrypt.hashpw("user2pass".encode(), bcrypt.gensalt()).decode() + + async with async_engine.begin() as conn: + user2_id = str(uuid.uuid4()) + await conn.execute( + text(""" + INSERT INTO agents (id, username, password_hash, role, is_deleted, created_at, updated_at) + VALUES (:id, :username, :password_hash, 'agent', 0, datetime('now'), datetime('now')) + """), + { + "id": user2_id, + "username": "user2", + "password_hash": password_hash + } + ) + + # Login as user2 + login_resp = await client.post( + "/api/v1/auth/login", + json={"username": "user2", "password": "user2pass"} + ) + user2_token = login_resp.json()["access_token"] + + # Create project by user2 + proj_resp = await client.post( + "/api/v1/projects", + json={"name": "User2 Project"}, + headers={"Authorization": f"Bearer {user2_token}"} + ) + user2_proj_id = proj_resp.json()["id"] + + # Create document in user2's project + doc_resp = await client.post( + f"/api/v1/projects/{user2_proj_id}/documents", + json={"title": "User2 Doc", "content": "Content", "agent_type": "general"}, + headers={"Authorization": f"Bearer {user2_token}"} + ) + user2_doc_id = doc_resp.json()["id"] + + # Admin creates a researcher token + gen_resp = await client.post( + "/api/v1/auth/token/generate", + json={"name": "research-token", "role": "researcher"}, + headers={"Authorization": f"Bearer {admin_user}"} + ) + researcher_token = gen_resp.json()["token"] + + # Admin creates a project and document + admin_proj_resp = await client.post( + "/api/v1/projects", + json={"name": "Admin Project"}, + headers={"Authorization": f"Bearer {admin_user}"} + ) + admin_proj_id = admin_proj_resp.json()["id"] + + admin_doc_resp = await client.post( + f"/api/v1/projects/{admin_proj_id}/documents", + json={"title": "Admin Doc", "content": "Content", "agent_type": "research"}, + headers={"Authorization": f"Bearer {admin_user}"} + ) + admin_doc_id = admin_doc_resp.json()["id"] + + # Researcher token should NOT be able to access user2's project/document + get_resp = await client.get( + f"/api/v1/documents/{user2_doc_id}", + headers={"Authorization": f"Bearer {researcher_token}"} + ) + assert get_resp.status_code == 403 + + # Researcher token SHOULD be able to access admin's research document + get_resp = await client.get( + f"/api/v1/documents/{admin_doc_id}", + headers={"Authorization": f"Bearer {researcher_token}"} + ) + assert get_resp.status_code == 200