From 1b6cc68188766e6ec46fc20376a3ed1f6b0766a4 Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Fri, 24 Apr 2026 16:18:23 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=90=91=E5=AF=BC?= =?UTF-8?q?=E5=A4=A7=E7=BA=B2=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C=E5=80=BC=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/wizard_stream.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 56cdbab..c1e305a 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -26,7 +26,7 @@ router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"]) logger = get_logger(__name__) -async def get_owned_project(db: AsyncSession, project_id: str, user_id: str) -> Project | None: +async def get_owned_project(db: AsyncSession, project_id: str, user_id: str | None) -> Project | None: if not project_id or not user_id: return None result = await db.execute( @@ -1528,6 +1528,7 @@ async def outline_generator( @router.post("/outline", summary="流式生成完整大纲") async def generate_outline_stream( + request: Request, data: Dict[str, Any], db: AsyncSession = Depends(get_db), user_ai_service: AIService = Depends(get_user_ai_service) @@ -1535,6 +1536,10 @@ async def generate_outline_stream( """ 使用SSE流式生成完整大纲,避免超时 """ + # 从中间件注入user_id到data中,供outline_generator进行项目归属校验 + if hasattr(request.state, 'user_id'): + data['user_id'] = request.state.user_id + return create_sse_response(outline_generator(data, db, user_ai_service)) @@ -1552,6 +1557,12 @@ async def world_building_regenerate_generator( try: yield await tracker.start("开始重新生成世界观...") + # 提取参数 + provider = data.get("provider") + model = data.get("model") + enable_mcp = data.get("enable_mcp", True) + user_id = data.get("user_id") + # 获取项目信息 yield await tracker.loading("加载项目信息...") project = await get_owned_project(db, project_id, user_id) @@ -1559,12 +1570,6 @@ async def world_building_regenerate_generator( yield await tracker.error("项目不存在或无权访问", 404) return - # 提取参数 - provider = data.get("provider") - model = data.get("model") - enable_mcp = data.get("enable_mcp", True) - user_id = data.get("user_id") - # 获取基础提示词(支持自定义) yield await tracker.preparing("准备AI提示词...") template = await PromptService.get_template("WORLD_BUILDING", user_id, db)