update:1.切换数据库PostgreSQL
This commit is contained in:
+125
-17
@@ -43,6 +43,39 @@ logger = get_logger(__name__)
|
||||
db_write_locks: dict[str, Lock] = {}
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
async def get_db_write_lock(user_id: str) -> Lock:
|
||||
"""获取或创建用户的数据库写入锁"""
|
||||
if user_id not in db_write_locks:
|
||||
@@ -54,16 +87,13 @@ async def get_db_write_lock(user_id: str) -> Lock:
|
||||
@router.post("", response_model=ChapterResponse, summary="创建章节")
|
||||
async def create_chapter(
|
||||
chapter: ChapterCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
project = await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 计算字数
|
||||
word_count = len(chapter.content)
|
||||
@@ -85,9 +115,14 @@ async def create_chapter(
|
||||
@router.get("/project/{project_id}", response_model=ChapterListResponse, summary="获取项目的所有章节")
|
||||
async def get_project_chapters(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有章节(路径参数版本)"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Chapter.id)).where(Chapter.project_id == project_id)
|
||||
@@ -108,6 +143,7 @@ async def get_project_chapters(
|
||||
@router.get("/{chapter_id}", response_model=ChapterResponse, summary="获取章节详情")
|
||||
async def get_chapter(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取章节详情"""
|
||||
@@ -119,12 +155,17 @@ async def get_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
return chapter
|
||||
|
||||
|
||||
@router.get("/{chapter_id}/navigation", summary="获取章节导航信息")
|
||||
async def get_chapter_navigation(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -140,6 +181,10 @@ async def get_chapter_navigation(
|
||||
if not current_chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(current_chapter.project_id, user_id, db)
|
||||
|
||||
# 获取上一章
|
||||
prev_result = await db.execute(
|
||||
select(Chapter)
|
||||
@@ -183,6 +228,7 @@ async def get_chapter_navigation(
|
||||
async def update_chapter(
|
||||
chapter_id: str,
|
||||
chapter_update: ChapterUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新章节信息"""
|
||||
@@ -194,6 +240,10 @@ async def update_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 记录旧字数
|
||||
old_word_count = chapter.word_count or 0
|
||||
|
||||
@@ -223,6 +273,7 @@ async def update_chapter(
|
||||
@router.delete("/{chapter_id}", summary="删除章节")
|
||||
async def delete_chapter(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除章节"""
|
||||
@@ -234,6 +285,10 @@ async def delete_chapter(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 更新项目字数
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == chapter.project_id)
|
||||
@@ -481,6 +536,7 @@ async def build_smart_chapter_context(
|
||||
@router.get("/{chapter_id}/can-generate", summary="检查章节是否可以生成")
|
||||
async def check_can_generate(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -495,6 +551,10 @@ async def check_can_generate(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 检查前置条件
|
||||
can_generate, error_msg, previous_chapters = await check_prerequisites(db, chapter)
|
||||
|
||||
@@ -1238,6 +1298,7 @@ async def generate_chapter_content_stream(
|
||||
@router.get("/{chapter_id}/analysis/status", summary="查询章节分析任务状态")
|
||||
async def get_analysis_task_status(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1248,16 +1309,32 @@ async def get_analysis_task_status(
|
||||
- 如果任务状态为pending且超过2分钟未启动,自动标记为failed
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID
|
||||
- status: pending/running/completed/failed
|
||||
- has_task: 是否存在分析任务
|
||||
- task_id: 任务ID(如果存在)
|
||||
- status: pending/running/completed/failed/none(如果不存在则为none)
|
||||
- progress: 0-100
|
||||
- error_message: 错误信息(如果失败)
|
||||
- auto_recovered: 是否被自动恢复
|
||||
- created_at: 创建时间
|
||||
- completed_at: 完成时间
|
||||
|
||||
注意:当章节不存在或无权访问时返回404,当没有分析任务时返回has_task=false
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
# 先获取章节以验证存在性和权限
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter = chapter_result.scalar_one_or_none()
|
||||
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取该章节最新的分析任务
|
||||
result = await db.execute(
|
||||
select(AnalysisTask)
|
||||
@@ -1268,7 +1345,19 @@ async def get_analysis_task_status(
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="未找到分析任务")
|
||||
# 返回无任务状态,而不是抛出404错误
|
||||
return {
|
||||
"has_task": False,
|
||||
"chapter_id": chapter_id,
|
||||
"status": "none",
|
||||
"progress": 0,
|
||||
"error_message": None,
|
||||
"auto_recovered": False,
|
||||
"task_id": None,
|
||||
"created_at": None,
|
||||
"started_at": None,
|
||||
"completed_at": None
|
||||
}
|
||||
|
||||
auto_recovered = False
|
||||
current_time = datetime.now()
|
||||
@@ -1299,6 +1388,7 @@ async def get_analysis_task_status(
|
||||
logger.warning(f"🔄 自动恢复未启动的任务: {task.id}, 章节: {chapter_id}")
|
||||
|
||||
return {
|
||||
"has_task": True,
|
||||
"task_id": task.id,
|
||||
"chapter_id": task.chapter_id,
|
||||
"status": task.status,
|
||||
@@ -1314,6 +1404,7 @@ async def get_analysis_task_status(
|
||||
@router.get("/{chapter_id}/analysis", summary="获取章节分析结果")
|
||||
async def get_chapter_analysis(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1325,6 +1416,16 @@ async def get_chapter_analysis(
|
||||
- memories: 提取的记忆列表
|
||||
- created_at: 分析时间
|
||||
"""
|
||||
# 先获取章节以验证权限
|
||||
chapter_result_check = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
)
|
||||
chapter_check = chapter_result_check.scalar_one_or_none()
|
||||
if chapter_check:
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(chapter_check.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
@@ -1369,6 +1470,7 @@ async def get_chapter_analysis(
|
||||
@router.get("/{chapter_id}/annotations", summary="获取章节标注数据")
|
||||
async def get_chapter_annotations(
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -1377,6 +1479,9 @@ async def get_chapter_annotations(
|
||||
返回格式化的标注列表,包含精确位置信息
|
||||
适用于章节内容的可视化标注展示
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 获取章节
|
||||
chapter_result = await db.execute(
|
||||
select(Chapter).where(Chapter.id == chapter_id)
|
||||
@@ -1386,6 +1491,9 @@ async def get_chapter_annotations(
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="章节不存在")
|
||||
|
||||
# 验证项目访问权限
|
||||
await verify_project_access(chapter.project_id, user_id, db)
|
||||
|
||||
# 获取分析结果
|
||||
analysis_result = await db.execute(
|
||||
select(PlotAnalysis)
|
||||
@@ -1623,13 +1731,8 @@ async def batch_generate_chapters_in_order(
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 验证项目存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证项目存在和用户权限
|
||||
project = await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取项目的所有章节,按序号排序
|
||||
result = await db.execute(
|
||||
@@ -1750,12 +1853,17 @@ async def get_batch_generation_status(
|
||||
@router.get("/project/{project_id}/batch-generate/active", summary="获取项目当前运行中的批量生成任务")
|
||||
async def get_active_batch_generation(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取项目当前运行中的批量生成任务
|
||||
用于页面刷新后恢复任务状态
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(BatchGenerationTask)
|
||||
.where(BatchGenerationTask.project_id == project_id)
|
||||
|
||||
@@ -24,12 +24,50 @@ router = APIRouter(prefix="/characters", tags=["角色管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("", response_model=CharacterListResponse, summary="获取角色列表")
|
||||
async def get_characters(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(query参数版本)"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
@@ -93,9 +131,14 @@ async def get_characters(
|
||||
@router.get("/project/{project_id}", response_model=CharacterListResponse, summary="获取项目的所有角色")
|
||||
async def get_project_characters(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有角色(路径参数版本)"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Character.id)).where(Character.project_id == project_id)
|
||||
@@ -159,6 +202,7 @@ async def get_project_characters(
|
||||
@router.get("/{character_id}", response_model=CharacterResponse, summary="获取角色详情")
|
||||
async def get_character(
|
||||
character_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取角色详情"""
|
||||
@@ -170,6 +214,10 @@ async def get_character(
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@@ -177,6 +225,7 @@ async def get_character(
|
||||
async def update_character(
|
||||
character_id: str,
|
||||
character_update: CharacterUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色信息"""
|
||||
@@ -188,6 +237,10 @@ async def update_character(
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = character_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -201,6 +254,7 @@ async def update_character(
|
||||
@router.delete("/{character_id}", summary="删除角色")
|
||||
async def delete_character(
|
||||
character_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色"""
|
||||
@@ -212,6 +266,10 @@ async def delete_character(
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="角色不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(character.project_id, user_id, db)
|
||||
|
||||
await db.delete(character)
|
||||
await db.commit()
|
||||
|
||||
@@ -233,13 +291,9 @@ async def generate_character(
|
||||
|
||||
生成内容包括:姓名、年龄、性别、性格、外貌、背景故事、人际关系等
|
||||
"""
|
||||
# 验证项目是否存在并获取项目信息
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限和项目是否存在
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(request.project_id, user_id, db)
|
||||
|
||||
try:
|
||||
# 获取已存在的角色列表,用于关系网络
|
||||
@@ -295,9 +349,6 @@ async def generate_character(
|
||||
user_input=user_input
|
||||
)
|
||||
|
||||
# 获取user_id用于MCP工具调用
|
||||
user_id = http_request.state.user_id if hasattr(http_request.state, 'user_id') else 'default_user'
|
||||
|
||||
# 调用AI生成角色(支持MCP工具)
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(启用MCP)")
|
||||
logger.info(f" - 角色名:{request.name or 'AI生成'}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional
|
||||
from app.database import get_db
|
||||
from app.models.memory import StoryMemory, PlotAnalysis
|
||||
from app.models.chapter import Chapter
|
||||
from app.models.project import Project
|
||||
from app.services.memory_service import memory_service
|
||||
from app.services.plot_analyzer import get_plot_analyzer
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
@@ -17,6 +18,26 @@ logger = get_logger(__name__)
|
||||
router = APIRouter(prefix="/api/memories", tags=["memories"])
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/analyze-chapter/{chapter_id}")
|
||||
async def analyze_chapter(
|
||||
project_id: str,
|
||||
@@ -30,7 +51,10 @@ async def analyze_chapter(
|
||||
对指定章节进行剧情分析,提取钩子、伏笔、情节点等,并存入记忆系统
|
||||
"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取章节内容
|
||||
result = await db.execute(
|
||||
@@ -192,7 +216,10 @@ async def get_project_memories(
|
||||
):
|
||||
"""获取项目的记忆列表"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 构建查询
|
||||
query = select(StoryMemory).where(StoryMemory.project_id == project_id)
|
||||
@@ -222,10 +249,16 @@ async def get_project_memories(
|
||||
async def get_chapter_analysis(
|
||||
project_id: str,
|
||||
chapter_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取章节的剧情分析"""
|
||||
try:
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(PlotAnalysis).where(
|
||||
and_(
|
||||
@@ -258,11 +291,15 @@ async def search_memories(
|
||||
query: str,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
min_importance: float = 0.0
|
||||
min_importance: float = 0.0,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""语义搜索项目记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
memories = await memory_service.search_memories(
|
||||
user_id=user_id,
|
||||
@@ -294,7 +331,10 @@ async def get_unresolved_foreshadows(
|
||||
):
|
||||
"""获取未完结的伏笔"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 从向量库搜索
|
||||
foreshadows = await memory_service.find_unresolved_foreshadows(
|
||||
@@ -317,11 +357,15 @@ async def get_unresolved_foreshadows(
|
||||
@router.get("/projects/{project_id}/stats")
|
||||
async def get_memory_stats(
|
||||
project_id: str,
|
||||
request: Request
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取记忆统计信息"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
stats = await memory_service.get_memory_stats(
|
||||
user_id=user_id,
|
||||
@@ -347,7 +391,10 @@ async def delete_chapter_memories(
|
||||
):
|
||||
"""删除章节的所有记忆"""
|
||||
try:
|
||||
user_id = request.state.user_id
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 验证用户权限
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 从数据库删除
|
||||
result = await db.execute(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""组织管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from typing import List, Optional
|
||||
@@ -31,6 +31,26 @@ router = APIRouter(prefix="/organizations", tags=["组织管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
class OrganizationGenerateRequest(BaseModel):
|
||||
"""AI生成组织的请求模型"""
|
||||
project_id: str = Field(..., description="项目ID")
|
||||
@@ -44,8 +64,13 @@ class OrganizationGenerateRequest(BaseModel):
|
||||
@router.get("/project/{project_id}", response_model=List[OrganizationDetailResponse], summary="获取项目的所有组织")
|
||||
async def get_project_organizations(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取项目中的所有组织及其详情
|
||||
|
||||
@@ -85,6 +110,7 @@ async def get_project_organizations(
|
||||
@router.get("/{org_id}", response_model=OrganizationResponse, summary="获取组织详情")
|
||||
async def get_organization(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取组织的详细信息"""
|
||||
@@ -96,12 +122,17 @@ async def get_organization(
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
return org
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganizationResponse, summary="创建组织")
|
||||
async def create_organization(
|
||||
organization: OrganizationCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -110,6 +141,10 @@ async def create_organization(
|
||||
- 需要关联到一个已存在的角色记录(is_organization=True)
|
||||
- 可以设置父组织、势力等级等属性
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(organization.project_id, user_id, db)
|
||||
|
||||
# 验证角色是否存在且是组织
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == organization.character_id)
|
||||
@@ -142,6 +177,7 @@ async def create_organization(
|
||||
async def update_organization(
|
||||
org_id: str,
|
||||
organization: OrganizationUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织的属性"""
|
||||
@@ -153,6 +189,10 @@ async def update_organization(
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_org.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = organization.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -168,6 +208,7 @@ async def update_organization(
|
||||
@router.delete("/{org_id}", summary="删除组织")
|
||||
async def delete_organization(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除组织(会级联删除所有成员关系)"""
|
||||
@@ -179,6 +220,10 @@ async def delete_organization(
|
||||
if not db_org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_org.project_id, user_id, db)
|
||||
|
||||
await db.delete(db_org)
|
||||
await db.commit()
|
||||
|
||||
@@ -191,6 +236,7 @@ async def delete_organization(
|
||||
@router.get("/{org_id}/members", response_model=List[OrganizationMemberDetailResponse], summary="获取组织成员")
|
||||
async def get_organization_members(
|
||||
org_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -202,9 +248,14 @@ async def get_organization_members(
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
)
|
||||
if not org_result.scalar_one_or_none():
|
||||
org = org_result.scalar_one_or_none()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 获取成员列表
|
||||
result = await db.execute(
|
||||
select(OrganizationMember)
|
||||
@@ -244,6 +295,7 @@ async def get_organization_members(
|
||||
async def add_organization_member(
|
||||
org_id: str,
|
||||
member: OrganizationMemberCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -260,6 +312,10 @@ async def add_organization_member(
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="组织不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 验证角色存在
|
||||
char_result = await db.execute(
|
||||
select(Character).where(Character.id == member.character_id)
|
||||
@@ -304,6 +360,7 @@ async def add_organization_member(
|
||||
async def update_organization_member(
|
||||
member_id: str,
|
||||
member: OrganizationMemberUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新组织成员的职位、忠诚度等信息"""
|
||||
@@ -315,6 +372,14 @@ async def update_organization_member(
|
||||
if not db_member:
|
||||
raise HTTPException(status_code=404, detail="成员记录不存在")
|
||||
|
||||
# 通过成员所属的组织验证用户权限
|
||||
org_result = await db.execute(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = member.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -330,6 +395,7 @@ async def update_organization_member(
|
||||
@router.delete("/members/{member_id}", summary="移除组织成员")
|
||||
async def remove_organization_member(
|
||||
member_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -350,6 +416,10 @@ async def remove_organization_member(
|
||||
select(Organization).where(Organization.id == db_member.organization_id)
|
||||
)
|
||||
org = org_result.scalar_one()
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(org.project_id, user_id, db)
|
||||
org.member_count = max(0, org.member_count - 1)
|
||||
|
||||
await db.delete(db_member)
|
||||
@@ -360,7 +430,8 @@ async def remove_organization_member(
|
||||
|
||||
@router.post("/generate", response_model=CharacterResponse, summary="AI生成组织")
|
||||
async def generate_organization(
|
||||
request: OrganizationGenerateRequest,
|
||||
gen_request: OrganizationGenerateRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||
):
|
||||
@@ -372,19 +443,15 @@ async def generate_organization(
|
||||
|
||||
生成内容包括:组织名称、类型、特性、背景、目的、势力等级等
|
||||
"""
|
||||
# 验证项目是否存在并获取项目信息
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(gen_request.project_id, user_id, db)
|
||||
|
||||
try:
|
||||
# 获取已存在的角色和组织列表
|
||||
existing_chars_result = await db.execute(
|
||||
select(Character)
|
||||
.where(Character.project_id == request.project_id)
|
||||
.where(Character.project_id == gen_request.project_id)
|
||||
.order_by(Character.created_at.desc())
|
||||
)
|
||||
existing_characters = existing_chars_result.scalars().all()
|
||||
@@ -422,10 +489,10 @@ async def generate_organization(
|
||||
# 构建用户输入信息
|
||||
user_input = f"""
|
||||
用户要求:
|
||||
- 组织名称:{request.name or '请AI生成'}
|
||||
- 组织类型:{request.organization_type or '请AI根据世界观决定'}
|
||||
- 背景设定:{request.background or '无特殊要求'}
|
||||
- 其他要求:{request.requirements or '无'}
|
||||
- 组织名称:{gen_request.name or '请AI生成'}
|
||||
- 组织类型:{gen_request.organization_type or '请AI根据世界观决定'}
|
||||
- 背景设定:{gen_request.background or '无特殊要求'}
|
||||
- 其他要求:{gen_request.requirements or '无'}
|
||||
"""
|
||||
|
||||
# 使用统一的提示词服务
|
||||
@@ -435,10 +502,10 @@ async def generate_organization(
|
||||
)
|
||||
|
||||
# 调用AI生成组织
|
||||
logger.info(f"🎯 开始为项目 {request.project_id} 生成组织")
|
||||
logger.info(f" - 组织名:{request.name or 'AI生成'}")
|
||||
logger.info(f" - 组织类型:{request.organization_type or 'AI决定'}")
|
||||
logger.info(f" - 背景设定:{request.background or '无'}")
|
||||
logger.info(f"🎯 开始为项目 {gen_request.project_id} 生成组织")
|
||||
logger.info(f" - 组织名:{gen_request.name or 'AI生成'}")
|
||||
logger.info(f" - 组织类型:{gen_request.organization_type or 'AI决定'}")
|
||||
logger.info(f" - 背景设定:{gen_request.background or '无'}")
|
||||
logger.info(f" - AI提供商:{user_ai_service.api_provider}")
|
||||
logger.info(f" - AI模型:{user_ai_service.default_model}")
|
||||
logger.info(f" - Prompt长度:{len(prompt)} 字符")
|
||||
@@ -492,8 +559,8 @@ async def generate_organization(
|
||||
|
||||
# 创建角色记录(组织也是角色的一种)
|
||||
character = Character(
|
||||
project_id=request.project_id,
|
||||
name=organization_data.get("name", request.name or "未命名组织"),
|
||||
project_id=gen_request.project_id,
|
||||
name=organization_data.get("name", gen_request.name or "未命名组织"),
|
||||
is_organization=True,
|
||||
role_type="supporting", # 组织通常作为配角
|
||||
personality=organization_data.get("personality", ""),
|
||||
@@ -518,7 +585,7 @@ async def generate_organization(
|
||||
# 自动创建Organization详情记录
|
||||
organization = Organization(
|
||||
character_id=character.id,
|
||||
project_id=request.project_id,
|
||||
project_id=gen_request.project_id,
|
||||
member_count=0,
|
||||
power_level=organization_data.get("power_level", 50),
|
||||
location=organization_data.get("location"),
|
||||
@@ -532,7 +599,7 @@ async def generate_organization(
|
||||
|
||||
# 记录生成历史
|
||||
history = GenerationHistory(
|
||||
project_id=request.project_id,
|
||||
project_id=gen_request.project_id,
|
||||
prompt=prompt,
|
||||
generated_content=ai_content,
|
||||
model=user_ai_service.default_model
|
||||
@@ -542,7 +609,7 @@ async def generate_organization(
|
||||
await db.commit()
|
||||
await db.refresh(character)
|
||||
|
||||
logger.info(f"🎉 成功为项目 {request.project_id} 生成组织: {character.name}")
|
||||
logger.info(f"🎉 成功为项目 {gen_request.project_id} 生成组织: {character.name}")
|
||||
|
||||
return character
|
||||
|
||||
|
||||
+81
-23
@@ -30,19 +30,49 @@ router = APIRouter(prefix="/outlines", tags=["大纲管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""
|
||||
验证用户是否有权访问指定项目
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Project: 项目对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 未登录,404 项目不存在或无权访问
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.post("", response_model=OutlineResponse, summary="创建大纲")
|
||||
async def create_outline(
|
||||
outline: OutlineCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""创建新的章节大纲,同时创建对应的章节记录"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == outline.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(outline.project_id, user_id, db)
|
||||
|
||||
# 创建大纲
|
||||
db_outline = Outline(**outline.model_dump())
|
||||
@@ -66,9 +96,14 @@ async def create_outline(
|
||||
@router.get("", response_model=OutlineListResponse, summary="获取大纲列表")
|
||||
async def get_outlines(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
@@ -89,9 +124,14 @@ async def get_outlines(
|
||||
@router.get("/project/{project_id}", response_model=OutlineListResponse, summary="获取项目的所有大纲")
|
||||
async def get_project_outlines(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定项目的所有大纲(路径参数版本)"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取总数
|
||||
count_result = await db.execute(
|
||||
select(func.count(Outline.id)).where(Outline.project_id == project_id)
|
||||
@@ -112,6 +152,7 @@ async def get_project_outlines(
|
||||
@router.get("/{outline_id}", response_model=OutlineResponse, summary="获取大纲详情")
|
||||
async def get_outline(
|
||||
outline_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据ID获取大纲详情"""
|
||||
@@ -123,6 +164,10 @@ async def get_outline(
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(outline.project_id, user_id, db)
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
@@ -130,6 +175,7 @@ async def get_outline(
|
||||
async def update_outline(
|
||||
outline_id: str,
|
||||
outline_update: OutlineUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新大纲信息,同步更新对应章节和structure字段"""
|
||||
@@ -141,6 +187,10 @@ async def update_outline(
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(outline.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = outline_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -196,6 +246,7 @@ async def update_outline(
|
||||
@router.delete("/{outline_id}", summary="删除大纲")
|
||||
async def delete_outline(
|
||||
outline_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除大纲,同步删除章节,并重新排序后续项"""
|
||||
@@ -207,6 +258,10 @@ async def delete_outline(
|
||||
if not outline:
|
||||
raise HTTPException(status_code=404, detail="大纲不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(outline.project_id, user_id, db)
|
||||
|
||||
project_id = outline.project_id
|
||||
deleted_order = outline.order_index
|
||||
|
||||
@@ -252,7 +307,8 @@ async def delete_outline(
|
||||
|
||||
@router.post("/reorder", summary="批量重排序大纲")
|
||||
async def reorder_outlines(
|
||||
request: OutlineReorderRequest,
|
||||
reorder_request: OutlineReorderRequest,
|
||||
http_request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -261,10 +317,20 @@ async def reorder_outlines(
|
||||
策略:先收集所有变更,最后一次性提交,避免临时冲突
|
||||
"""
|
||||
try:
|
||||
# 验证用户权限(通过第一个大纲的project_id)
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
if reorder_request.orders and len(reorder_request.orders) > 0:
|
||||
first_outline_result = await db.execute(
|
||||
select(Outline).where(Outline.id == reorder_request.orders[0].id)
|
||||
)
|
||||
first_outline = first_outline_result.scalar_one_or_none()
|
||||
if first_outline:
|
||||
await verify_project_access(first_outline.project_id, user_id, db)
|
||||
|
||||
# 第一步:收集所有大纲和对应的章节
|
||||
outline_chapter_map = {} # {outline_id: (outline, chapter, old_order, new_order)}
|
||||
|
||||
for item in request.orders:
|
||||
for item in reorder_request.orders:
|
||||
outline_id = item.id
|
||||
new_order = item.order_index
|
||||
|
||||
@@ -341,13 +407,9 @@ async def generate_outline(
|
||||
- new: 强制全新生成
|
||||
- continue: 强制续写模式
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == request.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(http_request.state, 'user_id', None)
|
||||
project = await verify_project_access(request.project_id, user_id, db)
|
||||
|
||||
try:
|
||||
# 获取现有大纲(强制从数据库获取最新数据,包括用户手动修改的内容)
|
||||
@@ -1472,13 +1534,9 @@ async def generate_outline_stream(
|
||||
"model": "gpt-4" // 可选
|
||||
}
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == data.get("project_id"))
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
project = await verify_project_access(data.get("project_id"), user_id, db)
|
||||
|
||||
# 判断模式
|
||||
mode = data.get("mode", "auto")
|
||||
|
||||
+155
-41
@@ -41,17 +41,31 @@ router = APIRouter(prefix="/projects", tags=["项目管理"])
|
||||
@router.post("", response_model=ProjectResponse, summary="创建项目")
|
||||
async def create_project(
|
||||
project: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.info(f"创建新项目: {project.title}")
|
||||
db_project = Project(**project.model_dump())
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试创建项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"创建新项目: {project.title}, user_id={user_id}")
|
||||
|
||||
# 创建项目时自动设置user_id
|
||||
project_data = project.model_dump()
|
||||
project_data['user_id'] = user_id
|
||||
db_project = Project(**project_data)
|
||||
|
||||
db.add(db_project)
|
||||
await db.commit()
|
||||
await db.refresh(db_project)
|
||||
logger.info(f"项目创建成功: {db_project.id}")
|
||||
logger.info(f"项目创建成功: project_id={db_project.id}, user_id={user_id}")
|
||||
|
||||
return db_project
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建项目失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -61,24 +75,38 @@ async def create_project(
|
||||
async def get_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
"""获取所有项目列表"""
|
||||
"""获取当前用户的项目列表"""
|
||||
try:
|
||||
logger.debug(f"获取项目列表: skip={skip}, limit={limit}")
|
||||
count_result = await db.execute(select(func.count(Project.id)))
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试获取项目列表")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.debug(f"获取项目列表: user_id={user_id}, skip={skip}, limit={limit}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
count_result = await db.execute(
|
||||
select(func.count(Project.id)).where(Project.user_id == user_id)
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
result = await db.execute(
|
||||
select(Project)
|
||||
.where(Project.user_id == user_id)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
projects = result.scalars().all()
|
||||
logger.info(f"获取项目列表成功: 共{total}个项目")
|
||||
logger.info(f"获取项目列表成功: user_id={user_id}, 共{total}个项目")
|
||||
|
||||
return ProjectListResponse(total=total, items=projects)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取项目列表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -87,17 +115,29 @@ async def get_projects(
|
||||
@router.get("/{project_id}", response_model=ProjectResponse, summary="获取项目详情")
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.debug(f"获取项目详情: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试获取项目详情")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.debug(f"获取项目详情: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
logger.info(f"获取项目详情成功: {project.title}")
|
||||
@@ -113,17 +153,29 @@ async def get_project(
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
project_update: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
try:
|
||||
logger.info(f"更新项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试更新项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"更新项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
update_data = project_update.model_dump(exclude_unset=True)
|
||||
@@ -149,22 +201,30 @@ async def delete_project(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
logger.info(f"删除项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试删除项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"删除项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
project_title = project.title
|
||||
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
|
||||
# 删除向量数据库中的记忆
|
||||
# 删除向量数据库中的记忆(user_id已在上面获取)
|
||||
if user_id:
|
||||
try:
|
||||
await memory_service.delete_project_memories(user_id, project_id)
|
||||
@@ -234,22 +294,33 @@ async def delete_project(
|
||||
@router.get("/{project_id}/export", summary="导出项目章节为TXT")
|
||||
async def export_project_chapters(
|
||||
project_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
request: Request = None
|
||||
):
|
||||
"""
|
||||
导出项目的所有章节内容为TXT文本文件
|
||||
按章节顺序组织,包含项目基本信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试导出项目")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始导出项目: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
chapters_result = await db.execute(
|
||||
@@ -326,6 +397,7 @@ async def export_project_chapters(
|
||||
@router.post("/{project_id}/check-consistency", summary="检查数据一致性")
|
||||
async def check_project_consistency(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
auto_fix: bool = True,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@@ -343,15 +415,25 @@ async def check_project_consistency(
|
||||
- organization_members: 验证组织成员数据完整性
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始数据一致性检查: {project_id}, auto_fix={auto_fix}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试检查数据一致性")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始数据一致性检查: project_id={project_id}, user_id={user_id}, auto_fix={auto_fix}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
report = await run_full_data_consistency_check(project_id, db, auto_fix)
|
||||
@@ -369,6 +451,7 @@ async def check_project_consistency(
|
||||
@router.post("/{project_id}/fix-organizations", summary="修复组织记录")
|
||||
async def fix_project_organizations(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -377,15 +460,25 @@ async def fix_project_organizations(
|
||||
为所有is_organization=True但没有Organization记录的Character创建记录
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复组织记录: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试修复组织记录")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始修复组织记录: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_missing_organization_records(project_id, db)
|
||||
@@ -407,6 +500,7 @@ async def fix_project_organizations(
|
||||
@router.post("/{project_id}/fix-member-counts", summary="修复成员计数")
|
||||
async def fix_project_member_counts(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -415,15 +509,25 @@ async def fix_project_member_counts(
|
||||
从实际成员记录重新计算每个组织的member_count
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始修复成员计数: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试修复成员计数")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
logger.info(f"开始修复成员计数: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
fixed_count, total_count = await fix_organization_member_counts(project_id, db)
|
||||
@@ -445,6 +549,7 @@ async def fix_project_member_counts(
|
||||
@router.post("/{project_id}/export-data", summary="导出项目数据为JSON")
|
||||
async def export_project_data(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
options: ExportOptions,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@@ -459,16 +564,25 @@ async def export_project_data(
|
||||
JSON文件下载
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始导出项目数据: {project_id}")
|
||||
# 从认证中间件获取用户ID
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
logger.warning("未登录用户尝试导出项目数据")
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
# 检查项目是否存在
|
||||
logger.info(f"开始导出项目数据: project_id={project_id}, user_id={user_id}")
|
||||
|
||||
# 只查询当前用户的项目
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目不存在: {project_id}")
|
||||
logger.warning(f"项目不存在或无权访问: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 导出数据
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""关系管理API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, or_, and_
|
||||
from typing import List, Optional
|
||||
@@ -12,6 +12,7 @@ from app.models.relationship import (
|
||||
OrganizationMember
|
||||
)
|
||||
from app.models.character import Character
|
||||
from app.models.project import Project
|
||||
from app.schemas.relationship import (
|
||||
RelationshipTypeResponse,
|
||||
CharacterRelationshipCreate,
|
||||
@@ -27,6 +28,26 @@ router = APIRouter(prefix="/relationships", tags=["关系管理"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("/types", response_model=List[RelationshipTypeResponse], summary="获取关系类型列表")
|
||||
async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
"""获取所有预定义的关系类型"""
|
||||
@@ -38,9 +59,14 @@ async def get_relationship_types(db: AsyncSession = Depends(get_db)):
|
||||
@router.get("/project/{project_id}", response_model=List[CharacterRelationshipResponse], summary="获取项目的所有关系")
|
||||
async def get_project_relationships(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
character_id: Optional[str] = Query(None, description="筛选特定角色的关系"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取项目中的所有角色关系
|
||||
|
||||
@@ -70,8 +96,13 @@ async def get_project_relationships(
|
||||
@router.get("/graph/{project_id}", response_model=RelationshipGraphData, summary="获取关系图谱数据")
|
||||
async def get_relationship_graph(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
"""
|
||||
获取用于可视化的关系图谱数据
|
||||
|
||||
@@ -122,6 +153,7 @@ async def get_relationship_graph(
|
||||
@router.post("/", response_model=CharacterRelationshipResponse, summary="创建角色关系")
|
||||
async def create_relationship(
|
||||
relationship: CharacterRelationshipCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -131,6 +163,10 @@ async def create_relationship(
|
||||
- 可以指定预定义的关系类型或自定义关系名称
|
||||
- 可以设置亲密度、状态等属性
|
||||
"""
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(relationship.project_id, user_id, db)
|
||||
|
||||
# 验证角色是否存在
|
||||
char_from = await db.execute(
|
||||
select(Character).where(Character.id == relationship.character_from_id)
|
||||
@@ -161,6 +197,7 @@ async def create_relationship(
|
||||
async def update_relationship(
|
||||
relationship_id: str,
|
||||
relationship: CharacterRelationshipUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""更新角色关系的属性(亲密度、状态等)"""
|
||||
@@ -174,6 +211,10 @@ async def update_relationship(
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_rel.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = relationship.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
@@ -189,6 +230,7 @@ async def update_relationship(
|
||||
@router.delete("/{relationship_id}", summary="删除关系")
|
||||
async def delete_relationship(
|
||||
relationship_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""删除角色关系"""
|
||||
@@ -202,6 +244,10 @@ async def delete_relationship(
|
||||
if not db_rel:
|
||||
raise HTTPException(status_code=404, detail="关系不存在")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(db_rel.project_id, user_id, db)
|
||||
|
||||
await db.delete(db_rel)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@@ -183,7 +183,13 @@ async def world_building_generator(
|
||||
# 保存到数据库
|
||||
yield await SSEResponse.send_progress("保存到数据库...", 90)
|
||||
|
||||
# 确保user_id存在
|
||||
if not user_id:
|
||||
yield await SSEResponse.send_error("用户ID缺失,无法创建项目", 401)
|
||||
return
|
||||
|
||||
project = Project(
|
||||
user_id=user_id, # 添加user_id字段
|
||||
title=title,
|
||||
description=description,
|
||||
theme=theme,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""写作风格管理 API"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, delete
|
||||
from typing import List
|
||||
@@ -16,8 +16,30 @@ from ..schemas.writing_style import (
|
||||
SetDefaultStyleRequest
|
||||
)
|
||||
from ..services.prompt_service import WritingStyleManager
|
||||
from ..logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/writing-styles", tags=["writing-styles"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def verify_project_access(project_id: str, user_id: str, db: AsyncSession) -> Project:
|
||||
"""验证用户是否有权访问指定项目"""
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="未登录")
|
||||
|
||||
result = await db.execute(
|
||||
select(Project).where(
|
||||
Project.id == project_id,
|
||||
Project.user_id == user_id
|
||||
)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
logger.warning(f"项目访问被拒绝: project_id={project_id}, user_id={user_id}")
|
||||
raise HTTPException(status_code=404, detail="项目不存在或无权访问")
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@router.get("/presets/list", response_model=List[dict])
|
||||
@@ -42,6 +64,7 @@ async def get_preset_styles():
|
||||
@router.post("", response_model=WritingStyleResponse, status_code=201)
|
||||
async def create_writing_style(
|
||||
style_data: WritingStyleCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -50,13 +73,9 @@ async def create_writing_style(
|
||||
- **基于预设创建**:提供 preset_id,系统会自动填充预设内容
|
||||
- **完全自定义**:不提供 preset_id,需要手动填写所有字段
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == style_data.project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style_data.project_id, user_id, db)
|
||||
|
||||
# 如果基于预设创建,获取预设内容
|
||||
if style_data.preset_id:
|
||||
@@ -120,6 +139,7 @@ async def create_writing_style(
|
||||
@router.get("/project/{project_id}", response_model=WritingStyleListResponse)
|
||||
async def get_project_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -128,13 +148,9 @@ async def get_project_styles(
|
||||
返回:全局预设风格 + 该项目的自定义风格
|
||||
按 order_index 排序,并标记哪个是当前项目的默认风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 获取该项目的默认风格ID
|
||||
result = await db.execute(
|
||||
@@ -222,6 +238,7 @@ async def get_writing_style(
|
||||
async def update_writing_style(
|
||||
style_id: int,
|
||||
style_data: WritingStyleUpdate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -241,6 +258,10 @@ async def update_writing_style(
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能修改全局预设风格,只能修改自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
|
||||
# 更新字段
|
||||
update_data = style_data.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -279,6 +300,7 @@ async def update_writing_style(
|
||||
@router.delete("/{style_id}", status_code=204)
|
||||
async def delete_writing_style(
|
||||
style_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -300,6 +322,10 @@ async def delete_writing_style(
|
||||
if style.project_id is None:
|
||||
raise HTTPException(status_code=403, detail="不能删除全局预设风格,只能删除自定义风格")
|
||||
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(style.project_id, user_id, db)
|
||||
|
||||
# 检查是否有项目将其设置为默认风格
|
||||
result = await db.execute(
|
||||
select(ProjectDefaultStyle).where(ProjectDefaultStyle.style_id == style_id)
|
||||
@@ -321,6 +347,7 @@ async def delete_writing_style(
|
||||
async def set_default_style(
|
||||
style_id: int,
|
||||
request_data: SetDefaultStyleRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -335,13 +362,9 @@ async def set_default_style(
|
||||
"""
|
||||
project_id = request_data.project_id
|
||||
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 验证风格是否存在
|
||||
result = await db.execute(
|
||||
@@ -379,6 +402,7 @@ async def set_default_style(
|
||||
@router.post("/project/{project_id}/init-defaults", response_model=WritingStyleListResponse)
|
||||
async def initialize_default_styles(
|
||||
project_id: str,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
@@ -387,13 +411,9 @@ async def initialize_default_styles(
|
||||
新架构下,预设风格是全局的,不需要为每个项目单独初始化
|
||||
该接口保留用于兼容性,直接返回项目可用的所有风格
|
||||
"""
|
||||
# 验证项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
# 验证用户权限
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
await verify_project_access(project_id, user_id, db)
|
||||
|
||||
# 直接返回项目可用的所有风格(全局预设 + 项目自定义)
|
||||
return await get_project_styles(project_id, db)
|
||||
return await get_project_styles(project_id, request, db)
|
||||
Reference in New Issue
Block a user