diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 56cdbab..c1e305a 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -26,7 +26,7 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"]) logger = get_logger(__name__) -async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None: +async def get_owned_project(db: AsyncSession, project_id: str, user_id: str | None) -> Project | None: if not project_id or not user_id: return None result = await db.execute( @@ -1528,6 +1528,7 @@ async def outline_generator( @router.post("/outline", summary="流式生成完整大纲") async def generate_outline_stream( + request: Request, data: Dict[str, Any], db: AsyncSession = Depends(get_db), user_ai_service: AIService = Depends(get_user_ai_service) @@ -1535,6 +1536,10 @@ async def generate_outline_stream( """ 使用SSE流式生成完整大纲,避免超时 """ + # 从中间件注入user_id到data中,供outline_generator进行项目归属校验 + if hasattr(request.state, 'user_id'): + data['user_id'] = request.state.user_id + return create_sse_response(outline_generator(data, db, user_ai_service)) @@ -1552,6 +1557,12 @@ async def world_building_regenerate_generator( try: yield await tracker.start("开始重新生成世界观...") + # 提取参数 + provider = data.get("provider") + model = data.get("model") + enable_mcp = data.get("enable_mcp", True) + user_id = data.get("user_id") + # 获取项目信息 yield await tracker.loading("加载项目信息...") project = await get_owned_project(db, project_id, user_id) @@ -1559,12 +1570,6 @@ async def world_building_regenerate_generator( yield await tracker.error("项目不存在或无权访问", 404) return - # 提取参数 - provider = data.get("provider") - model = data.get("model") - enable_mcp = data.get("enable_mcp", True) - user_id = data.get("user_id") - # 获取基础提示词(支持自定义) yield await tracker.preparing("准备AI提示词...") template = await PromptService.get_template("WORLD_BUILDING", user_id, db)