refactor: 后端代码重构,提取通用权限验证逻辑至common模块,减少代码冗余
This commit is contained in:
@@ -27,31 +27,12 @@ from app.schemas.career import (
|
||||
from app.services.ai_service import AIService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.api.common import verify_project_access
|
||||
|
||||
router = APIRouter(prefix="/careers", tags=["职业管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("", response_model=CareerListResponse, summary="获取职业列表")
|
||||
async def get_careers(
|
||||
project_id: str,
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
from asyncio import Queue, Lock
|
||||
|
||||
from app.database import get_db
|
||||
from app.api.common import verify_project_access
|
||||
from app.services.chapter_context_service import ChapterContextBuilder, FocusedMemoryRetriever
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
@@ -54,39 +55,6 @@ logger = get_logger(__name__)
|
||||
db_write_locks: dict[str, Lock] = {}
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
async def get_db_write_lock(user_id: str) -> Lock:
|
||||
"""获取或创建用户的数据库写入锁"""
|
||||
if user_id not in db_write_locks:
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
"""API 公共函数模块
|
||||
|
||||
包含跨 API 模块共享的通用函数和工具。
|
||||
"""
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional
|
||||
|
||||
from app.models.project import Project
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(
|
||||
project_id: str,
|
||||
user_id: Optional[str],
|
||||
db: AsyncSession
|
||||
) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
统一的项目访问验证函数,确保:
|
||||
1. 用户已登录
|
||||
2. 项目存在
|
||||
3. 用户有权访问该项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID(从 request.state.user_id 获取)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 验证通过后返回项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 401: 用户未登录
|
||||
- 404: 项目不存在或用户无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def get_user_id(request: Request) -> Optional[str]:
|
||||
"""
|
||||
从请求中获取用户ID
|
||||
|
||||
这是一个便捷函数,用于从 request.state 中提取 user_id。
|
||||
|
||||
Args:
|
||||
request: FastAPI 请求对象
|
||||
|
||||
Returns:
|
||||
用户ID,如果未登录则返回 None
|
||||
"""
|
||||
return getattr(request.state, 'user_id', None)
|
||||
|
||||
|
||||
async def verify_project_access_from_request(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession
|
||||
) -> Project:
|
||||
"""
|
||||
从请求中验证项目访问权限(便捷函数)
|
||||
|
||||
结合 get_user_id 和 verify_project_access,简化调用。
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
request: FastAPI 请求对象
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 验证通过后返回项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401/404
|
||||
|
||||
Usage:
|
||||
project = await verify_project_access_from_request(project_id, request, db)
|
||||
"""
|
||||
user_id = get_user_id(request)
|
||||
return await verify_project_access(project_id, user_id, db)
|
||||
@@ -12,32 +12,13 @@ from app.services.plot_analyzer import get_plot_analyzer
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.models.settings import Settings
|
||||
from app.logger import get_logger
|
||||
from app.api.common import verify_project_access
|
||||
import uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||
async def analyze_chapter(
|
||||
project_id: str,
|
||||
|
||||
@@ -27,31 +27,12 @@ from app.services.ai_service import AIService
|
||||
from app.services.prompt_service import prompt_service, PromptService
|
||||
from app.logger import get_logger
|
||||
from app.api.settings import get_user_ai_service
|
||||
from app.api.common import verify_project_access
|
||||
|
||||
router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
class OrganizationGenerateRequest(BaseModel):
|
||||
"""AI生成组织的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, AsyncGenerator, Dict, Any
|
||||
import json
|
||||
|
||||
from app.database import get_db
|
||||
from app.api.common import verify_project_access
|
||||
from app.models.outline import Outline
|
||||
from app.models.project import Project
|
||||
from app.models.chapter import Chapter
|
||||
@@ -42,39 +43,6 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
def _build_chapters_brief(outlines: List[Outline], max_recent: int = 20) -> str:
|
||||
"""构建章节概览字符串"""
|
||||
target = outlines[-max_recent:] if len(outlines) > max_recent else outlines
|
||||
|
||||
@@ -23,31 +23,12 @@ from app.schemas.relationship import (
|
||||
RelationshipGraphLink
|
||||
)
|
||||
from app.logger import get_logger
|
||||
from app.api.common import verify_project_access
|
||||
|
||||
router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有预定义的关系类型"""
|
||||
|
||||
Reference in New Issue
Block a user