fix: 修复向导大纲生成接口校验值问题
This commit is contained in:
@@ -26,7 +26,7 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"])
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None:
|
async def get_owned_project(db: AsyncSession, project_id: str, user_id: str | None) -> Project | None:
|
||||||
if not project_id or not user_id:
|
if not project_id or not user_id:
|
||||||
return None
|
return None
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -1528,6 +1528,7 @@ async def outline_generator(
|
|||||||
|
|
||||||
@router.post("/outline", summary="流式生成完整大纲")
|
@router.post("/outline", summary="流式生成完整大纲")
|
||||||
async def generate_outline_stream(
|
async def generate_outline_stream(
|
||||||
|
request: Request,
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
user_ai_service: AIService = Depends(get_user_ai_service)
|
user_ai_service: AIService = Depends(get_user_ai_service)
|
||||||
@@ -1535,6 +1536,10 @@ async def generate_outline_stream(
|
|||||||
"""
|
"""
|
||||||
使用SSE流式生成完整大纲,避免超时
|
使用SSE流式生成完整大纲,避免超时
|
||||||
"""
|
"""
|
||||||
|
# 从中间件注入user_id到data中,供outline_generator进行项目归属校验
|
||||||
|
if hasattr(request.state, 'user_id'):
|
||||||
|
data['user_id'] = request.state.user_id
|
||||||
|
|
||||||
return create_sse_response(outline_generator(data, db, user_ai_service))
|
return create_sse_response(outline_generator(data, db, user_ai_service))
|
||||||
|
|
||||||
|
|
||||||
@@ -1552,6 +1557,12 @@ async def world_building_regenerate_generator(
|
|||||||
try:
|
try:
|
||||||
yield await tracker.start("开始重新生成世界观...")
|
yield await tracker.start("开始重新生成世界观...")
|
||||||
|
|
||||||
|
# 提取参数
|
||||||
|
provider = data.get("provider")
|
||||||
|
model = data.get("model")
|
||||||
|
enable_mcp = data.get("enable_mcp", True)
|
||||||
|
user_id = data.get("user_id")
|
||||||
|
|
||||||
# 获取项目信息
|
# 获取项目信息
|
||||||
yield await tracker.loading("加载项目信息...")
|
yield await tracker.loading("加载项目信息...")
|
||||||
project = await get_owned_project(db, project_id, user_id)
|
project = await get_owned_project(db, project_id, user_id)
|
||||||
@@ -1559,12 +1570,6 @@ async def world_building_regenerate_generator(
|
|||||||
yield await tracker.error("项目不存在或无权访问", 404)
|
yield await tracker.error("项目不存在或无权访问", 404)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 提取参数
|
|
||||||
provider = data.get("provider")
|
|
||||||
model = data.get("model")
|
|
||||||
enable_mcp = data.get("enable_mcp", True)
|
|
||||||
user_id = data.get("user_id")
|
|
||||||
|
|
||||||
# 获取基础提示词(支持自定义)
|
# 获取基础提示词(支持自定义)
|
||||||
yield await tracker.preparing("准备AI提示词...")
|
yield await tracker.preparing("准备AI提示词...")
|
||||||
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
|
template = await PromptService.get_template("WORLD_BUILDING", user_id, db)
|
||||||
|
|||||||
Reference in New Issue
Block a user