update:1.切换数据库PostgreSQL
This commit is contained in:
+125
-17
@@ -43,6 +43,39 @@ logger = get_logger(__name__)
|
||||
db_write_locks: dict[str, Lock] = {}
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
async def get_db_write_lock(user_id: str) -> Lock:
|
||||
"""获取或创建用户的数据库写入锁"""
|
||||
if user_id not in db_write_locks:
|
||||
@@ -54,16 +87,13 @@ async def get_db_write_lock(user_id: str) -> Lock:
|
||||
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
||||
async def create_chapter(
|
||||
chapter: ChapterCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
project = await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 计算字数
|
||||
word_count = len(chapter.content)
|
||||
@@ -85,9 +115,14 @@ async def create_chapter(
|
||||
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
||||
async def get_project_chapters(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有章节(路径参数版本)"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
||||
@@ -108,6 +143,7 @@ async def get_project_chapters(
|
||||
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
||||
async def get_chapter(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取章节详情"""
|
||||
@@ -119,12 +155,17 @@ async def get_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
return chapter
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
|
||||
async def get_chapter_navigation(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -140,6 +181,10 @@ async def get_chapter_navigation(
|
||||
if not current_chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(current_chapter.project_id, user_id, db)
|
||||
|
||||
# 获取上一章
|
||||
prev_result = await db.execute(
|
||||
select(Chapter)
|
||||
@@ -183,6 +228,7 @@ async def get_chapter_navigation(
|
||||
async def update_chapter(
|
||||
chapter_id: str,
|
||||
chapter_update: ChapterUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新章节信息"""
|
||||
@@ -194,6 +240,10 @@ async def update_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 记录旧字数
|
||||
old_word_count = chapter.word_count or 0
|
||||
|
||||
@@ -223,6 +273,7 @@ async def update_chapter(
|
||||
@router.delete("/{chapter_id}", summary="删除章节")
|
||||
async def delete_chapter(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节"""
|
||||
@@ -234,6 +285,10 @@ async def delete_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
@@ -481,6 +536,7 @@ async def build_smart_chapter_context(
|
||||
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
||||
async def check_can_generate(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -495,6 +551,10 @@ async def check_can_generate(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
|
||||
@@ -1238,6 +1298,7 @@ async def generate_chapter_content_stream(
|
||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||
async def get_analysis_task_status(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1248,16 +1309,32 @@ async def get_analysis_task_status(
|
||||
- 如果任务状态为pending且超过2分钟未启动,自动标记为failed
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID
|
||||
- status: pending/running/completed/failed
|
||||
- has_task: 是否存在分析任务
|
||||
- task_id: 任务ID(如果存在)
|
||||
- status: pending/running/completed/failed/none(如果不存在则为none)
|
||||
- progress: 0-100
|
||||
- error_message: 错误信息(如果失败)
|
||||
- auto_recovered: 是否被自动恢复
|
||||
- created_at: 创建时间
|
||||
- completed_at: 完成时间
|
||||
|
||||
注意:当章节不存在或无权访问时返回404,当没有分析任务时返回has_task=false
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
# 先获取章节以验证存在性和权限
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取该章节最新的分析任务
|
||||
result = await db.execute(
|
||||
select(AnalysisTask)
|
||||
@@ -1268,7 +1345,19 @@ async def get_analysis_task_status(
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="未找到分析任务")
|
||||
# 返回无任务状态,而不是抛出404错误
|
||||
return {
|
||||
"has_task": False,
|
||||
"chapter_id": chapter_id,
|
||||
"status": "none",
|
||||
"progress": 0,
|
||||
"error_message": None,
|
||||
"auto_recovered": False,
|
||||
"task_id": None,
|
||||
"created_at": None,
|
||||
"started_at": None,
|
||||
"completed_at": None
|
||||
}
|
||||
|
||||
auto_recovered = False
|
||||
current_time = datetime.now()
|
||||
@@ -1299,6 +1388,7 @@ async def get_analysis_task_status(
|
||||
logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}")
|
||||
|
||||
return {
|
||||
"has_task": True,
|
||||
"task_id": task.id,
|
||||
"chapter_id": task.chapter_id,
|
||||
"status": task.status,
|
||||
@@ -1314,6 +1404,7 @@ async def get_analysis_task_status(
|
||||
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
|
||||
async def get_chapter_analysis(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1325,6 +1416,16 @@ async def get_chapter_analysis(
|
||||
- memories: 提取的记忆列表
|
||||
- created_at: 分析时间
|
||||
"""
|
||||
# 先获取章节以验证权限
|
||||
chapter_result_check = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter_check = chapter_result_check.scalar_one_or_none()
|
||||
if chapter_check:
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter_check.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
@@ -1369,6 +1470,7 @@ async def get_chapter_analysis(
|
||||
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
|
||||
async def get_chapter_annotations(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1377,6 +1479,9 @@ async def get_chapter_annotations(
|
||||
返回格式化的标注列表,包含精确位置信息
|
||||
适用于章节内容的可视化标注展示
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 获取章节
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
@@ -1386,6 +1491,9 @@ async def get_chapter_annotations(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证项目访问权限
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
@@ -1623,13 +1731,8 @@ async def batch_generate_chapters_in_order(
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证项目存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证项目存在和用户权限
|
||||
project = await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取项目的所有章节,按序号排序
|
||||
result = await db.execute(
|
||||
@@ -1750,12 +1853,17 @@ async def get_batch_generation_status(
|
||||
@router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务")
|
||||
async def get_active_batch_generation(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目当前运行中的批量生成任务
|
||||
用于页面刷新后恢复任务状态
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(BatchGenerationTask)
|
||||
.where(BatchGenerationTask.project_id == project_id)
|
||||
|
||||
Reference in New Issue
Block a user