diff --git a/Dockerfile b/Dockerfile index 79ad2e8..96c4bd7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -60,10 +60,14 @@ RUN apt-get update && apt-get install -y \ # 复制后端依赖文件 COPY backend/requirements.txt ./ -# 安装 Python 依赖(包含 torch,避免单独安装造成重复层) +# 安装 Python 依赖 +# 先安装 torch CPU版本(~200MB vs 完整版~2GB,节省90%下载时间) +# 对于embedding场景,CPU版本完全够用 RUN if [ "$USE_CN_MIRROR" = "true" ]; then \ + pip install --no-cache-dir torch==2.8.0 --index-url https://mirrors.aliyun.com/pypi/simple/ --extra-index-url https://download.pytorch.org/whl/cpu && \ pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/; \ else \ + pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu && \ pip install --no-cache-dir -r requirements.txt; \ fi diff --git a/backend/app/api/careers.py b/backend/app/api/careers.py index cacee6a..6f207c1 100644 --- a/backend/app/api/careers.py +++ b/backend/app/api/careers.py @@ -7,7 +7,7 @@ import json from typing import AsyncGenerator from app.database import get_db -from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker +from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT from app.models.career import Career, CharacterCareer from app.models.character import Character from app.models.project import Project @@ -25,6 +25,7 @@ from app.schemas.career import ( CareerStage ) from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.logger import get_logger from app.api.settings import get_user_ai_service from app.api.common import verify_project_access @@ -155,14 +156,10 @@ async def create_career( raise HTTPException(status_code=500, detail=f"创建职业失败: {str(e)}") -@router.get("/generate-system", summary="AI生成新职业(增量式,流式)") +@router.post("/generate-system", summary="AI生成新职业(增量式,流式)") async def generate_career_system( - project_id: str, - main_career_count: int = 3, - sub_career_count: int = 6, - user_requirements: str = "", - enable_mcp: bool = False, - http_request: Request = None, + request_data: CareerGenerateRequest, + http_request: Request, db: AsyncSession = Depends(get_db), user_ai_service: AIService = Depends(get_user_ai_service) ): @@ -176,6 +173,10 @@ async def generate_career_system( try: # 验证用户权限和项目是否存在 user_id = getattr(http_request.state, 'user_id', None) + project_id = request_data.project_id + main_career_count = request_data.main_career_count + sub_career_count = request_data.sub_career_count + user_requirements = request_data.user_requirements project = await verify_project_access(project_id, user_id, db) yield await tracker.start() @@ -316,7 +317,15 @@ async def generate_career_system( chunk_count = 0 estimated_total = max(3000, len(prompt) * 8) - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + async for chunk in wrap_stream_with_heartbeat( + user_ai_service.generate_text_stream(prompt=prompt), + heartbeat_interval=15.0 + ): + # 心跳哨兵:发送心跳保活,不混入AI响应 + if chunk is HEARTBEAT: + yield await tracker.heartbeat() + continue + chunk_count += 1 ai_response += chunk @@ -345,7 +354,7 @@ async def generate_career_system( # 清洗并解析JSON try: cleaned_response = user_ai_service._clean_json_response(ai_response) - career_data = json.loads(cleaned_response) + career_data = loads_json(cleaned_response) logger.info(f"✅ 职业体系JSON解析成功") except json.JSONDecodeError as e: logger.error(f"❌ 职业体系JSON解析失败: {e}") diff --git a/backend/app/api/characters.py b/backend/app/api/characters.py index d09eb1d..5051474 100644 --- a/backend/app/api/characters.py +++ b/backend/app/api/characters.py @@ -7,7 +7,7 @@ import json from typing import AsyncGenerator from app.database import get_db -from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker +from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT from app.models.character import Character from app.models.project import Project from app.models.generation_history import GenerationHistory @@ -20,6 +20,7 @@ from app.schemas.character import ( CharacterGenerateRequest ) from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.services.import_export_service import ImportExportService from app.schemas.import_export import CharactersExportRequest, CharactersImportResult @@ -947,10 +948,18 @@ async def generate_character_stream( logger.info(f"🎯 开始生成角色(流式模式)...") yield await tracker.generating(0, estimated_total, "开始生成角色...") - async for chunk in user_ai_service.generate_text_stream( - prompt=prompt, - tool_choice="required", + async for chunk in wrap_stream_with_heartbeat( + user_ai_service.generate_text_stream( + prompt=prompt, + tool_choice="required", + ), + heartbeat_interval=15.0 ): + # 心跳哨兵:发送心跳保活,不混入AI响应 + if chunk is HEARTBEAT: + yield await tracker.heartbeat() + continue + # chunk 现在可能是 dict 或 str,提取 content 字段 if isinstance(chunk, dict): content = chunk.get("content", "") @@ -987,7 +996,7 @@ async def generate_character_stream( # ✅ 使用统一的 JSON 清洗方法 try: cleaned_response = user_ai_service._clean_json_response(ai_response) - character_data = json.loads(cleaned_response) + character_data = loads_json(cleaned_response) logger.info(f"✅ 角色JSON解析成功") except json.JSONDecodeError as e: logger.error(f"❌ 角色JSON解析失败: {e}") diff --git a/backend/app/api/inspiration.py b/backend/app/api/inspiration.py index ad42c39..e07fa32 100644 --- a/backend/app/api/inspiration.py +++ b/backend/app/api/inspiration.py @@ -6,6 +6,7 @@ import json from app.database import get_db from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.api.settings import get_user_ai_service from app.services.prompt_service import PromptService from app.logger import get_logger @@ -166,7 +167,7 @@ async def generate_options( # 使用统一的JSON清洗方法 cleaned_content = ai_service._clean_json_response(content) - result = json.loads(cleaned_content) + result = loads_json(cleaned_content) # 校验返回格式 is_valid, error_msg = validate_options_response(result, step) @@ -343,7 +344,7 @@ async def refine_options( # 解析JSON try: cleaned_content = ai_service._clean_json_response(content) - result = json.loads(cleaned_content) + result = loads_json(cleaned_content) # 校验返回格式 is_valid, error_msg = validate_options_response(result, step) @@ -466,7 +467,7 @@ async def quick_generate( # 使用统一的JSON清洗方法 cleaned_content = ai_service._clean_json_response(content) - result = json.loads(cleaned_content) + result = loads_json(cleaned_content) # 合并用户已提供的信息(用户输入优先) final_result = { @@ -487,4 +488,4 @@ async def quick_generate( logger.error(f"智能补全失败: {e}", exc_info=True) return { "error": str(e) - } \ No newline at end of file + } diff --git a/backend/app/api/mcp_plugins.py b/backend/app/api/mcp_plugins.py index d732389..9720dd1 100644 --- a/backend/app/api/mcp_plugins.py +++ b/backend/app/api/mcp_plugins.py @@ -54,65 +54,75 @@ async def _register_plugin_background( plugin_type: str, server_url: str, headers: Optional[dict], - config: Optional[dict] + config: Optional[dict], + max_retries: int = 2, + retry_delay: float = 3.0 ): """ - 后台任务:注册MCP插件并更新数据库状态 + 后台任务:注册MCP插件并更新数据库状态(带重试) - 在独立的任务中执行MCP连接,避免阻塞请求处理 + 在独立的任务中执行MCP连接,避免阻塞请求处理。 + 连接失败时会自动重试,提高对临时网络问题的容错性。 """ + last_error = None + + for attempt in range(max_retries + 1): + try: + if attempt > 0: + logger.info(f"后台注册MCP插件重试 ({attempt}/{max_retries}): {plugin_name}") + await asyncio.sleep(retry_delay) + else: + logger.info(f"后台注册MCP插件: {plugin_name}") + + if plugin_type in HTTP_PLUGIN_TYPES and server_url: + server_url = _validate_mcp_server_url(plugin_type, server_url) + success = await mcp_client.register(MCPPluginConfig( + user_id=user_id, + plugin_name=plugin_name, + url=server_url, + plugin_type=plugin_type, + headers=headers, + timeout=config.get('timeout', 60.0) if config else 60.0 + )) + else: + success = False + + if success: + # 更新数据库状态为active + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with AsyncSessionLocal() as db: + stmt = ( + update(MCPPlugin) + .where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name) + .values(status="active", last_error=None) + ) + await db.execute(stmt) + await db.commit() + logger.info(f"后台注册MCP插件成功: {plugin_name}") + return + else: + last_error = "连接失败" + + except Exception as e: + last_error = str(e) + logger.warning(f"后台注册MCP插件异常 (尝试 {attempt + 1}/{max_retries + 1}): {plugin_name}, 错误: {e}") + + # 所有重试都失败,更新数据库状态为error + logger.error(f"后台注册MCP插件最终失败 (已重试{max_retries}次): {plugin_name}, 错误: {last_error}") try: - logger.info(f"后台注册MCP插件: {plugin_name}") - - if plugin_type in HTTP_PLUGIN_TYPES and server_url: - server_url = _validate_mcp_server_url(plugin_type, server_url) - success = await mcp_client.register(MCPPluginConfig( - user_id=user_id, - plugin_name=plugin_name, - url=server_url, - plugin_type=plugin_type, - headers=headers, - timeout=config.get('timeout', 60.0) if config else 60.0 - )) - else: - success = False - - # 更新数据库状态 engine = await get_engine(user_id) AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - async with AsyncSessionLocal() as db: stmt = ( update(MCPPlugin) .where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name) - .values( - status="active" if success else "error", - last_error=None if success else "连接失败" - ) + .values(status="error", last_error=str(last_error)[:500] if last_error else "连接失败") ) await db.execute(stmt) await db.commit() - - if success: - logger.info(f"后台注册MCP插件成功: {plugin_name}") - else: - logger.warning(f"后台注册MCP插件失败: {plugin_name}") - - except Exception as e: - logger.error(f"后台注册MCP插件异常: {plugin_name}, 错误: {e}") - try: - engine = await get_engine(user_id) - AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - async with AsyncSessionLocal() as db: - stmt = ( - update(MCPPlugin) - .where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name) - .values(status="error", last_error=str(e)) - ) - await db.execute(stmt) - await db.commit() - except Exception as db_error: - logger.error(f"更新插件状态失败: {db_error}") + except Exception as db_error: + logger.error(f"更新插件状态失败: {db_error}") async def _unregister_plugin_safe(user_id: str, plugin_name: str): @@ -215,22 +225,26 @@ async def create_plugin( **plugin_data ) + # 如果启用,设为pending状态等待后台连接 + if plugin.enabled: + plugin.status = "pending" + db.add(plugin) await db.commit() await db.refresh(plugin) - # 如果启用,注册到统一门面 + # 如果启用,后台注册到统一门面(避免MCP操作阻塞导致超时) if plugin.enabled: - success = await _register_plugin_to_facade(plugin, user.user_id) - if success: - plugin.status = "active" - else: - plugin.status = "error" - plugin.last_error = "加载失败" - await db.commit() - await db.refresh(plugin) + asyncio.create_task(_register_plugin_background( + user_id=user.user_id, + plugin_name=plugin.plugin_name, + plugin_type=plugin.plugin_type, + server_url=plugin.server_url, + headers=plugin.headers, + config=plugin.config + )) - logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}") + logger.info(f"用户 {user.user_id} 创建插件: {plugin.plugin_name}(MCP注册在后台执行)") return plugin @@ -438,15 +452,29 @@ async def update_plugin( for key, value in update_data.items(): setattr(plugin, key, value) + # 如果启用,设为pending状态等待后台连接 + if plugin.enabled: + plugin.status = "pending" + plugin.last_error = None + await db.commit() await db.refresh(plugin) - # 如果插件已启用,重新注册 + # 如果插件已启用,后台重新注册MCP连接 if plugin.enabled: - await mcp_client.unregister(user.user_id, plugin.plugin_name) - await _register_plugin_to_facade(plugin, user.user_id) + # 先后台注销旧连接 + asyncio.create_task(_unregister_plugin_safe(user.user_id, plugin.plugin_name)) + # 再后台注册新连接 + asyncio.create_task(_register_plugin_background( + user_id=user.user_id, + plugin_name=plugin.plugin_name, + plugin_type=plugin.plugin_type, + server_url=plugin.server_url, + headers=plugin.headers, + config=plugin.config + )) - logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}") + logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}(MCP操作在后台执行)") return plugin @@ -470,15 +498,19 @@ async def delete_plugin( if not plugin: raise HTTPException(status_code=404, detail="插件不存在") - # 从统一门面注销 - await mcp_client.unregister(user.user_id, plugin.plugin_name) + # 保存插件信息用于后台注销 + plugin_name = plugin.plugin_name + user_id = user.user_id - # 删除数据库记录 + # 先删除数据库记录 await db.delete(plugin) await db.commit() - logger.info(f"用户 {user.user_id} 删除插件: {plugin.plugin_name}") - return {"message": "插件已删除", "plugin_name": plugin.plugin_name} + # 后台从统一门面注销(避免MCP操作阻塞导致超时) + asyncio.create_task(_unregister_plugin_safe(user_id, plugin_name)) + + logger.info(f"用户 {user.user_id} 删除插件: {plugin_name}(MCP注销在后台执行)") + return {"message": "插件已删除", "plugin_name": plugin_name} @router.post("/{plugin_id}/toggle", response_model=MCPPluginResponse) @@ -490,6 +522,10 @@ async def toggle_plugin( ): """ 启用或禁用插件 + + 启用时:先更新数据库状态为pending,再通过后台任务注册MCP连接, + 避免长时间持有数据库会话导致超时。 + 禁用时:先更新数据库状态,再通过后台任务注销MCP连接。 """ result = await db.execute( select(MCPPlugin).where( @@ -509,51 +545,35 @@ async def toggle_plugin( headers = plugin.headers config = plugin.config - # 先更新数据库状态 + # 更新数据库状态 plugin.enabled = enabled - if not enabled: + if enabled: + # 启用时先设为pending状态,等待后台MCP连接完成 + plugin.status = "pending" + plugin.last_error = None + else: plugin.status = "inactive" await db.commit() await db.refresh(plugin) - # 数据库操作完成后,再进行MCP操作 + # 数据库操作完成后,通过后台任务进行MCP操作(避免长时间持有数据库会话) if enabled: - # 启用:注册到统一门面 - try: - if plugin_type in HTTP_PLUGIN_TYPES and server_url: - server_url = _validate_mcp_server_url(plugin_type, server_url) - success = await mcp_client.register(MCPPluginConfig( - user_id=user.user_id, - plugin_name=plugin_name, - url=server_url, - plugin_type=plugin_type, - headers=headers, - timeout=config.get('timeout', 60.0) if config else 60.0 - )) - else: - success = False - - # 更新状态 - plugin.status = "active" if success else "error" - plugin.last_error = None if success else "加载失败" - await db.commit() - await db.refresh(plugin) - except Exception as e: - logger.error(f"注册插件失败: {plugin_name}, 错误: {e}") - plugin.status = "error" - plugin.last_error = str(e) - await db.commit() - await db.refresh(plugin) + # 启用:后台注册到统一门面 + asyncio.create_task(_register_plugin_background( + user_id=user.user_id, + plugin_name=plugin_name, + plugin_type=plugin_type, + server_url=server_url, + headers=headers, + config=config + )) else: - # 禁用:从统一门面注销(不影响数据库状态) - try: - await mcp_client.unregister(user.user_id, plugin_name) - except Exception as e: - logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}") + # 禁用:后台从统一门面注销(不影响数据库状态) + asyncio.create_task(_unregister_plugin_safe(user.user_id, plugin_name)) action = "启用" if enabled else "禁用" - logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}") + logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}(MCP操作在后台执行)") return plugin diff --git a/backend/app/api/organizations.py b/backend/app/api/organizations.py index 5082286..acddb4b 100644 --- a/backend/app/api/organizations.py +++ b/backend/app/api/organizations.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field import json from app.database import get_db -from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker +from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker, wrap_stream_with_heartbeat, HEARTBEAT from app.models.relationship import Organization, OrganizationMember from app.models.character import Character from app.models.project import Project @@ -24,6 +24,7 @@ from app.schemas.relationship import ( ) from app.schemas.character import CharacterResponse from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.logger import get_logger from app.api.settings import get_user_ai_service @@ -500,7 +501,15 @@ async def generate_organization_stream( chunk_count = 0 estimated_total = max(3000, len(prompt) * 8) - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): + async for chunk in wrap_stream_with_heartbeat( + user_ai_service.generate_text_stream(prompt=prompt), + heartbeat_interval=15.0 + ): + # 心跳哨兵:发送心跳保活,不混入AI响应 + if chunk is HEARTBEAT: + yield await tracker.heartbeat() + continue + chunk_count += 1 ai_content += chunk @@ -529,7 +538,7 @@ async def generate_organization_stream( # ✅ 使用统一的 JSON 清洗方法 try: cleaned_response = user_ai_service._clean_json_response(ai_content) - organization_data = json.loads(cleaned_response) + organization_data = loads_json(cleaned_response) logger.info(f"✅ 组织JSON解析成功") except json.JSONDecodeError as e: logger.error(f"❌ 组织JSON解析失败: {e}") diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index ba1efab..aac00e4 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -27,6 +27,7 @@ from app.schemas.outline import ( CreateChaptersFromPlansResponse ) from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.services.memory_service import memory_service from app.services.plot_expansion_service import PlotExpansionService @@ -850,7 +851,7 @@ def _parse_ai_response(ai_response: str, raise_on_error: bool = False) -> list: ai_service_temp = AIService() cleaned_text = ai_service_temp._clean_json_response(ai_response) - outline_data = json.loads(cleaned_text) + outline_data = loads_json(cleaned_text) # 确保是列表格式 if not isinstance(outline_data, list): @@ -1447,6 +1448,31 @@ async def continue_outline_generator( message=f"🤖 调用AI生成第{str(batch_num + 1)}批..." ) + # 获取伏笔提醒信息(用于大纲续写) + foreshadow_reminders_text = "暂无需要关注的伏笔" + try: + foreshadow_context = await foreshadow_service.build_chapter_context( + db=db, + project_id=project_id, + chapter_number=current_start_chapter, + include_pending=False, + include_overdue=True, + lookahead=10 + ) + if foreshadow_context and foreshadow_context.get("context_text"): + foreshadow_reminders_text = foreshadow_context["context_text"] + logger.info(f"✅ 大纲续写获取到伏笔提醒: {len(foreshadow_reminders_text)}字符") + # 追加伏笔统计信息 + foreshadow_stats = await foreshadow_service.get_stats(db, project_id) + if foreshadow_stats: + planted = foreshadow_stats.get('planted', 0) + resolved = foreshadow_stats.get('resolved', 0) + partial = foreshadow_stats.get('partially_resolved', 0) + pending = foreshadow_stats.get('pending', 0) + foreshadow_reminders_text += f"\n【📊 伏笔统计】已埋设:{planted} 已回收:{resolved} 部分回收:{partial} 待埋入:{pending}" + except Exception as e: + logger.warning(f"⚠️ 获取大纲续写伏笔提醒失败: {str(e)}") + # 使用标准续写提示词模板(简化版) template = await PromptService.get_template("OUTLINE_CONTINUE", user_id, db) prompt = PromptService.format_prompt( @@ -1463,6 +1489,8 @@ async def continue_outline_generator( # 上下文信息 recent_outlines=context['recent_outlines'], characters_info=context['characters_info'], + # 伏笔提醒 + foreshadow_reminders=foreshadow_reminders_text, # 续写参数 chapter_count=current_batch_size, start_chapter=current_start_chapter, @@ -2482,4 +2510,4 @@ async def create_chapters_from_existing_plans( except Exception as e: logger.error(f"根据已有规划创建章节失败: {str(e)}", exc_info=True) await db.rollback() - raise HTTPException(status_code=500, detail=f"创建章节失败: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=500, detail=f"创建章节失败: {str(e)}") diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 8061eb8..ea45513 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -16,6 +16,7 @@ from app.models.relationship import CharacterRelationship, Organization, Organiz from app.models.writing_style import WritingStyle from app.models.project_default_style import ProjectDefaultStyle from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.services.plot_expansion_service import PlotExpansionService from app.logger import get_logger @@ -169,7 +170,7 @@ async def world_building_generator( logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}") logger.info(f" 清洗后预览: {cleaned_text[:300]}...") - world_data = json.loads(cleaned_text) + world_data = loads_json(cleaned_text) logger.info(f"✅ 世界观JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") world_generation_success = True # 解析成功,标记完成 @@ -433,7 +434,7 @@ async def career_system_generator( # 清洗并解析JSON try: cleaned_response = user_ai_service._clean_json_response(career_response) - career_data = json.loads(cleaned_response) + career_data = loads_json(cleaned_response) logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})") yield await tracker.saving("保存职业数据...") @@ -771,7 +772,7 @@ async def characters_generator( # 解析批次结果 - 使用统一的JSON清洗方法 cleaned_text = user_ai_service._clean_json_response(accumulated_text) - characters_data = json.loads(cleaned_text) + characters_data = loads_json(cleaned_text) if not isinstance(characters_data, list): characters_data = [characters_data] @@ -1362,7 +1363,7 @@ async def outline_generator( try: cleaned_text = user_ai_service._clean_json_response(accumulated_text) - outline_data = json.loads(cleaned_text) + outline_data = loads_json(cleaned_text) if not isinstance(outline_data, list): outline_data = [outline_data] except json.JSONDecodeError as e: @@ -1668,7 +1669,7 @@ async def world_building_regenerate_generator( cleaned_text = user_ai_service._clean_json_response(accumulated_text) logger.info(f"✅ JSON清洗完成,清洗后长度: {len(cleaned_text)}") - world_data = json.loads(cleaned_text) + world_data = loads_json(cleaned_text) logger.info(f"✅ 世界观重新生成JSON解析成功(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") world_generation_success = True diff --git a/backend/app/mcp/facade.py b/backend/app/mcp/facade.py index 4f7572b..cf3c697 100644 --- a/backend/app/mcp/facade.py +++ b/backend/app/mcp/facade.py @@ -316,6 +316,9 @@ class MCPClientFacade: if key in self._sessions: await self._close_session_unsafe(key) + stream_ctx = None + session = None + try: logger.info(f"🔗 连接MCP服务器: {config.plugin_name} -> {config.url} (类型: {config.plugin_type})") @@ -365,11 +368,19 @@ class MCPClientFacade: error_details.append(f"{type(exc).__name__}: {exc}") error_msg = "; ".join(error_details) logger.error(f"❌ MCP连接失败 {key}: TaskGroup异常 - {error_msg}") + + # 在同一任务中清理已创建的上下文,避免跨任务清理cancel scope + await self._cleanup_contexts_in_task(session, stream_ctx) + await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "error", error_msg) return False except Exception as e: logger.error(f"❌ MCP连接失败 {key}: {type(e).__name__}: {e}") + + # 在同一任务中清理已创建的上下文,避免跨任务清理cancel scope + await self._cleanup_contexts_in_task(session, stream_ctx) + await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "error", str(e)) return False @@ -392,6 +403,27 @@ class MCPClientFacade: await self._emit_status_change(user_id, plugin_name, old_status, "inactive", "已注销") + async def _cleanup_contexts_in_task(self, session, stream_ctx): + """在当前任务中清理已创建的上下文(异步方法) + + 当MCP连接失败时,上下文(cancel scope)必须在与创建时相同的任务中清理。 + 由于异常处理和上下文创建在同一个任务中,这里可以安全地await __aexit__。 + """ + # 先清理session,再清理stream(LIFO顺序) + if session is not None: + try: + await session.__aexit__(None, None, None) + except Exception as e: + logger.debug(f"清理session上下文: {e}") + + if stream_ctx is not None: + try: + await stream_ctx.__aexit__(None, None, None) + except Exception as e: + logger.debug(f"清理stream上下文: {e}") + + logger.debug("已在当前任务中清理MCP上下文") + async def _close_session_unsafe(self, key: str): """关闭会话(不加用户锁,需要调用者确保线程安全)""" async with self._session_lock: diff --git a/backend/app/schemas/career.py b/backend/app/schemas/career.py index e95aeae..fdaca6b 100644 --- a/backend/app/schemas/career.py +++ b/backend/app/schemas/career.py @@ -78,6 +78,7 @@ class CareerGenerateRequest(BaseModel): project_id: str = Field(..., description="项目ID") main_career_count: int = Field(5, description="主职业数量", ge=1, le=20) sub_career_count: int = Field(8, description="副职业数量", ge=0, le=30) + user_requirements: str = Field("", description="用户额外要求") enable_mcp: bool = Field(False, description="是否启用MCP工具增强") diff --git a/backend/app/services/foreshadow_service.py b/backend/app/services/foreshadow_service.py index 2c4fbb1..555e47c 100644 --- a/backend/app/services/foreshadow_service.py +++ b/backend/app/services/foreshadow_service.py @@ -1283,6 +1283,10 @@ class ForeshadowService: # 预先获取所有已埋入的伏笔,用于内容匹配 planted_foreshadows = await self.get_planted_foreshadows_for_analysis(db, project_id) + # 每章最多创建的新伏笔数量 + MAX_NEW_FORESHADOWS_PER_CHAPTER = 2 + new_foreshadow_count = 0 + for fs_data in analysis_foreshadows: try: fs_type = fs_data.get("type", "planted") @@ -1416,6 +1420,11 @@ class ForeshadowService: logger.info(f"📝 更新已存在伏笔(避免重复): {fs_title} (ID: {existing_fs.id})") else: # 创建新伏笔 + # 检查每章新伏笔数量上限 + if new_foreshadow_count >= MAX_NEW_FORESHADOWS_PER_CHAPTER: + logger.info(f"🚫 已达每章新伏笔上限({MAX_NEW_FORESHADOWS_PER_CHAPTER}个),跳过: {fs_title}") + continue + # 不再为 estimated_resolve_chapter 设置默认值,避免误报"超期" estimated_resolve = fs_data.get("estimated_resolve_chapter") if estimated_resolve is None: @@ -1448,10 +1457,11 @@ class ForeshadowService: db.add(new_foreshadow) await db.flush() + new_foreshadow_count += 1 stats["planted_count"] += 1 stats["created_count"] += 1 stats["created_ids"].append(new_foreshadow.id) - logger.info(f"✅ 自动创建伏笔: {fs_title} (ID: {new_foreshadow.id})") + logger.info(f"✅ 自动创建伏笔: {fs_title} (ID: {new_foreshadow.id}) [{new_foreshadow_count}/{MAX_NEW_FORESHADOWS_PER_CHAPTER}]") except Exception as item_error: error_msg = f"处理伏笔时出错: {str(item_error)}" diff --git a/backend/app/services/json_helper.py b/backend/app/services/json_helper.py index 31310e5..318a03e 100644 --- a/backend/app/services/json_helper.py +++ b/backend/app/services/json_helper.py @@ -4,9 +4,154 @@ import re from typing import Any, Dict, List, Union from app.logger import get_logger +try: + import json5 + HAS_JSON5 = True +except ImportError: + HAS_JSON5 = False + logger = get_logger(__name__) +# 中文引号/括号到ASCII的映射 +_QUOTE_MAP = { + '\u201c': '"', # " → " + '\u201d': '"', # " → " + '\u2018': "'", # ' → ' + '\u2019': "'", # ' → ' + '\u300e': '"', # 『 → " + '\u300f': '"', # 』 → " + '\u300c': '"', # 「 → " + '\u300d': '"', # 」 → " +} + + +def _fix_json_string_values(text: str) -> str: + """ + 修复JSON字符串值中的常见问题: + 1. 裸换行符/制表符 → 转义 + 2. 字符串值内的中文引号 → 转义为ASCII引号(避免破坏JSON结构) + 3. 结构位置的中文引号 → 直接替换为ASCII引号 + + AI生成的JSON常在字符串值中插入未转义的换行符和中文引号。 + 此函数遍历文本,区分字符串内外,分别处理。 + """ + if not text or '"' not in text: + return text + + result = [] + i = 0 + in_string = False + fixed_count = 0 + + while i < len(text): + c = text[i] + + if c == '"' and not in_string: + # 进入字符串 + in_string = True + result.append(c) + i += 1 + continue + + if in_string: + if c == '\\': + # 转义字符,检查下一个字符是否合法 + if i + 1 < len(text): + next_c = text[i + 1] + # JSON 合法转义:\" \\ \/ \b \f \n \r \t \uXXXX + if next_c in ('"', '\\', '/', 'b', 'f', 'n', 'r', 't'): + # 合法转义,直接保留 + result.append(c) + result.append(next_c) + i += 2 + continue + elif next_c == 'u': + # Unicode 转义 \uXXXX,检查是否有4个十六进制字符 + if i + 5 < len(text) and all(text[i+2+k] in '0123456789abcdefABCDEF' for k in range(4)): + result.append(text[i:i+6]) + i += 6 + continue + else: + # 不完整的unicode转义,去掉反斜杠 + result.append(next_c) + fixed_count += 1 + i += 2 + continue + else: + # 非法转义字符(如 \c \p \d 等),去掉反斜杠只保留字符 + result.append(next_c) + fixed_count += 1 + i += 2 + continue + else: + # 末尾孤立的反斜杠,去掉 + fixed_count += 1 + i += 1 + continue + + if c == '"': + # 字符串结束 + in_string = False + result.append(c) + i += 1 + continue + + if c == '\n': + # 裸换行符 → 替换为转义换行 + result.append('\\') + result.append('n') + fixed_count += 1 + i += 1 + continue + + if c == '\r': + # 裸回车符 → 忽略或替换 + if i + 1 < len(text) and text[i + 1] == '\n': + result.append('\\') + result.append('n') + fixed_count += 1 + i += 2 + else: + result.append('\\') + result.append('n') + fixed_count += 1 + i += 1 + continue + + if c == '\t': + # 裸制表符 → 替换为转义制表符 + result.append('\\') + result.append('t') + fixed_count += 1 + i += 1 + continue + + # 字符串值内的中文引号 → 转义为 \"(避免破坏JSON结构) + if c in _QUOTE_MAP: + result.append('\\') + result.append(_QUOTE_MAP[c]) + fixed_count += 1 + i += 1 + continue + + # 非字符串内的字符 + # 结构位置的中文引号 → 直接替换 + if not in_string and c in _QUOTE_MAP: + result.append(_QUOTE_MAP[c]) + fixed_count += 1 + i += 1 + continue + + result.append(c) + i += 1 + + if fixed_count > 0: + logger.debug(f"✅ 修复了{fixed_count}个JSON问题(裸控制字符/中文引号)") + + return ''.join(result) + + def clean_json_response(text: str) -> str: """清洗 AI 返回的 JSON(改进版 - 流式安全)""" try: @@ -17,6 +162,13 @@ def clean_json_response(text: str) -> str: original_length = len(text) logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}") + # 替换中文逗号/冒号(AI可能在JSON结构位置使用,全局替换是安全的) + text = text.replace('\uff0c', ',') # ,→ , + text = text.replace('\uff1a', ':') # :→ : + + # 修复JSON中的中文引号和裸控制字符(上下文感知,区分字符串内外) + text = _fix_json_string_values(text) + # 去除 markdown 代码块 text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE) text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE) @@ -148,12 +300,54 @@ def clean_json_response(text: str) -> str: def parse_json(text: str) -> Union[Dict, List]: - """解析 JSON""" + """解析 JSON,优先使用标准json,失败后用json5容错解析""" + cleaned = clean_json_response(text) + + # 优先使用标准 json try: - cleaned = clean_json_response(text) return json.loads(cleaned) - except Exception as e: - logger.error(f"❌ parse_json 出错: {e}") - logger.error(f" 原始文本长度: {len(text) if text else 0}") - logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}") - raise \ No newline at end of file + except (json.JSONDecodeError, Exception): + pass + + # json5 容错解析(处理单引号、多余逗号、宽松格式等) + if HAS_JSON5: + try: + logger.info("🔄 标准JSON解析失败,使用json5容错解析") + result = json5.loads(cleaned) + logger.info("✅ json5容错解析成功") + return result + except Exception as e5: + logger.error(f"❌ json5容错解析也失败: {e5}") + + # 最终失败 + logger.error(f"❌ parse_json 完全失败") + logger.error(f" 原始文本长度: {len(text) if text else 0}") + logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}") + logger.debug(f" 清洗后文本预览: {cleaned[:500] if cleaned else 'None'}") + raise json.JSONDecodeError("JSON解析失败(标准和json5均失败)", cleaned, 0) + + +def loads_json(text: str) -> Any: + """ + json.loads 的容错替代品,可直接替换 json.loads()。 + 优先用标准 json.loads,失败后自动降级到 json5。 + 适用于解析 AI 返回的、可能包含不规范格式的 JSON。 + """ + # 优先使用标准 json + try: + return json.loads(text) + except (json.JSONDecodeError, Exception): + pass + + # json5 容错解析 + if HAS_JSON5: + try: + logger.info("🔄 json.loads失败,使用json5容错解析") + result = json5.loads(text) + logger.info("✅ json5容错解析成功") + return result + except Exception as e5: + logger.error(f"❌ json5容错解析也失败: {e5}") + + # 最终失败,抛出标准异常 + raise json.JSONDecodeError("JSON解析失败(标准和json5均失败)", text, 0) diff --git a/backend/app/services/plot_analyzer.py b/backend/app/services/plot_analyzer.py index c13992a..ec104ce 100644 --- a/backend/app/services/plot_analyzer.py +++ b/backend/app/services/plot_analyzer.py @@ -2,6 +2,7 @@ from typing import Dict, Any, List, Optional, Callable, Awaitable from sqlalchemy.ext.asyncio import AsyncSession from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.logger import get_logger import json @@ -277,7 +278,7 @@ class PlotAnalyzer: cleaned = self.ai_service._clean_json_response(response) # 尝试解析JSON - result = json.loads(cleaned) + result = loads_json(cleaned) # 验证必要字段 required_fields = ['hooks', 'plot_points', 'scores'] @@ -594,4 +595,4 @@ def get_plot_analyzer(ai_service: AIService) -> PlotAnalyzer: global _plot_analyzer_instance if _plot_analyzer_instance is None: _plot_analyzer_instance = PlotAnalyzer(ai_service) - return _plot_analyzer_instance \ No newline at end of file + return _plot_analyzer_instance diff --git a/backend/app/services/plot_expansion_service.py b/backend/app/services/plot_expansion_service.py index b7fbe7c..c95e799 100644 --- a/backend/app/services/plot_expansion_service.py +++ b/backend/app/services/plot_expansion_service.py @@ -9,6 +9,7 @@ from app.models.project import Project from app.models.character import Character from app.models.chapter import Chapter from app.services.ai_service import AIService +from app.services.json_helper import loads_json from app.services.prompt_service import prompt_service, PromptService from app.logger import get_logger @@ -531,7 +532,7 @@ class PlotExpansionService: cleaned_text = self.ai_service._clean_json_response(ai_response) # 解析JSON - chapter_plans = json.loads(cleaned_text) + chapter_plans = loads_json(cleaned_text) # 确保是列表 if not isinstance(chapter_plans, list): @@ -687,4 +688,4 @@ class PlotExpansionService: # 工厂函数 def create_plot_expansion_service(ai_service: AIService) -> PlotExpansionService: """创建剧情展开服务实例""" - return PlotExpansionService(ai_service) \ No newline at end of file + return PlotExpansionService(ai_service) diff --git a/backend/app/utils/sse_response.py b/backend/app/utils/sse_response.py index e624ca2..4a07ad3 100644 --- a/backend/app/utils/sse_response.py +++ b/backend/app/utils/sse_response.py @@ -388,6 +388,42 @@ async def create_sse_generator( yield await SSEResponse.send_error(str(e)) +class _HeartbeatSentinel: + """心跳哨兵对象,用于标识心跳事件(非AI内容)""" + pass + +HEARTBEAT = _HeartbeatSentinel() + + +async def wrap_stream_with_heartbeat( + async_gen: AsyncGenerator, + heartbeat_interval: float = 15.0 +) -> AsyncGenerator: + """ + 包装异步生成器,在等待数据时产生心跳哨兵,防止连接超时断开。 + + 用法: + async for chunk in wrap_stream_with_heartbeat( + ai_service.generate_text_stream(prompt), + heartbeat_interval=15 + ): + if chunk is HEARTBEAT: + yield await tracker.heartbeat() + continue + # chunk 是原始AI数据 + """ + ait = async_gen.__aiter__() + while True: + try: + item = await asyncio.wait_for(ait.__anext__(), timeout=heartbeat_interval) + yield item + except asyncio.TimeoutError: + # 等待超时,产生心跳哨兵 + yield HEARTBEAT + except StopAsyncIteration: + return + + def create_sse_response(generator: AsyncGenerator[str, None]) -> StreamingResponse: """ 创建SSE StreamingResponse - 兼容HTTP/2协议 diff --git a/backend/requirements.txt b/backend/requirements.txt index f4dec26..d933890 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -35,5 +35,8 @@ transformers==4.57.1 # Sentence Transformers(更新到最新稳定版本以修复 FutureWarning) sentence-transformers==5.1.2 +# 宽松JSON解析 +json5==0.12.0 + # PyTorch 版本锁定(用于打包环境) -torch==2.8.0 \ No newline at end of file +torch==2.8.0 diff --git a/frontend/src/pages/Careers.tsx b/frontend/src/pages/Careers.tsx index c229c84..c4c0abd 100644 --- a/frontend/src/pages/Careers.tsx +++ b/frontend/src/pages/Careers.tsx @@ -171,47 +171,67 @@ export default function Careers() { try { const userRequirements = values.user_requirements?.trim() || ''; - const eventSource = new EventSource( - `/api/careers/generate-system?` + - new URLSearchParams({ + + // 使用 fetch + POST 替代 EventSource GET,避免 URL 长度限制 + const response = await fetch('/api/careers/generate-system', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + credentials: 'include', + body: JSON.stringify({ project_id: projectId || '', - main_career_count: values.main_career_count.toString(), - sub_career_count: values.sub_career_count.toString(), + main_career_count: values.main_career_count, + sub_career_count: values.sub_career_count, user_requirements: userRequirements, - enable_mcp: 'false' - }).toString(), - { withCredentials: true } - ); + enable_mcp: false + }) + }); - eventSource.onmessage = (event) => { - try { - const data = JSON.parse(event.data); - - if (data.type === 'progress') { - setAiProgress(data.progress || 0); - setAiMessage(data.message || ''); - } else if (data.type === 'done') { - eventSource.close(); - setTimeout(() => { - setAiGenerating(false); - message.success('AI新职业生成完成!'); - fetchCareers(); - }, 1000); - } else if (data.type === 'error') { - eventSource.close(); - setAiGenerating(false); - message.error(data.message || '生成失败'); - } - } catch (e) { - console.error('解析SSE数据失败:', e); - } - }; - - eventSource.onerror = () => { - eventSource.close(); + if (!response.ok || !response.body) { setAiGenerating(false); - message.error('连接中断,生成失败'); - }; + message.error(`请求失败: ${response.status}`); + return; + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.startsWith('data: ')) { + try { + const data = JSON.parse(line.slice(6)); + + if (data.type === 'progress') { + setAiProgress(data.progress || 0); + setAiMessage(data.message || ''); + } else if (data.type === 'done') { + setTimeout(() => { + setAiGenerating(false); + message.success('AI新职业生成完成!'); + fetchCareers(); + }, 1000); + } else if (data.type === 'error') { + setAiGenerating(false); + message.error(data.error || data.message || '生成失败'); + } + } catch (e) { + // 忽略非JSON行(如心跳注释) + } + } + } + } + + setAiGenerating(false); } catch (err: unknown) { setAiGenerating(false); const error = err as Error; diff --git a/frontend/src/utils/sseClient.ts b/frontend/src/utils/sseClient.ts index dce7cfb..1788c06 100644 --- a/frontend/src/utils/sseClient.ts +++ b/frontend/src/utils/sseClient.ts @@ -151,6 +151,7 @@ export class SSEPostClient { headers: { 'Content-Type': 'application/json', }, + credentials: 'include', body: JSON.stringify(this.data), signal: this.abortController.signal, });