From 096c2ebbd1416bfab653e0ac531456b1980d9cf6 Mon Sep 17 00:00:00 2001 From: xiamuceer Date: Tue, 25 Nov 2025 19:35:38 +0800 Subject: [PATCH] =?UTF-8?q?update:1.=E6=9B=B4=E6=96=B0=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/wizard_stream.py | 201 ++++++++++++++++++++++++- backend/app/models/project.py | 2 +- backend/scripts/fix_user_id_length.sql | 9 ++ frontend/package.json | 2 +- 4 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 backend/scripts/fix_user_id_length.sql diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 037b466..c83946c 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -1062,7 +1062,6 @@ async def outline_generator( logger.info("大纲生成事务已回滚(异常)") yield await SSEResponse.send_error(f"生成失败: {str(e)}") - @router.post("/outline", summary="流式生成完整大纲") async def generate_outline_stream( data: Dict[str, Any], @@ -1074,3 +1073,203 @@ async def generate_outline_stream( """ return create_sse_response(outline_generator(data, db, user_ai_service)) + +async def world_building_regenerate_generator( + project_id: str, + data: Dict[str, Any], + db: AsyncSession, + user_ai_service: AIService +) -> AsyncGenerator[str, None]: + """世界观重新生成流式生成器""" + db_committed = False + try: + yield await SSEResponse.send_progress("开始重新生成世界观...", 10) + + # 获取项目信息 + result = await db.execute( + select(Project).where(Project.id == project_id) + ) + project = result.scalar_one_or_none() + if not project: + yield await SSEResponse.send_error("项目不存在", 404) + return + + # 提取参数 + provider = data.get("provider") + model = data.get("model") + enable_mcp = data.get("enable_mcp", True) + user_id = data.get("user_id") + + # 获取基础提示词 + yield await SSEResponse.send_progress("准备AI提示词...", 15) + base_prompt = prompt_service.get_world_building_prompt( + title=project.title, + theme=project.theme or "未设定", + genre=project.genre or "通用" + ) + + # MCP工具增强:收集参考资料 + reference_materials = "" + if enable_mcp and user_id: + try: + from app.services.mcp_tool_service import mcp_tool_service + available_tools = await mcp_tool_service.get_user_enabled_tools( + user_id=user_id, + db_session=db + ) + + if available_tools: + yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18) + + planning_prompt = f"""你正在为小说《{project.title}》重新设计世界观。 + +【小说信息】 +- 题材:{project.genre} +- 主题:{project.theme} +- 简介:{project.description or '未设定'} + +【任务】 +请使用可用工具搜索相关背景资料,帮助构建更真实、更有深度的世界观设定。 +你可以查询: +1. 历史背景(如果是历史题材) +2. 地理环境和文化特征 +3. 相关领域的专业知识 +4. 类似作品的设定参考 + +请查询最关键的1个问题(不要超过1个)。""" + + planning_result = await user_ai_service.generate_text_with_mcp( + prompt=planning_prompt, + user_id=user_id, + db_session=db, + enable_mcp=True, + max_tool_rounds=1, + tool_choice="auto", + provider=None, + model=None + ) + + if planning_result.get("tool_calls_made", 0) > 0: + yield await SSEResponse.send_progress( + f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", + 25 + ) + reference_materials = planning_result.get("content", "") + else: + logger.debug("MCP工具可用但AI未选择使用") + else: + logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") + + except Exception as e: + logger.warning(f"MCP工具调用失败(降级处理): {e}") + yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25) + + # 构建增强提示词 + if reference_materials: + enhanced_prompt = f"""{base_prompt} + +【参考资料】 +以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观: + +{reference_materials} + +请结合上述资料,生成符合历史/现实的世界观设定。""" + final_prompt = enhanced_prompt + yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 30) + else: + final_prompt = base_prompt + yield await SSEResponse.send_progress("正在调用AI生成...", 30) + + # 流式生成世界观 + accumulated_text = "" + chunk_count = 0 + + async for chunk in user_ai_service.generate_text_stream( + prompt=final_prompt, + provider=provider, + model=model + ): + chunk_count += 1 + accumulated_text += chunk + + yield await SSEResponse.send_chunk(chunk) + + if chunk_count % 5 == 0: + progress = min(30 + (chunk_count // 5), 70) + yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress) + + if chunk_count % 20 == 0: + yield await SSEResponse.send_heartbeat() + + # 解析结果 + yield await SSEResponse.send_progress("解析AI返回结果...", 80) + + world_data = {} + try: + cleaned_text = accumulated_text.strip() + + if cleaned_text.startswith('```json'): + cleaned_text = cleaned_text[7:].lstrip('\n\r') + elif cleaned_text.startswith('```'): + cleaned_text = cleaned_text[3:].lstrip('\n\r') + if cleaned_text.endswith('```'): + cleaned_text = cleaned_text[:-3].rstrip('\n\r') + cleaned_text = cleaned_text.strip() + + world_data = json.loads(cleaned_text) + + except json.JSONDecodeError as e: + logger.error(f"世界构建JSON解析失败: {e}") + world_data = { + "time_period": "AI返回格式错误,请重试", + "location": "AI返回格式错误,请重试", + "atmosphere": "AI返回格式错误,请重试", + "rules": "AI返回格式错误,请重试" + } + + # 不保存到数据库,仅返回生成结果供用户预览 + yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90) + + # 发送最终结果(不包含project_id,表示未保存) + yield await SSEResponse.send_result({ + "time_period": world_data.get("time_period"), + "location": world_data.get("location"), + "atmosphere": world_data.get("atmosphere"), + "rules": world_data.get("rules") + }) + + yield await SSEResponse.send_progress("完成!", 100, "success") + yield await SSEResponse.send_done() + + except GeneratorExit: + logger.warning("世界观重新生成器被提前关闭") + if not db_committed and db.in_transaction(): + await db.rollback() + logger.info("世界观重新生成事务已回滚(GeneratorExit)") + except Exception as e: + logger.error(f"世界观重新生成失败: {str(e)}") + if not db_committed and db.in_transaction(): + await db.rollback() + logger.info("世界观重新生成事务已回滚(异常)") + yield await SSEResponse.send_error(f"生成失败: {str(e)}") + + +@router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观") +async def regenerate_world_building_stream( + project_id: str, + request: Request, + data: Dict[str, Any], + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) +): + """ + 使用SSE流式重新生成世界观,避免超时 + 前端使用EventSource接收实时进度和结果 + """ + # 从中间件注入user_id到data中 + if hasattr(request.state, 'user_id'): + data['user_id'] = request.state.user_id + + return create_sse_response(world_building_regenerate_generator(project_id, data, db, user_ai_service)) + + diff --git a/backend/app/models/project.py b/backend/app/models/project.py index d69d57c..b506811 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -10,7 +10,7 @@ class Project(Base): __tablename__ = "projects" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) - user_id = Column(String(36), nullable=False, index=True, comment="用户ID") + user_id = Column(String(100), nullable=False, index=True, comment="用户ID") title = Column(String(200), nullable=False, comment="项目标题") description = Column(Text, comment="项目简介") theme = Column(Text, comment="主题") diff --git a/backend/scripts/fix_user_id_length.sql b/backend/scripts/fix_user_id_length.sql new file mode 100644 index 0000000..22ff19a --- /dev/null +++ b/backend/scripts/fix_user_id_length.sql @@ -0,0 +1,9 @@ +-- 修复 projects 表中 user_id 字段长度不足的问题 +-- 将 user_id 从 VARCHAR(36) 扩展到 VARCHAR(100) + +ALTER TABLE projects ALTER COLUMN user_id TYPE VARCHAR(100); + +-- 验证修改 +SELECT column_name, data_type, character_maximum_length +FROM information_schema.columns +WHERE table_name = 'projects' AND column_name = 'user_id'; \ No newline at end of file diff --git a/frontend/package.json b/frontend/package.json index e218059..cfc1b2f 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,7 +1,7 @@ { "name": "frontend", "private": true, - "version": "1.0.3", + "version": "1.0.4", "type": "module", "scripts": { "dev": "vite",