refactor: 后端代码重构,提取通用权限验证逻辑至common模块,减少代码冗余

This commit is contained in:
xiamuceer-j
2026-01-13 16:45:58 +08:00
parent 6f33e12ead
commit 46debab624
14 changed files with 907 additions and 716 deletions
+1 -20
View File
@@ -27,31 +27,12 @@ from app.schemas.career import (
from app.services.ai_service import AIService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.api.common import verify_project_access
router = APIRouter(prefix="/careers", tags=["职业管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("", response_model=CareerListResponse, summary="获取职业列表")
async def get_careers(
project_id: str,
+1 -33
View File
@@ -10,6 +10,7 @@ from datetime import datetime
from asyncio import Queue, Lock
from app.database import get_db
from app.api.common import verify_project_access
from app.services.chapter_context_service import ChapterContextBuilder, FocusedMemoryRetriever
from app.models.chapter import Chapter
from app.models.project import Project
@@ -54,39 +55,6 @@ logger = get_logger(__name__)
db_write_locks: dict[str, Lock] = {}
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""
验证用户是否有权访问指定项目
Args:
project_id: 项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Project: 项目对象
Raises:
HTTPException: 401 未登录,404 项目不存在或无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
async def get_db_write_lock(user_id: str) -> Lock:
"""获取或创建用户的数据库写入锁"""
if user_id not in db_write_locks:
+100
View File
@@ -0,0 +1,100 @@
"""API 公共函数模块
包含跨 API 模块共享的通用函数和工具。
"""
from fastapi import HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
from app.models.project import Project
from app.logger import get_logger
logger = get_logger(__name__)
async def verify_project_access(
project_id: str,
user_id: Optional[str],
db: AsyncSession
) -> Project:
"""
验证用户是否有权访问指定项目
统一的项目访问验证函数,确保:
1. 用户已登录
2. 项目存在
3. 用户有权访问该项目
Args:
project_id: 项目ID
user_id: 用户ID(从 request.state.user_id 获取)
db: 数据库会话
Returns:
Project: 验证通过后返回项目对象
Raises:
HTTPException:
- 401: 用户未登录
- 404: 项目不存在或用户无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
def get_user_id(request: Request) -> Optional[str]:
"""
从请求中获取用户ID
这是一个便捷函数,用于从 request.state 中提取 user_id。
Args:
request: FastAPI 请求对象
Returns:
用户ID,如果未登录则返回 None
"""
return getattr(request.state, 'user_id', None)
async def verify_project_access_from_request(
project_id: str,
request: Request,
db: AsyncSession
) -> Project:
"""
从请求中验证项目访问权限(便捷函数)
结合 get_user_id 和 verify_project_access,简化调用。
Args:
project_id: 项目ID
request: FastAPI 请求对象
db: 数据库会话
Returns:
Project: 验证通过后返回项目对象
Raises:
HTTPException: 401/404
Usage:
project = await verify_project_access_from_request(project_id, request, db)
"""
user_id = get_user_id(request)
return await verify_project_access(project_id, user_id, db)
+1 -20
View File
@@ -12,32 +12,13 @@ from app.services.plot_analyzer import get_plot_analyzer
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings
from app.logger import get_logger
from app.api.common import verify_project_access
import uuid
logger = get_logger(__name__)
router = APIRouter(prefix="/api/memories", tags=["memories"])
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
async def analyze_chapter(
project_id: str,
+1 -20
View File
@@ -27,31 +27,12 @@ from app.services.ai_service import AIService
from app.services.prompt_service import prompt_service, PromptService
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.api.common import verify_project_access
router = APIRouter(prefix="/organizations", tags=["组织管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
class OrganizationGenerateRequest(BaseModel):
"""AI生成组织的请求模型"""
project_id: str = Field(..., description="项目ID")
+1 -33
View File
@@ -6,6 +6,7 @@ from typing import List, AsyncGenerator, Dict, Any
import json
from app.database import get_db
from app.api.common import verify_project_access
from app.models.outline import Outline
from app.models.project import Project
from app.models.chapter import Chapter
@@ -42,39 +43,6 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""
验证用户是否有权访问指定项目
Args:
project_id: 项目ID
user_id: 用户ID
db: 数据库会话
Returns:
Project: 项目对象
Raises:
HTTPException: 401 未登录,404 项目不存在或无权访问
"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
def _build_chapters_brief(outlines: List[Outline], max_recent: int = 20) -> str:
"""构建章节概览字符串"""
target = outlines[-max_recent:] if len(outlines) > max_recent else outlines
+1 -20
View File
@@ -23,31 +23,12 @@ from app.schemas.relationship import (
RelationshipGraphLink
)
from app.logger import get_logger
from app.api.common import verify_project_access
router = APIRouter(prefix="/relationships", tags=["关系管理"])
logger = get_logger(__name__)
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
"""验证用户是否有权访问指定项目"""
if not user_id:
raise HTTPException(status_code=401, detail="未登录")
result = await db.execute(
select(Project).where(
Project.id == project_id,
Project.user_id == user_id
)
)
project = result.scalar_one_or_none()
if not project:
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
return project
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
"""获取所有预定义的关系类型"""