diff --git a/backend/app/api/chapters.py b/backend/app/api/chapters.py index 7e52559..ec77aee 100644 --- a/backend/app/api/chapters.py +++ b/backend/app/api/chapters.py @@ -45,7 +45,7 @@ from app.services.memory_service import memory_service from app.services.chapter_regenerator import ChapterRegenerator from app.logger import get_logger from app.api.settings import get_user_ai_service -from app.utils.sse_response import create_sse_response +from app.utils.sse_response import SSEResponse, create_sse_response router = APIRouter(prefix="/chapters", tags=["章节管理"]) logger = get_logger(__name__) @@ -1172,7 +1172,6 @@ async def generate_chapter_content_stream( """ style_id = generate_request.style_id target_word_count = generate_request.target_word_count or 3000 - enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True custom_model = generate_request.model if hasattr(generate_request, 'model') else None temp_narrative_perspective = generate_request.narrative_perspective if hasattr(generate_request, 'narrative_perspective') else None # 预先验证章节存在性(使用临时会话) @@ -1211,25 +1210,36 @@ async def generate_chapter_content_stream( # 获取当前用户ID(在生成器外部就需要) current_user_id = getattr(request.state, "user_id", "system") + # 初始化标准进度追踪器 + from app.utils.sse_response import WizardProgressTracker + tracker = WizardProgressTracker("章节") + try: + yield await tracker.start() + # 创建新的数据库会话 async for db_session in get_db(request): + # === 加载阶段 === + yield await tracker.loading("加载章节信息...", 0.2) + # 重新获取章节信息 chapter_result = await db_session.execute( select(Chapter).where(Chapter.id == chapter_id) ) current_chapter = chapter_result.scalar_one_or_none() if not current_chapter: - yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n" + yield await tracker.error("章节不存在", 404) return + yield await tracker.loading("加载项目信息...", 0.4) + # 获取项目信息 project_result = await db_session.execute( select(Project).where(Project.id == current_chapter.project_id) ) project = project_result.scalar_one_or_none() if not project: - yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n" + yield await tracker.error("项目不存在", 404) return # 获取项目的大纲模式 @@ -1333,80 +1343,7 @@ async def generate_chapter_content_stream( logger.info(f" - 相关记忆: {chapter_context.context_stats.get('memory_count', 0)} 条") logger.info(f" - 总上下文长度: {chapter_context.context_stats.get('total_length', 0)} 字符") - # 发送开始事件 - yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n" - - # 发送初始进度0% - yield f"data: {json.dumps({'type': 'progress', 'progress': 0, 'message': '准备生成...', 'status': 'processing'}, ensure_ascii=False)}\n\n" - - # 🔧 MCP工具增强:收集章节参考资料(优化版) - mcp_reference_materials = "" - if enable_mcp and current_user_id: - try: - # 1️⃣ 静默检查工具可用性 - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=current_user_id, - db_session=db_session - ) - - # 2️⃣ 只在有工具时才显示消息和调用 - if available_tools: - yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n" - - # 构建资料收集提示词 - planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。 - -【章节大纲】 -{outline.content if outline else current_chapter.summary or '暂无大纲'} - -【小说信息】 -- 题材:{project.genre or '未设定'} -- 主题:{project.theme or '未设定'} -- 时代背景:{project.world_time_period or '未设定'} -- 地理位置:{project.world_location or '未设定'} - -【任务】 -请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。 -你可以查询: -1. 该章节涉及的历史事件或时代背景 -2. 地理环境和场景描写参考 -3. 相关领域的专业知识(如武术、科技、魔法等) -4. 文化习俗和生活细节 - -请根据章节内容,有针对性地查询1-2个最关键的问题。""" - - # 调用MCP增强的AI(非流式,限制1轮避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=current_user_id, - db_session=db_session, - enable_mcp=True, - max_tool_rounds=2, # ✅ 减少为1轮,避免超时 - tool_choice="auto", - provider=None, - model=None - ) - - # 3️⃣ 提取参考资料并显示结果 - if planning_result.get("tool_calls_made", 0) > 0: - tool_count = planning_result["tool_calls_made"] - yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n" - mcp_reference_materials = planning_result.get("content", "") - logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符") - else: - yield f"data: {json.dumps({'type': 'progress', 'message': 'ℹ️ MCP未使用工具,继续', 'progress': 32}, ensure_ascii=False)}\n\n" - else: - logger.debug(f"用户 {current_user_id} 未启用MCP工具,跳过MCP增强") - # 未启用MCP时也发送进度,保持连贯性 - yield f"data: {json.dumps({'type': 'progress', 'message': '准备生成内容...', 'progress': 10}, ensure_ascii=False)}\n\n" - - except Exception as e: - logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}") - yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 10}, ensure_ascii=False)}\n\n" - else: - # 如果未启用MCP,也发送基础进度 - yield f"data: {json.dumps({'type': 'progress', 'message': '开始构建创作上下文...', 'progress': 10}, ensure_ascii=False)}\n\n" + yield await tracker.loading("上下文构建完成", 0.8) # 🎭 确定使用的叙事人称(临时指定 > 项目默认 > 系统默认) chapter_perspective = ( @@ -1496,26 +1433,17 @@ async def generate_chapter_content_stream( characters_info=characters_info or '暂无角色信息' ) - # 添加 MCP 参考资料(如果有) - if mcp_reference_materials: - mcp_section = f"\n\n\n{mcp_reference_materials}\n" - base_prompt = base_prompt.replace("", f"{mcp_section}\n") - logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)") - # 应用写作风格 if style_content: prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content) else: prompt = base_prompt - if mcp_reference_materials: - logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词") + # === 准备阶段 === + yield await tracker.preparing("准备AI提示词...") logger.info(f"开始AI流式创作章节 {chapter_id}") - # 发送开始生成的进度 - yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n" - # 🎨 方案一:将写作风格注入到系统提示词(最高优先级) system_prompt_with_style = None if style_content: @@ -1530,7 +1458,8 @@ async def generate_chapter_content_stream( # 准备生成参数 generate_kwargs = { "prompt": prompt, - "system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格 + "system_prompt": system_prompt_with_style, + "tool_choice": "required" } if custom_model: logger.info(f" 使用自定义模型: {custom_model}") @@ -1538,47 +1467,38 @@ async def generate_chapter_content_stream( # 注意:这里使用用户配置的AI服务,模型参数会覆盖默认模型 # 如果需要切换provider,需要在前端传递provider参数 - # 流式生成内容 + # === 生成阶段 === full_content = "" chunk_count = 0 - last_progress = 0 + + yield await tracker.generating( + current_chars=0, + estimated_total=target_word_count + ) async for chunk in user_ai_service.generate_text_stream(**generate_kwargs): full_content += chunk chunk_count += 1 # 发送内容块 - yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n" + yield await tracker.generating_chunk(chunk) - # 每5个chunk发送一次进度更新(10-95%,更平滑) + # 每5个chunk发送一次进度更新 if chunk_count % 5 == 0: - current_word_count = len(full_content) - # 优化进度计算:使用更平滑的递增方式 - # 基于chunk数量和字数的混合计算,避免大幅跳跃 - chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40% - word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45% - estimated_progress = min(95, 10 + chunk_progress + word_progress) - - # 只在进度变化时发送 - if estimated_progress > last_progress: - progress_data = { - 'type': 'progress', - 'progress': estimated_progress, - 'message': f'正在创作中... 已生成 {current_word_count} 字', - 'word_count': current_word_count, - 'status': 'processing' - } - yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n" - last_progress = estimated_progress + yield await tracker.generating( + current_chars=len(full_content), + estimated_total=target_word_count, + message=f'正在创作中... 已生成 {len(full_content)} 字' + ) # 每20个chunk发送心跳 if chunk_count % 20 == 0: - yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n" + yield await tracker.heartbeat() await asyncio.sleep(0) # 让出控制权 - # 发送保存进度 - yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n" + # === 保存阶段 === + yield await tracker.saving("正在保存章节...", 0.3) # 更新章节内容到数据库 old_word_count = current_chapter.word_count or 0 @@ -1634,25 +1554,28 @@ async def generate_chapter_content_stream( ai_service=user_ai_service ) - # 发送最终进度100% - yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n" + yield await tracker.saving("章节保存完成", 0.8) - # 发送完成事件(包含分析任务ID) - completion_data = { - 'type': 'done', - 'message': '创作完成', + # === 完成阶段 === + yield await tracker.complete("创作完成!") + + # 发送结果数据 + yield await tracker.result({ 'word_count': new_word_count, 'analysis_task_id': task_id - } - yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n" + }) - # 发送分析开始事件 - analysis_started_data = { - 'type': 'analysis_started', - 'task_id': task_id, - 'message': '章节分析已开始' - } - yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n" + # 发送分析开始事件(使用自定义事件) + yield await SSEResponse.send_event( + event='analysis_started', + data={ + 'task_id': task_id, + 'message': '章节分析已开始' + } + ) + + # 发送完成信号 + yield await tracker.done() break # 退出async for db_session循环 @@ -1675,7 +1598,7 @@ async def generate_chapter_content_stream( logger.info("章节生成事务已回滚(异常)") except Exception as rollback_error: logger.error(f"回滚失败: {str(rollback_error)}") - yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n" + yield await tracker.error(str(e)) finally: # 确保数据库会话被正确关闭 if db_session: @@ -2813,7 +2736,8 @@ async def generate_single_chapter_for_batch( # 准备生成参数 generate_kwargs = { "prompt": prompt, - "system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格 + "system_prompt": system_prompt_with_style, + "tool_choice": "required" } # 如果传入了自定义模型,使用指定的模型 if custom_model: @@ -3029,11 +2953,16 @@ async def regenerate_chapter_stream( db_session = None db_committed = False + # 初始化标准进度追踪器 + from app.utils.sse_response import WizardProgressTracker + tracker = WizardProgressTracker("章节重新生成") + try: + yield await tracker.start() + # 创建独立数据库会话 async for db_session in get_db(request): - # 发送开始事件 - yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n" + yield await tracker.loading("加载章节信息...", 0.5) # 创建重新生成任务 regen_task = RegenerationTask( @@ -3062,13 +2991,25 @@ async def regenerate_chapter_stream( task_id = regen_task.id logger.info(f"📝 创建重新生成任务: {task_id}") - yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n" + yield await tracker.preparing("准备重新生成...") + + yield await SSEResponse.send_event( + event='task_created', + data={'task_id': task_id} + ) # 初始化重新生成器 regenerator = ChapterRegenerator(user_ai_service) - # 流式生成新内容 + # === 生成阶段 === full_content = "" + estimated_total = regenerate_request.target_word_count or len(chapter.content) + + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_total + ) + async for event in regenerator.regenerate_with_feedback( chapter=chapter, analysis=analysis, @@ -3083,19 +3024,35 @@ async def regenerate_chapter_stream( # 内容块 chunk = event['content'] full_content += chunk - yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n" + yield await tracker.generating_chunk(chunk) + + # 定期更新进度 + if len(full_content) % 500 == 0: + yield await tracker.generating( + current_chars=len(full_content), + estimated_total=estimated_total, + message=f'重新生成中... 已生成 {len(full_content)} 字' + ) elif event['type'] == 'progress': - # 进度更新 - progress_data = { - 'type': 'progress', - 'progress': event.get('progress', 0), - 'message': event.get('message', ''), - 'word_count': event.get('word_count', 0) - } - yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n" + # 进度更新 - 映射到对应阶段 + progress = event.get('progress', 0) + message = event.get('message', '') + if progress < 20: + yield await tracker.preparing(message) + elif progress < 85: + yield await tracker.generating( + current_chars=len(full_content), + estimated_total=estimated_total, + message=message + ) + else: + yield await tracker.parsing(message) await asyncio.sleep(0) + # === 保存阶段 === + yield await tracker.saving("保存重新生成的内容...", 0.5) + # 更新任务状态 regen_task.status = 'completed' regen_task.regenerated_content = full_content @@ -3108,25 +3065,22 @@ async def regenerate_chapter_stream( await db_session.commit() db_committed = True - # 先发送结果数据 - result_data = { - 'type': 'result', - 'data': { - 'task_id': task_id, - 'word_count': len(full_content), - 'version_number': regen_task.version_number, - 'auto_applied': regenerate_request.auto_apply, - 'diff_stats': diff_stats - } - } - yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n" + yield await tracker.saving("保存完成", 0.9) - # 再发送完成事件 - completion_data = { - 'type': 'done', - 'message': '重新生成完成' - } - yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n" + # === 完成阶段 === + yield await tracker.complete("重新生成完成!") + + # 发送结果数据 + yield await tracker.result({ + 'task_id': task_id, + 'word_count': len(full_content), + 'version_number': regen_task.version_number, + 'auto_applied': regenerate_request.auto_apply, + 'diff_stats': diff_stats + }) + + # 发送完成信号 + yield await tracker.done() logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}") @@ -3151,7 +3105,7 @@ async def regenerate_chapter_stream( except Exception as update_error: logger.error(f"更新任务失败状态失败: {str(update_error)}") - yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n" + yield await tracker.error(str(e)) finally: if db_session: diff --git a/backend/app/api/characters.py b/backend/app/api/characters.py index 02a7215..396580e 100644 --- a/backend/app/api/characters.py +++ b/backend/app/api/characters.py @@ -775,148 +775,46 @@ async def generate_character_stream( logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)") try: - # 🔧 MCP工具增强:静默检查并收集参考资料 + # 直接使用 AIService 流式生成 ai_response = "" chunk_count = 0 - if user_id: - try: - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 只在有工具时才调用 - if available_tools: - logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...") - result = await user_ai_service.generate_text_with_mcp( - prompt=prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - if isinstance(result, dict): - ai_response = result.get('content', '') - finish_reason = result.get('finish_reason', '') - tool_calls_made = result.get('tool_calls_made', 0) - - # 🔧 修复:检查工具调用是否真正成功 - if tool_calls_made > 0: - if finish_reason == 'tool_error': - logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式") - # 工具调用失败,重新用基础模式生成 - ai_response = "" - elif not ai_response.strip(): - logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式") - # 工具调用成功但返回空内容,重新生成 - ai_response = "" - else: - logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}") - # MCP成功且有内容,模拟流式输出(分块发送) - chunk_size = 50 - for i in range(0, len(ai_response), chunk_size): - chunk = ai_response[i:i+chunk_size] - chunk_count += 1 - yield await SSEResponse.send_chunk(chunk) - - if chunk_count % 3 == 0: - yield await SSEResponse.send_progress( - f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)", - 10 + min(85 * (i+len(chunk)) // len(ai_response), 85) - ) - - # 跳过后续的流式生成 - ai_response = result.get('content', '') - else: - ai_response = result - - # 如果MCP调用失败或返回空,继续走流式生成 - if not ai_response or not ai_response.strip(): - logger.info(f"🔄 开始流式生成...") - ai_response = "" - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): - chunk_count += 1 - ai_response += chunk - - # 发送内容块 - yield await SSEResponse.send_chunk(chunk) - - # 定期更新进度 - if chunk_count % 5 == 0: - yield await SSEResponse.send_progress( - f"AI生成角色中... ({len(ai_response)}字符)", - 10 + min(chunk_count // 2, 85) - ) - - # 心跳 - if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() - else: - logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式") - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): - chunk_count += 1 - ai_response += chunk - - # 发送内容块 - yield await SSEResponse.send_chunk(chunk) - - # 定期更新进度 - if chunk_count % 5 == 0: - yield await SSEResponse.send_progress( - f"AI生成角色中... ({len(ai_response)}字符)", - 10 + min(chunk_count // 2, 85) - ) - - # 心跳 - if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() - - except Exception as mcp_error: - logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}") - ai_response = "" - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): - chunk_count += 1 - ai_response += chunk - - # 发送内容块 - yield await SSEResponse.send_chunk(chunk) - - # 定期更新进度 - if chunk_count % 5 == 0: - yield await SSEResponse.send_progress( - f"AI生成角色中... ({len(ai_response)}字符)", - 10 + min(chunk_count // 2, 85) - ) - - # 心跳 - if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() - else: - logger.debug(f"未登录用户,使用流式基础模式") - async for chunk in user_ai_service.generate_text_stream(prompt=prompt): - chunk_count += 1 - ai_response += chunk + logger.info(f"🎯 开始生成角色(流式模式)...") + yield await SSEResponse.send_progress("🎯 开始生成角色...", 15) + + async for chunk in user_ai_service.generate_text_stream( + prompt=prompt, + tool_choice="required", + ): + # chunk 现在可能是 dict 或 str,提取 content 字段 + if isinstance(chunk, dict): + content = chunk.get("content", "") + else: + content = chunk + + if content: + ai_response += content # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await SSEResponse.send_chunk(content) - # 定期更新进度 - if chunk_count % 5 == 0: + # 定期更新进度(每收到约500字符更新一次,避免过于频繁) + current_len = len(ai_response) + if current_len >= chunk_count * 500: + chunk_count += 1 + # 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存) + # 估算最终字符数约为提示词的8倍,最少3000字符 + estimated_total = max(3000, len(prompt) * 8) + progress = min(15 + int(current_len / estimated_total * 70), 85) yield await SSEResponse.send_progress( - f"AI生成角色中... ({len(ai_response)}字符)", - 10 + min(chunk_count // 2, 85) + f"AI生成角色中... ({current_len}字符)", + progress ) # 心跳 if chunk_count % 20 == 0: yield await SSEResponse.send_heartbeat() - + except Exception as ai_error: logger.error(f"❌ AI服务调用异常:{str(ai_error)}") yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}") diff --git a/backend/app/api/mcp_plugins.py b/backend/app/api/mcp_plugins.py index 8b936a1..496e438 100644 --- a/backend/app/api/mcp_plugins.py +++ b/backend/app/api/mcp_plugins.py @@ -1,4 +1,7 @@ -"""MCP插件管理API""" +"""MCP插件管理API + +重构后使用统一的MCPClientFacade门面来管理所有MCP操作。 +""" from fastapi import APIRouter, HTTPException, Depends, Query, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select @@ -17,9 +20,8 @@ from app.schemas.mcp_plugin import ( ) import json from app.user_manager import User -from app.mcp.registry import mcp_registry +from app.mcp import mcp_client, MCPPluginConfig, PluginStatus from app.services.mcp_test_service import mcp_test_service -from app.services.mcp_tool_service import mcp_tool_service from app.logger import get_logger logger = get_logger(__name__) @@ -34,6 +36,31 @@ def require_login(request: Request) -> User: return request.state.user +async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool: + """ + 将插件注册到统一门面 + + Args: + plugin: 插件对象 + user_id: 用户ID + + Returns: + 是否注册成功 + """ + if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url: + return await mcp_client.register(MCPPluginConfig( + user_id=user_id, + plugin_name=plugin.plugin_name, + url=plugin.server_url, + plugin_type=plugin.plugin_type, + headers=plugin.headers, + timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0 + )) + else: + logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}") + return False + + @router.get("", response_model=List[MCPPluginResponse]) async def list_plugins( enabled_only: bool = Query(False, description="只返回启用的插件"), @@ -99,9 +126,9 @@ async def create_plugin( await db.commit() await db.refresh(plugin) - # 如果启用,加载到注册表 + # 如果启用,注册到统一门面 if plugin.enabled: - success = await mcp_registry.load_plugin(plugin) + success = await _register_plugin_to_facade(plugin, user.user_id) if success: plugin.status = "active" else: @@ -153,7 +180,7 @@ async def create_plugin_simple( # 提取配置 server_type = server_config.get("type", "http") - if server_type not in ["http", "stdio"]: + if server_type not in ["http", "stdio", "streamable_http", "sse"]: raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}") # 检查插件名是否已存在 @@ -175,12 +202,12 @@ async def create_plugin_simple( "sort_order": 0 } - if server_type == "http": + if server_type in ["http", "streamable_http", "sse"]: plugin_data["server_url"] = server_config.get("url") plugin_data["headers"] = server_config.get("headers", {}) if not plugin_data["server_url"]: - raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段") + raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段") elif server_type == "stdio": plugin_data["command"] = server_config.get("command") @@ -194,9 +221,9 @@ async def create_plugin_simple( # 更新现有插件 logger.info(f"插件 {plugin_name} 已存在,执行更新操作") - # 先卸载旧插件 - if existing.enabled: - await mcp_registry.unload_plugin(user.user_id, existing.plugin_name) + # 保存旧状态 + old_enabled = existing.enabled + old_plugin_name = existing.plugin_name # 更新字段 for key, value in plugin_data.items(): @@ -206,17 +233,24 @@ async def create_plugin_simple( await db.commit() await db.refresh(plugin) - # 如果启用,重新加载 + # 数据库完成后进行MCP操作 + if old_enabled: + try: + await mcp_client.unregister(user.user_id, old_plugin_name) + except Exception as e: + logger.warning(f"注销旧插件出错: {e}") + if plugin.enabled: - success = await mcp_registry.load_plugin(plugin) - if success: - plugin.status = "active" - plugin.last_error = None - else: + try: + success = await _register_plugin_to_facade(plugin, user.user_id) + plugin.status = "active" if success else "error" + plugin.last_error = None if success else "加载失败" + await db.commit() + except Exception as e: + logger.error(f"注册插件失败: {e}") plugin.status = "error" - plugin.last_error = "加载失败" - await db.commit() - await db.refresh(plugin) + plugin.last_error = str(e) + await db.commit() logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}") else: @@ -230,16 +264,18 @@ async def create_plugin_simple( await db.commit() await db.refresh(plugin) - # 如果启用,加载到注册表 + # 数据库完成后进行MCP操作 if plugin.enabled: - success = await mcp_registry.load_plugin(plugin) - if success: - plugin.status = "active" - else: + try: + success = await _register_plugin_to_facade(plugin, user.user_id) + plugin.status = "active" if success else "error" + plugin.last_error = None if success else "加载失败" + await db.commit() + except Exception as e: + logger.error(f"注册插件失败: {e}") plugin.status = "error" - plugin.last_error = "加载失败" - await db.commit() - await db.refresh(plugin) + plugin.last_error = str(e) + await db.commit() logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}") @@ -306,9 +342,10 @@ async def update_plugin( await db.commit() await db.refresh(plugin) - # 如果插件已启用,重新加载 + # 如果插件已启用,重新注册 if plugin.enabled: - await mcp_registry.reload_plugin(plugin) + await mcp_client.unregister(user.user_id, plugin.plugin_name) + await _register_plugin_to_facade(plugin, user.user_id) logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}") return plugin @@ -334,8 +371,8 @@ async def delete_plugin( if not plugin: raise HTTPException(status_code=404, detail="插件不存在") - # 从注册表卸载 - await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name) + # 从统一门面注销 + await mcp_client.unregister(user.user_id, plugin.plugin_name) # 删除数据库记录 await db.delete(plugin) @@ -366,27 +403,57 @@ async def toggle_plugin( if not plugin: raise HTTPException(status_code=404, detail="插件不存在") - plugin.enabled = enabled + # 保存插件信息用于后续MCP操作 + plugin_name = plugin.plugin_name + plugin_type = plugin.plugin_type + server_url = plugin.server_url + headers = plugin.headers + config = plugin.config - if enabled: - # 启用:加载到注册表 - success = await mcp_registry.load_plugin(plugin) - if success: - plugin.status = "active" - plugin.last_error = None - else: - plugin.status = "error" - plugin.last_error = "加载失败" - else: - # 禁用:从注册表卸载 - await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name) + # 先更新数据库状态 + plugin.enabled = enabled + if not enabled: plugin.status = "inactive" await db.commit() await db.refresh(plugin) + # 数据库操作完成后,再进行MCP操作 + if enabled: + # 启用:注册到统一门面 + try: + if plugin_type in ["http", "streamable_http", "sse"] and server_url: + success = await mcp_client.register(MCPPluginConfig( + user_id=user.user_id, + plugin_name=plugin_name, + url=server_url, + plugin_type=plugin_type, + headers=headers, + timeout=config.get('timeout', 60.0) if config else 60.0 + )) + else: + success = False + + # 更新状态 + plugin.status = "active" if success else "error" + plugin.last_error = None if success else "加载失败" + await db.commit() + await db.refresh(plugin) + except Exception as e: + logger.error(f"注册插件失败: {plugin_name}, 错误: {e}") + plugin.status = "error" + plugin.last_error = str(e) + await db.commit() + await db.refresh(plugin) + else: + # 禁用:从统一门面注销(不影响数据库状态) + try: + await mcp_client.unregister(user.user_id, plugin_name) + except Exception as e: + logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}") + action = "启用" if enabled else "禁用" - logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}") + logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}") return plugin @@ -399,7 +466,7 @@ async def test_plugin( """ 测试插件连接并调用工具验证功能 - 使用新的MCPTestService进行测试 + 使用MCPTestService进行测试 """ result = await db.execute( @@ -421,7 +488,7 @@ async def test_plugin( suggestions=["点击开关按钮启用插件"] ) - # 使用新的测试服务 + # 使用测试服务 try: test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db) @@ -447,32 +514,77 @@ async def test_plugin( raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}") -async def _ensure_plugin_loaded( +async def _ensure_plugin_registered( plugin: MCPPlugin, user_id: str ) -> bool: """ - 确保插件已加载(共享逻辑) + 确保插件已注册到统一门面 Args: plugin: 插件对象 user_id: 用户ID Returns: - 是否加载成功 + 是否成功 Raises: - HTTPException: 加载失败 + HTTPException: 注册失败 """ - if not mcp_registry.get_client(user_id, plugin.plugin_name): - logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...") - success = await mcp_registry.load_plugin(plugin) + try: + # 使用ensure_registered方法,它会检查是否已注册 + if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url: + return await mcp_client.ensure_registered( + user_id=user_id, + plugin_name=plugin.plugin_name, + url=plugin.server_url, + plugin_type=plugin.plugin_type, + headers=plugin.headers + ) + return False + except ValueError as e: + logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...") + success = await _register_plugin_to_facade(plugin, user_id) if not success: raise HTTPException( status_code=500, - detail=f"插件加载失败: {plugin.plugin_name}" + detail=f"插件注册失败: {plugin.plugin_name}" ) - return True + return True + + +@router.get("/{plugin_id}/status") +async def get_plugin_status( + plugin_id: str, + user: User = Depends(require_login), + db: AsyncSession = Depends(get_db) +): + """获取插件的实时状态(包括内存中的会话状态)""" + result = await db.execute( + select(MCPPlugin).where( + MCPPlugin.id == plugin_id, + MCPPlugin.user_id == user.user_id + ) + ) + plugin = result.scalar_one_or_none() + + if not plugin: + raise HTTPException(status_code=404, detail="插件不存在") + + session_stats = mcp_client.get_session_stats() + session_key = f"{user.user_id}:{plugin.plugin_name}" + session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None) + + return { + "plugin_id": plugin_id, + "plugin_name": plugin.plugin_name, + "db_status": plugin.status, + "session_status": session_info["status"] if session_info else None, + "is_registered": session_info is not None, + "error_rate": session_info["error_rate"] if session_info else 0, + "in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"), + "timestamp": datetime.now().isoformat() + } @router.get("/metrics") @@ -495,7 +607,8 @@ async def get_metrics( - avg_duration_ms: 平均耗时(毫秒) - last_call_time: 最后调用时间 """ - metrics = mcp_tool_service.get_metrics(tool_name) + # 使用统一门面获取指标 + metrics = mcp_client.get_metrics(tool_name) return { "metrics": metrics, @@ -518,7 +631,8 @@ async def get_cache_stats( - cache_ttl_minutes: 缓存TTL(分钟) - entries: 各缓存条目详情 """ - stats = mcp_tool_service.get_cache_stats() + # 使用统一门面获取缓存统计 + stats = mcp_client.get_cache_stats() return { "cache_stats": stats, @@ -526,6 +640,27 @@ async def get_cache_stats( } +@router.get("/sessions/stats") +async def get_session_stats( + user: User = Depends(require_login) +): + """ + 获取MCP会话统计信息 + + Returns: + 会话统计信息,包含: + - total_sessions: 会话总数 + - sessions: 各会话详情 + """ + # 使用统一门面获取会话统计 + stats = mcp_client.get_session_stats() + + return { + "session_stats": stats, + "timestamp": datetime.now().isoformat() + } + + @router.post("/cache/clear") async def clear_cache( user_id: Optional[str] = Query(None, description="用户ID(可选)"), @@ -551,7 +686,8 @@ async def clear_cache( # 如果没有指定user_id,使用当前用户 target_user_id = user_id or user.user_id - mcp_tool_service.clear_cache(target_user_id, plugin_name) + # 使用统一门面清理缓存 + mcp_client.clear_cache(target_user_id, plugin_name) message = "已清理" if plugin_name: @@ -594,12 +730,13 @@ async def get_plugin_tools( raise HTTPException(status_code=400, detail="插件未启用") try: - # 确保插件已加载 - await _ensure_plugin_loaded(plugin, user.user_id) + # 确保插件已注册 + await _ensure_plugin_registered(plugin, user.user_id) - tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name) + # 使用统一门面获取工具列表 + tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name) - # 更新缓存 + # 更新数据库中的工具缓存 plugin.tools = tools await db.commit() @@ -640,22 +777,22 @@ async def call_mcp_tool( raise HTTPException(status_code=400, detail="插件未启用") try: - # 确保插件已加载 - await _ensure_plugin_loaded(plugin, user.user_id) + # 确保插件已注册 + await _ensure_plugin_registered(plugin, user.user_id) - # 调用工具 - result = await mcp_registry.call_tool( - user.user_id, - plugin.plugin_name, - data.tool_name, - data.arguments + # 使用统一门面调用工具 + tool_result = await mcp_client.call_tool( + user_id=user.user_id, + plugin_name=plugin.plugin_name, + tool_name=data.tool_name, + arguments=data.arguments ) return { "success": True, "plugin_name": plugin.plugin_name, "tool_name": data.tool_name, - "result": result + "result": tool_result } except HTTPException: raise diff --git a/backend/app/api/outlines.py b/backend/app/api/outlines.py index dfbb415..9c91b13 100644 --- a/backend/app/api/outlines.py +++ b/backend/app/api/outlines.py @@ -36,7 +36,7 @@ from app.services.memory_service import memory_service from app.services.plot_expansion_service import PlotExpansionService from app.logger import get_logger from app.api.settings import get_user_ai_service -from app.utils.sse_response import SSEResponse, create_sse_response +from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker router = APIRouter(prefix="/outlines", tags=["大纲管理"]) logger = get_logger(__name__) @@ -597,75 +597,12 @@ async def _generate_new_outline( characters = characters_result.scalars().all() characters_info = _build_characters_info(characters) - # 🔍 MCP工具增强:收集情节设计参考资料(优化版) - mcp_reference_materials = "" - if request.enable_mcp: - try: - # 1️⃣ 静默检查工具可用性(注意:新建大纲时user_id可能不可用) - from app.services.mcp_tool_service import mcp_tool_service - # 使用传入的user_id参数 - - if user_id: - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 2️⃣ 只在有工具时才调用 - if available_tools: - logger.info(f"🔍 检测到可用MCP工具,收集大纲设计参考资料...") - - # 构建资料收集查询 - planning_query = f"""你正在为小说《{project.title}》设计完整大纲。 -项目信息: -- 主题:{request.theme or project.theme} -- 类型:{request.genre or project.genre} -- 章节数:{request.chapter_count} -- 叙事视角:{request.narrative_perspective} -- 目标字数:{request.target_words} - -世界观设定: -- 时间背景:{project.world_time_period or '未设定'} -- 地理位置:{project.world_location or '未设定'} -- 氛围基调:{project.world_atmosphere or '未设定'} - -角色信息: -{characters_info or '暂无角色'} - -请搜索: -1. 该类型小说的经典情节结构和套路 -2. 适合该主题的冲突设计思路 -3. 符合世界观的情节元素和场景设计灵感 - -请有针对性地查询1-2个最关键的问题。""" - - # 调用MCP增强的AI(非流式,限制1轮避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_query, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - mcp_reference_materials = planning_result.get("content", "") - logger.info(f"✅ MCP工具收集参考资料:{len(mcp_reference_materials)} 字符") - else: - logger.info(f"ℹ️ MCP未使用工具,继续") - else: - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") - else: - logger.debug("无用户上下文,跳过MCP增强") - except Exception as e: - logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}") - mcp_reference_materials = "" + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db - # 使用完整提示词(插入MCP参考资料,支持自定义) + # 使用提示词模板 template = await PromptService.get_template("OUTLINE_CREATE", user_id, db) prompt = PromptService.format_prompt( template, @@ -681,7 +618,7 @@ async def _generate_new_outline( rules=project.world_rules or "未设定", characters_info=characters_info or "暂无角色信息", requirements=request.requirements or "", - mcp_references=mcp_reference_materials + mcp_references="" ) # 调用AI流式生成大纲(带字数统计) @@ -691,7 +628,8 @@ async def _generate_new_outline( async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=request.provider, - model=request.model + model=request.model, + auto_mcp=request.enable_mcp ): chunk_count += 1 accumulated_text += chunk @@ -1270,66 +1208,10 @@ async def _continue_outline( logger.warning(f"⚠️ 记忆上下文构建失败,继续不使用记忆: {str(e)}") memory_context = None - # 🔍 MCP工具增强:收集续写参考资料(优化版) - mcp_reference_materials = "" - if request.enable_mcp: - try: - # 1️⃣ 静默检查工具可用性 - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 2️⃣ 只在有工具时才调用 - if available_tools: - logger.info(f"🔍 第{batch_num + 1}批:检测到可用MCP工具,收集续写参考资料...") - - # 构建资料收集查询 - latest_summary = latest_outlines[-1].content if latest_outlines else "" - planning_query = f"""你正在为小说《{project.title}》续写大纲。 -当前进度:已有{len(latest_outlines)}章,即将续写第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章 - -项目信息: -- 主题:{request.theme or project.theme} -- 类型:{request.genre or project.genre} -- 叙事视角:{request.narrative_perspective} -- 情节阶段:{request.plot_stage} -- 故事发展方向:{request.story_direction or '自然延续'} - -最近章节概要: -{latest_summary[:200]} - -请搜索: -1. 该情节阶段的经典处理手法和技巧 -2. 适合该发展方向的情节转折和冲突设计 -3. 符合类型特点的场景设计和剧情元素 - -请有针对性地查询1-2个最关键的问题。""" - - # 调用MCP增强的AI(非流式,限制1轮避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_query, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, # ✅ 减少为1轮,避免超时 - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - mcp_reference_materials = planning_result.get("content", "") - logger.info(f"✅ 第{batch_num + 1}批MCP工具收集参考资料:{len(mcp_reference_materials)} 字符") - else: - logger.info(f"ℹ️ 第{batch_num + 1}批MCP未使用工具,继续") - else: - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过第{batch_num + 1}批MCP增强") - except Exception as e: - logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,降级为基础模式: {str(e)}") - mcp_reference_materials = "" + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db # 使用标准续写提示词模板(支持记忆+MCP增强+自定义) template = await PromptService.get_template("OUTLINE_CONTINUE", user_id, db) @@ -1354,7 +1236,7 @@ async def _continue_outline( story_direction=request.story_direction or "自然延续", requirements=request.requirements or "", memory_context=memory_context, - mcp_references=mcp_reference_materials + mcp_references="" ) # 调用AI生成当前批次(带重试机制) @@ -1601,8 +1483,11 @@ async def new_outline_generator( ) -> AsyncGenerator[str, None]: """全新生成大纲SSE生成器(MCP增强版)""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("大纲") + try: - yield await SSEResponse.send_progress("开始生成大纲...", 5) + yield await tracker.start() project_id = data.get("project_id") # 确保chapter_count是整数(前端可能传字符串) @@ -1610,16 +1495,16 @@ async def new_outline_generator( enable_mcp = data.get("enable_mcp", True) # 验证项目 - yield await SSEResponse.send_progress("加载项目信息...", 10) + yield await tracker.loading("加载项目信息...", 0.3) result = await db.execute( select(Project).where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return - yield await SSEResponse.send_progress(f"准备生成{chapter_count}章大纲...", 15) + yield await tracker.loading(f"准备生成{chapter_count}章大纲...", 0.6) # 获取角色信息 characters_result = await db.execute( @@ -1628,80 +1513,14 @@ async def new_outline_generator( characters = characters_result.scalars().all() characters_info = _build_characters_info(characters) - # 🔍 MCP工具增强:收集情节设计参考资料(优化版) - mcp_reference_materials = "" - if enable_mcp: - try: - # 1️⃣ 静默检查工具可用性 - from app.services.mcp_tool_service import mcp_tool_service - # 尝试从环境获取user_id(SSE流式场景下可能没有) - # 这里可以考虑让前端传递user_id - user_id_for_mcp = data.get("user_id") # 需要前端传递 - - if user_id_for_mcp: - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id_for_mcp, - db_session=db - ) - - # 2️⃣ 只在有工具时才显示消息和调用 - if available_tools: - yield await SSEResponse.send_progress("🔍 使用MCP工具收集参考资料...", 18) - logger.info(f"🔍 检测到可用MCP工具,收集大纲设计参考资料...") - - # 构建资料收集查询 - planning_query = f"""你正在为小说《{project.title}》设计完整大纲。 -项目信息: -- 主题:{data.get('theme') or project.theme} -- 类型:{data.get('genre') or project.genre} -- 章节数:{chapter_count} -- 叙事视角:{data.get('narrative_perspective') or '第三人称'} -- 目标字数:{data.get('target_words') or project.target_words or 100000} - -世界观设定: -- 时间背景:{project.world_time_period or '未设定'} -- 地理位置:{project.world_location or '未设定'} -- 氛围基调:{project.world_atmosphere or '未设定'} - -角色信息: -{characters_info or '暂无角色'} - -请搜索: -1. 该类型小说的经典情节结构和套路 -2. 适合该主题的冲突设计思路 -3. 符合世界观的情节元素和场景设计灵感 - -请有针对性地查询1-2个最关键的问题。""" - - # 调用MCP增强的AI(非流式,限制1轮避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_query, - user_id=user_id_for_mcp, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, # ✅ 减少为1轮,避免超时 - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - mcp_reference_materials = planning_result.get("content", "") - logger.info(f"✅ MCP工具收集参考资料:{len(mcp_reference_materials)} 字符") - yield await SSEResponse.send_progress(f"✅ MCP收集到参考资料 ({len(mcp_reference_materials)}字符)", 19) - else: - logger.info(f"ℹ️ MCP未使用工具,继续") - else: - logger.debug(f"用户 {user_id_for_mcp} 未启用MCP工具,跳过MCP增强") - else: - logger.debug("无用户上下文,跳过MCP增强") - except Exception as e: - logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}") - mcp_reference_materials = "" + # 设置用户信息以启用MCP + user_id_for_mcp = data.get("user_id") + if user_id_for_mcp: + user_ai_service.user_id = user_id_for_mcp + user_ai_service.db_session = db - # 使用完整提示词(插入MCP参考资料,支持自定义) - yield await SSEResponse.send_progress("准备AI提示词...", 20) + # 使用提示词模板 + yield await tracker.preparing("准备AI提示词...") template = await PromptService.get_template("OUTLINE_CREATE", user_id_for_mcp, db) prompt = PromptService.format_prompt( template, @@ -1717,12 +1536,9 @@ async def new_outline_generator( rules=project.world_rules or "未设定", characters_info=characters_info or "暂无角色信息", requirements=data.get("requirements") or "", - mcp_references=mcp_reference_materials + mcp_references="" ) - # 调用AI流式生成 - yield await SSEResponse.send_progress("🤖 正在调用AI生成...", 30) - # 添加调试日志 model_param = data.get("model") provider_param = data.get("provider") @@ -1731,9 +1547,12 @@ async def new_outline_generator( logger.info(f" model参数: {model_param}") # ✅ 流式生成(带字数统计和进度) + estimated_total = chapter_count * 1000 accumulated_text = "" chunk_count = 0 + yield await tracker.generating(current_chars=0, estimated_total=estimated_total) + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider_param, @@ -1743,21 +1562,20 @@ async def new_outline_generator( accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 定期更新进度和字数(30-95%,AI生成占65%) - if chunk_count % 5 == 0: - progress = min(30 + (chunk_count // 2), 95) - yield await SSEResponse.send_progress( - f"AI生成大纲中... ({len(accumulated_text)}字符)", - progress + # 定期更新进度 + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=len(accumulated_text), + estimated_total=estimated_total ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() - yield await SSEResponse.send_progress("✅ AI生成完成,正在解析...", 96) + yield await tracker.parsing("解析大纲数据...") ai_content = accumulated_text ai_response = {"content": ai_content} @@ -1778,18 +1596,15 @@ async def new_outline_generator( if retry_count > max_retries: # 超过最大重试次数,使用fallback数据 logger.error(f"❌ 大纲解析失败,已达最大重试次数({max_retries}),使用fallback数据") - yield await SSEResponse.send_progress( - f"⚠️ 解析失败,使用备用数据", - 96.5 - ) + yield await tracker.warning("解析失败,使用备用数据") outline_data = _parse_ai_response(ai_content, raise_on_error=False) break logger.warning(f"⚠️ JSON解析失败(第{retry_count}次),正在重试...") - yield await SSEResponse.send_progress( - f"⚠️ 解析失败,正在重试({retry_count}/{max_retries})...", - 96 - ) + yield await tracker.retry(retry_count, max_retries, "JSON解析失败") + + # 重试时重置生成进度 + tracker.reset_generating_progress() # 重新调用AI生成 accumulated_text = "" @@ -1807,18 +1622,18 @@ async def new_outline_generator( accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() ai_content = accumulated_text ai_response = {"content": ai_content} logger.info(f"🔄 重试生成完成,累计{len(ai_content)}字符") # 全新生成模式:删除旧大纲和关联的所有章节 - yield await SSEResponse.send_progress("清理旧大纲和章节...", 97) + yield await tracker.saving("清理旧大纲和章节...", 0.2) logger.info(f"全新生成:删除项目 {project_id} 的旧大纲和章节(outline_mode: {project.outline_mode})") from sqlalchemy import delete as sql_delete @@ -1850,7 +1665,7 @@ async def new_outline_generator( logger.info(f"✅ 全新生成:删除了 {deleted_outlines_count} 个旧大纲") # 保存新大纲 - yield await SSEResponse.send_progress("💾 保存大纲到数据库...", 98) + yield await tracker.saving("保存大纲到数据库...", 0.6) outlines = await _save_outlines( project_id, outline_data, db, start_index=1 ) @@ -1870,12 +1685,12 @@ async def new_outline_generator( for outline in outlines: await db.refresh(outline) - yield await SSEResponse.send_progress("整理结果数据...", 99) - logger.info(f"全新生成完成 - {len(outlines)} 章") + yield await tracker.complete() + # 发送最终结果 - yield await SSEResponse.send_result({ + yield await tracker.result({ "message": f"成功生成{len(outlines)}章大纲", "total_chapters": len(outlines), "outlines": [ @@ -1892,8 +1707,7 @@ async def new_outline_generator( ] }) - yield await SSEResponse.send_progress("🎉 生成完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.done() except GeneratorExit: logger.warning("大纲生成器被提前关闭") @@ -1905,7 +1719,7 @@ async def new_outline_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("大纲生成事务已回滚(异常)") - yield await SSEResponse.send_error(f"生成失败: {str(e)}") + yield await tracker.error(f"生成失败: {str(e)}") async def continue_outline_generator( @@ -1916,26 +1730,29 @@ async def continue_outline_generator( ) -> AsyncGenerator[str, None]: """大纲续写SSE生成器 - 分批生成,推送进度(记忆+MCP增强版)""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("大纲续写") + try: - # === 初始化阶段 5-10% === - yield await SSEResponse.send_progress("开始续写大纲...", 5) + # === 初始化阶段 === + yield await tracker.start("开始续写大纲...") project_id = data.get("project_id") # 确保chapter_count是整数(前端可能传字符串) total_chapters_to_generate = int(data.get("chapter_count", 5)) # 验证项目 - yield await SSEResponse.send_progress("加载项目信息...", 6) + yield await tracker.loading("加载项目信息...", 0.2) result = await db.execute( select(Project).where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return # 获取现有大纲 - yield await SSEResponse.send_progress("分析已有大纲...", 8) + yield await tracker.loading("分析已有大纲...", 0.5) existing_result = await db.execute( select(Outline) .where(Outline.project_id == project_id) @@ -1944,15 +1761,15 @@ async def continue_outline_generator( existing_outlines = existing_result.scalars().all() if not existing_outlines: - yield await SSEResponse.send_error("续写模式需要已有大纲,当前项目没有大纲", 400) + yield await tracker.error("续写模式需要已有大纲,当前项目没有大纲", 400) return current_chapter_count = len(existing_outlines) last_chapter_number = existing_outlines[-1].order_index - yield await SSEResponse.send_progress( + yield await tracker.loading( f"当前已有{str(current_chapter_count)}章,将续写{str(total_chapters_to_generate)}章", - 10 + 0.8 ) # 获取角色信息 @@ -1979,16 +1796,15 @@ async def continue_outline_generator( confirmed_characters = data.get("confirmed_characters") confirmed_organizations = data.get("confirmed_organizations") - # === 角色引入阶段 10-20% === + # === 角色引入阶段 === # 🔧 判断:如果confirmed_organizations存在,说明已经是组织确认阶段,跳过角色处理 if enable_auto_characters and not confirmed_organizations: # 检查是否有用户确认的角色列表 if confirmed_characters: # 直接使用用户确认的角色列表创建角色 try: - yield await SSEResponse.send_progress( - f"🎭 【确认模式】创建 {len(confirmed_characters)} 个用户确认的角色...", - 11 + yield await tracker.preparing( + f"🎭 【确认模式】创建 {len(confirmed_characters)} 个用户确认的角色..." ) from app.services.auto_character_service import get_auto_character_service @@ -2010,16 +1826,14 @@ async def continue_outline_generator( char_name = char_data.get("name") or char_data.get("character_name") if char_name in existing_character_names: logger.warning(f"⚠️ 角色 '{char_name}' 已存在,跳过创建") - yield await SSEResponse.send_progress( - f"⏭️ [{idx+1}/{len(confirmed_characters)}] 角色 '{char_name}' 已存在,跳过", - char_progress + yield await tracker.preparing( + f"⏭️ [{idx+1}/{len(confirmed_characters)}] 角色 '{char_name}' 已存在,跳过" ) continue # 生成角色详细信息 - yield await SSEResponse.send_progress( - f"🤖 [{idx+1}/{len(confirmed_characters)}] AI生成角色详情:{char_name}...", - char_progress + yield await tracker.preparing( + f"🤖 [{idx+1}/{len(confirmed_characters)}] AI生成角色详情:{char_name}..." ) character_data = await auto_char_service._generate_character_details( spec=char_data, @@ -2031,9 +1845,8 @@ async def continue_outline_generator( ) # 创建角色记录 - yield await SSEResponse.send_progress( - f"💾 [{idx+1}/{len(confirmed_characters)}] 保存角色:{char_name}...", - char_progress + 1 + yield await tracker.preparing( + f"💾 [{idx+1}/{len(confirmed_characters)}] 保存角色:{char_name}..." ) character = await auto_char_service._create_character_record( project_id=project_id, @@ -2044,9 +1857,8 @@ async def continue_outline_generator( # 建立关系 relationships_data = character_data.get("relationships") or character_data.get("relationships_array", []) if relationships_data: - yield await SSEResponse.send_progress( - f"🔗 [{idx+1}/{len(confirmed_characters)}] 建立 {len(relationships_data)} 个关系:{char_name}...", - char_progress + 2 + yield await tracker.preparing( + f"🔗 [{idx+1}/{len(confirmed_characters)}] 建立 {len(relationships_data)} 个关系:{char_name}..." ) await auto_char_service._create_relationships( new_character=character, @@ -2060,40 +1872,33 @@ async def continue_outline_generator( existing_character_names.add(character.name) # 更新已存在的角色名称集合 actually_created_count += 1 logger.info(f"✅ 创建确认的角色: {character.name}") - yield await SSEResponse.send_progress( - f"✅ [{idx+1}/{len(confirmed_characters)}] 角色创建成功:{character.name}", - char_progress + 3 + yield await tracker.preparing( + f"✅ [{idx+1}/{len(confirmed_characters)}] 角色创建成功:{character.name}" ) except Exception as e: logger.error(f"创建确认的角色失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"❌ [{idx+1}/{len(confirmed_characters)}] 角色创建失败:{char_name}", - char_progress + 3 + yield await tracker.warning( + f"[{idx+1}/{len(confirmed_characters)}] 角色创建失败:{char_name}" ) continue # 提交角色到数据库 if actually_created_count > 0: await db.commit() - yield await SSEResponse.send_progress( - f"✅ 【确认模式】实际创建了 {actually_created_count} 个新角色(跳过 {len(confirmed_characters) - actually_created_count} 个已存在)", - 20 + yield await tracker.preparing( + f"✅ 【确认模式】实际创建了 {actually_created_count} 个新角色(跳过 {len(confirmed_characters) - actually_created_count} 个已存在)" ) logger.info(f"✅ 【确认模式】实际创建了 {actually_created_count} 个新角色(跳过了 {len(confirmed_characters) - actually_created_count} 个已存在的角色)") else: - yield await SSEResponse.send_progress( - f"ℹ️ 【确认模式】所有角色均已存在,无需创建", - 20 + yield await tracker.preparing( + f"ℹ️ 【确认模式】所有角色均已存在,无需创建" ) logger.info(f"ℹ️ 【确认模式】所有角色均已存在,无需创建") except Exception as e: logger.error(f"⚠️ 【确认模式】创建确认角色失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"⚠️ 角色创建失败,继续生成大纲", - 20 - ) + yield await tracker.warning("角色创建失败,继续生成大纲") else: # 根据 require_character_confirmation 决定处理方式 require_confirmation = data.get("require_character_confirmation", True) @@ -2108,10 +1913,7 @@ async def continue_outline_generator( if require_confirmation: # 🔮 预测模式:仅预测角色,不自动创建,需要用户确认 - yield await SSEResponse.send_progress( - "🔮 【预测模式】开始分析角色需求...", - 12 - ) + yield await tracker.preparing("🔮 【预测模式】开始分析角色需求...") logger.info(f"🔮 【预测模式】在生成大纲前预测是否需要新角色") # 进度消息不使用回调,因为在async generator中无法嵌套yield @@ -2130,10 +1932,7 @@ async def continue_outline_generator( preview_only=True # ✅ 仅预测不创建 ) - yield await SSEResponse.send_progress( - "✅ 【预测模式】角色需求分析完成", - 18 - ) + yield await tracker.preparing("✅ 【预测模式】角色需求分析完成") # 检查是否需要新角色 if auto_result.get("needs_new_characters") and auto_result.get("predicted_characters"): @@ -2154,17 +1953,11 @@ async def continue_outline_generator( ) return else: - yield await SSEResponse.send_progress( - "✅ 【预测模式】无需引入新角色,继续生成大纲", - 20 - ) + yield await tracker.preparing("✅ 【预测模式】无需引入新角色,继续生成大纲") logger.info(f"✅ 【预测模式】AI判断无需引入新角色") else: # 🚀 直接创建模式:预测后自动创建,无需用户确认 - yield await SSEResponse.send_progress( - "🚀 【直接创建模式】开始分析并创建角色...", - 14 - ) + yield await tracker.preparing("🚀 【直接创建模式】开始分析并创建角色...") logger.info(f"🚀 【直接创建模式】在生成大纲前预测并直接创建新角色") # 使用队列桥接回调和generator @@ -2198,27 +1991,22 @@ async def continue_outline_generator( while not char_task.done(): try: message = await asyncio.wait_for(progress_queue.get(), timeout=0.1) - char_progress_base = min(char_progress_base + 1, 17) - yield await SSEResponse.send_progress(message, char_progress_base) + yield await tracker.preparing(message) except asyncio.TimeoutError: pass # 获取结果 auto_result = await char_task - yield await SSEResponse.send_progress( - "✅ 【直接创建模式】角色分析和创建完成", - 18 - ) + yield await tracker.preparing("✅ 【直接创建模式】角色分析和创建完成") # 如果创建了新角色,更新角色列表 if auto_result.get("new_characters"): new_count = len(auto_result["new_characters"]) logger.info(f"✅ 【直接创建模式】自动创建了 {new_count} 个新角色") - yield await SSEResponse.send_progress( - f"✅ 【直接创建模式】自动创建了 {new_count} 个新角色", - 18 + yield await tracker.preparing( + f"✅ 【直接创建模式】自动创建了 {new_count} 个新角色" ) # 提交角色到数据库 @@ -2228,21 +2016,15 @@ async def continue_outline_generator( characters.extend(auto_result["new_characters"]) characters_info = _build_characters_info(characters) else: - yield await SSEResponse.send_progress( - "✅ 【直接创建模式】无需引入新角色,继续生成大纲", - 20 - ) + yield await tracker.preparing("✅ 【直接创建模式】无需引入新角色,继续生成大纲") logger.info(f"✅ 【直接创建模式】AI判断无需引入新角色") except Exception as e: logger.error(f"⚠️ 【方案A】预测性角色引入失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"⚠️ 角色预测失败,继续生成大纲", - 20 - ) + yield await tracker.warning("角色预测失败,继续生成大纲") # 不阻断大纲生成流程 - # === 组织引入阶段 20-30% === + # === 组织引入阶段 === # 🏛️ 【组织引入】在生成大纲前预测并创建组织 enable_auto_organizations = data.get("enable_auto_organizations", True) # confirmed_organizations在上面已经获取了,这里注释掉避免重复 @@ -2256,9 +2038,8 @@ async def continue_outline_generator( if confirmed_organizations: # 直接使用用户确认的组织列表创建组织 try: - yield await SSEResponse.send_progress( - f"🏛️ 【确认模式】创建 {len(confirmed_organizations)} 个用户确认的组织...", - 20 + yield await tracker.preparing( + f"🏛️ 【确认模式】创建 {len(confirmed_organizations)} 个用户确认的组织..." ) from app.services.auto_organization_service import get_auto_organization_service @@ -2275,9 +2056,8 @@ async def continue_outline_generator( org_progress = 21 + int((idx / max(len(confirmed_organizations), 1)) * 8) # 生成组织详细信息 - yield await SSEResponse.send_progress( - f"🤖 [{idx+1}/{len(confirmed_organizations)}] AI生成组织详情:{org_name}...", - org_progress + yield await tracker.preparing( + f"🤖 [{idx+1}/{len(confirmed_organizations)}] AI生成组织详情:{org_name}..." ) organization_data = await auto_org_service._generate_organization_details( spec=org_data, @@ -2290,9 +2070,8 @@ async def continue_outline_generator( ) # 创建组织记录 - yield await SSEResponse.send_progress( - f"💾 [{idx+1}/{len(confirmed_organizations)}] 保存组织:{org_name}...", - org_progress + 0.5 + yield await tracker.preparing( + f"💾 [{idx+1}/{len(confirmed_organizations)}] 保存组织:{org_name}..." ) org_character, organization = await auto_org_service._create_organization_record( project_id=project_id, @@ -2303,9 +2082,8 @@ async def continue_outline_generator( # 建立成员关系 members_data = organization_data.get("initial_members", []) if members_data: - yield await SSEResponse.send_progress( - f"🔗 [{idx+1}/{len(confirmed_organizations)}] 建立 {len(members_data)} 个成员关系:{org_name}...", - org_progress + 1 + yield await tracker.preparing( + f"🔗 [{idx+1}/{len(confirmed_organizations)}] 建立 {len(members_data)} 个成员关系:{org_name}..." ) await auto_org_service._create_member_relationships( organization=organization, @@ -2328,34 +2106,28 @@ async def continue_outline_generator( }) created_org_count += 1 logger.info(f"✅ 创建确认的组织: {org_character.name}") - yield await SSEResponse.send_progress( - f"✅ [{idx+1}/{len(confirmed_organizations)}] 组织创建成功:{org_character.name}", - org_progress + 1.5 + yield await tracker.preparing( + f"✅ [{idx+1}/{len(confirmed_organizations)}] 组织创建成功:{org_character.name}" ) except Exception as e: logger.error(f"创建确认的组织失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"❌ [{idx+1}/{len(confirmed_organizations)}] 组织创建失败:{org_name}", - org_progress + 1.5 + yield await tracker.warning( + f"[{idx+1}/{len(confirmed_organizations)}] 组织创建失败:{org_name}" ) continue # 提交组织到数据库 await db.commit() - yield await SSEResponse.send_progress( - f"✅ 【确认模式】成功创建 {created_org_count} 个组织", - 30 + yield await tracker.preparing( + f"✅ 【确认模式】成功创建 {created_org_count} 个组织" ) logger.info(f"✅ 【确认模式】成功创建 {created_org_count} 个用户确认的组织") except Exception as e: logger.error(f"⚠️ 【确认模式】创建确认组织失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"⚠️ 组织创建失败,继续生成大纲", - 30 - ) + yield await tracker.warning("组织创建失败,继续生成大纲") else: # 根据 require_organization_confirmation 决定处理方式 require_org_confirmation = data.get("require_organization_confirmation", True) @@ -2370,10 +2142,7 @@ async def continue_outline_generator( if require_org_confirmation: # 🔮 预测模式:仅预测组织,不自动创建,需要用户确认 - yield await SSEResponse.send_progress( - "🔮 【预测模式】开始分析组织需求...", - 22 - ) + yield await tracker.preparing("🔮 【预测模式】开始分析组织需求...") logger.info(f"🔮 【预测模式】在生成大纲前预测是否需要新组织") auto_result = await auto_org_service.analyze_and_create_organizations( @@ -2392,10 +2161,7 @@ async def continue_outline_generator( preview_only=True # ✅ 仅预测不创建 ) - yield await SSEResponse.send_progress( - "✅ 【预测模式】组织需求分析完成", - 28 - ) + yield await tracker.preparing("✅ 【预测模式】组织需求分析完成") # 检查是否需要新组织 if auto_result.get("needs_new_organizations") and auto_result.get("predicted_organizations"): @@ -2416,17 +2182,11 @@ async def continue_outline_generator( ) return else: - yield await SSEResponse.send_progress( - "✅ 【预测模式】无需引入新组织,继续生成大纲", - 30 - ) + yield await tracker.preparing("✅ 【预测模式】无需引入新组织,继续生成大纲") logger.info(f"✅ 【预测模式】AI判断无需引入新组织") else: # 🚀 直接创建模式:预测后自动创建,无需用户确认 - yield await SSEResponse.send_progress( - "🚀 【直接创建模式】开始分析并创建组织...", - 24 - ) + yield await tracker.preparing("🚀 【直接创建模式】开始分析并创建组织...") logger.info(f"🚀 【直接创建模式】在生成大纲前预测并直接创建新组织") # 使用队列桥接回调和generator @@ -2461,18 +2221,14 @@ async def continue_outline_generator( while not org_task.done(): try: message = await asyncio.wait_for(org_progress_queue.get(), timeout=0.1) - org_progress_base = min(org_progress_base + 1, 27) - yield await SSEResponse.send_progress(message, org_progress_base) + yield await tracker.preparing(message) except asyncio.TimeoutError: pass # 获取结果 auto_result = await org_task - yield await SSEResponse.send_progress( - "✅ 【直接创建模式】组织分析和创建完成", - 28 - ) + yield await tracker.preparing("✅ 【直接创建模式】组织分析和创建完成") # 如果创建了新组织,更新角色列表 if auto_result.get("new_organizations"): @@ -2484,9 +2240,8 @@ async def continue_outline_generator( new_org_names.append(org_char.name) logger.info(f"✅ 【直接创建模式】自动创建了 {new_count} 个新组织") - yield await SSEResponse.send_progress( - f"✅ 【直接创建模式】成功创建 {new_count} 个新组织:{', '.join(new_org_names[:3])}{'...' if new_count > 3 else ''}", - 30 + yield await tracker.preparing( + f"✅ 【直接创建模式】成功创建 {new_count} 个新组织:{', '.join(new_org_names[:3])}{'...' if new_count > 3 else ''}" ) # 提交组织到数据库 @@ -2499,35 +2254,33 @@ async def continue_outline_generator( characters.append(org_char) characters_info = _build_characters_info(characters) else: - yield await SSEResponse.send_progress( - "✅ 【直接创建模式】无需引入新组织,继续生成大纲", - 30 - ) + yield await tracker.preparing("✅ 【直接创建模式】无需引入新组织,继续生成大纲") logger.info(f"✅ 【直接创建模式】AI判断无需引入新组织") except Exception as e: logger.error(f"⚠️ 【组织引入】预测性组织引入失败: {e}", exc_info=True) - yield await SSEResponse.send_progress( - f"⚠️ 组织预测失败,继续生成大纲", - 30 - ) + yield await tracker.warning("组织预测失败,继续生成大纲") # 不阻断大纲生成流程 - # === 批次生成阶段 30-90% === + # === 批次生成阶段 === all_new_outlines = [] current_start_chapter = last_chapter_number + 1 - + for batch_num in range(total_batches): # 计算当前批次的章节数 remaining_chapters = int(total_chapters_to_generate) - len(all_new_outlines) current_batch_size = min(batch_size, remaining_chapters) - # 批次进度:30-90%,每批平均分配 - batch_progress = 30 + (batch_num * 60 // total_batches) + # 每批使用的进度预估 + estimated_chars_per_batch = current_batch_size * 1000 - yield await SSEResponse.send_progress( - f"📝 第{str(batch_num + 1)}/{str(total_batches)}批: 生成第{str(current_start_chapter)}-{str(current_start_chapter + current_batch_size - 1)}章", - batch_progress + # 重置生成进度以便于每批独立计算 + tracker.reset_generating_progress() + + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_chars_per_batch, + message=f"📝 第{str(batch_num + 1)}/{str(total_batches)}批: 生成第{str(current_start_chapter)}-{str(current_start_chapter + current_batch_size - 1)}章" ) # 获取最新的大纲列表(包括之前批次生成的) @@ -2564,9 +2317,10 @@ async def continue_outline_generator( # 🧠 构建记忆增强上下文 memory_context = None try: - yield await SSEResponse.send_progress( - f"🧠 构建记忆上下文...", - batch_progress + 3 + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_chars_per_batch, + message="🧠 构建记忆上下文..." ) query_outline = latest_outlines[-1].content if latest_outlines else "" memory_context = await memory_service.build_context_for_generation( @@ -2580,80 +2334,15 @@ async def continue_outline_generator( except Exception as e: logger.warning(f"⚠️ 记忆上下文构建失败: {str(e)}") memory_context = None - # 🔍 MCP工具增强:收集续写参考资料(优化版) - mcp_reference_materials = "" - enable_mcp = data.get("enable_mcp", True) - if enable_mcp: - try: - # 1️⃣ 静默检查工具可用性 - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 2️⃣ 只在有工具时才显示消息和调用 - if available_tools: - yield await SSEResponse.send_progress( - f"🔍 第{str(batch_num + 1)}批:使用MCP工具收集参考资料...", - batch_progress + 4 - ) - logger.info(f"🔍 第{batch_num + 1}批:检测到可用MCP工具,收集续写参考资料...") - - # 构建资料收集查询 - latest_summary = latest_outlines[-1].content if latest_outlines else "" - planning_query = f"""你正在为小说《{project.title}》续写大纲。 -当前进度:已有{len(latest_outlines)}章,即将续写第{current_start_chapter}-{current_start_chapter + current_batch_size - 1}章 - -项目信息: -- 主题:{data.get('theme') or project.theme} -- 类型:{data.get('genre') or project.genre} -- 叙事视角:{data.get('narrative_perspective') or project.narrative_perspective or '第三人称'} -- 情节阶段:{data.get('plot_stage', 'development')} -- 故事发展方向:{data.get('story_direction', '自然延续')} - -最近章节概要: -{latest_summary[:200]} - -请搜索: -1. 该情节阶段的经典处理手法和技巧 -2. 适合该发展方向的情节转折和冲突设计 -3. 符合类型特点的场景设计和剧情元素 - -请有针对性地查询1-2个最关键的问题。""" - - # 调用MCP增强的AI(非流式,限制1轮避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_query, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, # ✅ 减少为1轮,避免超时 - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - mcp_reference_materials = planning_result.get("content", "") - logger.info(f"✅ 第{batch_num + 1}批MCP工具收集参考资料:{len(mcp_reference_materials)} 字符") - yield await SSEResponse.send_progress( - f"✅ 第{str(batch_num + 1)}批收集到参考资料 ({len(mcp_reference_materials)}字符)", - batch_progress + 4.5 - ) - else: - logger.info(f"ℹ️ 第{batch_num + 1}批MCP未使用工具,继续") - else: - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过第{batch_num + 1}批MCP增强") - except Exception as e: - logger.warning(f"⚠️ 第{batch_num + 1}批MCP工具调用失败,降级为基础模式: {str(e)}") - mcp_reference_materials = "" + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db - - yield await SSEResponse.send_progress( - f" 调用AI生成第{str(batch_num + 1)}批...", - batch_progress + 5 + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_chars_per_batch, + message=f"🤖 调用AI生成第{str(batch_num + 1)}批..." ) # 使用标准续写提示词模板(支持记忆+MCP增强+自定义) @@ -2679,7 +2368,7 @@ async def continue_outline_generator( story_direction=data.get("story_direction", "自然延续"), requirements=data.get("requirements", ""), memory_context=memory_context, - mcp_references=mcp_reference_materials + mcp_references="" ) # 调用AI生成当前批次 @@ -2702,26 +2391,21 @@ async def continue_outline_generator( accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 定期更新进度(每批在分配范围内平滑递增) - if chunk_count % 5 == 0: - # 在批次范围内平滑递增 - batch_range = 60 // total_batches # 每批分配的进度范围 - progress_in_batch = batch_progress + min((chunk_count // 3), batch_range - 2) - yield await SSEResponse.send_progress( - f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中... ({len(accumulated_text)}字符)", - progress_in_batch + # 定期更新进度 + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=len(accumulated_text), + estimated_total=estimated_chars_per_batch, + message=f"📝 第{str(batch_num + 1)}/{str(total_batches)}批生成中" ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() - yield await SSEResponse.send_progress( - f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...", - min(batch_progress + batch_range - 5, 88) - ) + yield await tracker.parsing(f"✅ 第{str(batch_num + 1)}批AI生成完成,正在解析...") # 提取内容 ai_content = accumulated_text @@ -2743,18 +2427,15 @@ async def continue_outline_generator( if retry_count > max_retries: # 超过最大重试次数,使用fallback数据 logger.error(f"❌ 第{batch_num + 1}批解析失败,已达最大重试次数({max_retries}),使用fallback数据") - yield await SSEResponse.send_progress( - f"⚠️ 第{str(batch_num + 1)}批解析失败,使用备用数据", - min(batch_progress + batch_range - 3, 89) - ) + yield await tracker.warning(f"第{str(batch_num + 1)}批解析失败,使用备用数据") outline_data = _parse_ai_response(ai_content, raise_on_error=False) break logger.warning(f"⚠️ 第{batch_num + 1}批JSON解析失败(第{retry_count}次),正在重试...") - yield await SSEResponse.send_progress( - f"⚠️ 第{str(batch_num + 1)}批解析失败,正在重试({retry_count}/{max_retries})...", - min(batch_progress + batch_range - 4, 88) - ) + yield await tracker.retry(retry_count, max_retries, f"第{str(batch_num + 1)}批解析失败") + + # 重试时重置生成进度 + tracker.reset_generating_progress() # 重新调用AI生成 accumulated_text = "" @@ -2772,11 +2453,11 @@ async def continue_outline_generator( accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() ai_content = accumulated_text ai_response = {"content": ai_content} @@ -2805,9 +2486,9 @@ async def continue_outline_generator( all_new_outlines.extend(batch_outlines) current_start_chapter += current_batch_size - yield await SSEResponse.send_progress( + yield await tracker.saving( f"💾 第{str(batch_num + 1)}批保存成功!本批生成{str(len(batch_outlines))}章,累计新增{str(len(all_new_outlines))}章", - min(batch_progress + batch_range, 90) + (batch_num + 1) / total_batches ) logger.info(f"第{str(batch_num + 1)}批生成完成,本批生成{str(len(batch_outlines))}章") @@ -2822,11 +2503,10 @@ async def continue_outline_generator( ) all_outlines = final_result.scalars().all() - # === 结果整理阶段 90-100% === - yield await SSEResponse.send_progress("整理结果数据...", 92) + yield await tracker.complete() # 发送最终结果 - yield await SSEResponse.send_result({ + yield await tracker.result({ "message": f"续写完成!共{str(total_batches)}批,新增{str(len(all_new_outlines))}章,总计{str(len(all_outlines))}章", "total_batches": total_batches, "new_chapters": len(all_new_outlines), @@ -2845,8 +2525,7 @@ async def continue_outline_generator( ] }) - yield await SSEResponse.send_progress("🎉 续写完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.done() except GeneratorExit: logger.warning("大纲续写生成器被提前关闭") @@ -2858,7 +2537,7 @@ async def continue_outline_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("大纲续写事务已回滚(异常)") - yield await SSEResponse.send_error(f"续写失败: {str(e)}") + yield await tracker.error(f"续写失败: {str(e)}") @router.post("/generate-stream", summary="AI生成/续写大纲(SSE流式)") @@ -2938,8 +2617,11 @@ async def expand_outline_generator( ) -> AsyncGenerator[str, None]: """单个大纲展开SSE生成器 - 实时推送进度(支持分批生成)""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("大纲展开") + try: - yield await SSEResponse.send_progress("开始展开大纲...", 5) + yield await tracker.start() target_chapter_count = int(data.get("target_chapter_count", 3)) expansion_strategy = data.get("expansion_strategy", "balanced") @@ -2948,50 +2630,46 @@ async def expand_outline_generator( batch_size = int(data.get("batch_size", 5)) # 支持自定义批次大小 # 获取大纲 - yield await SSEResponse.send_progress("加载大纲信息...", 10) + yield await tracker.loading("加载大纲信息...", 0.3) result = await db.execute( select(Outline).where(Outline.id == outline_id) ) outline = result.scalar_one_or_none() if not outline: - yield await SSEResponse.send_error("大纲不存在", 404) + yield await tracker.error("大纲不存在", 404) return # 获取项目信息 - yield await SSEResponse.send_progress("加载项目信息...", 15) + yield await tracker.loading("加载项目信息...", 0.7) project_result = await db.execute( select(Project).where(Project.id == outline.project_id) ) project = project_result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return - yield await SSEResponse.send_progress( - f"准备展开《{outline.title}》为 {target_chapter_count} 章...", - 20 + yield await tracker.preparing( + f"准备展开《{outline.title}》为 {target_chapter_count} 章..." ) # 创建展开服务实例 expansion_service = PlotExpansionService(user_ai_service) - # 定义进度回调函数 - async def progress_callback(batch_num: int, total_batches: int, start_idx: int, batch_size: int): - progress = 30 + int((batch_num - 1) / total_batches * 40) - yield await SSEResponse.send_progress( - f"📝 生成第{batch_num}/{total_batches}批(第{start_idx}-{start_idx + batch_size - 1}节)...", - progress - ) - # 分析大纲并生成章节规划(支持分批) if target_chapter_count > batch_size: - yield await SSEResponse.send_progress( - f"🤖 AI分批生成章节规划(每批{batch_size}章)...", - 30 + yield await tracker.generating( + current_chars=0, + estimated_total=target_chapter_count * 500, + message=f"🤖 AI分批生成章节规划(每批{batch_size}章)..." ) else: - yield await SSEResponse.send_progress("🤖 AI分析大纲,生成章节规划...", 30) + yield await tracker.generating( + current_chars=0, + estimated_total=target_chapter_count * 500, + message="🤖 AI分析大纲,生成章节规划..." + ) chapter_plans = await expansion_service.analyze_outline_for_chapters( outline=outline, @@ -3007,18 +2685,17 @@ async def expand_outline_generator( ) if not chapter_plans: - yield await SSEResponse.send_error("AI分析失败,未能生成章节规划", 500) + yield await tracker.error("AI分析失败,未能生成章节规划", 500) return - yield await SSEResponse.send_progress( - f"✅ 规划生成完成!共 {len(chapter_plans)} 个章节", - 70 + yield await tracker.parsing( + f"✅ 规划生成完成!共 {len(chapter_plans)} 个章节" ) # 根据配置决定是否创建章节记录 created_chapters = None if auto_create_chapters: - yield await SSEResponse.send_progress("💾 创建章节记录...", 80) + yield await tracker.saving("💾 创建章节记录...", 0.3) created_chapters = await expansion_service.create_chapters_from_plans( outline_id=outline_id, @@ -3035,12 +2712,12 @@ async def expand_outline_generator( for chapter in created_chapters: await db.refresh(chapter) - yield await SSEResponse.send_progress( + yield await tracker.saving( f"✅ 成功创建 {len(created_chapters)} 个章节记录", - 90 + 0.8 ) - yield await SSEResponse.send_progress("整理结果数据...", 95) + yield await tracker.complete() # 构建响应数据 result_data = { @@ -3064,9 +2741,8 @@ async def expand_outline_generator( ] if created_chapters else None } - yield await SSEResponse.send_result(result_data) - yield await SSEResponse.send_progress("🎉 展开完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.result(result_data) + yield await tracker.done() except GeneratorExit: logger.warning("大纲展开生成器被提前关闭") @@ -3078,7 +2754,7 @@ async def expand_outline_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("大纲展开事务已回滚(异常)") - yield await SSEResponse.send_error(f"展开失败: {str(e)}") + yield await tracker.error(f"展开失败: {str(e)}") @router.post("/{outline_id}/create-single-chapter", summary="一对一创建章节(传统模式)") @@ -3313,8 +2989,11 @@ async def batch_expand_outlines_generator( ) -> AsyncGenerator[str, None]: """批量展开大纲SSE生成器 - 实时推送进度""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("批量大纲展开") + try: - yield await SSEResponse.send_progress("开始批量展开大纲...", 5) + yield await tracker.start() project_id = data.get("project_id") chapters_per_outline = int(data.get("chapters_per_outline", 3)) @@ -3323,17 +3002,17 @@ async def batch_expand_outlines_generator( outline_ids = data.get("outline_ids") # 获取项目信息 - yield await SSEResponse.send_progress("加载项目信息...", 10) + yield await tracker.loading("加载项目信息...", 0.5) project_result = await db.execute( select(Project).where(Project.id == project_id) ) project = project_result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return # 获取要展开的大纲列表 - yield await SSEResponse.send_progress("获取大纲列表...", 15) + yield await tracker.loading("获取大纲列表...", 0.8) if outline_ids: outlines_result = await db.execute( select(Outline) @@ -3353,13 +3032,12 @@ async def batch_expand_outlines_generator( outlines = outlines_result.scalars().all() if not outlines: - yield await SSEResponse.send_error("没有找到要展开的大纲", 404) + yield await tracker.error("没有找到要展开的大纲", 404) return total_outlines = len(outlines) - yield await SSEResponse.send_progress( - f"共找到 {total_outlines} 个大纲,开始批量展开...", - 20 + yield await tracker.preparing( + f"共找到 {total_outlines} 个大纲,开始批量展开..." ) # 创建展开服务实例 @@ -3371,12 +3049,13 @@ async def batch_expand_outlines_generator( for idx, outline in enumerate(outlines): try: - # 计算当前进度 (20% - 90%) - progress = 20 + int((idx / total_outlines) * 70) + # 计算当前子进度 (0.0-1.0),用于generating阶段 + sub_progress = idx / max(total_outlines, 1) - yield await SSEResponse.send_progress( - f"📝 处理第 {idx + 1}/{total_outlines} 个大纲: {outline.title}", - progress + yield await tracker.generating( + current_chars=idx * chapters_per_outline * 500, + estimated_total=total_outlines * chapters_per_outline * 500, + message=f"📝 处理第 {idx + 1}/{total_outlines} 个大纲: {outline.title}" ) # 检查大纲是否已经展开过 @@ -3394,16 +3073,18 @@ async def batch_expand_outlines_generator( "outline_title": outline.title, "reason": "已展开" }) - yield await SSEResponse.send_progress( - f"⏭️ {outline.title} 已展开过,跳过", - progress + 1 + yield await tracker.generating( + current_chars=(idx + 1) * chapters_per_outline * 500, + estimated_total=total_outlines * chapters_per_outline * 500, + message=f"⏭️ {outline.title} 已展开过,跳过" ) continue # 分析大纲生成章节规划 - yield await SSEResponse.send_progress( - f"🤖 AI分析大纲: {outline.title}", - progress + 2 + yield await tracker.generating( + current_chars=idx * chapters_per_outline * 500, + estimated_total=total_outlines * chapters_per_outline * 500, + message=f"🤖 AI分析大纲: {outline.title}" ) chapter_plans = await expansion_service.analyze_outline_for_chapters( @@ -3417,9 +3098,10 @@ async def batch_expand_outlines_generator( model=data.get("model") ) - yield await SSEResponse.send_progress( - f"✅ {outline.title} 规划生成完成 ({len(chapter_plans)} 章)", - progress + 3 + yield await tracker.generating( + current_chars=(idx + 0.5) * chapters_per_outline * 500, + estimated_total=total_outlines * chapters_per_outline * 500, + message=f"✅ {outline.title} 规划生成完成 ({len(chapter_plans)} 章)" ) created_chapters = None @@ -3446,9 +3128,10 @@ async def batch_expand_outlines_generator( ] total_chapters_created += len(chapters) - yield await SSEResponse.send_progress( - f"💾 {outline.title} 章节创建完成 ({len(chapters)} 章)", - progress + 4 + yield await tracker.generating( + current_chars=(idx + 1) * chapters_per_outline * 500, + estimated_total=total_outlines * chapters_per_outline * 500, + message=f"💾 {outline.title} 章节创建完成 ({len(chapters)} 章)" ) expansion_results.append({ @@ -3465,9 +3148,8 @@ async def batch_expand_outlines_generator( except Exception as e: logger.error(f"展开大纲 {outline.id} 失败: {str(e)}", exc_info=True) - yield await SSEResponse.send_progress( - f"❌ {outline.title} 展开失败: {str(e)}", - progress + yield await tracker.warning( + f"❌ {outline.title} 展开失败: {str(e)}" ) expansion_results.append({ "outline_id": outline.id, @@ -3480,12 +3162,14 @@ async def batch_expand_outlines_generator( "error": str(e) }) - yield await SSEResponse.send_progress("整理结果数据...", 95) + yield await tracker.parsing("整理结果数据...") db_committed = True logger.info(f"批量展开完成: {len(expansion_results)} 个大纲,跳过 {len(skipped_outlines)} 个,共生成 {total_chapters_created} 个章节") + yield await tracker.complete() + # 发送最终结果 result_data = { "project_id": project_id, @@ -3507,9 +3191,8 @@ async def batch_expand_outlines_generator( ] } - yield await SSEResponse.send_result(result_data) - yield await SSEResponse.send_progress("🎉 批量展开完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.result(result_data) + yield await tracker.done() except GeneratorExit: logger.warning("批量展开生成器被提前关闭") diff --git a/backend/app/api/settings.py b/backend/app/api/settings.py index 9b33b88..e737dd9 100644 --- a/backend/app/api/settings.py +++ b/backend/app/api/settings.py @@ -22,7 +22,7 @@ from app.schemas.settings import ( from app.user_manager import User from app.logger import get_logger from app.config import settings as app_settings, PROJECT_ROOT -from app.services.ai_service import AIService, create_user_ai_service +from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp logger = get_logger(__name__) @@ -53,9 +53,14 @@ async def get_user_ai_service( db: AsyncSession = Depends(get_db) ) -> AIService: """ - 依赖:获取当前用户的AI服务实例 - 从数据库读取用户设置并创建对应的AI服务 + 依赖:获取当前用户的AI服务实例(支持MCP工具自动加载) + + 从数据库读取用户设置并创建对应的AI服务。 + 自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。 + 根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。 """ + from app.models.mcp_plugin import MCPPlugin + result = await db.execute( select(Settings).where(Settings.user_id == user.user_id) ) @@ -73,15 +78,34 @@ async def get_user_ai_service( await db.refresh(settings) logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库") - # 使用用户设置创建AI服务实例(包括系统提示词) - return create_user_ai_service( + # 查询用户的所有MCP插件状态 + mcp_result = await db.execute( + select(MCPPlugin).where(MCPPlugin.user_id == user.user_id) + ) + mcp_plugins = mcp_result.scalars().all() + + # 检查是否有启用的MCP插件 + enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False + + if mcp_plugins: + enabled_count = sum(1 for p in mcp_plugins if p.enabled) + logger.info(f"用户 {user.user_id} 有 {len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP") + else: + logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP") + + # ✅ 使用支持MCP的工厂函数创建AI服务实例 + # 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具 + return create_user_ai_service_with_mcp( api_provider=settings.api_provider, api_key=settings.api_key, api_base_url=settings.api_base_url or "", model_name=settings.llm_model, temperature=settings.temperature, max_tokens=settings.max_tokens, - system_prompt=settings.system_prompt # 传递系统提示词 + user_id=user.user_id, # ✅ 传递 user_id + db_session=db, # ✅ 传递 db_session + system_prompt=settings.system_prompt, + enable_mcp=enable_mcp, # 根据MCP插件状态动态决定 ) @@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel): llm_model: str +@router.post("/check-function-calling") +async def check_function_calling_support(data: ApiTestRequest): + """ + 检查模型是否支持 Function Calling(工具调用) + + 基于业界最佳实践的测试方法: + 1. 发送包含工具定义的请求 + 2. 检查响应的 finish_reason 是否为 "tool_calls" + 3. 验证响应中是否包含有效的 tool_calls 数据 + + Args: + data: 包含 API 配置的请求数据 + + Returns: + 检测结果包含支持状态、详细信息和建议 + """ + api_key = data.api_key + api_base_url = data.api_base_url + provider = data.provider + llm_model = data.llm_model + + try: + start_time = time.time() + + # 定义一个简单的测试工具(天气查询) + test_tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "获取指定城市的当前天气信息", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称,例如:北京、上海、深圳" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "温度单位" + } + }, + "required": ["city"] + } + } + }] + + # 测试提示:故意设计一个需要调用工具的问题 + test_prompt = "请告诉我北京现在的天气情况如何?" + + logger.info(f"🧪 开始检测 Function Calling 支持") + logger.info(f" - 提供商: {provider}") + logger.info(f" - 模型: {llm_model}") + logger.info(f" - 测试工具: get_weather") + + # 创建临时 AI 服务实例进行测试 + test_service = AIService( + api_provider=provider, + api_key=api_key, + api_base_url=api_base_url, + default_model=llm_model, + default_temperature=0.3, # 使用较低温度以获得更确定的行为 + default_max_tokens=200 + ) + + # 发送带工具的测试请求 + response = await test_service.generate_text( + prompt=test_prompt, + provider=provider, + model=llm_model, + temperature=0.3, + max_tokens=200, + tools=test_tools, + tool_choice="auto", # 让模型自动决定是否使用工具 + auto_mcp=False # 禁用 MCP 自动加载 + ) + + end_time = time.time() + response_time = round((end_time - start_time) * 1000, 2) + + # 分析响应以确定是否支持 Function Calling + supported = False + finish_reason = None + tool_calls = None + response_content = None + + if isinstance(response, dict): + # 检查 finish_reason(OpenAI 标准) + finish_reason = response.get("finish_reason") + + # 检查是否有 tool_calls + if "tool_calls" in response and response["tool_calls"]: + supported = True + tool_calls = response["tool_calls"] + logger.info(f"✅ 检测到工具调用: {len(tool_calls)} 个") + + # 记录返回的内容(如果有) + if "content" in response: + response_content = response["content"] + elif isinstance(response, str): + # 如果只返回字符串,说明不支持工具调用 + response_content = response + + logger.info(f" - 响应时间: {response_time}ms") + logger.info(f" - finish_reason: {finish_reason}") + logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}") + + # 构建详细的返回信息 + result = { + "success": True, + "supported": supported, + "message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling", + "response_time_ms": response_time, + "provider": provider, + "model": llm_model, + "details": { + "finish_reason": finish_reason, + "has_tool_calls": bool(tool_calls), + "tool_call_count": len(tool_calls) if tool_calls else 0, + "test_tool": "get_weather", + "test_prompt": test_prompt, + "response_type": "tool_calls" if supported else "text" + } + } + + # 添加工具调用详情 + if tool_calls: + result["tool_calls"] = tool_calls + result["suggestions"] = [ + "✅ 该模型支持 Function Calling,可以正常使用 MCP 插件", + "建议:启用需要的 MCP 插件以扩展 AI 能力", + "提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具" + ] + else: + result["response_preview"] = response_content[:200] if response_content else None + result["suggestions"] = [ + "❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能", + "建议:更换支持工具调用的模型", + "推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等", + "说明:模型返回了文本回复而非工具调用,表明不支持该功能" + ] + + return result + + except ValueError as e: + error_msg = str(e) + logger.error(f"❌ Function Calling 检测配置错误: {error_msg}") + return { + "success": False, + "supported": False, + "message": "配置错误", + "error": error_msg, + "error_type": "ConfigurationError", + "suggestions": [ + "请检查 API Key 是否正确", + "请确认 API Base URL 格式是否正确", + "请验证所选提供商与配置是否匹配" + ] + } + + except TimeoutError as e: + error_msg = str(e) + logger.error(f"❌ Function Calling 检测超时: {error_msg}") + return { + "success": False, + "supported": None, + "message": "检测超时", + "error": error_msg, + "error_type": "TimeoutError", + "suggestions": [ + "请检查网络连接是否正常", + "请确认 API 服务是否可访问", + "建议:稍后重试或使用其他网络环境" + ] + } + + except Exception as e: + error_msg = str(e) + error_type = type(e).__name__ + + logger.error(f"❌ Function Calling 检测失败: {error_msg}") + logger.error(f" - 错误类型: {error_type}") + + # 智能分析错误原因 + suggestions = [] + if "tool" in error_msg.lower() or "function" in error_msg.lower(): + suggestions = [ + "该模型可能不支持 Function Calling 功能", + "API 返回了与工具调用相关的错误", + "建议:更换支持工具调用的模型或联系 API 提供商" + ] + elif "unauthorized" in error_msg.lower() or "401" in error_msg: + suggestions = [ + "API Key 认证失败", + "请检查 API Key 是否正确且有效", + "请确认 API Key 是否有足够的权限" + ] + elif "not found" in error_msg.lower() or "404" in error_msg: + suggestions = [ + "模型不存在或不可用", + "请检查模型名称是否正确", + "请确认该模型在当前 API 中是否可用" + ] + else: + suggestions = [ + "检测过程中遇到未知错误", + "建议:检查所有配置参数是否正确", + "提示:查看详细错误信息以获取更多线索" + ] + + return { + "success": False, + "supported": False, + "message": "Function Calling 检测失败", + "error": error_msg, + "error_type": error_type, + "suggestions": suggestions + } + + @router.post("/test") async def test_api_connection(data: ApiTestRequest): """ @@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest): provider=provider, model=llm_model, temperature=0.7, - max_tokens=8000 + max_tokens=8000, + auto_mcp=False # 测试时不加载MCP工具 ) end_time = time.time() diff --git a/backend/app/api/wizard_stream.py b/backend/app/api/wizard_stream.py index 2bb851a..dee9669 100644 --- a/backend/app/api/wizard_stream.py +++ b/backend/app/api/wizard_stream.py @@ -16,11 +16,10 @@ from app.models.relationship import CharacterRelationship, Organization, Organiz from app.models.writing_style import WritingStyle from app.models.project_default_style import ProjectDefaultStyle from app.services.ai_service import AIService -from app.services.mcp_tool_service import MCPToolService from app.services.prompt_service import prompt_service, PromptService from app.services.plot_expansion_service import PlotExpansionService from app.logger import get_logger -from app.utils.sse_response import SSEResponse, create_sse_response +from app.utils.sse_response import SSEResponse, create_sse_response, WizardProgressTracker from app.api.settings import get_user_ai_service router = APIRouter(prefix="/wizard-stream", tags=["项目创建向导(流式)"]) @@ -35,9 +34,12 @@ async def world_building_generator( """世界构建流式生成器 - 支持MCP工具增强""" # 标记数据库会话是否已提交 db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("世界观") + try: # 发送开始消息 - yield await SSEResponse.send_progress("开始生成世界观...", 10) + yield await tracker.start() # 提取参数 title = data.get("title") @@ -55,11 +57,11 @@ async def world_building_generator( user_id = data.get("user_id") # 从中间件注入 if not title or not description or not theme or not genre: - yield await SSEResponse.send_error("title、description、theme 和 genre 是必需的参数", 400) + yield await tracker.error("title、description、theme 和 genre 是必需的参数", 400) return # 获取基础提示词(支持自定义) - yield await SSEResponse.send_progress("准备AI提示词...", 15) + yield await tracker.preparing("准备AI提示词...") template = await PromptService.get_template("WORLD_BUILDING", user_id, db) base_prompt = PromptService.format_prompt( template, @@ -69,121 +71,67 @@ async def world_building_generator( description=description or "暂无简介" ) - # MCP工具增强:收集参考资料 - reference_materials = "" - if enable_mcp and user_id: - try: - # 先静默检查是否有可用工具 - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 只有在真正有可用工具时才显示消息和调用 - if available_tools: - yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18) - - mcp_template = await PromptService.get_template("MCP_WORLD_BUILDING_PLANNING", user_id, db) - planning_prompt = PromptService.format_prompt( - mcp_template, - title=title, - genre=genre, - theme=theme, - description=description - ) - - # 调用MCP增强的AI(非流式,最多1轮工具调用,避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - yield await SSEResponse.send_progress( - f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", - 25 - ) - reference_materials = planning_result.get("content", "") - else: - # 有工具但未使用 - logger.debug("MCP工具可用但AI未选择使用") - else: - # 没有可用工具,静默跳过 - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") - - except Exception as e: - logger.warning(f"MCP工具调用失败(降级处理): {e}") - yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25) - - # 构建增强提示词 - if reference_materials: - enhanced_prompt = f"""{base_prompt} - -【参考资料】 -以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观: - -{reference_materials} - -请结合上述资料,生成符合历史/现实的世界观设定。""" - final_prompt = enhanced_prompt - yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 10) - else: - final_prompt = base_prompt - yield await SSEResponse.send_progress("正在调用AI生成...", 10) + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db # ===== 流式生成世界观(带重试机制) ===== MAX_WORLD_RETRIES = 3 # 最多重试3次 world_retry_count = 0 world_generation_success = False world_data = {} + estimated_total = 1000 while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success: try: - retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else "" - yield await SSEResponse.send_progress(f"生成世界观{retry_suffix}...", 10 + world_retry_count * 5) + # 重试时重置生成进度 + if world_retry_count > 0: + tracker.reset_generating_progress() + + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_total, + retry_count=world_retry_count, + max_retries=MAX_WORLD_RETRIES + ) # 流式生成世界观 accumulated_text = "" chunk_count = 0 async for chunk in user_ai_service.generate_text_stream( - prompt=final_prompt, + prompt=base_prompt, provider=provider, - model=model + model=model, + tool_choice="required", ): chunk_count += 1 accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 世界观生成独立进度:5-95% - if chunk_count % 5 == 0: - progress = min(5 + (chunk_count // 3), 95) - yield await SSEResponse.send_progress(f"世界观生成中... ({len(accumulated_text)}字符)", progress) + # 定期更新进度 + current_len = len(accumulated_text) + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=current_len, + estimated_total=estimated_total, + retry_count=world_retry_count, + max_retries=MAX_WORLD_RETRIES + ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() # 检查是否返回空响应 if not accumulated_text or not accumulated_text.strip(): logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ AI返回为空,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "AI返回为空") continue else: # 达到最大重试次数,使用默认值 @@ -198,7 +146,7 @@ async def world_building_generator( break # 解析结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析世界观数据...", 96) + yield await tracker.parsing("解析世界观数据...") try: logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}") @@ -219,11 +167,7 @@ async def world_building_generator( logger.error(f" 原始内容预览: {accumulated_text[:200]}") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ JSON解析失败,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "JSON解析失败") continue else: # 达到最大重试次数,使用默认值 @@ -239,18 +183,15 @@ async def world_building_generator( logger.error(f"❌ 世界构建生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ 生成异常,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "生成异常") continue else: # 最后一次重试仍失败,抛出异常 logger.error(f" accumulated_text 长度: {len(accumulated_text) if 'accumulated_text' in locals() else 'N/A'}") raise + # 保存到数据库 - yield await SSEResponse.send_progress("保存世界观到数据库...", 99) + yield await tracker.saving("保存世界观到数据库...") # 确保user_id存在 if not user_id: @@ -304,78 +245,180 @@ async def world_building_generator( logger.warning(f"设置默认写作风格失败: {e},不影响项目创建") # 更新向导步骤状态为1(世界观已完成) + # wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成 project.wizard_step = 1 await db.commit() - # ===== 自动生成职业体系(带重试机制+流式) ===== - yield await SSEResponse.send_progress("世界观完成!", 100, "success") - yield await SSEResponse.send_progress("🎯 开始生成职业体系框架...", 5) - logger.info(f"🎯 世界观已完成,开始为项目 {project.id} 自动生成职业体系") + # ===== 世界观生成完成 ===== + db_committed = True + yield await tracker.complete() + + # 发送世界观结果 + yield await tracker.result({ + "project_id": project.id, + "time_period": world_data.get("time_period"), + "location": world_data.get("location"), + "atmosphere": world_data.get("atmosphere"), + "rules": world_data.get("rules") + }) + + # 发送世界观完成信号 + yield await tracker.done() + + logger.info(f"✅ 世界观生成完成,项目ID: {project.id}") + + except GeneratorExit: + # SSE连接断开,回滚未提交的事务 + logger.warning("世界构建生成器被提前关闭") + if not db_committed and db.in_transaction(): + await db.rollback() + logger.info("世界构建事务已回滚(GeneratorExit)") + except Exception as e: + logger.error(f"世界构建流式生成失败: {str(e)}") + # 异常时回滚事务 + if not db_committed and db.in_transaction(): + await db.rollback() + logger.info("世界构建事务已回滚(异常)") + yield await tracker.error(f"生成失败: {str(e)}") + + +@router.post("/world-building", summary="流式生成世界构建") +async def generate_world_building_stream( + request: Request, + data: Dict[str, Any], + db: AsyncSession = Depends(get_db), + user_ai_service: AIService = Depends(get_user_ai_service) +): + """ + 使用SSE流式生成世界构建,避免超时 + 前端使用EventSource接收实时进度和结果 + """ + # 从中间件注入user_id到data中 + if hasattr(request.state, 'user_id'): + data['user_id'] = request.state.user_id + + return create_sse_response(world_building_generator(data, db, user_ai_service)) + + +async def career_system_generator( + data: Dict[str, Any], + db: AsyncSession, + user_ai_service: AIService +) -> AsyncGenerator[str, None]: + """职业体系生成流式生成器 - 独立接口""" + db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("职业体系") + + try: + yield await tracker.start() + + # 提取参数 + project_id = data.get("project_id") + provider = data.get("provider") + model = data.get("model") + user_id = data.get("user_id") + + if not project_id: + yield await tracker.error("project_id 是必需的参数", 400) + return + + # 获取项目信息 + yield await tracker.loading("加载项目信息...") + result = await db.execute( + select(Project).where(Project.id == project_id) + ) + project = result.scalar_one_or_none() + if not project: + yield await tracker.error("项目不存在", 404) + return + + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db + + # 获取世界观数据 + world_data = { + "time_period": project.world_time_period or "未设定", + "location": project.world_location or "未设定", + "atmosphere": project.world_atmosphere or "未设定", + "rules": project.world_rules or "未设定" + } + + # 获取职业生成提示词模板(支持用户自定义) + yield await tracker.preparing("准备AI提示词...") + template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db) + career_prompt = PromptService.format_prompt( + template, + title=project.title, + genre=project.genre or '未设定', + theme=project.theme or '未设定', + time_period=world_data.get('time_period', '未设定'), + location=world_data.get('location', '未设定'), + atmosphere=world_data.get('atmosphere', '未设定'), + rules=world_data.get('rules', '未设定') + ) + + estimated_total = 8000 MAX_CAREER_RETRIES = 3 # 最多重试3次 career_retry_count = 0 career_generation_success = False while career_retry_count < MAX_CAREER_RETRIES and not career_generation_success: try: - retry_suffix = f" (重试{career_retry_count}/{MAX_CAREER_RETRIES})" if career_retry_count > 0 else "" - yield await SSEResponse.send_progress(f"正在生成职业体系{retry_suffix}...", 10) + # 重试时重置生成进度 + if career_retry_count > 0: + tracker.reset_generating_progress() - # 获取职业生成提示词模板(支持用户自定义) - template = await PromptService.get_template("CAREER_SYSTEM_GENERATION", user_id, db) - career_prompt = PromptService.format_prompt( - template, - title=project.title, - genre=genre or '未设定', - theme=theme or '未设定', - time_period=world_data.get('time_period', '未设定'), - location=world_data.get('location', '未设定'), - atmosphere=world_data.get('atmosphere', '未设定'), - rules=world_data.get('rules', '未设定') + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_total, + retry_count=career_retry_count, + max_retries=MAX_CAREER_RETRIES ) - # ✅ 使用流式生成职业体系 + # 使用流式生成职业体系 career_response = "" chunk_count = 0 async for chunk in user_ai_service.generate_text_stream( prompt=career_prompt, provider=provider, - model=model + model=model, ): chunk_count += 1 career_response += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 职业体系生成独立进度:10-95% - if chunk_count % 5 == 0: - progress = min(10 + (chunk_count // 3), 95) - yield await SSEResponse.send_progress( - f"生成职业体系中... ({len(career_response)}字符)", - progress + # 定期更新进度 + current_len = len(career_response) + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=current_len, + estimated_total=estimated_total, + retry_count=career_retry_count, + max_retries=MAX_CAREER_RETRIES ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() if not career_response or not career_response.strip(): logger.warning(f"⚠️ AI返回空职业体系(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})") career_retry_count += 1 if career_retry_count < MAX_CAREER_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ AI返回为空,准备重试...", - 10, - "warning" - ) + yield await tracker.retry(career_retry_count, MAX_CAREER_RETRIES, "AI返回为空") continue else: - yield await SSEResponse.send_progress("职业体系生成跳过(AI多次返回为空)", 99) - break + yield await tracker.error("职业体系生成失败(AI多次返回为空)") + return - yield await SSEResponse.send_progress("解析职业体系数据...", 96) + yield await tracker.parsing("解析职业体系数据...") # 清洗并解析JSON try: @@ -383,6 +426,8 @@ async def world_building_generator( career_data = json.loads(cleaned_response) logger.info(f"✅ 职业体系JSON解析成功(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES})") + yield await tracker.saving("保存职业数据...") + # 保存主职业 main_careers_created = [] for idx, career_info in enumerate(career_data.get("main_careers", [])): @@ -443,100 +488,88 @@ async def world_building_generator( logger.error(f" ❌ 创建副职业失败:{str(e)}") continue + # 更新向导步骤状态为2(职业体系已完成) + # wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成 + project.wizard_step = 2 + await db.commit() + db_committed = True # 标记成功 career_generation_success = True logger.info(f"🎉 职业体系生成完成:主职业{len(main_careers_created)}个,副职业{len(sub_careers_created)}个") - yield await SSEResponse.send_progress( - f"✅ 职业体系生成完成(主{len(main_careers_created)}+副{len(sub_careers_created)})", - 99 - ) + + yield await tracker.complete() + + # 发送结果 + yield await tracker.result({ + "project_id": project.id, + "main_careers_count": len(main_careers_created), + "sub_careers_count": len(sub_careers_created), + "main_careers": main_careers_created, + "sub_careers": sub_careers_created + }) + + yield await tracker.done() except json.JSONDecodeError as e: logger.error(f"❌ 职业体系JSON解析失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") career_retry_count += 1 if career_retry_count < MAX_CAREER_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ JSON解析失败,准备重试...", - 10, - "warning" - ) + yield await tracker.retry(career_retry_count, MAX_CAREER_RETRIES, "JSON解析失败") continue else: - yield await SSEResponse.send_progress("⚠️ 职业体系解析失败(已达最大重试次数),已跳过", 99) + yield await tracker.error("职业体系解析失败(已达最大重试次数)") + return except Exception as e: logger.error(f"❌ 职业体系保存失败(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") career_retry_count += 1 if career_retry_count < MAX_CAREER_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ 保存失败,准备重试...", - 10, - "warning" - ) + yield await tracker.retry(career_retry_count, MAX_CAREER_RETRIES, "保存失败") continue else: - yield await SSEResponse.send_progress("⚠️ 职业体系保存失败(已达最大重试次数),已跳过", 99) + yield await tracker.error("职业体系保存失败(已达最大重试次数)") + return except Exception as e: logger.error(f"❌ 职业体系生成异常(尝试{career_retry_count+1}/{MAX_CAREER_RETRIES}): {e}") career_retry_count += 1 if career_retry_count < MAX_CAREER_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ 生成异常,准备重试...", - 10, - "warning" - ) + yield await tracker.retry(career_retry_count, MAX_CAREER_RETRIES, "生成异常") continue else: - yield await SSEResponse.send_progress("⚠️ 职业体系生成失败(已达最大重试次数),已跳过(不影响项目创建)", 99) - - db_committed = True - - # 发送最终结果 - yield await SSEResponse.send_result({ - "project_id": project.id, - "time_period": world_data.get("time_period"), - "location": world_data.get("location"), - "atmosphere": world_data.get("atmosphere"), - "rules": world_data.get("rules") - }) - - yield await SSEResponse.send_progress("职业体系完成!", 100, "success") - yield await SSEResponse.send_progress("🎉 所有步骤已完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.error(f"职业体系生成失败: {str(e)}") + return except GeneratorExit: - # SSE连接断开,回滚未提交的事务 - logger.warning("世界构建生成器被提前关闭") + logger.warning("职业体系生成器被提前关闭") if not db_committed and db.in_transaction(): await db.rollback() - logger.info("世界构建事务已回滚(GeneratorExit)") + logger.info("职业体系事务已回滚(GeneratorExit)") except Exception as e: - logger.error(f"世界构建流式生成失败: {str(e)}") - # 异常时回滚事务 + logger.error(f"职业体系流式生成失败: {str(e)}") if not db_committed and db.in_transaction(): await db.rollback() - logger.info("世界构建事务已回滚(异常)") - yield await SSEResponse.send_error(f"生成失败: {str(e)}") + logger.info("职业体系事务已回滚(异常)") + yield await tracker.error(f"生成失败: {str(e)}") -@router.post("/world-building", summary="流式生成世界构建") -async def generate_world_building_stream( +@router.post("/career-system", summary="流式生成职业体系") +async def generate_career_system_stream( request: Request, data: Dict[str, Any], db: AsyncSession = Depends(get_db), user_ai_service: AIService = Depends(get_user_ai_service) ): """ - 使用SSE流式生成世界构建,避免超时 + 使用SSE流式生成职业体系,避免超时 前端使用EventSource接收实时进度和结果 """ # 从中间件注入user_id到data中 if hasattr(request.state, 'user_id'): data['user_id'] = request.state.user_id - return create_sse_response(world_building_generator(data, db, user_ai_service)) + return create_sse_response(career_system_generator(data, db, user_ai_service)) async def characters_generator( @@ -546,8 +579,11 @@ async def characters_generator( ) -> AsyncGenerator[str, None]: """角色批量生成流式生成器 - 优化版:分批+重试+MCP工具增强""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("角色") + try: - yield await SSEResponse.send_progress("开始生成角色...", 5) + yield await tracker.start() project_id = data.get("project_id") count = data.get("count", 5) @@ -561,13 +597,13 @@ async def characters_generator( user_id = data.get("user_id") # 从中间件注入 # 验证项目 - yield await SSEResponse.send_progress("验证项目...", 10) + yield await tracker.loading("验证项目...", 0.3) result = await db.execute( select(Project).where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return project.wizard_step = 2 @@ -579,63 +615,13 @@ async def characters_generator( "rules": project.world_rules or "未设定" } - # MCP工具增强:收集角色参考资料 - character_reference_materials = "" - if enable_mcp and user_id: - try: - # 先静默检查是否有可用工具 - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - # 只有在真正有可用工具时才显示消息和调用 - if available_tools: - yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集角色参考资料...", 8) - - mcp_template = await PromptService.get_template("MCP_CHARACTER_PLANNING", user_id, db) - planning_prompt = PromptService.format_prompt( - mcp_template, - title=project.title, - genre=genre or project.genre, - theme=theme or project.theme, - time_period=world_context.get('time_period', '未设定'), - location=world_context.get('location', '未设定') - ) - - # 调用MCP增强的AI(非流式,最多1轮工具调用,避免超时) - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, # ✅ 优化: 从2轮减少到1轮 - tool_choice="auto", - provider=None, - model=None - ) - - # 提取参考资料 - if planning_result.get("tool_calls_made", 0) > 0: - yield await SSEResponse.send_progress( - f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", - 12 - ) - character_reference_materials = planning_result.get("content", "") - else: - # 有工具但未使用 - logger.debug("MCP工具可用但AI未选择使用") - else: - # 没有可用工具,静默跳过 - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") - - except Exception as e: - logger.warning(f"MCP工具调用失败(降级处理): {e}") - yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 12) + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db # 获取项目的职业列表,用于角色职业分配 - yield await SSEResponse.send_progress("加载职业体系...", 13) + yield await tracker.loading("加载职业体系...", 0.8) career_result = await db.execute( select(Career).where(Career.project_id == project_id).order_by(Career.type, Career.id) ) @@ -668,8 +654,8 @@ async def characters_generator( else: logger.warning("⚠️ 项目没有职业体系,跳过职业分配") - # 优化的分批策略:每批生成3个,平衡效率和成功率 - BATCH_SIZE = 3 # 每批生成3个角色 + # 优化的分批策略:每批生成5个,平衡效率和成功率 + BATCH_SIZE = 5 # 每批生成5个角色 MAX_RETRIES = 3 # 每批最多重试3次 all_characters = [] total_batches = (count + BATCH_SIZE - 1) // BATCH_SIZE @@ -693,10 +679,16 @@ async def characters_generator( while retry_count < MAX_RETRIES and not batch_success: try: - retry_suffix = f" (重试{retry_count}/{MAX_RETRIES})" if retry_count > 0 else "" - yield await SSEResponse.send_progress( - f"生成第{batch_idx+1}/{total_batches}批角色 ({current_batch_size}个){retry_suffix}...", - batch_progress + # 重试时重置生成进度 + if retry_count > 0: + tracker.reset_generating_progress() + + yield await tracker.generating( + current_chars=0, + estimated_total=1000, + message=f"生成第{batch_idx+1}/{total_batches}批角色 ({current_batch_size}个)", + retry_count=retry_count, + max_retries=MAX_RETRIES ) # 构建批次要求 - 包含已生成角色信息保持连贯 @@ -735,45 +727,40 @@ async def characters_generator( requirements=batch_requirements + careers_context # 添加职业上下文 ) - # 如果有MCP参考资料,增强提示词 - if character_reference_materials: - prompt = f"""{base_prompt} - -【参考资料】 -以下是通过MCP工具收集的真实背景资料,请参考这些信息设计更真实的角色: - -{character_reference_materials} - -请结合上述资料,设计符合历史/文化背景的角色。""" - else: - prompt = base_prompt + prompt = base_prompt # 流式生成(带字数统计) accumulated_text = "" chunk_count = 0 + estimated_total = 1000 + async for chunk in user_ai_service.generate_text_stream( prompt=prompt, provider=provider, - model=model + model=model, + tool_choice="required", ): chunk_count += 1 accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 定期更新进度和字数 - if chunk_count % 5 == 0: - progress = min(batch_progress + 5 + (chunk_count // 10), batch_progress + 15) - yield await SSEResponse.send_progress( - f"生成角色中... ({len(accumulated_text)}字符)", - progress + # 定期更新进度 + current_len = len(accumulated_text) + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=current_len, + estimated_total=estimated_total, + message=f"生成第{batch_idx+1}/{total_batches}批角色中", + retry_count=retry_count, + max_retries=MAX_RETRIES ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() # 解析批次结果 - 使用统一的JSON清洗方法 cleaned_text = user_ai_service._clean_json_response(accumulated_text) @@ -789,15 +776,11 @@ async def characters_generator( # 如果还有重试机会,继续重试 if retry_count < MAX_RETRIES - 1: retry_count += 1 - yield await SSEResponse.send_progress( - f"⚠️ {error_msg},准备重试...", - batch_progress, - "warning" - ) + yield await tracker.retry(retry_count, MAX_RETRIES, error_msg) continue else: # 最后一次重试仍失败,直接返回错误 - yield await SSEResponse.send_error(error_msg) + yield await tracker.error(error_msg) return all_characters.extend(characters_data) @@ -809,21 +792,13 @@ async def characters_generator( batch_error_message = f"JSON解析失败: {str(e)}" retry_count += 1 if retry_count < MAX_RETRIES: - yield await SSEResponse.send_progress( - f"解析失败,准备重试...", - batch_progress, - "warning" - ) + yield await tracker.retry(retry_count, MAX_RETRIES, "JSON解析失败") except Exception as e: logger.error(f"批次{batch_idx+1}生成异常(尝试{retry_count+1}/{MAX_RETRIES}): {e}") batch_error_message = f"生成异常: {str(e)}" retry_count += 1 if retry_count < MAX_RETRIES: - yield await SSEResponse.send_progress( - f"生成异常,准备重试...", - batch_progress, - "warning" - ) + yield await tracker.retry(retry_count, MAX_RETRIES, "生成异常") # 检查批次是否成功 if not batch_success: @@ -831,11 +806,11 @@ async def characters_generator( if batch_error_message: error_msg += f": {batch_error_message}" logger.error(error_msg) - yield await SSEResponse.send_error(error_msg) + yield await tracker.error(error_msg) return # 保存到数据库 - 分阶段处理以保证一致性 - yield await SSEResponse.send_progress("验证角色数据...", 82) + yield await tracker.parsing("验证角色数据...") # 预处理:构建本批次所有实体的名称集合 valid_entity_names = set() @@ -879,9 +854,9 @@ async def characters_generator( if cleaned_count > 0: logger.info(f"✨ 清理了{cleaned_count}个AI幻觉引用") - yield await SSEResponse.send_progress(f"已清理{cleaned_count}个无效引用", 84) + yield await tracker.parsing(f"已清理{cleaned_count}个无效引用", 0.7) - yield await SSEResponse.send_progress("保存角色到数据库...", 85) + yield await tracker.saving("保存角色到数据库...") # 第一阶段:创建所有Character记录 created_characters = [] @@ -932,7 +907,7 @@ async def characters_generator( # 第二阶段:为角色分配职业并创建CharacterCareer关联 if main_careers or sub_careers: - yield await SSEResponse.send_progress("分配角色职业...", 86) + yield await tracker.saving("分配角色职业...", 0.3) careers_assigned = 0 # 构建职业名称到对象的映射 @@ -1016,7 +991,7 @@ async def characters_generator( await db.flush() logger.info(f"💼 职业分配完成:共分配{careers_assigned}个职业") - yield await SSEResponse.send_progress(f"已分配{careers_assigned}个职业", 87) + yield await tracker.saving(f"已分配{careers_assigned}个职业", 0.4) # 刷新并建立名称映射 for character, _ in created_characters: @@ -1025,7 +1000,7 @@ async def characters_generator( logger.info(f"向导创建角色:{character.name} (ID: {character.id}, 是否组织: {character.is_organization})") # 第三阶段:为is_organization=True的角色创建Organization记录 - yield await SSEResponse.send_progress("创建组织记录...", 88) + yield await tracker.saving("创建组织记录...", 0.5) organization_name_to_obj = {} # 组织名称到Organization对象的映射 for character, char_data in created_characters: @@ -1062,7 +1037,7 @@ async def characters_generator( await db.refresh(character) # 第四阶段:创建角色间的关系 - yield await SSEResponse.send_progress("创建角色关系...", 91) + yield await tracker.saving("创建角色关系...", 0.7) relationships_created = 0 for character, char_data in created_characters: @@ -1130,7 +1105,7 @@ async def characters_generator( continue # 第五阶段:创建组织成员关系 - yield await SSEResponse.send_progress("创建组织成员关系...", 94) + yield await tracker.saving("创建组织成员关系...", 0.9) members_created = 0 for character, char_data in created_characters: @@ -1194,9 +1169,10 @@ async def characters_generator( logger.info(f" - 创建角色关系:{relationships_created} 条") logger.info(f" - 创建组织成员:{members_created} 条") - # 更新项目的角色数量和向导步骤状态为2(角色已完成) + # 更新项目的角色数量和向导步骤状态为3(角色已完成) + # wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成 project.character_count = len(created_characters) - project.wizard_step = 2 + project.wizard_step = 3 logger.info(f"✅ 更新项目角色数量: {project.character_count}") await db.commit() @@ -1205,8 +1181,10 @@ async def characters_generator( # 重新提取character对象 created_characters = [char for char, _ in created_characters] + yield await tracker.complete() + # 发送结果 - yield await SSEResponse.send_result({ + yield await tracker.result({ "message": f"成功生成{len(created_characters)}个角色/组织(分{total_batches}批完成)", "count": len(created_characters), "batches": total_batches, @@ -1233,8 +1211,7 @@ async def characters_generator( ] }) - yield await SSEResponse.send_progress("完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.done() except GeneratorExit: logger.warning("角色生成器被提前关闭") @@ -1246,7 +1223,7 @@ async def characters_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("角色生成事务已回滚(异常)") - yield await SSEResponse.send_error(f"生成失败: {str(e)}") + yield await tracker.error(f"生成失败: {str(e)}") @router.post("/characters", summary="流式批量生成角色") @@ -1274,8 +1251,11 @@ async def outline_generator( ) -> AsyncGenerator[str, None]: """大纲生成流式生成器 - 向导仅生成大纲节点,不展开章节(避免等待过久)""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("大纲") + try: - yield await SSEResponse.send_progress("开始生成大纲...", 5) + yield await tracker.start() project_id = data.get("project_id") # 向导固定生成3个大纲节点(不展开) @@ -1285,20 +1265,21 @@ async def outline_generator( requirements = data.get("requirements", "") provider = data.get("provider") model = data.get("model") + enable_mcp = data.get("enable_mcp", True) # 默认启用MCP user_id = data.get("user_id") # 从中间件注入 # 获取项目信息 - yield await SSEResponse.send_progress("加载项目信息...", 10) + yield await tracker.loading("加载项目信息...", 0.3) result = await db.execute( select(Project).where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return # 获取角色信息 - yield await SSEResponse.send_progress("加载角色信息...", 15) + yield await tracker.loading("加载角色信息...", 0.8) result = await db.execute( select(Character).where(Character.project_id == project_id) ) @@ -1309,8 +1290,8 @@ async def outline_generator( for char in characters ]) - # 第一阶段:生成3个粗粒度大纲节点 - yield await SSEResponse.send_progress(f"生成{outline_count}个大纲节点...", 10) + # 准备提示词 + yield await tracker.preparing(f"准备生成{outline_count}个大纲节点...") outline_requirements = f"{requirements}\n\n【重要说明】这是小说的开局部分,请生成{outline_count}个大纲节点,重点关注:\n" outline_requirements += "1. 引入主要角色和世界观设定\n" @@ -1338,35 +1319,38 @@ async def outline_generator( requirements=outline_requirements ) - # 流式生成大纲(带字数统计) + # 流式生成大纲 + estimated_total = 1000 accumulated_text = "" chunk_count = 0 + yield await tracker.generating(current_chars=0, estimated_total=estimated_total) + async for chunk in user_ai_service.generate_text_stream( prompt=outline_prompt, provider=provider, - model=model + model=model, ): chunk_count += 1 accumulated_text += chunk # 发送内容块 - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - # 定期更新进度和字数(5-95%,AI生成占90%) - if chunk_count % 5 == 0: - progress = min(10 + (chunk_count // 3), 90) - yield await SSEResponse.send_progress( - f"生成大纲中... ({len(accumulated_text)}字符)", - progress + # 定期更新进度 + current_len = len(accumulated_text) + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=current_len, + estimated_total=estimated_total ) # 每20个块发送心跳 if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() # 解析大纲结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析大纲...", 96) + yield await tracker.parsing("解析大纲数据...") try: cleaned_text = user_ai_service._clean_json_response(accumulated_text) @@ -1375,11 +1359,11 @@ async def outline_generator( outline_data = [outline_data] except json.JSONDecodeError as e: logger.error(f"大纲JSON解析失败: {e}") - yield await SSEResponse.send_error("大纲生成失败,请重试") + yield await tracker.error("大纲生成失败,请重试") return # 保存大纲到数据库 - yield await SSEResponse.send_progress("保存大纲到数据库...", 97) + yield await tracker.saving("保存大纲到数据库...") created_outlines = [] for index, outline_item in enumerate(outline_data[:outline_count], 1): outline = Outline( @@ -1402,7 +1386,7 @@ async def outline_generator( created_chapters = [] if project.outline_mode == 'one-to-one': # 一对一模式:自动为每个大纲创建对应的章节 - yield await SSEResponse.send_progress("一对一模式:自动创建章节...", 98) + yield await tracker.saving("一对一模式:自动创建章节...", 0.7) for outline in created_outlines: chapter = Chapter( @@ -1421,19 +1405,20 @@ async def outline_generator( await db.refresh(chapter) logger.info(f"✅ 一对一模式:自动创建了{len(created_chapters)}个章节") - yield await SSEResponse.send_progress(f"已自动创建{len(created_chapters)}个章节", 99) + yield await tracker.saving(f"已自动创建{len(created_chapters)}个章节", 0.9) else: # 一对多模式:跳过自动创建,用户可手动展开 - yield await SSEResponse.send_progress("细化模式:跳过自动创建章节", 99) + yield await tracker.saving("细化模式:跳过自动创建章节", 0.9) logger.info(f"📝 细化模式:跳过章节创建,用户可在大纲页面手动展开") # 更新项目信息 + # wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成 project.chapter_count = len(created_chapters) # 记录实际创建的章节数 project.narrative_perspective = narrative_perspective project.target_words = target_words project.status = "writing" project.wizard_status = "completed" - project.wizard_step = 3 + project.wizard_step = 4 await db.commit() db_committed = True @@ -1451,8 +1436,10 @@ async def outline_generator( result_message = f"成功生成{len(created_outlines)}个大纲节点(细化模式,可在大纲页面手动展开)" result_note = "可在大纲页面展开为多个章节" + yield await tracker.complete() + # 发送结果 - yield await SSEResponse.send_result({ + yield await tracker.result({ "message": result_message, "outline_count": len(created_outlines), "chapter_count": len(created_chapters), @@ -1476,8 +1463,7 @@ async def outline_generator( ] if created_chapters else [] }) - yield await SSEResponse.send_progress("完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.done() except GeneratorExit: logger.warning("大纲生成器被提前关闭") @@ -1489,7 +1475,7 @@ async def outline_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("大纲生成事务已回滚(异常)") - yield await SSEResponse.send_error(f"生成失败: {str(e)}") + yield await tracker.error(f"生成失败: {str(e)}") @router.post("/outline", summary="流式生成完整大纲") async def generate_outline_stream( @@ -1511,16 +1497,20 @@ async def world_building_regenerate_generator( ) -> AsyncGenerator[str, None]: """世界观重新生成流式生成器""" db_committed = False + # 初始化标准进度追踪器 + tracker = WizardProgressTracker("世界观") + try: - yield await SSEResponse.send_progress("开始重新生成世界观...", 10) + yield await tracker.start("开始重新生成世界观...") # 获取项目信息 + yield await tracker.loading("加载项目信息...") result = await db.execute( select(Project).where(Project.id == project_id) ) project = result.scalar_one_or_none() if not project: - yield await SSEResponse.send_error("项目不存在", 404) + yield await tracker.error("项目不存在", 404) return # 提取参数 @@ -1530,7 +1520,7 @@ async def world_building_regenerate_generator( user_id = data.get("user_id") # 获取基础提示词(支持自定义) - yield await SSEResponse.send_progress("准备AI提示词...", 15) + yield await tracker.preparing("准备AI提示词...") template = await PromptService.get_template("WORLD_BUILDING", user_id, db) base_prompt = PromptService.format_prompt( template, @@ -1540,112 +1530,67 @@ async def world_building_regenerate_generator( description=project.description or "暂无简介" ) - # MCP工具增强:收集参考资料 - reference_materials = "" - if enable_mcp and user_id: - try: - from app.services.mcp_tool_service import mcp_tool_service - available_tools = await mcp_tool_service.get_user_enabled_tools( - user_id=user_id, - db_session=db - ) - - if available_tools: - yield await SSEResponse.send_progress("🔍 尝试使用MCP工具收集参考资料...", 18) - - mcp_template = await PromptService.get_template("MCP_WORLD_BUILDING_PLANNING", user_id, db) - planning_prompt = PromptService.format_prompt( - mcp_template, - title=project.title, - genre=project.genre, - theme=project.theme, - description=project.description or '未设定' - ) - - planning_result = await user_ai_service.generate_text_with_mcp( - prompt=planning_prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2, - tool_choice="auto", - provider=None, - model=None - ) - - if planning_result.get("tool_calls_made", 0) > 0: - yield await SSEResponse.send_progress( - f"✅ MCP工具调用成功({planning_result['tool_calls_made']}次)", - 25 - ) - reference_materials = planning_result.get("content", "") - else: - logger.debug("MCP工具可用但AI未选择使用") - else: - logger.debug(f"用户 {user_id} 未启用MCP工具,跳过MCP增强") - - except Exception as e: - logger.warning(f"MCP工具调用失败(降级处理): {e}") - yield await SSEResponse.send_progress("⚠️ MCP工具暂时不可用,使用基础模式", 25) - - # 构建增强提示词 - if reference_materials: - enhanced_prompt = f"""{base_prompt} - -【参考资料】 -以下是通过MCP工具收集的真实背景资料,请参考这些信息构建更真实的世界观: - -{reference_materials} - -请结合上述资料,生成符合历史/现实的世界观设定。""" - final_prompt = enhanced_prompt - yield await SSEResponse.send_progress("💡 已整合参考资料,开始生成世界观...", 10) - else: - final_prompt = base_prompt - yield await SSEResponse.send_progress("正在调用AI生成...", 10) + # 设置用户信息以启用MCP + if user_id: + user_ai_service.user_id = user_id + user_ai_service.db_session = db # ===== 流式生成世界观(带重试机制) ===== MAX_WORLD_RETRIES = 3 # 最多重试3次 world_retry_count = 0 world_generation_success = False world_data = {} + estimated_total = 1000 while world_retry_count < MAX_WORLD_RETRIES and not world_generation_success: try: - retry_suffix = f" (重试{world_retry_count}/{MAX_WORLD_RETRIES})" if world_retry_count > 0 else "" - yield await SSEResponse.send_progress(f"重新生成世界观{retry_suffix}...", 10 + world_retry_count * 5) + # 重试时重置生成进度 + if world_retry_count > 0: + tracker.reset_generating_progress() + + yield await tracker.generating( + current_chars=0, + estimated_total=estimated_total, + message="重新生成世界观", + retry_count=world_retry_count, + max_retries=MAX_WORLD_RETRIES + ) # 流式生成世界观 accumulated_text = "" chunk_count = 0 async for chunk in user_ai_service.generate_text_stream( - prompt=final_prompt, + prompt=base_prompt, provider=provider, - model=model + model=model, + tool_choice="required", ): chunk_count += 1 accumulated_text += chunk - yield await SSEResponse.send_chunk(chunk) + yield await tracker.generating_chunk(chunk) - if chunk_count % 5 == 0: - progress = min(10 + (chunk_count // 5), 85) - yield await SSEResponse.send_progress(f"生成中... ({len(accumulated_text)}字符)", progress) + # 定期更新进度 + current_len = len(accumulated_text) + if chunk_count % 10 == 0: + yield await tracker.generating( + current_chars=current_len, + estimated_total=estimated_total, + message="重新生成世界观", + retry_count=world_retry_count, + max_retries=MAX_WORLD_RETRIES + ) if chunk_count % 20 == 0: - yield await SSEResponse.send_heartbeat() + yield await tracker.heartbeat() # 检查是否返回空响应 if not accumulated_text or not accumulated_text.strip(): logger.warning(f"⚠️ AI返回空世界观(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES})") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ AI返回为空,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "AI返回为空") continue else: # 达到最大重试次数,使用默认值 @@ -1660,7 +1605,7 @@ async def world_building_regenerate_generator( break # 解析结果 - 使用统一的JSON清洗方法 - yield await SSEResponse.send_progress("解析AI返回结果...", 80) + yield await tracker.parsing("解析AI返回结果...") try: logger.info(f"🔍 开始清洗JSON,原始长度: {len(accumulated_text)}") @@ -1677,11 +1622,7 @@ async def world_building_regenerate_generator( logger.error(f" 原始内容预览: {accumulated_text[:200]}") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ JSON解析失败,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "JSON解析失败") continue else: # 达到最大重试次数,使用默认值 @@ -1697,11 +1638,7 @@ async def world_building_regenerate_generator( logger.error(f"❌ 世界观重新生成异常(尝试{world_retry_count+1}/{MAX_WORLD_RETRIES}): {type(e).__name__}: {e}") world_retry_count += 1 if world_retry_count < MAX_WORLD_RETRIES: - yield await SSEResponse.send_progress( - f"⚠️ 生成异常,准备重试...", - 10 + world_retry_count * 5, - "warning" - ) + yield await tracker.retry(world_retry_count, MAX_WORLD_RETRIES, "生成异常") continue else: # 最后一次重试仍失败,抛出异常 @@ -1709,18 +1646,19 @@ async def world_building_regenerate_generator( raise # 不保存到数据库,仅返回生成结果供用户预览 - yield await SSEResponse.send_progress("生成完成,等待用户确认...", 90) + yield await tracker.saving("生成完成,等待用户确认...", 0.5) + + yield await tracker.complete() # 发送最终结果(不包含project_id,表示未保存) - yield await SSEResponse.send_result({ + yield await tracker.result({ "time_period": world_data.get("time_period"), "location": world_data.get("location"), "atmosphere": world_data.get("atmosphere"), "rules": world_data.get("rules") }) - yield await SSEResponse.send_progress("完成!", 100, "success") - yield await SSEResponse.send_done() + yield await tracker.done() except GeneratorExit: logger.warning("世界观重新生成器被提前关闭") @@ -1732,7 +1670,7 @@ async def world_building_regenerate_generator( if not db_committed and db.in_transaction(): await db.rollback() logger.info("世界观重新生成事务已回滚(异常)") - yield await SSEResponse.send_error(f"生成失败: {str(e)}") + yield await tracker.error(f"生成失败: {str(e)}") @router.post("/world-building/{project_id}/regenerate", summary="流式重新生成世界观") @@ -1751,5 +1689,3 @@ async def regenerate_world_building_stream( if hasattr(request.state, 'user_id'): data['user_id'] = request.state.user_id return create_sse_response(world_building_regenerate_generator(project_id, data, db, user_ai_service)) - - diff --git a/backend/app/config.py b/backend/app/config.py index 6d6a43a..7c28375 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -77,10 +77,8 @@ class Settings(BaseSettings): default_temperature: float = 0.7 default_max_tokens: int = 32000 - # MCP适配器配置 - enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力) - mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时) - mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入) + # MCP配置 + mcp_max_rounds: int = 3 # MCP工具调用最大轮数(全局统一控制) # LinuxDO OAuth2 配置 LINUXDO_CLIENT_ID: Optional[str] = None diff --git a/backend/app/database.py b/backend/app/database.py index 3101a98..29afc91 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -167,7 +167,7 @@ async def get_db(request: Request): _session_stats["created"] += 1 _session_stats["active"] += 1 - logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}") + # logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}") try: yield session diff --git a/backend/app/logger.py b/backend/app/logger.py index 0c699c0..475451d 100644 --- a/backend/app/logger.py +++ b/backend/app/logger.py @@ -130,11 +130,15 @@ def _configure_third_party_loggers(): logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING) logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING) + # aiosqlite - 异步SQLite,禁用DEBUG日志 + logging.getLogger('aiosqlite').setLevel(logging.WARNING) + # Watchfiles - 开发时的文件监控,降低级别 logging.getLogger('watchfiles').setLevel(logging.WARNING) - # httpx - HTTP客户端 + # httpx/httpcore - HTTP客户端,禁用DEBUG日志 logging.getLogger('httpx').setLevel(logging.WARNING) + logging.getLogger('httpcore').setLevel(logging.WARNING) # openai/anthropic - AI客户端库 logging.getLogger('openai').setLevel(logging.WARNING) diff --git a/backend/app/main.py b/backend/app/main.py index 5a3a57e..526859e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -12,7 +12,7 @@ from app.database import close_db, _session_stats from app.logger import setup_logging, get_logger from app.middleware import RequestIDMiddleware from app.middleware.auth_middleware import AuthMiddleware -from app.mcp.registry import mcp_registry +from app.mcp import mcp_client, register_status_sync setup_logging( level=config_settings.log_level, @@ -27,12 +27,15 @@ logger = get_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" + # 注册MCP状态同步服务 + register_status_sync() + logger.info("应用启动完成") yield # 清理MCP插件 - await mcp_registry.cleanup_all() + await mcp_client.cleanup() # 清理HTTP客户端池 from app.services.ai_service import cleanup_http_clients diff --git a/backend/app/mcp/__init__.py b/backend/app/mcp/__init__.py index 1a8b35f..f6a006c 100644 --- a/backend/app/mcp/__init__.py +++ b/backend/app/mcp/__init__.py @@ -1,4 +1,36 @@ -"""MCP插件系统""" -from .registry import mcp_registry +"""MCP模块 - 统一的MCP客户端管理 -__all__ = ["mcp_registry"] \ No newline at end of file +本模块提供MCP(Model Context Protocol)客户端的统一管理接口。 + +推荐使用方式: + from app.mcp import mcp_client, MCPPluginConfig + + # 注册插件 + await mcp_client.register(MCPPluginConfig( + user_id="user123", + plugin_name="exa-search", + url="http://localhost:8000/mcp" + )) + + # 获取工具 + tools = await mcp_client.get_tools("user123", "exa-search") + + # 调用工具 + result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."}) + + # 注册状态变更回调 + from app.mcp.status_sync import register_status_sync + register_status_sync() +""" + +from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus +from .status_sync import register_status_sync + +__all__ = [ + "mcp_client", + "MCPClientFacade", + "MCPPluginConfig", + "MCPError", + "PluginStatus", + "register_status_sync", +] \ No newline at end of file diff --git a/backend/app/mcp/adapters/__init__.py b/backend/app/mcp/adapters/__init__.py deleted file mode 100644 index 54f5236..0000000 --- a/backend/app/mcp/adapters/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""MCP适配器模块 - 支持多种AI API的工具调用方式""" - -from .base import BaseMCPAdapter, AdapterType -from .prompt_injection import PromptInjectionAdapter -from .function_calling import FunctionCallingAdapter -from .universal import UniversalMCPAdapter - -__all__ = [ - "BaseMCPAdapter", - "AdapterType", - "PromptInjectionAdapter", - "FunctionCallingAdapter", - "UniversalMCPAdapter", -] \ No newline at end of file diff --git a/backend/app/mcp/adapters/base.py b/backend/app/mcp/adapters/base.py deleted file mode 100644 index a744987..0000000 --- a/backend/app/mcp/adapters/base.py +++ /dev/null @@ -1,89 +0,0 @@ -"""MCP适配器基类""" - -from abc import ABC, abstractmethod -from enum import Enum -from typing import Dict, Any, List, Optional -from dataclasses import dataclass - - -class AdapterType(Enum): - """适配器类型""" - FUNCTION_CALLING = "function_calling" # 标准Function Calling - PROMPT_INJECTION = "prompt_injection" # 提示词注入 - REACT = "react" # ReAct模式 - XML = "xml" # XML标记 - - -@dataclass -class ToolCallResult: - """工具调用结果""" - tool_calls: List[Dict[str, Any]] # 解析出的工具调用 - raw_response: str # 原始AI响应 - has_tool_calls: bool # 是否包含工具调用 - needs_continuation: bool = False # 是否需要继续对话 - - -class BaseMCPAdapter(ABC): - """MCP适配器基类""" - - def __init__(self): - self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION - - @abstractmethod - def format_tools_for_prompt( - self, - tools: List[Dict[str, Any]], - user_message: str - ) -> str: - """ - 将工具列表格式化为提示词 - - Args: - tools: MCP工具列表 - user_message: 用户消息 - - Returns: - 格式化后的提示词 - """ - pass - - @abstractmethod - def parse_tool_calls(self, ai_response: str) -> ToolCallResult: - """ - 从AI响应中解析工具调用 - - Args: - ai_response: AI的原始响应 - - Returns: - 解析结果 - """ - pass - - @abstractmethod - def build_continuation_prompt( - self, - original_message: str, - ai_response: str, - tool_results: List[Dict[str, Any]] - ) -> str: - """ - 构建包含工具结果的继续对话提示词 - - Args: - original_message: 原始用户消息 - ai_response: AI响应 - tool_results: 工具执行结果 - - Returns: - 继续对话的提示词 - """ - pass - - def supports_native_tools(self) -> bool: - """是否支持原生工具调用(如Function Calling)""" - return False - - def get_adapter_type(self) -> AdapterType: - """获取适配器类型""" - return self.adapter_type \ No newline at end of file diff --git a/backend/app/mcp/adapters/function_calling.py b/backend/app/mcp/adapters/function_calling.py deleted file mode 100644 index d302bb5..0000000 --- a/backend/app/mcp/adapters/function_calling.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Function Calling适配器 - 支持原生Function Calling的API""" - -import json -from typing import Dict, Any, List -from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult -from app.logger import get_logger - -logger = get_logger(__name__) - - -class FunctionCallingAdapter(BaseMCPAdapter): - """Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)""" - - def __init__(self): - super().__init__() - self.adapter_type = AdapterType.FUNCTION_CALLING - - def supports_native_tools(self) -> bool: - """支持原生工具调用""" - return True - - def format_tools_for_prompt( - self, - tools: List[Dict[str, Any]], - user_message: str - ) -> str: - """ - Function Calling模式下,工具通过API参数传递,不需要修改提示词 - - Returns: - 原始用户消息 - """ - return user_message - - def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - 获取适用于API的工具格式 - - Args: - tools: MCP工具列表 - - Returns: - 适用于OpenAI Function Calling的工具格式 - """ - return tools - - def parse_tool_calls(self, ai_response: Any) -> ToolCallResult: - """ - 从AI响应中解析工具调用(Function Calling格式) - - Args: - ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象) - - Returns: - 解析结果 - """ - - try: - # 处理不同类型的响应 - if isinstance(ai_response, dict): - # 字典格式(OpenAI API响应) - message = ai_response.get("choices", [{}])[0].get("message", {}) - tool_calls = message.get("tool_calls", []) - content = message.get("content", "") - - elif hasattr(ai_response, "choices"): - # 对象格式(OpenAI SDK响应) - message = ai_response.choices[0].message - tool_calls = getattr(message, "tool_calls", None) or [] - content = getattr(message, "content", "") or "" - - # 转换为字典格式 - if tool_calls: - tool_calls = [ - { - "id": tc.id, - "type": tc.type, - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments - } - } - for tc in tool_calls - ] - else: - # 字符串格式(降级为文本响应) - return ToolCallResult( - tool_calls=[], - raw_response=str(ai_response), - has_tool_calls=False - ) - - has_tool_calls = len(tool_calls) > 0 - - if has_tool_calls: - logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用") - for tc in tool_calls: - logger.info(f" - {tc['function']['name']}") - - return ToolCallResult( - tool_calls=tool_calls, - raw_response=content or "", - has_tool_calls=has_tool_calls, - needs_continuation=has_tool_calls - ) - - except Exception as e: - logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True) - return ToolCallResult( - tool_calls=[], - raw_response=str(ai_response), - has_tool_calls=False - ) - - def build_continuation_prompt( - self, - original_message: str, - ai_response: str, - tool_results: List[Dict[str, Any]] - ) -> str: - """ - 构建包含工具结果的继续对话提示词 - - 在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分 - """ - # Function Calling模式下通常通过消息历史传递工具结果 - # 这里提供一个降级方案 - results_text = "\n\n".join([ - f"工具 {r['name']} 的结果:\n{r['content']}" - for r in tool_results - ]) - - return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。" - - def build_messages_with_tool_results( - self, - messages: List[Dict[str, Any]], - tool_calls: List[Dict[str, Any]], - tool_results: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - 构建包含工具结果的消息历史(Function Calling标准格式) - - Args: - messages: 原始消息历史 - tool_calls: AI的工具调用 - tool_results: 工具执行结果 - - Returns: - 更新后的消息历史 - """ - - new_messages = messages.copy() - - # 添加助手的工具调用消息 - new_messages.append({ - "role": "assistant", - "content": None, - "tool_calls": tool_calls - }) - - # 添加工具结果消息 - for result in tool_results: - new_messages.append({ - "role": "tool", - "tool_call_id": result.get("tool_call_id", ""), - "name": result.get("name", ""), - "content": result.get("content", "") - }) - - return new_messages \ No newline at end of file diff --git a/backend/app/mcp/adapters/prompt_injection.py b/backend/app/mcp/adapters/prompt_injection.py deleted file mode 100644 index 15ea7b8..0000000 --- a/backend/app/mcp/adapters/prompt_injection.py +++ /dev/null @@ -1,274 +0,0 @@ -"""提示词注入适配器 - 最通用的MCP工具调用方式""" - -import re -import json -from typing import Dict, Any, List -from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult -from app.logger import get_logger - -logger = get_logger(__name__) - - -class PromptInjectionAdapter(BaseMCPAdapter): - """提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用""" - - def __init__(self): - super().__init__() - self.adapter_type = AdapterType.PROMPT_INJECTION - - def format_tools_for_prompt( - self, - tools: List[Dict[str, Any]], - user_message: str - ) -> str: - """将工具列表注入到提示词中""" - - if not tools: - return user_message - - # 格式化工具描述 - tool_descriptions = self._format_tools_as_text(tools) - - # 构建增强的提示词 - enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。 - -## 可用工具 - -{tool_descriptions} - -## 工具使用说明 - -当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具): - - - -工具名称 - -{{ - "参数名1": "参数值1", - "参数名2": "参数值2" -}} - - - - -## 重要提示 - -1. 只有在确实需要使用工具时才调用工具 -2. 参数必须是有效的JSON格式 -3. 仔细检查参数是否符合工具的要求 -4. 可以在一个标签内包含多个 -5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答 - ---- - -用户问题:{user_message} - -请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。""" - - return enhanced_prompt - - def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str: - """将工具格式化为可读的文本描述""" - lines = [] - - for i, tool in enumerate(tools, 1): - func = tool.get("function", {}) - name = func.get("name", "unknown") - description = func.get("description", "无描述") - parameters = func.get("parameters", {}) - - lines.append(f"### {i}. {name}") - lines.append(f"**描述**: {description}") - lines.append("") - - # 格式化参数信息 - if parameters and "properties" in parameters: - lines.append("**参数**:") - properties = parameters.get("properties", {}) - required = parameters.get("required", []) - - for param_name, param_info in properties.items(): - param_type = param_info.get("type", "string") - param_desc = param_info.get("description", "") - is_required = "必填" if param_name in required else "可选" - - lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}") - lines.append("") - - # 添加示例 - if "example" in func: - lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}") - lines.append("") - - return "\n".join(lines) - - def parse_tool_calls(self, ai_response) -> ToolCallResult: - """从AI响应中解析工具调用""" - - tool_calls = [] - - try: - # 处理不同类型的响应 - if isinstance(ai_response, dict): - # 如果是字典,提取content字段 - ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "") - if not ai_response: - return ToolCallResult( - tool_calls=[], - raw_response="", - has_tool_calls=False - ) - elif not isinstance(ai_response, str): - # 转换为字符串 - ai_response = str(ai_response) - - # 使用正则提取 标签内容 - tool_calls_match = re.search( - r'(.*?)', - ai_response, - re.DOTALL | re.IGNORECASE - ) - - if not tool_calls_match: - # 没有找到工具调用 - return ToolCallResult( - tool_calls=[], - raw_response=ai_response, - has_tool_calls=False - ) - - tool_calls_content = tool_calls_match.group(1) - - # 提取所有 标签 - tool_call_pattern = r'(.*?)' - tool_call_matches = re.findall( - tool_call_pattern, - tool_calls_content, - re.DOTALL | re.IGNORECASE - ) - - for i, tool_call_content in enumerate(tool_call_matches): - # 提取工具名称 - name_match = re.search( - r'(.*?)', - tool_call_content, - re.DOTALL | re.IGNORECASE - ) - - # 提取参数 - args_match = re.search( - r'(.*?)', - tool_call_content, - re.DOTALL | re.IGNORECASE - ) - - if name_match and args_match: - tool_name = name_match.group(1).strip() - arguments_str = args_match.group(1).strip() - - try: - # 解析JSON参数 - arguments = json.loads(arguments_str) - - # 构建标准格式的工具调用 - tool_calls.append({ - "id": f"call_{i}", - "type": "function", - "function": { - "name": tool_name, - "arguments": json.dumps(arguments, ensure_ascii=False) - } - }) - - logger.info(f"✅ 解析工具调用: {tool_name}") - - except json.JSONDecodeError as e: - logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}") - continue - - has_tool_calls = len(tool_calls) > 0 - - if has_tool_calls: - logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用") - - return ToolCallResult( - tool_calls=tool_calls, - raw_response=ai_response, - has_tool_calls=has_tool_calls, - needs_continuation=has_tool_calls - ) - - except Exception as e: - logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True) - return ToolCallResult( - tool_calls=[], - raw_response=ai_response, - has_tool_calls=False - ) - - def build_continuation_prompt( - self, - original_message: str, - ai_response: str, - tool_results: List[Dict[str, Any]] - ) -> str: - """构建包含工具结果的继续对话提示词""" - - # 格式化工具结果 - results_text = self._format_tool_results(tool_results) - - continuation = f"""你之前尝试使用工具来回答用户的问题。 - -原始问题:{original_message} - -你的工具调用: -{self._extract_tool_calls_text(ai_response)} - -工具执行结果: -{results_text} - -现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。""" - - return continuation - - def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str: - """格式化工具结果为可读文本""" - lines = [] - - for i, result in enumerate(tool_results, 1): - tool_name = result.get("name", "unknown") - success = result.get("success", False) - content = result.get("content", "") - - status = "✅ 成功" if success else "❌ 失败" - lines.append(f"{i}. {tool_name} - {status}") - - if success: - # 尝试美化JSON内容 - try: - if isinstance(content, str): - content_obj = json.loads(content) - content = json.dumps(content_obj, ensure_ascii=False, indent=2) - except: - pass - lines.append(f"```\n{content}\n```") - else: - error = result.get("error", "未知错误") - lines.append(f"错误信息: {error}") - - lines.append("") - - return "\n".join(lines) - - def _extract_tool_calls_text(self, ai_response: str) -> str: - """从AI响应中提取工具调用部分的文本""" - match = re.search( - r'(.*?)', - ai_response, - re.DOTALL | re.IGNORECASE - ) - - if match: - return match.group(0) - return "(未找到工具调用)" \ No newline at end of file diff --git a/backend/app/mcp/adapters/universal.py b/backend/app/mcp/adapters/universal.py deleted file mode 100644 index aa2be30..0000000 --- a/backend/app/mcp/adapters/universal.py +++ /dev/null @@ -1,353 +0,0 @@ -"""通用MCP适配器 - 自动检测API能力并选择最佳适配器""" - -import time -import asyncio -from typing import Dict, Any, List, Optional -from datetime import datetime, timedelta -from dataclasses import dataclass - -from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult -from app.mcp.adapters.prompt_injection import PromptInjectionAdapter -from app.mcp.adapters.function_calling import FunctionCallingAdapter -from app.logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class APICapability: - """API能力检测结果""" - supports_function_calling: bool - tested_at: datetime - test_duration_ms: float - error_message: Optional[str] = None - - -class UniversalMCPAdapter: - """ - 通用MCP适配器管理器 - - 功能: - 1. 自动检测API是否支持Function Calling - 2. 缓存检测结果 - 3. 自动降级策略:FC失败时切换到提示词注入 - 4. 提供统一接口 - """ - - def __init__( - self, - cache_ttl_hours: int = 24, - enable_auto_fallback: bool = True - ): - """ - 初始化通用适配器 - - Args: - cache_ttl_hours: 能力检测缓存时长(小时) - enable_auto_fallback: 是否启用自动降级 - """ - # 适配器实例 - self.adapters = { - AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(), - AdapterType.PROMPT_INJECTION: PromptInjectionAdapter() - } - - # API能力缓存: {api_identifier: APICapability} - self._capability_cache: Dict[str, APICapability] = {} - self._cache_ttl = timedelta(hours=cache_ttl_hours) - self._cache_lock = asyncio.Lock() - - # 配置 - self._enable_auto_fallback = enable_auto_fallback - - logger.info( - f"✅ UniversalMCPAdapter初始化完成 " - f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})" - ) - - async def get_adapter( - self, - api_identifier: str, - test_function: Optional[callable] = None - ) -> BaseMCPAdapter: - """ - 获取适合当前API的适配器 - - Args: - api_identifier: API标识符(如"openai_official", "azure_openai"等) - test_function: 可选的测试函数,用于检测API能力 - - Returns: - 最适合的适配器实例 - """ - - # 检查缓存 - capability = await self._get_cached_capability(api_identifier) - - if capability is None and test_function: - # 缓存未命中,执行检测 - capability = await self._detect_capability(api_identifier, test_function) - - # 选择适配器 - if capability and capability.supports_function_calling: - logger.info(f"🎯 使用Function Calling适配器: {api_identifier}") - return self.adapters[AdapterType.FUNCTION_CALLING] - else: - logger.info(f"🎯 使用提示词注入适配器: {api_identifier}") - return self.adapters[AdapterType.PROMPT_INJECTION] - - async def _get_cached_capability( - self, - api_identifier: str - ) -> Optional[APICapability]: - """获取缓存的能力检测结果""" - - async with self._cache_lock: - if api_identifier not in self._capability_cache: - return None - - capability = self._capability_cache[api_identifier] - - # 检查是否过期 - if datetime.now() - capability.tested_at > self._cache_ttl: - logger.info(f"⏰ API能力缓存过期: {api_identifier}") - del self._capability_cache[api_identifier] - return None - - logger.debug(f"🎯 API能力缓存命中: {api_identifier}") - return capability - - async def _detect_capability( - self, - api_identifier: str, - test_function: callable - ) -> APICapability: - """ - 检测API能力 - - Args: - api_identifier: API标识符 - test_function: 测试函数,应该尝试使用Function Calling - - Returns: - 能力检测结果 - """ - - logger.info(f"🔍 开始检测API能力: {api_identifier}") - start_time = time.time() - - try: - # 调用测试函数 - result = await test_function() - - # 判断是否成功 - supports_fc = self._is_function_calling_response(result) - - duration_ms = (time.time() - start_time) * 1000 - - capability = APICapability( - supports_function_calling=supports_fc, - tested_at=datetime.now(), - test_duration_ms=duration_ms - ) - - # 缓存结果 - async with self._cache_lock: - self._capability_cache[api_identifier] = capability - - status = "✅ 支持" if supports_fc else "❌ 不支持" - logger.info( - f"{status} Function Calling: {api_identifier} " - f"(耗时: {duration_ms:.2f}ms)" - ) - - return capability - - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - - logger.warning( - f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, " - f"将使用提示词注入模式" - ) - - capability = APICapability( - supports_function_calling=False, - tested_at=datetime.now(), - test_duration_ms=duration_ms, - error_message=str(e) - ) - - # 缓存失败结果(避免重复测试) - async with self._cache_lock: - self._capability_cache[api_identifier] = capability - - return capability - - def _is_function_calling_response(self, response: Any) -> bool: - """ - 判断响应是否是Function Calling格式 - - Args: - response: API响应 - - Returns: - 是否支持Function Calling - """ - - try: - # 检查字典格式 - if isinstance(response, dict): - message = response.get("choices", [{}])[0].get("message", {}) - return "tool_calls" in message or "function_call" in message - - # 检查对象格式(OpenAI SDK) - if hasattr(response, "choices"): - message = response.choices[0].message - return hasattr(message, "tool_calls") or hasattr(message, "function_call") - - return False - - except Exception: - return False - - async def call_with_fallback( - self, - api_identifier: str, - tools: List[Dict[str, Any]], - user_message: str, - call_function: callable, - test_function: Optional[callable] = None - ) -> ToolCallResult: - """ - 带降级策略的工具调用 - - Args: - api_identifier: API标识符 - tools: MCP工具列表 - user_message: 用户消息 - call_function: 实际调用API的函数 - test_function: 可选的测试函数 - - Returns: - 工具调用结果 - """ - - # 获取适配器 - adapter = await self.get_adapter(api_identifier, test_function) - - # 首次尝试 - try: - if adapter.supports_native_tools(): - # Function Calling模式 - logger.info("🚀 尝试使用Function Calling模式") - result = await self._try_function_calling( - tools, user_message, call_function, adapter - ) - else: - # 提示词注入模式 - logger.info("🚀 使用提示词注入模式") - result = await self._try_prompt_injection( - tools, user_message, call_function, adapter - ) - - return result - - except Exception as e: - logger.error(f"❌ 工具调用失败: {e}") - - # 自动降级 - if self._enable_auto_fallback and adapter.supports_native_tools(): - logger.warning("⚠️ Function Calling失败,降级到提示词注入模式") - - # 更新缓存,标记为不支持 - async with self._cache_lock: - self._capability_cache[api_identifier] = APICapability( - supports_function_calling=False, - tested_at=datetime.now(), - test_duration_ms=0, - error_message=str(e) - ) - - # 使用提示词注入重试 - fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION] - return await self._try_prompt_injection( - tools, user_message, call_function, fallback_adapter - ) - - raise - - async def _try_function_calling( - self, - tools: List[Dict[str, Any]], - user_message: str, - call_function: callable, - adapter: FunctionCallingAdapter - ) -> ToolCallResult: - """尝试Function Calling模式""" - - # Function Calling不需要修改提示词 - response = await call_function( - message=user_message, - tools_param=tools, - tool_choice_param="auto" - ) - - return adapter.parse_tool_calls(response) - - async def _try_prompt_injection( - self, - tools: List[Dict[str, Any]], - user_message: str, - call_function: callable, - adapter: PromptInjectionAdapter - ) -> ToolCallResult: - """尝试提示词注入模式""" - - # 注入工具到提示词 - enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message) - - # 调用API(不传tools参数) - response = await call_function( - message=enhanced_prompt, - tools_param=None, - tool_choice_param=None - ) - - # 从文本响应中解析工具调用 - return adapter.parse_tool_calls(response) - - def clear_cache(self, api_identifier: Optional[str] = None): - """ - 清理能力缓存 - - Args: - api_identifier: 可选,只清理特定API的缓存 - """ - if api_identifier: - if api_identifier in self._capability_cache: - del self._capability_cache[api_identifier] - logger.info(f"🧹 已清理API能力缓存: {api_identifier}") - else: - self._capability_cache.clear() - logger.info("🧹 已清理所有API能力缓存") - - def get_cache_stats(self) -> Dict[str, Any]: - """获取缓存统计信息""" - return { - "total_cached": len(self._capability_cache), - "cache_ttl_hours": self._cache_ttl.total_seconds() / 3600, - "cached_apis": [ - { - "api_identifier": api_id, - "supports_fc": cap.supports_function_calling, - "tested_at": cap.tested_at.isoformat(), - "test_duration_ms": cap.test_duration_ms - } - for api_id, cap in self._capability_cache.items() - ] - } - - -# 全局单例 -universal_mcp_adapter = UniversalMCPAdapter() \ No newline at end of file diff --git a/backend/app/mcp/facade.py b/backend/app/mcp/facade.py new file mode 100644 index 0000000..162646e --- /dev/null +++ b/backend/app/mcp/facade.py @@ -0,0 +1,1171 @@ +"""MCP客户端统一门面 - 所有MCP操作的唯一入口 + +本模块提供统一的MCP(Model Context Protocol)客户端接口, +整合了连接管理、工具操作、格式转换、缓存和指标收集等功能。 + +使用示例: + from app.mcp import mcp_client, MCPPluginConfig + + # 注册插件 + await mcp_client.register(MCPPluginConfig( + user_id="user123", + plugin_name="exa-search", + url="http://localhost:8000/mcp" + )) + + # 获取工具列表 + tools = await mcp_client.get_tools("user123", "exa-search") + + # 调用工具 + result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."}) + + # 注册状态变更回调 + async def on_status_change(event): + print(f"插件 {event['plugin_name']} 状态: {event['old_status']} -> {event['new_status']}") + + mcp_client.register_status_callback(on_status_change) +""" + +from typing import Dict, Any, List, Optional, Callable, Awaitable +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from collections import defaultdict +from enum import Enum +import asyncio +import time +import json + +from mcp import ClientSession, types +from mcp.client.streamable_http import streamablehttp_client +from mcp.client.sse import sse_client +from anyio import ClosedResourceError + +from app.mcp.config import mcp_config +from app.logger import get_logger + +logger = get_logger(__name__) + + +# ==================== 数据结构 ==================== + +class PluginStatus(str, Enum): + """插件状态枚举""" + ACTIVE = "active" + INACTIVE = "inactive" + DEGRADED = "degraded" + ERROR = "error" + + +# 状态变更回调类型 +StatusCallback = Callable[[Dict[str, Any]], Awaitable[None]] + + +@dataclass +class MCPPluginConfig: + """MCP插件配置""" + user_id: str + plugin_name: str + url: str + plugin_type: str = "streamable_http" # streamable_http, sse, http + headers: Optional[Dict[str, str]] = None + env: Optional[Dict[str, str]] = None + timeout: float = 60.0 + + +@dataclass +class SessionInfo: + """会话信息""" + session: ClientSession + url: str + plugin_type: str = "streamable_http" + created_at: float = field(default_factory=time.time) + last_access: float = field(default_factory=time.time) + request_count: int = 0 + error_count: int = 0 + status: str = "active" # active, degraded, error + _context_stack: List = field(default_factory=list) + _expiry_warned: bool = False + + @property + def error_rate(self) -> float: + """计算错误率""" + if self.request_count == 0: + return 0.0 + return self.error_count / self.request_count + + +@dataclass +class ToolCacheEntry: + """工具缓存条目""" + tools: List[Dict[str, Any]] + expire_time: datetime + hit_count: int = 0 + + +@dataclass +class ToolMetrics: + """工具调用指标""" + total_calls: int = 0 + success_calls: int = 0 + failed_calls: int = 0 + total_duration_ms: float = 0.0 + last_call_time: Optional[datetime] = None + + @property + def avg_duration_ms(self) -> float: + """平均调用时间""" + return self.total_duration_ms / self.total_calls if self.total_calls > 0 else 0.0 + + @property + def success_rate(self) -> float: + """成功率""" + return self.success_calls / self.total_calls if self.total_calls > 0 else 0.0 + + def record_success(self, duration_ms: float): + """记录成功调用""" + self.total_calls += 1 + self.success_calls += 1 + self.total_duration_ms += duration_ms + self.last_call_time = datetime.now() + + def record_failure(self, duration_ms: float): + """记录失败调用""" + self.total_calls += 1 + self.failed_calls += 1 + self.total_duration_ms += duration_ms + self.last_call_time = datetime.now() + + +class MCPError(Exception): + """MCP操作异常""" + pass + + +# ==================== 统一门面 ==================== + +class MCPClientFacade: + """ + MCP客户端统一门面 + + 这是所有MCP操作的唯一入口,提供: + 1. 连接管理(注册、注销、测试) + 2. 工具操作(获取、调用、批量调用) + 3. 格式转换(MCP ↔ OpenAI Function Calling) + 4. 缓存和指标 + + 设计模式: + - 单例模式:全局唯一实例 + - 门面模式:统一对外接口 + + 线程安全: + - 使用asyncio.Lock保护会话操作 + - 使用用户级别的细粒度锁避免阻塞 + """ + + _instance: Optional['MCPClientFacade'] = None + + def __new__(cls): + """单例模式""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # 会话管理 + self._sessions: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + self._user_locks: Dict[str, asyncio.Lock] = {} + self._locks_lock = asyncio.Lock() + + # 工具缓存 + self._tool_cache: Dict[str, ToolCacheEntry] = {} + self._cache_ttl = timedelta(minutes=mcp_config.TOOL_CACHE_TTL_MINUTES) + + # 调用指标 + self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics) + + # 后台任务 + self._cleanup_task: Optional[asyncio.Task] = None + self._health_check_task: Optional[asyncio.Task] = None + self._tasks_started = False + + # 状态变更回调 + self._status_callbacks: List[StatusCallback] = [] + + self._initialized = True + logger.info("✅ MCPClientFacade 初始化完成") + + def _get_key(self, user_id: str, plugin_name: str) -> str: + """生成会话键""" + return f"{user_id}:{plugin_name}" + + async def _get_user_lock(self, user_id: str) -> asyncio.Lock: + """获取用户专属锁(细粒度锁)""" + async with self._locks_lock: + if user_id not in self._user_locks: + self._user_locks[user_id] = asyncio.Lock() + return self._user_locks[user_id] + + def _ensure_background_tasks(self): + """确保后台任务已启动(延迟初始化)""" + if not self._tasks_started: + try: + loop = asyncio.get_running_loop() + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info("✅ MCP后台清理任务已启动") + + if self._health_check_task is None: + self._health_check_task = asyncio.create_task(self._health_check_loop()) + logger.info("✅ MCP健康检查任务已启动") + + self._tasks_started = True + except RuntimeError: + # 没有运行中的事件循环,稍后再试 + pass + + async def _cleanup_loop(self): + """后台清理过期会话""" + while True: + try: + await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS) + await self._cleanup_expired_sessions() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"清理任务异常: {e}") + + async def _health_check_loop(self): + """后台健康检查""" + while True: + try: + await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS) + await self._check_session_health() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"健康检查任务异常: {e}") + + async def _cleanup_expired_sessions(self): + """清理过期的会话""" + now = time.time() + expired_keys = [] + + async with self._session_lock: + for key, session in list(self._sessions.items()): + if now - session.last_access > mcp_config.CLIENT_TTL_SECONDS: + expired_keys.append(key) + + if expired_keys: + logger.info(f"🧹 清理 {len(expired_keys)} 个过期的MCP会话") + for key in expired_keys: + user_id = key.split(':', 1)[0] + user_lock = await self._get_user_lock(user_id) + async with user_lock: + await self._close_session_unsafe(key) + + async def _check_session_health(self): + """检查会话健康状态""" + async with self._session_lock: + for key, session in list(self._sessions.items()): + # 检查错误率 + if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK: + old_status = session.status + user_id, plugin_name = key.split(':', 1) + + if session.error_rate > mcp_config.ERROR_RATE_CRITICAL: + if session.status != "error": + session.status = "error" + logger.error(f"❌ 会话 {key} 错误率过高 ({session.error_rate:.1%})") + await self._emit_status_change(user_id, plugin_name, old_status, "error", + f"错误率过高: {session.error_rate:.1%}") + elif session.error_rate > mcp_config.ERROR_RATE_WARNING: + if session.status == "active": + session.status = "degraded" + logger.warning(f"⚠️ 会话 {key} 健康状况下降 ({session.error_rate:.1%})") + await self._emit_status_change(user_id, plugin_name, old_status, "degraded", + f"错误率较高: {session.error_rate:.1%}") + elif session.status == "degraded": + session.status = "active" + logger.info(f"✅ 会话 {key} 恢复正常") + await self._emit_status_change(user_id, plugin_name, old_status, "active", "恢复正常") + + # ==================== 连接管理 ==================== + + async def register(self, config: MCPPluginConfig) -> bool: + """ + 注册MCP插件并建立连接 + + Args: + config: 插件配置 + + Returns: + 是否注册成功 + """ + self._ensure_background_tasks() + + key = self._get_key(config.user_id, config.plugin_name) + user_lock = await self._get_user_lock(config.user_id) + + async with user_lock: + # 如果已存在,先关闭 + if key in self._sessions: + await self._close_session_unsafe(key) + + try: + logger.info(f"🔗 连接MCP服务器: {config.plugin_name} -> {config.url} (类型: {config.plugin_type})") + + # 根据类型选择客户端 + if config.plugin_type == "sse": + # SSE 客户端 - 返回 2 个值 + stream_ctx = sse_client( + url=config.url, + headers=config.headers, + timeout=config.timeout + ) + read, write = await stream_ctx.__aenter__() + else: + # streamable_http 客户端(默认,也用于 http 类型)- 返回 3 个值 + stream_ctx = streamablehttp_client( + url=config.url, + headers=config.headers, + timeout=config.timeout + ) + read, write, _ = await stream_ctx.__aenter__() + + session = ClientSession(read, write) + await session.__aenter__() + await session.initialize() + + now = time.time() + info = SessionInfo( + session=session, + url=config.url, + plugin_type=config.plugin_type, + created_at=now, + last_access=now, + _context_stack=[('stream', stream_ctx), ('session', session)] + ) + + async with self._session_lock: + self._sessions[key] = info + + logger.info(f"✅ MCP会话建立成功: {key}") + await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "active", "连接成功") + return True + + except Exception as e: + logger.error(f"❌ MCP连接失败 {key}: {e}") + await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "error", str(e)) + return False + + async def unregister(self, user_id: str, plugin_name: str): + """ + 注销MCP插件 + + Args: + user_id: 用户ID + plugin_name: 插件名称 + """ + key = self._get_key(user_id, plugin_name) + user_lock = await self._get_user_lock(user_id) + + old_status = self._sessions.get(key, SessionInfo(session=None, url="")).status if key in self._sessions else "active" + + async with user_lock: + await self._close_session_unsafe(key) + self._invalidate_cache(key) + + await self._emit_status_change(user_id, plugin_name, old_status, "inactive", "已注销") + + async def _close_session_unsafe(self, key: str): + """关闭会话(不加用户锁,需要调用者确保线程安全)""" + async with self._session_lock: + info = self._sessions.pop(key, None) + + if info: + # 按LIFO顺序清理上下文 + for ctx_type, ctx in reversed(info._context_stack): + try: + await ctx.__aexit__(None, None, None) + except RuntimeError as e: + if "cancel scope" in str(e).lower() or "different task" in str(e).lower(): + logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}") + else: + logger.error(f"清理{ctx_type}上下文失败: {e}") + except Exception as e: + logger.debug(f"清理{ctx_type}上下文: {e}") + + logger.info(f"🗑️ 关闭MCP会话: {key}") + + async def _get_session(self, user_id: str, plugin_name: str) -> ClientSession: + """ + 获取会话 + + Args: + user_id: 用户ID + plugin_name: 插件名称 + + Returns: + ClientSession实例 + + Raises: + ValueError: 会话不存在 + """ + key = self._get_key(user_id, plugin_name) + + info = self._sessions.get(key) + if not info: + raise ValueError(f"MCP会话不存在: {plugin_name},请先调用register()") + + if info.status == "error": + logger.warning(f"⚠️ 会话 {key} 处于错误状态,可能需要重新注册") + + info.last_access = time.time() + info.request_count += 1 + return info.session + + async def ensure_registered( + self, + user_id: str, + plugin_name: str, + url: str, + plugin_type: str = "streamable_http", + headers: Optional[Dict[str, str]] = None + ) -> bool: + """ + 确保插件已注册(如果未注册则自动注册) + + Args: + user_id: 用户ID + plugin_name: 插件名称 + url: 服务器URL + plugin_type: 插件类型 (streamable_http, sse, http) + headers: HTTP头 + + Returns: + 是否成功 + """ + key = self._get_key(user_id, plugin_name) + + if key in self._sessions: + info = self._sessions[key] + # 检查URL和类型是否变化 + if info.url == url and info.plugin_type == plugin_type and info.status != "error": + return True + + # 注册 + return await self.register(MCPPluginConfig( + user_id=user_id, + plugin_name=plugin_name, + url=url, + plugin_type=plugin_type, + headers=headers + )) + + async def test_connection(self, user_id: str, plugin_name: str) -> Dict[str, Any]: + """ + 测试连接 + + Args: + user_id: 用户ID + plugin_name: 插件名称 + + Returns: + 测试结果字典 + """ + start = time.time() + + try: + session = await self._get_session(user_id, plugin_name) + result = await session.list_tools() + + tools = [ + {"name": t.name, "description": t.description or ""} + for t in result.tools + ] + + return { + "success": True, + "message": "连接成功", + "response_time_ms": round((time.time() - start) * 1000, 2), + "tools_count": len(tools), + "tools": tools + } + except Exception as e: + return { + "success": False, + "message": str(e), + "response_time_ms": round((time.time() - start) * 1000, 2), + "error_type": type(e).__name__ + } + + # ==================== 工具操作 ==================== + + async def get_tools( + self, + user_id: str, + plugin_name: str, + use_cache: bool = True + ) -> List[Dict[str, Any]]: + """ + 获取工具列表 + + Args: + user_id: 用户ID + plugin_name: 插件名称 + use_cache: 是否使用缓存 + + Returns: + 工具列表 [{"name": ..., "description": ..., "inputSchema": ...}] + """ + cache_key = self._get_key(user_id, plugin_name) + now = datetime.now() + + # 检查缓存 + if use_cache and cache_key in self._tool_cache: + entry = self._tool_cache[cache_key] + if now < entry.expire_time: + entry.hit_count += 1 + logger.debug(f"🎯 工具缓存命中: {cache_key} (命中次数: {entry.hit_count})") + return entry.tools + else: + del self._tool_cache[cache_key] + logger.debug(f"⏰ 工具缓存过期: {cache_key}") + + # 从服务器获取 + session = await self._get_session(user_id, plugin_name) + result = await session.list_tools() + + tools = [ + { + "name": t.name, + "description": t.description or "", + "inputSchema": t.inputSchema + } + for t in result.tools + ] + + # 更新缓存 + self._tool_cache[cache_key] = ToolCacheEntry( + tools=tools, + expire_time=now + self._cache_ttl + ) + + logger.info(f"获取到 {len(tools)} 个工具: {plugin_name}") + return tools + + async def call_tool( + self, + user_id: str, + plugin_name: str, + tool_name: str, + arguments: Dict[str, Any], + timeout: Optional[float] = None, + max_reconnect_attempts: int = 2 + ) -> Any: + """ + 调用单个工具 + + Args: + user_id: 用户ID + plugin_name: 插件名称 + tool_name: 工具名称 + arguments: 工具参数 + timeout: 超时时间(秒) + max_reconnect_attempts: 最大重连次数 + + Returns: + 工具执行结果 + """ + tool_key = f"{plugin_name}.{tool_name}" + start_time = time.time() + actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS + + for attempt in range(max_reconnect_attempts + 1): + try: + session = await self._get_session(user_id, plugin_name) + + logger.info(f"调用工具: {tool_key}") + logger.debug(f" 参数: {arguments}") + + # 带超时调用 + result = await asyncio.wait_for( + session.call_tool(tool_name, arguments), + timeout=actual_timeout + ) + + # 处理返回结果 + output = self._extract_tool_result(result) + + # 记录成功指标 + duration_ms = (time.time() - start_time) * 1000 + self._metrics[tool_key].record_success(duration_ms) + + logger.info(f"✅ 工具调用成功: {tool_key} ({duration_ms:.2f}ms)") + return output + + except asyncio.TimeoutError: + duration_ms = (time.time() - start_time) * 1000 + self._metrics[tool_key].record_failure(duration_ms) + raise MCPError(f"工具调用超时(>{actual_timeout}秒)") + + except ClosedResourceError as e: + # 连接已关闭,尝试重连 + if attempt < max_reconnect_attempts: + logger.warning(f"⚠️ MCP连接已关闭,尝试重连 (第{attempt + 1}/{max_reconnect_attempts}次)") + key = self._get_key(user_id, plugin_name) + + # 保存旧的会话信息用于重新注册 + old_info = None + async with self._session_lock: + if key in self._sessions: + old_info = self._sessions[key] + + # 关闭旧会话 + try: + await self._close_session_unsafe(key) + except Exception as close_err: + logger.debug(f"关闭旧会话时出错: {close_err}") + + # 使用旧的会话信息重新注册 + url = old_info.url if old_info else "" + plugin_type = old_info.plugin_type if old_info else "streamable_http" + + if url: + success = await self.ensure_registered( + user_id, plugin_name, url, plugin_type + ) + if success: + logger.info(f"✅ MCP会话重新建立成功: {key}") + await asyncio.sleep(0.5) + continue + + # 如果无法获取旧信息或重新注册失败,等待后重试 + await asyncio.sleep(0.5) + continue + else: + duration_ms = (time.time() - start_time) * 1000 + self._metrics[tool_key].record_failure(duration_ms) + raise MCPError(f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)") + + except ValueError as e: + # 会话不存在,尝试重新注册 + if "MCP会话不存在" in str(e) and attempt < max_reconnect_attempts: + logger.warning(f"⚠️ MCP会话不存在,尝试重新注册 (第{attempt + 1}/{max_reconnect_attempts}次)") + + # 尝试获取会话信息用于重新注册 + key = self._get_key(user_id, plugin_name) + old_info = None + async with self._session_lock: + if key in self._sessions: + old_info = self._sessions[key] + + url = old_info.url if old_info else "" + plugin_type = old_info.plugin_type if old_info else "streamable_http" + + if url: + success = await self.ensure_registered( + user_id, plugin_name, url, plugin_type + ) + if success: + logger.info(f"✅ MCP会话重新注册成功: {key}") + await asyncio.sleep(0.5) + continue + + await asyncio.sleep(0.5) + continue + else: + duration_ms = (time.time() - start_time) * 1000 + self._metrics[tool_key].record_failure(duration_ms) + raise MCPError(f"会话不存在: {e}") + + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + self._metrics[tool_key].record_failure(duration_ms) + + # 更新会话错误计数 + key = self._get_key(user_id, plugin_name) + if key in self._sessions: + session_info = self._sessions[key] + session_info.error_count += 1 + + # 检查是否需要更新状态 + if session_info.request_count >= mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK: + old_status = session_info.status + if session_info.error_rate > mcp_config.ERROR_RATE_CRITICAL and old_status != "error": + session_info.status = "error" + asyncio.create_task(self._emit_status_change( + user_id, plugin_name, old_status, "error", f"错误率过高: {session_info.error_rate:.1%}" + )) + elif session_info.error_rate > mcp_config.ERROR_RATE_WARNING and old_status == "active": + session_info.status = "degraded" + asyncio.create_task(self._emit_status_change( + user_id, plugin_name, old_status, "degraded", f"错误率较高: {session_info.error_rate:.1%}" + )) + + error_msg = str(e) + error_type = type(e).__name__ + + # 检查是否是 JSON 解析错误(MCP SDK 内部错误) + if "parsing JSON" in error_msg.lower() or "json" in error_msg.lower(): + logger.error(f"❌ 工具调用失败 (JSON解析错误): {tool_key}: {e}") + raise MCPError(f"MCP服务器响应格式错误,请检查服务器状态或稍后重试") + + logger.error(f"❌ 工具调用失败: {tool_key} [{error_type}]: {e}") + raise MCPError(f"工具调用失败: {error_msg}") + + raise MCPError("工具调用失败: 未知错误") + + def _extract_tool_result(self, result) -> Any: + """从MCP结果中提取实际内容""" + if result.content: + for content in result.content: + if isinstance(content, types.TextContent): + return content.text + elif isinstance(content, types.ImageContent): + return { + "type": "image", + "data": content.data, + "mimeType": content.mimeType + } + return result.content[0] if result.content else None + + if hasattr(result, 'structuredContent') and result.structuredContent: + return result.structuredContent + + return None + + async def batch_call_tools( + self, + user_id: str, + tool_calls: List[Dict[str, Any]], + max_concurrent: int = 2, + timeout: Optional[float] = None + ) -> List[Dict[str, Any]]: + """ + 批量执行AI返回的工具调用 + + Args: + user_id: 用户ID + tool_calls: AI返回的工具调用列表,格式: + [{"id": "...", "function": {"name": "plugin_tool", "arguments": "{...}"}}] + max_concurrent: 最大并发数 + timeout: 单个工具超时时间 + + Returns: + 工具调用结果列表 + """ + if not tool_calls: + return [] + + logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (最大并发={max_concurrent})") + + results = [] + + for i in range(0, len(tool_calls), max_concurrent): + batch = tool_calls[i:i+max_concurrent] + batch_num = i // max_concurrent + 1 + total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent + + logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}") + + tasks = [ + self._execute_single_tool_call(user_id, tc, timeout) + for tc in batch + ] + + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + for j, result in enumerate(batch_results): + tc = batch[j] + if isinstance(result, Exception): + results.append({ + "tool_call_id": tc.get("id", f"call_{i+j}"), + "role": "tool", + "name": tc["function"]["name"], + "content": f"工具调用失败: {str(result)}", + "success": False, + "error": str(result) + }) + else: + results.append(result) + + # 批次间延迟,避免API限流 + if i + max_concurrent < len(tool_calls): + await asyncio.sleep(0.3) + + return results + + async def _execute_single_tool_call( + self, + user_id: str, + tool_call: Dict[str, Any], + timeout: Optional[float] = None + ) -> Dict[str, Any]: + """执行单个工具调用""" + tool_call_id = tool_call.get("id", "unknown") + function_name = tool_call["function"]["name"] + + try: + # 解析插件名和工具名 + plugin_name, tool_name = self.parse_function_name(function_name) + + # 解析参数 + arguments = tool_call["function"]["arguments"] + if isinstance(arguments, str): + arguments = json.loads(arguments) + + # 调用工具 + result = await self.call_tool( + user_id=user_id, + plugin_name=plugin_name, + tool_name=tool_name, + arguments=arguments, + timeout=timeout + ) + + return { + "tool_call_id": tool_call_id, + "role": "tool", + "name": function_name, + "content": json.dumps(result, ensure_ascii=False) if result else "", + "success": True + } + + except json.JSONDecodeError as e: + return { + "tool_call_id": tool_call_id, + "role": "tool", + "name": function_name, + "content": f"参数JSON解析失败: {str(e)}", + "success": False, + "error": str(e) + } + except Exception as e: + return { + "tool_call_id": tool_call_id, + "role": "tool", + "name": function_name, + "content": f"工具调用失败: {str(e)}", + "success": False, + "error": str(e) + } + + # ==================== 格式转换 ==================== + + def format_tools_for_openai( + self, + tools: List[Dict[str, Any]], + plugin_name: str + ) -> List[Dict[str, Any]]: + """ + 将MCP工具转换为OpenAI Function Calling格式 + + Args: + tools: MCP工具列表 + plugin_name: 插件名称(作为前缀) + + Returns: + OpenAI格式的工具列表 + """ + return [ + { + "type": "function", + "function": { + "name": f"{plugin_name}_{tool['name']}", + "description": tool.get("description", ""), + "parameters": tool.get("inputSchema", { + "type": "object", + "properties": {}, + "required": [] + }) + } + } + for tool in tools + ] + + def parse_function_name(self, function_name: str) -> tuple: + """ + 解析函数名为插件名和工具名 + + 支持两种格式: + - "plugin_tool" (下划线分隔) + - "plugin.tool" (点号分隔) + + Args: + function_name: 工具名称 + + Returns: + (plugin_name, tool_name) + + Raises: + ValueError: 格式无效 + """ + # 优先尝试用下划线分割 + if "_" in function_name: + parts = function_name.split("_", 1) + if len(parts) == 2 and parts[0] and parts[1]: + return (parts[0], parts[1]) + + # 如果下划线分割失败,尝试用点号分割 + if "." in function_name: + parts = function_name.split(".", 1) + if len(parts) == 2 and parts[0] and parts[1]: + logger.debug(f"🔧 工具名使用点号分隔: {function_name} -> plugin={parts[0]}, tool={parts[1]}") + return (parts[0], parts[1]) + + raise ValueError(f"无效的工具名称格式: {function_name},应为 'plugin_tool' 或 'plugin.tool' 格式") + + def build_tool_context( + self, + tool_results: List[Dict[str, Any]], + format: str = "markdown" + ) -> str: + """ + 将工具结果格式化为上下文 + + Args: + tool_results: 工具调用结果列表 + format: 输出格式(markdown/json/plain) + + Returns: + 格式化的上下文字符串 + """ + if not tool_results: + return "" + + if format == "markdown": + return self._build_markdown_context(tool_results) + elif format == "json": + return json.dumps(tool_results, ensure_ascii=False, indent=2) + else: + return self._build_plain_context(tool_results) + + def _build_markdown_context(self, tool_results: List[Dict[str, Any]]) -> str: + """构建Markdown格式的工具上下文""" + lines = ["## 🔧 工具调用结果\n"] + + for i, result in enumerate(tool_results, 1): + tool_name = result.get("name", "unknown") + success = result.get("success", False) + content = result.get("content", "") + + status_emoji = "✅" if success else "❌" + lines.append(f"### {status_emoji} {i}. {tool_name}\n") + + if success: + # 尝试美化JSON内容 + try: + content_obj = json.loads(content) + content = json.dumps(content_obj, ensure_ascii=False, indent=2) + except: + pass + lines.append(f"```json\n{content}\n```\n") + else: + lines.append(f"**错误**: {content}\n") + + return "\n".join(lines) + + def _build_plain_context(self, tool_results: List[Dict[str, Any]]) -> str: + """构建纯文本格式的工具上下文""" + lines = ["=== 工具调用结果 ===\n"] + + for i, result in enumerate(tool_results, 1): + tool_name = result.get("name", "unknown") + success = result.get("success", False) + content = result.get("content", "") + + status = "成功" if success else "失败" + lines.append(f"{i}. {tool_name} - {status}") + lines.append(f" 结果: {content}\n") + + return "\n".join(lines) + + # ==================== 缓存和指标 ==================== + + def _invalidate_cache(self, key: str): + """使缓存失效""" + if key in self._tool_cache: + del self._tool_cache[key] + logger.debug(f"🧹 已清理缓存: {key}") + + def clear_cache( + self, + user_id: Optional[str] = None, + plugin_name: Optional[str] = None + ): + """ + 清理缓存 + + Args: + user_id: 用户ID(可选) + plugin_name: 插件名称(可选) + """ + if user_id and plugin_name: + key = self._get_key(user_id, plugin_name) + self._invalidate_cache(key) + logger.info(f"🧹 已清理缓存: {key}") + elif user_id: + keys = [k for k in self._tool_cache if k.startswith(f"{user_id}:")] + for k in keys: + del self._tool_cache[k] + logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys)}个)") + else: + count = len(self._tool_cache) + self._tool_cache.clear() + logger.info(f"🧹 已清理所有缓存 ({count}个)") + + def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Any]: + """ + 获取调用指标 + + Args: + tool_name: 工具名称(可选) + + Returns: + 指标字典 + """ + if tool_name and tool_name in self._metrics: + m = self._metrics[tool_name] + return { + tool_name: { + "total_calls": m.total_calls, + "success_calls": m.success_calls, + "failed_calls": m.failed_calls, + "success_rate": round(m.success_rate, 3), + "avg_duration_ms": round(m.avg_duration_ms, 2), + "last_call_time": m.last_call_time.isoformat() if m.last_call_time else None + } + } + + return { + k: { + "total_calls": m.total_calls, + "success_calls": m.success_calls, + "failed_calls": m.failed_calls, + "success_rate": round(m.success_rate, 3), + "avg_duration_ms": round(m.avg_duration_ms, 2), + "last_call_time": m.last_call_time.isoformat() if m.last_call_time else None + } + for k, m in self._metrics.items() + } + + def get_cache_stats(self) -> Dict[str, Any]: + """获取缓存统计""" + return { + "total_entries": len(self._tool_cache), + "total_hits": sum(e.hit_count for e in self._tool_cache.values()), + "cache_ttl_minutes": self._cache_ttl.total_seconds() / 60, + "entries": [ + { + "key": k, + "tools_count": len(e.tools), + "hit_count": e.hit_count, + "expire_time": e.expire_time.isoformat() + } + for k, e in self._tool_cache.items() + ] + } + + def get_session_stats(self) -> Dict[str, Any]: + """获取会话统计""" + return { + "total_sessions": len(self._sessions), + "sessions": [ + { + "key": k, + "url": s.url, + "status": s.status, + "request_count": s.request_count, + "error_count": s.error_count, + "error_rate": round(s.error_rate, 3), + "created_at": datetime.fromtimestamp(s.created_at).isoformat(), + "last_access": datetime.fromtimestamp(s.last_access).isoformat() + } + for k, s in self._sessions.items() + ] + } + + # ==================== 状态回调 ==================== + + def register_status_callback(self, callback: StatusCallback): + """注册状态变更回调""" + if callback not in self._status_callbacks: + self._status_callbacks.append(callback) + logger.info(f"✅ 已注册状态变更回调: {callback.__name__ if hasattr(callback, '__name__') else 'anonymous'}") + + def unregister_status_callback(self, callback: StatusCallback): + """注销状态变更回调""" + if callback in self._status_callbacks: + self._status_callbacks.remove(callback) + + async def _emit_status_change( + self, + user_id: str, + plugin_name: str, + old_status: str, + new_status: str, + reason: str = "" + ): + """触发状态变更事件""" + if old_status == new_status: + return + + event = { + "user_id": user_id, + "plugin_name": plugin_name, + "old_status": old_status, + "new_status": new_status, + "reason": reason, + "timestamp": datetime.now().isoformat() + } + + logger.info(f"📢 状态变更: {plugin_name} [{old_status} -> {new_status}] {reason}") + + for callback in self._status_callbacks: + try: + await callback(event) + except Exception as e: + logger.error(f"状态回调执行失败: {e}") + + # ==================== 生命周期 ==================== + + async def cleanup(self): + """清理所有资源""" + # 停止后台任务 + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + if self._health_check_task: + self._health_check_task.cancel() + try: + await self._health_check_task + except asyncio.CancelledError: + pass + + # 关闭所有会话 + async with self._session_lock: + keys = list(self._sessions.keys()) + + for key in keys: + await self._close_session_unsafe(key) + + # 清理缓存 + self._tool_cache.clear() + + self._tasks_started = False + logger.info("✅ MCPClientFacade 资源已清理") + + +# ==================== 全局单例 ==================== + +mcp_client = MCPClientFacade() \ No newline at end of file diff --git a/backend/app/mcp/http_client.py b/backend/app/mcp/http_client.py deleted file mode 100644 index 139ff90..0000000 --- a/backend/app/mcp/http_client.py +++ /dev/null @@ -1,385 +0,0 @@ -"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现""" -import asyncio -from typing import Dict, Any, List, Optional -from contextlib import asynccontextmanager - -from mcp import ClientSession, types -from mcp.client.streamable_http import streamablehttp_client -from pydantic import AnyUrl -from anyio import ClosedResourceError - -from app.logger import get_logger - -logger = get_logger(__name__) - - -class MCPError(Exception): - """MCP错误""" - pass - - -class HTTPMCPClient: - """HTTP模式MCP客户端(基于官方 MCP Python SDK)""" - - def __init__( - self, - url: str, - headers: Optional[Dict[str, str]] = None, - env: Optional[Dict[str, str]] = None, - timeout: float = 60.0 - ): - """ - 初始化HTTP MCP客户端 - - Args: - url: MCP服务器URL - headers: HTTP请求头 - env: 环境变量(用于API Key等) - timeout: 超时时间(秒) - """ - self.url = url.rstrip('/') - self.headers = headers or {} - self.env = env or {} - self.timeout = timeout - - # 如果env中有API Key,添加到headers - if 'API_KEY' in self.env: - self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}' - - self._session: Optional[ClientSession] = None - self._context_stack = [] # 保存上下文管理器栈 - self._initialized = False - self._lock = asyncio.Lock() - - async def _ensure_connected(self): - """确保连接已建立""" - async with self._lock: - if self._session is None: - try: - logger.info(f"🔗 连接到MCP服务器: {self.url}") - - # 使用官方 SDK 的 streamable_http_client - # 保存上下文管理器以便后续正确清理 - stream_context = streamablehttp_client(self.url) - read_stream, write_stream, _ = await stream_context.__aenter__() - self._context_stack.append(('stream', stream_context)) - - # 创建客户端会话 - self._session = ClientSession(read_stream, write_stream) - session_context = self._session - await session_context.__aenter__() - self._context_stack.append(('session', session_context)) - - # 初始化会话 - await self._session.initialize() - self._initialized = True - - logger.info(f"✅ MCP会话初始化成功") - - except Exception as e: - logger.error(f"❌ MCP连接失败: {e}") - await self._cleanup() - raise MCPError(f"连接MCP服务器失败: {str(e)}") - - async def _cleanup(self): - """清理连接资源(按照进入的相反顺序退出)""" - # 按照LIFO顺序清理上下文 - while self._context_stack: - ctx_type, ctx = self._context_stack.pop() - try: - await ctx.__aexit__(None, None, None) - except RuntimeError as e: - # 忽略 anyio 的任务上下文错误(在关闭时可能发生) - if "cancel scope" in str(e).lower() or "different task" in str(e).lower(): - logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}") - else: - logger.error(f"清理{ctx_type}上下文失败: {e}") - except Exception as e: - logger.error(f"清理{ctx_type}上下文失败: {e}") - - self._session = None - self._initialized = False - - async def initialize(self) -> Dict[str, Any]: - """ - 初始化MCP会话 - - Returns: - 初始化响应 - """ - await self._ensure_connected() - return {"status": "initialized"} - - async def list_tools(self) -> List[Dict[str, Any]]: - """ - 列举可用工具 - - Returns: - 工具列表 - """ - try: - await self._ensure_connected() - - result = await self._session.list_tools() - - # 转换为字典格式 - tools = [] - for tool in result.tools: - tool_dict = { - "name": tool.name, - "description": tool.description or "", - "inputSchema": tool.inputSchema - } - tools.append(tool_dict) - - logger.info(f"获取到 {len(tools)} 个工具") - return tools - - except Exception as e: - logger.error(f"获取工具列表失败: {e}") - raise MCPError(f"获取工具列表失败: {str(e)}") - - async def call_tool( - self, - tool_name: str, - arguments: Dict[str, Any], - max_reconnect_attempts: int = 2 - ) -> Any: - """ - 调用工具(带自动重连) - - Args: - tool_name: 工具名称 - arguments: 工具参数 - max_reconnect_attempts: 最大重连尝试次数 - - Returns: - 工具执行结果 - """ - for attempt in range(max_reconnect_attempts + 1): - try: - await self._ensure_connected() - - logger.info(f"调用工具: {tool_name}") - logger.debug(f" 参数类型: {type(arguments)}") - logger.debug(f" 参数内容: {arguments}") - logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}") - - result = await self._session.call_tool(tool_name, arguments) - - logger.debug(f" 工具返回类型: {type(result)}") - logger.debug(f" 返回内容: {result}") - - # 处理返回结果 - # MCP SDK 返回 CallToolResult 对象 - if result.content: - logger.debug(f" 返回content数量: {len(result.content)}") - # 提取第一个content的文本 - for idx, content in enumerate(result.content): - logger.debug(f" content[{idx}]类型: {type(content)}") - if isinstance(content, types.TextContent): - logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}") - return content.text - elif isinstance(content, types.ImageContent): - logger.debug(f" ✅ 返回ImageContent") - return { - "type": "image", - "data": content.data, - "mimeType": content.mimeType - } - # 如果没有文本内容,返回原始内容 - logger.debug(f" ⚠️ 返回原始content[0]") - return result.content[0] if result.content else None - - # 如果有结构化内容(2025-06-18规范) - if hasattr(result, 'structuredContent') and result.structuredContent: - logger.debug(f" ✅ 返回structuredContent") - return result.structuredContent - - logger.warning(f" ⚠️ 工具返回为None") - return None - - except ClosedResourceError as e: - # 连接已关闭,尝试重连 - if attempt < max_reconnect_attempts: - logger.warning( - f"⚠️ MCP连接已关闭,尝试重新连接 " - f"(第{attempt + 1}/{max_reconnect_attempts}次重连)" - ) - await self._cleanup() - await asyncio.sleep(0.5) # 短暂延迟后重连 - continue - else: - logger.error(f"❌ MCP连接重连失败,已达最大重试次数") - error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)" - raise MCPError(error_msg) - - except Exception as e: - logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True) - logger.error(f" 参数: {arguments}") - logger.error(f" 错误类型: {type(e).__name__}") - logger.error(f" 错误详情: {repr(e)}") - logger.error(f" 错误字符串: '{str(e)}'") - error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})" - raise MCPError(f"调用工具失败: {error_msg}") - - # 理论上不会到这里 - raise MCPError(f"工具调用失败: 未知错误") - - async def list_resources(self) -> List[Dict[str, Any]]: - """ - 列举可用资源 - - Returns: - 资源列表 - """ - try: - await self._ensure_connected() - - result = await self._session.list_resources() - - # 转换为字典格式 - resources = [] - for resource in result.resources: - resource_dict = { - "uri": str(resource.uri), - "name": resource.name, - "description": resource.description or "", - "mimeType": resource.mimeType or "" - } - resources.append(resource_dict) - - logger.info(f"获取到 {len(resources)} 个资源") - return resources - - except Exception as e: - logger.error(f"获取资源列表失败: {e}") - raise MCPError(f"获取资源列表失败: {str(e)}") - - async def read_resource(self, uri: str) -> Any: - """ - 读取资源 - - Args: - uri: 资源URI - - Returns: - 资源内容 - """ - try: - await self._ensure_connected() - - result = await self._session.read_resource(AnyUrl(uri)) - - # 提取资源内容 - if result.contents: - content = result.contents[0] - if isinstance(content, types.TextContent): - return content.text - elif isinstance(content, types.ImageContent): - return { - "type": "image", - "data": content.data, - "mimeType": content.mimeType - } - elif isinstance(content, types.BlobResourceContents): - return { - "type": "blob", - "blob": content.blob, - "mimeType": content.mimeType - } - - return None - - except Exception as e: - logger.error(f"读取资源失败: {uri}, 错误: {e}") - raise MCPError(f"读取资源失败: {str(e)}") - - async def test_connection(self) -> Dict[str, Any]: - """ - 测试连接 - - Returns: - 测试结果 - """ - import time - start_time = time.time() - - try: - # 尝试连接并列举工具(直接调用SDK,避免重复日志) - await self._ensure_connected() - - result = await self._session.list_tools() - - # 转换为字典格式 - tools = [] - for tool in result.tools: - tool_dict = { - "name": tool.name, - "description": tool.description or "", - "inputSchema": tool.inputSchema - } - tools.append(tool_dict) - - end_time = time.time() - response_time = round((end_time - start_time) * 1000, 2) - - logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具") - - return { - "success": True, - "message": "连接测试成功", - "response_time_ms": response_time, - "tools_count": len(tools), - "tools": tools - } - - except Exception as e: - end_time = time.time() - response_time = round((end_time - start_time) * 1000, 2) - - return { - "success": False, - "message": "连接测试失败", - "response_time_ms": response_time, - "error": str(e), - "error_type": type(e).__name__, - "suggestions": [ - "请检查服务器URL是否正确", - "请确认API Key是否有效", - "请检查网络连接", - "请确认MCP服务器是否在线" - ] - } - - async def close(self): - """关闭客户端连接""" - logger.info(f"关闭MCP客户端: {self.url}") - await self._cleanup() - - -@asynccontextmanager -async def create_mcp_client( - url: str, - headers: Optional[Dict[str, str]] = None, - env: Optional[Dict[str, str]] = None, - timeout: float = 60.0 -): - """ - 创建MCP客户端的上下文管理器 - - Args: - url: MCP服务器URL - headers: HTTP请求头 - env: 环境变量 - timeout: 超时时间 - - Yields: - HTTPMCPClient实例 - """ - client = HTTPMCPClient(url, headers, env, timeout) - try: - await client.initialize() - yield client - finally: - await client.close() \ No newline at end of file diff --git a/backend/app/mcp/registry.py b/backend/app/mcp/registry.py deleted file mode 100644 index e6ac981..0000000 --- a/backend/app/mcp/registry.py +++ /dev/null @@ -1,527 +0,0 @@ -"""MCP插件注册表 - 管理运行时插件实例""" -import asyncio -import time -from typing import Dict, Optional, Any, List -from dataclasses import dataclass -from datetime import datetime -from app.mcp.http_client import HTTPMCPClient, MCPError -from app.mcp.config import mcp_config -from app.models.mcp_plugin import MCPPlugin -from app.logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class SessionInfo: - """会话信息""" - client: HTTPMCPClient - created_at: float - last_access: float - request_count: int = 0 - error_count: int = 0 - status: str = "active" # active, degraded, error - - -class MCPPluginRegistry: - """MCP插件注册表 - 管理运行时插件实例(优化版)""" - - def __init__( - self, - max_clients: Optional[int] = None, - client_ttl: Optional[int] = None - ): - """ - 初始化注册表 - - Args: - max_clients: 最大缓存客户端数量(默认使用配置) - client_ttl: 客户端过期时间(秒,默认使用配置) - """ - # 存储格式: {plugin_id: SessionInfo} - self._sessions: Dict[str, SessionInfo] = {} - - # 全局锁用于保护会话字典 - self._sessions_lock = asyncio.Lock() - - # 细粒度锁:每个用户一个锁 - self._user_locks: Dict[str, asyncio.Lock] = {} - self._locks_lock = asyncio.Lock() # 保护locks字典本身 - - # 配置参数(使用配置常量) - self._max_clients = max_clients or mcp_config.MAX_CLIENTS - self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS - - # 启动后台清理任务 - self._cleanup_task = None - self._health_check_task = None - self._tasks_started = False - - def _ensure_background_tasks(self): - """确保后台任务已启动(延迟初始化)""" - if not self._tasks_started: - try: - # 检查是否有运行中的事件循环 - loop = asyncio.get_running_loop() - if self._cleanup_task is None: - self._cleanup_task = asyncio.create_task(self._cleanup_loop()) - logger.info("✅ MCP插件注册表后台清理任务已启动") - - if self._health_check_task is None: - self._health_check_task = asyncio.create_task(self._health_check_loop()) - logger.info("✅ MCP会话健康检查任务已启动") - - self._tasks_started = True - except RuntimeError: - # 没有运行中的事件循环,稍后再试 - pass - - async def _cleanup_loop(self): - """后台清理过期客户端""" - while True: - try: - await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS) - await self._cleanup_expired_sessions() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"清理任务异常: {e}") - - async def _health_check_loop(self): - """后台健康检查""" - while True: - try: - await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS) - await self._check_session_health() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"健康检查任务异常: {e}") - - async def _cleanup_expired_sessions(self): - """清理过期的会话""" - now = time.time() - expired_ids = [] - - async with self._sessions_lock: - # 收集过期的plugin_id - for plugin_id, session in list(self._sessions.items()): - if now - session.last_access > self._client_ttl: - expired_ids.append(plugin_id) - - if expired_ids: - logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话") - for plugin_id in expired_ids: - # 提取user_id来获取对应的锁 - user_id = plugin_id.split(':', 1)[0] - user_lock = await self._get_user_lock(user_id) - - async with user_lock: - async with self._sessions_lock: - if plugin_id in self._sessions: - await self._unload_plugin_unsafe(plugin_id) - - async def _check_session_health(self): - """增强的会话健康检查""" - async with self._sessions_lock: - for plugin_id, session in list(self._sessions.items()): - # 计算错误率 - if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK: - error_rate = session.error_count / session.request_count - - # 动态调整状态(使用配置常量) - if error_rate > mcp_config.ERROR_RATE_CRITICAL: - if session.status != "error": - session.status = "error" - logger.error( - f"❌ 会话 {plugin_id} 错误率过高 " - f"({error_rate:.1%}), 标记为error" - ) - elif error_rate > mcp_config.ERROR_RATE_WARNING: - if session.status == "active": - session.status = "degraded" - logger.warning( - f"⚠️ 会话 {plugin_id} 健康状况下降 " - f"(错误率: {error_rate:.1%})" - ) - elif session.status == "degraded": - # 错误率降低,恢复正常 - session.status = "active" - logger.info(f"✅ 会话 {plugin_id} 恢复正常") - - # 检查即将过期的会话(最后1分钟提醒) - idle_time = time.time() - session.last_access - time_until_expiry = self._client_ttl - idle_time - - # 仅在最后1分钟(60秒)内提醒一次 - if 0 < time_until_expiry <= 60: - # 使用会话属性避免重复提醒 - if not hasattr(session, '_expiry_warned') or not session._expiry_warned: - logger.warning( - f"⏰ 会话 {plugin_id} 即将过期 " - f"(剩余 {time_until_expiry:.0f} 秒)" - ) - session._expiry_warned = True - elif time_until_expiry > 60: - # 重置警告标志(如果会话被重新使用) - if hasattr(session, '_expiry_warned'): - session._expiry_warned = False - - async def _get_user_lock(self, user_id: str) -> asyncio.Lock: - """ - 获取用户专属的锁(细粒度锁) - - Args: - user_id: 用户ID - - Returns: - 该用户的锁对象 - """ - async with self._locks_lock: - if user_id not in self._user_locks: - self._user_locks[user_id] = asyncio.Lock() - return self._user_locks[user_id] - - def _touch_session(self, plugin_id: str): - """ - 更新会话的最后访问时间(需要在锁内调用) - - Args: - plugin_id: 插件ID - """ - if plugin_id in self._sessions: - session = self._sessions[plugin_id] - session.last_access = time.time() - session.request_count += 1 - - async def _evict_lru_session(self): - """驱逐最久未使用的会话(当达到max_clients限制时)""" - if len(self._sessions) >= self._max_clients: - # 找到最旧的会话 - oldest_id = None - oldest_time = float('inf') - - for plugin_id, session in self._sessions.items(): - if session.last_access < oldest_time: - oldest_time = session.last_access - oldest_id = plugin_id - - if oldest_id: - logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}") - await self._unload_plugin_unsafe(oldest_id) - - async def load_plugin(self, plugin: MCPPlugin) -> bool: - """ - 从配置加载插件 - - Args: - plugin: 插件配置 - - Returns: - 是否加载成功 - """ - # 确保后台任务已启动 - self._ensure_background_tasks() - - # 使用细粒度锁(只锁定当前用户) - user_lock = await self._get_user_lock(plugin.user_id) - async with user_lock: - try: - plugin_id = f"{plugin.user_id}:{plugin.plugin_name}" - - # 如果已加载,先卸载 - async with self._sessions_lock: - if plugin_id in self._sessions: - await self._unload_plugin_unsafe(plugin_id) - - # 检查是否需要驱逐LRU会话 - await self._evict_lru_session() - - # 目前只支持HTTP类型 - if plugin.plugin_type == "http": - if not plugin.server_url: - logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}") - return False - - # 为每个插件创建独立的HTTP客户端 - client = HTTPMCPClient( - url=plugin.server_url, - headers=plugin.headers or {}, - env=plugin.env or {}, - timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0 - ) - - # 创建会话信息 - now = time.time() - session = SessionInfo( - client=client, - created_at=now, - last_access=now, - request_count=0, - error_count=0, - status="active" - ) - - # 存储会话 - async with self._sessions_lock: - self._sessions[plugin_id] = session - - logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)") - return True - else: - logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}") - return False - - except Exception as e: - logger.error(f"加载插件失败 {plugin.plugin_name}: {e}") - return False - - async def unload_plugin(self, user_id: str, plugin_name: str): - """ - 卸载插件 - - Args: - user_id: 用户ID - plugin_name: 插件名称 - """ - # 使用细粒度锁(只锁定当前用户) - user_lock = await self._get_user_lock(user_id) - async with user_lock: - plugin_id = f"{user_id}:{plugin_name}" - async with self._sessions_lock: - await self._unload_plugin_unsafe(plugin_id) - - async def _unload_plugin_unsafe(self, plugin_id: str): - """卸载插件(不加锁,内部使用,需要在sessions_lock内调用)""" - if plugin_id in self._sessions: - session = self._sessions[plugin_id] - try: - await session.client.close() - except Exception as e: - logger.error(f"关闭插件客户端失败 {plugin_id}: {e}") - - del self._sessions[plugin_id] - logger.info(f"卸载MCP插件: {plugin_id}") - - async def reload_plugin(self, plugin: MCPPlugin) -> bool: - """ - 重新加载插件 - - Args: - plugin: 插件配置 - - Returns: - 是否重载成功 - """ - await self.unload_plugin(plugin.user_id, plugin.plugin_name) - return await self.load_plugin(plugin) - - def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]: - """ - 获取插件客户端(线程安全,支持访问时间更新) - - Args: - user_id: 用户ID - plugin_name: 插件名称 - - Returns: - 客户端实例或None - """ - plugin_id = f"{user_id}:{plugin_name}" - - session = self._sessions.get(plugin_id) - if session: - # 检查会话状态 - if session.status == "error": - logger.warning( - f"⚠️ 会话 {plugin_id} 处于错误状态," - f"建议调用者重新加载插件" - ) - # 不返回错误状态的客户端 - return None - - # ✅ 使用锁保护状态更新,避免并发问题 - # 注意:这里使用原子操作更新简单字段,不需要异步锁 - session.last_access = time.time() - session.request_count += 1 - return session.client - return None - - async def get_or_reconnect_client( - self, - user_id: str, - plugin_name: str, - plugin: MCPPlugin - ) -> HTTPMCPClient: - """ - 获取或重连客户端(自动处理错误状态) - - Args: - user_id: 用户ID - plugin_name: 插件名称 - plugin: 插件配置对象 - - Returns: - 客户端实例 - - Raises: - ValueError: 插件加载失败 - """ - plugin_id = f"{user_id}:{plugin_name}" - - # 获取用户锁 - user_lock = await self._get_user_lock(user_id) - async with user_lock: - session = self._sessions.get(plugin_id) - - # 检查会话健康状态 - if session and session.status == "error": - logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连") - async with self._sessions_lock: - await self._unload_plugin_unsafe(plugin_id) - session = None - - # 如果没有会话,加载插件 - if not session: - success = await self.load_plugin(plugin) - if not success: - raise ValueError(f"插件加载失败: {plugin_name}") - session = self._sessions[plugin_id] - - return session.client - - async def call_tool( - self, - user_id: str, - plugin_name: str, - tool_name: str, - arguments: Dict[str, Any] - ) -> Any: - """ - 调用插件工具(带错误计数和状态管理) - - Args: - user_id: 用户ID - plugin_name: 插件名称 - tool_name: 工具名称 - arguments: 工具参数 - - Returns: - 工具执行结果 - - Raises: - ValueError: 插件不存在或未启用 - MCPError: 工具调用失败 - """ - plugin_id = f"{user_id}:{plugin_name}" - - # 获取会话 - session = self._sessions.get(plugin_id) - if not session: - raise ValueError(f"插件未加载: {plugin_name}") - - try: - result = await session.client.call_tool(tool_name, arguments) - logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}") - - # 调用成功,重置状态(如果之前是degraded) - if session.status == "degraded": - session.status = "active" - logger.info(f"✅ 会话 {plugin_id} 恢复正常") - - return result - except Exception as e: - # 增加错误计数 - session.error_count += 1 - - # 根据错误率更新状态 - if session.request_count > 0: - error_rate = session.error_count / session.request_count - if error_rate > 0.5: - session.status = "error" - elif error_rate > 0.3: - session.status = "degraded" - - logger.error( - f"❌ 工具调用失败: {plugin_name}.{tool_name}, " - f"错误: {e} (错误计数: {session.error_count}/{session.request_count})" - ) - raise - - async def get_plugin_tools( - self, - user_id: str, - plugin_name: str - ) -> List[Dict[str, Any]]: - """ - 获取插件的工具列表 - - Args: - user_id: 用户ID - plugin_name: 插件名称 - - Returns: - 工具列表 - """ - client = self.get_client(user_id, plugin_name) - - if not client: - raise ValueError(f"插件未加载: {plugin_name}") - - try: - tools = await client.list_tools() - return tools - except Exception as e: - logger.error(f"获取工具列表失败: {plugin_name}, 错误: {e}") - raise - - async def test_plugin( - self, - user_id: str, - plugin_name: str - ) -> Dict[str, Any]: - """ - 测试插件连接 - - Args: - user_id: 用户ID - plugin_name: 插件名称 - - Returns: - 测试结果 - """ - client = self.get_client(user_id, plugin_name) - - if not client: - raise ValueError(f"插件未加载: {plugin_name}") - - return await client.test_connection() - - async def cleanup_all(self): - """清理所有插件和资源""" - # 停止后台任务 - if self._cleanup_task: - self._cleanup_task.cancel() - try: - await self._cleanup_task - except asyncio.CancelledError: - pass - - if self._health_check_task: - self._health_check_task.cancel() - try: - await self._health_check_task - except asyncio.CancelledError: - pass - - # 清理所有会话 - async with self._sessions_lock: - plugin_ids = list(self._sessions.keys()) - for plugin_id in plugin_ids: - await self._unload_plugin_unsafe(plugin_id) - - logger.info("✅ 已清理所有MCP插件和资源") - - -# 全局注册表实例 -mcp_registry = MCPPluginRegistry() \ No newline at end of file diff --git a/backend/app/mcp/status_sync.py b/backend/app/mcp/status_sync.py new file mode 100644 index 0000000..949d52f --- /dev/null +++ b/backend/app/mcp/status_sync.py @@ -0,0 +1,50 @@ +"""MCP插件状态同步服务 + +将内存中的会话状态变更同步到数据库,确保状态一致性。 +""" + +from typing import Dict, Any +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.models.mcp_plugin import MCPPlugin +from app.logger import get_logger + +logger = get_logger(__name__) + + +async def sync_status_to_db(event: Dict[str, Any]): + """ + 状态变更回调 - 同步到数据库 + """ + user_id = event["user_id"] + plugin_name = event["plugin_name"] + new_status = event["new_status"] + reason = event.get("reason", "") + + try: + from app.database import get_engine + + engine = await get_engine(user_id) + AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with AsyncSessionLocal() as db: + stmt = ( + update(MCPPlugin) + .where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name) + .values(status=new_status, last_error=reason if new_status == "error" else None) + ) + await db.execute(stmt) + await db.commit() + + logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}") + + except Exception as e: + logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}") + + +def register_status_sync(): + """注册状态同步回调到MCP客户端""" + from app.mcp import mcp_client + mcp_client.register_status_callback(sync_status_to_db) + logger.info("✅ MCP状态同步服务已注册") \ No newline at end of file diff --git a/backend/app/schemas/career.py b/backend/app/schemas/career.py index 3603c6c..e95aeae 100644 --- a/backend/app/schemas/career.py +++ b/backend/app/schemas/career.py @@ -1,5 +1,5 @@ """职业相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Dict, Any from datetime import datetime @@ -63,8 +63,7 @@ class CareerResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class CareerListResponse(BaseModel): diff --git a/backend/app/schemas/chapter.py b/backend/app/schemas/chapter.py index fe72893..4819bf4 100644 --- a/backend/app/schemas/chapter.py +++ b/backend/app/schemas/chapter.py @@ -1,5 +1,5 @@ """章节相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Dict, Any from datetime import datetime @@ -58,8 +58,7 @@ class ChapterResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ChapterListResponse(BaseModel): @@ -142,8 +141,7 @@ class ExpansionPlanUpdate(BaseModel): estimated_words: Optional[int] = Field(None, description="预估字数", ge=500, le=10000) scenes: Optional[List[SceneData]] = Field(None, description="场景列表") - class Config: - json_schema_extra = { + model_config = ConfigDict(json_schema_extra={ "example": { "key_events": ["主角遇到挑战", "关键决策时刻"], "character_focus": ["张三", "李四"], @@ -159,7 +157,7 @@ class ExpansionPlanUpdate(BaseModel): } ] } - } + }) class ExpansionPlanResponse(BaseModel): diff --git a/backend/app/schemas/character.py b/backend/app/schemas/character.py index cb5df43..0e80a5c 100644 --- a/backend/app/schemas/character.py +++ b/backend/app/schemas/character.py @@ -1,5 +1,5 @@ """角色相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Dict, Any from datetime import datetime @@ -98,8 +98,7 @@ class CharacterResponse(CharacterBase): main_career_stage: Optional[int] = Field(None, description="主职业阶段") sub_careers: Optional[List[Dict[str, Any]]] = Field(None, description="副职业列表") - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class CharacterGenerateRequest(BaseModel): diff --git a/backend/app/schemas/mcp_plugin.py b/backend/app/schemas/mcp_plugin.py index 4c22bdb..6bdacae 100644 --- a/backend/app/schemas/mcp_plugin.py +++ b/backend/app/schemas/mcp_plugin.py @@ -1,5 +1,5 @@ """MCP插件Pydantic模式""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, Dict, Any, List from datetime import datetime @@ -82,8 +82,7 @@ class MCPPluginResponse(BaseModel): # 时间戳 created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class MCPToolCall(BaseModel): diff --git a/backend/app/schemas/outline.py b/backend/app/schemas/outline.py index 45cbb5c..a31cbbe 100644 --- a/backend/app/schemas/outline.py +++ b/backend/app/schemas/outline.py @@ -1,5 +1,5 @@ """大纲相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Dict, Any from datetime import datetime @@ -103,8 +103,7 @@ class OutlineResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class OutlineGenerateRequest(BaseModel): diff --git a/backend/app/schemas/project.py b/backend/app/schemas/project.py index 78806e5..bd3a1c9 100644 --- a/backend/app/schemas/project.py +++ b/backend/app/schemas/project.py @@ -1,5 +1,5 @@ """项目相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, Literal from datetime import datetime @@ -59,8 +59,7 @@ class ProjectResponse(ProjectBase): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ProjectListResponse(BaseModel): diff --git a/backend/app/schemas/relationship.py b/backend/app/schemas/relationship.py index da8f5d3..bbff643 100644 --- a/backend/app/schemas/relationship.py +++ b/backend/app/schemas/relationship.py @@ -1,5 +1,5 @@ """关系管理相关的Pydantic模型""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List from datetime import datetime @@ -17,8 +17,7 @@ class RelationshipTypeResponse(BaseModel): description: Optional[str] = None created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ============ 角色关系相关 ============ @@ -62,8 +61,7 @@ class CharacterRelationshipResponse(CharacterRelationshipBase): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class RelationshipGraphNode(BaseModel): @@ -127,8 +125,7 @@ class OrganizationResponse(OrganizationBase): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class OrganizationDetailResponse(BaseModel): @@ -185,8 +182,7 @@ class OrganizationMemberResponse(OrganizationMemberBase): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class OrganizationMemberDetailResponse(BaseModel): diff --git a/backend/app/schemas/writing_style.py b/backend/app/schemas/writing_style.py index 1311375..7e0caa1 100644 --- a/backend/app/schemas/writing_style.py +++ b/backend/app/schemas/writing_style.py @@ -1,5 +1,5 @@ """写作风格 Schema""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Optional from datetime import datetime @@ -48,8 +48,7 @@ class WritingStyleResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class WritingStyleListResponse(BaseModel): diff --git a/backend/app/services/ai_clients/anthropic_client.py b/backend/app/services/ai_clients/anthropic_client.py index 396f234..7bde5ff 100644 --- a/backend/app/services/ai_clients/anthropic_client.py +++ b/backend/app/services/ai_clients/anthropic_client.py @@ -71,7 +71,18 @@ class AnthropicClient: temperature: float, max_tokens: int, system_prompt: Optional[str] = None, - ) -> AsyncGenerator[str, None]: + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + 流式生成,支持工具调用 + + Yields: + Dict with keys: + - content: str - 文本内容块 + - tool_calls: list - 工具调用列表(如果有) + - done: bool - 是否结束 + """ kwargs = { "model": model, "max_tokens": max_tokens, @@ -80,12 +91,42 @@ class AnthropicClient: } if system_prompt: kwargs["system"] = system_prompt + if tools: + kwargs["tools"] = tools + if tool_choice == "required": + kwargs["tool_choice"] = {"type": "any"} + elif tool_choice == "auto": + kwargs["tool_choice"] = {"type": "auto"} try: async with self.client.messages.stream(**kwargs) as stream: try: - async for text in stream.text_stream: - yield text + tool_calls = [] + async for chunk in stream: + # 处理不同类型的块 + if chunk.type == "text_delta": + yield {"content": chunk.text} + elif chunk.type == "tool_use_delta": + # 工具调用增量 + if not tool_calls or tool_calls[-1].get("id") != chunk.id: + tool_calls.append({ + "id": chunk.id, + "type": "function", + "function": { + "name": chunk.name, + "arguments": "" + } + }) + # 追加参数 + if tool_calls[-1]["function"]["arguments"] is None: + tool_calls[-1]["function"]["arguments"] = "" + tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or "" + elif chunk.type == "message_delta": + if chunk.stop_reason: + # 流结束 + if tool_calls: + yield {"tool_calls": tool_calls} + yield {"done": True, "finish_reason": chunk.stop_reason} except GeneratorExit: # 生成器被关闭,这是正常的清理过程 logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)") diff --git a/backend/app/services/ai_clients/gemini_client.py b/backend/app/services/ai_clients/gemini_client.py index 1354767..8497652 100644 --- a/backend/app/services/ai_clients/gemini_client.py +++ b/backend/app/services/ai_clients/gemini_client.py @@ -111,7 +111,18 @@ class GeminiClient: temperature: float, max_tokens: int, system_prompt: Optional[str] = None, - ) -> AsyncGenerator[str, None]: + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + 流式生成,支持工具调用 + + Yields: + Dict with keys: + - content: str - 文本内容块 + - tool_calls: list - 工具调用列表(如果有) + - done: bool - 是否结束 + """ url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse" contents = [] @@ -125,6 +136,8 @@ class GeminiClient: } if system_prompt: payload["systemInstruction"] = {"parts": [{"text": system_prompt}]} + if tools: + payload["tools"] = self._convert_tools_to_gemini(tools) try: async with self.client.stream("POST", url, json=payload) as response: @@ -139,9 +152,26 @@ class GeminiClient: if candidates and len(candidates) > 0: parts = candidates[0].get("content", {}).get("parts", []) if parts and len(parts) > 0: - text = parts[0].get("text", "") + text = "" + function_calls = [] + for part in parts: + if "text" in part: + text += part["text"] + elif "functionCall" in part: + fc = part["functionCall"] + function_calls.append({ + "id": f"call_{fc['name']}", + "type": "function", + "function": { + "name": fc["name"], + "arguments": fc.get("args", {}) + } + }) + if text: - yield text + yield {"content": text} + if function_calls: + yield {"tool_calls": function_calls} except json.JSONDecodeError: continue except GeneratorExit: diff --git a/backend/app/services/ai_clients/openai_client.py b/backend/app/services/ai_clients/openai_client.py index b1be00d..6b6c690 100644 --- a/backend/app/services/ai_clients/openai_client.py +++ b/backend/app/services/ai_clients/openai_client.py @@ -86,8 +86,21 @@ class OpenAIClient(BaseAIClient): model: str, temperature: float, max_tokens: int, - ) -> AsyncGenerator[str, None]: - payload = self._build_payload(messages, model, temperature, max_tokens, stream=True) + tools: Optional[list] = None, + tool_choice: Optional[str] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + 流式生成,支持工具调用 + + Yields: + Dict with keys: + - content: str - 文本内容块 + - tool_calls: list - 工具调用列表(如果有) + - done: bool - 是否结束 + """ + payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True) + + tool_calls_buffer = {} # 收集工具调用块 try: async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response: @@ -97,14 +110,38 @@ class OpenAIClient(BaseAIClient): if line.startswith("data: "): data_str = line[6:] if data_str.strip() == "[DONE]": + # 流结束,检查是否有工具调用需要处理 + if tool_calls_buffer: + yield {"tool_calls": list(tool_calls_buffer.values()), "done": True} + yield {"done": True} break try: data = json.loads(data_str) choices = data.get("choices", []) if choices and len(choices) > 0: - content = choices[0].get("delta", {}).get("content", "") + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + + # 检查工具调用 + tc_list = delta.get("tool_calls") + if tc_list: + for tc in tc_list: + index = tc.get("index", 0) + if index not in tool_calls_buffer: + tool_calls_buffer[index] = tc + else: + existing = tool_calls_buffer[index] + # 合并 function.arguments + if "function" in tc and "function" in existing: + if tc["function"].get("arguments"): + existing["function"]["arguments"] = ( + existing["function"].get("arguments", "") + + tc["function"]["arguments"] + ) + if content: - yield content + yield {"content": content} + except json.JSONDecodeError: continue except GeneratorExit: diff --git a/backend/app/services/ai_providers/anthropic_provider.py b/backend/app/services/ai_providers/anthropic_provider.py index bff9773..ce11207 100644 --- a/backend/app/services/ai_providers/anthropic_provider.py +++ b/backend/app/services/ai_providers/anthropic_provider.py @@ -1,9 +1,12 @@ """Anthropic Provider""" from typing import Any, AsyncGenerator, Dict, List, Optional +from app.logger import get_logger from app.services.ai_clients.anthropic_client import AnthropicClient from .base_provider import BaseAIProvider +logger = get_logger(__name__) + class AnthropicProvider(BaseAIProvider): """Anthropic 提供商""" @@ -39,7 +42,62 @@ class AnthropicProvider(BaseAIProvider): temperature: float, max_tokens: int, system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + user_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: + # 如果有工具,使用真正的流式工具调用 + if tools: + logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理") + messages = [{"role": "user", "content": prompt}] + actual_tool_choice = tool_choice if tool_choice else "auto" + + tool_calls_buffer = [] + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice=actual_tool_choice, + ): + # 检查是否有工具调用 + if chunk.get("tool_calls"): + tool_calls_buffer.extend(chunk["tool_calls"]) + logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个") + + # 检查是否结束 + if chunk.get("done"): + if tool_calls_buffer: + logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用") + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=tool_calls_buffer + ) + # 将工具结果注入到上下文中 + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + # 构建最终提示词,要求AI基于工具结果回答 + final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。" + final_messages = [{"role": "user", "content": final_prompt}] + + # 递归调用生成最终结果 + async for final_chunk in self._generate_with_tools( + final_messages, model, temperature, max_tokens, system_prompt, tools, user_id + ): + yield final_chunk + break + + # 输出文本内容 + if chunk.get("content"): + yield chunk["content"] + return + + # 无工具时普通流式生成 messages = [{"role": "user", "content": prompt}] async for chunk in self.client.chat_completion_stream( messages=messages, @@ -48,4 +106,56 @@ class AnthropicProvider(BaseAIProvider): max_tokens=max_tokens, system_prompt=system_prompt, ): - yield chunk \ No newline at end of file + # 确保只 yield 字符串内容,避免 yield 字典导致类型错误 + if isinstance(chunk, dict): + if chunk.get("content"): + yield chunk["content"] + else: + yield chunk + + async def _generate_with_tools( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: list = None, + user_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """辅助方法:带工具的流式生成""" + tool_calls_buffer = [] + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice="auto", + ): + if chunk.get("tool_calls"): + tool_calls_buffer.extend(chunk["tool_calls"]) + logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个") + + if chunk.get("done"): + if tool_calls_buffer: + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=tool_calls_buffer + ) + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"}) + + async for final_chunk in self._generate_with_tools( + messages, model, temperature, max_tokens, system_prompt, tools, user_id + ): + yield final_chunk + break + + if chunk.get("content"): + yield chunk["content"] \ No newline at end of file diff --git a/backend/app/services/ai_providers/base_provider.py b/backend/app/services/ai_providers/base_provider.py index e9c1934..3b883e1 100644 --- a/backend/app/services/ai_providers/base_provider.py +++ b/backend/app/services/ai_providers/base_provider.py @@ -28,6 +28,9 @@ class BaseAIProvider(ABC): temperature: float, max_tokens: int, system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + user_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """流式生成""" pass \ No newline at end of file diff --git a/backend/app/services/ai_providers/gemini_provider.py b/backend/app/services/ai_providers/gemini_provider.py index 5b16cd9..6efc04d 100644 --- a/backend/app/services/ai_providers/gemini_provider.py +++ b/backend/app/services/ai_providers/gemini_provider.py @@ -1,8 +1,12 @@ """Gemini Provider""" from typing import Any, AsyncGenerator, Dict, List, Optional + +from app.logger import get_logger from app.services.ai_clients.gemini_client import GeminiClient from .base_provider import BaseAIProvider +logger = get_logger(__name__) + class GeminiProvider(BaseAIProvider): def __init__(self, client: GeminiClient): @@ -36,7 +40,62 @@ class GeminiProvider(BaseAIProvider): temperature: float, max_tokens: int, system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + user_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: + # 如果有工具,使用真正的流式工具调用 + if tools: + logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理") + messages = [{"role": "user", "content": prompt}] + actual_tool_choice = tool_choice if tool_choice else "auto" + + tool_calls_buffer = [] + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice=actual_tool_choice, + ): + # 检查是否有工具调用 + if chunk.get("tool_calls"): + tool_calls_buffer.extend(chunk["tool_calls"]) + logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个") + + # 检查是否结束 + if chunk.get("done"): + if tool_calls_buffer: + logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用") + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=tool_calls_buffer + ) + # 将工具结果注入到上下文中 + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + # 构建最终提示词,要求AI基于工具结果回答 + final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。" + final_messages = [{"role": "user", "content": final_prompt}] + + # 递归调用生成最终结果 + async for final_chunk in self._generate_with_tools( + final_messages, model, temperature, max_tokens, system_prompt, tools, user_id + ): + yield final_chunk + break + + # 输出文本内容 + if chunk.get("content"): + yield chunk["content"] + return + + # 无工具时普通流式生成 messages = [{"role": "user", "content": prompt}] async for chunk in self.client.chat_completion_stream( messages=messages, @@ -45,4 +104,56 @@ class GeminiProvider(BaseAIProvider): max_tokens=max_tokens, system_prompt=system_prompt, ): - yield chunk \ No newline at end of file + # 确保只 yield 字符串内容,避免 yield 字典导致类型错误 + if isinstance(chunk, dict): + if chunk.get("content"): + yield chunk["content"] + else: + yield chunk + + async def _generate_with_tools( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + system_prompt: Optional[str] = None, + tools: list = None, + user_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """辅助方法:带工具的流式生成""" + tool_calls_buffer = [] + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tools=tools, + tool_choice="auto", + ): + if chunk.get("tool_calls"): + tool_calls_buffer.extend(chunk["tool_calls"]) + logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])} 个") + + if chunk.get("done"): + if tool_calls_buffer: + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=tool_calls_buffer + ) + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"}) + + async for final_chunk in self._generate_with_tools( + messages, model, temperature, max_tokens, system_prompt, tools, user_id + ): + yield final_chunk + break + + if chunk.get("content"): + yield chunk["content"] \ No newline at end of file diff --git a/backend/app/services/ai_providers/openai_provider.py b/backend/app/services/ai_providers/openai_provider.py index c9db53b..ddf5d90 100644 --- a/backend/app/services/ai_providers/openai_provider.py +++ b/backend/app/services/ai_providers/openai_provider.py @@ -1,9 +1,12 @@ """OpenAI Provider""" from typing import Any, AsyncGenerator, Dict, List, Optional +from app.logger import get_logger from app.services.ai_clients.openai_client import OpenAIClient from .base_provider import BaseAIProvider +logger = get_logger(__name__) + class OpenAIProvider(BaseAIProvider): """OpenAI 提供商""" @@ -42,16 +45,117 @@ class OpenAIProvider(BaseAIProvider): temperature: float, max_tokens: int, system_prompt: Optional[str] = None, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + user_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) + # 如果有工具,使用真正的流式工具调用 + if tools: + logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理") + actual_tool_choice = tool_choice if tool_choice else "auto" + + tool_calls_buffer = [] + + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + tool_choice=actual_tool_choice, + ): + # 检查是否有工具调用 + if chunk.get("tool_calls"): + tool_calls_buffer.extend(chunk["tool_calls"]) + logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])} 个") + + # 检查是否结束 + if chunk.get("done"): + if tool_calls_buffer: + logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用") + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=tool_calls_buffer + ) + # 将工具结果注入到上下文中 + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + # 构建最终提示词,要求AI基于工具结果回答 + final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。" + final_messages = messages.copy() + final_messages.append({"role": "user", "content": final_prompt}) + + # 递归调用生成最终结果 + async for final_chunk in self._generate_with_tools( + final_messages, model, temperature, max_tokens, tools, user_id + ): + yield final_chunk + break + + # 输出文本内容 + if chunk.get("content"): + yield chunk["content"] + return + + # 无工具时普通流式生成 async for chunk in self.client.chat_completion_stream( messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, ): - yield chunk \ No newline at end of file + # 确保只 yield 字符串内容,避免 yield 字典导致类型错误 + if isinstance(chunk, dict): + if chunk.get("content"): + yield chunk["content"] + else: + yield chunk + + async def _generate_with_tools( + self, + messages: list, + model: str, + temperature: float, + max_tokens: int, + tools: list, + user_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """辅助方法:带工具的流式生成(无tool_choice,AI自由决定)""" + async for chunk in self.client.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + ): + if chunk.get("tool_calls"): + from app.mcp import mcp_client + actual_user_id = user_id or "" + tool_results = await mcp_client.batch_call_tools( + user_id=actual_user_id, + tool_calls=chunk["tool_calls"] + ) + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + # 再次调用获取最终回答 + messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"}) + + async for final_chunk in self._generate_with_tools( + messages, model, temperature, max_tokens, tools, user_id + ): + yield final_chunk + break + + if chunk.get("done"): + break + + if chunk.get("content"): + yield chunk["content"] \ No newline at end of file diff --git a/backend/app/services/ai_service.py b/backend/app/services/ai_service.py index 063c073..bd31b86 100644 --- a/backend/app/services/ai_service.py +++ b/backend/app/services/ai_service.py @@ -1,4 +1,10 @@ -"""AI服务封装 - 统一的AI接口""" +"""AI服务封装 - 统一的AI接口 + +重构后支持自动MCP工具加载: +- 所有AI方法在请求前自动检查用户MCP配置 +- 如果有启用的MCP插件且有可用工具,自动发送tools +- 通过 auto_mcp 参数控制是否启用自动工具加载 +""" from typing import Optional, AsyncGenerator, List, Dict, Any, Union from app.config import settings as app_settings @@ -13,7 +19,6 @@ from app.services.ai_providers.anthropic_provider import AnthropicProvider from app.services.ai_providers.gemini_provider import GeminiProvider from app.services.ai_providers.base_provider import BaseAIProvider from app.services.json_helper import clean_json_response, parse_json -from app.mcp.adapters.universal import universal_mcp_adapter # 导出清理函数 cleanup_http_clients = cleanup_all_clients @@ -22,7 +27,41 @@ logger = get_logger(__name__) class AIService: - """AI服务统一接口""" + """ + AI服务统一接口 + + MCP工具支持: + - 在创建服务时传入 user_id 和 db_session + - 根据用户MCP插件的enabled状态自动决定是否启用MCP + - 如果有任意一个MCP插件启用,则加载并使用工具 + - 如果所有插件都关闭,则不使用任何MCP工具 + - 通过 auto_mcp=False 可临时禁用自动工具加载 + - 通过 mcp_max_rounds 控制工具调用轮数 + - 通过 clear_mcp_cache() 可清理MCP工具缓存 + + MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service): + - 查询用户的所有MCP插件 + - 如果有启用的插件 (enabled=True),则 enable_mcp=True + - 如果所有插件都关闭或没有插件,则 enable_mcp=False + + 使用示例: + # 创建支持MCP的AI服务(根据插件状态自动决定是否启用) + ai_service = create_user_ai_service_with_mcp( + api_provider="openai", + api_key="...", + user_id="user123", + db_session=db + ) + + # 自动加载MCP工具(如果有启用的插件) + result = await ai_service.generate_text(prompt="...") + + # 临时禁用MCP工具 + result = await ai_service.generate_text(prompt="...", auto_mcp=False) + + # 自定义轮数 + result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3) + """ def __init__( self, @@ -33,8 +72,11 @@ class AIService: default_temperature: Optional[float] = None, default_max_tokens: Optional[int] = None, default_system_prompt: Optional[str] = None, - enable_mcp_adapter: bool = True, config: Optional[AIClientConfig] = None, + # MCP支持参数 + user_id: Optional[str] = None, + db_session: Optional[Any] = None, + enable_mcp: bool = True, ): self.api_provider = api_provider or app_settings.default_ai_provider self.default_model = default_model or app_settings.default_model @@ -43,7 +85,12 @@ class AIService: self.default_system_prompt = default_system_prompt self.config = config or default_config - self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None + # MCP配置 + self.user_id = user_id + self.db_session = db_session + self._enable_mcp = enable_mcp + self._cached_tools: Optional[List[Dict]] = None + self._tools_loaded = False self._openai_provider: Optional[OpenAIProvider] = None self._anthropic_provider: Optional[AnthropicProvider] = None @@ -68,6 +115,36 @@ class AIService: client = GeminiClient(api_key, api_base_url, self.config) self._gemini_provider = GeminiProvider(client) + @property + def enable_mcp(self) -> bool: + """是否启用MCP工具""" + return self._enable_mcp + + @enable_mcp.setter + def enable_mcp(self, value: bool): + """设置MCP启用状态,如果禁用则清理缓存""" + if value is False and self._enable_mcp is True: + # 从启用变为禁用,清理缓存 + self.clear_mcp_cache() + self._enable_mcp = value + + def clear_mcp_cache(self): + """ + 清理MCP工具缓存 + + 当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。 + 同时更新 _tools_loaded 状态,使下次调用时重新检查。 + """ + if self._cached_tools is not None: + logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具") + self._cached_tools = None + else: + logger.debug(f"🔧 MCP工具缓存已经是空,无需清理") + + # 更新加载状态,确保下次调用会重新检查 + self._tools_loaded = False + logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False") + def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider: """获取对应的 Provider""" p = provider or self.api_provider @@ -79,6 +156,166 @@ class AIService: return self._gemini_provider raise ValueError(f"Provider {p} 未初始化") + async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]: + """ + 预处理MCP工具 + + 检查用户MCP配置并加载可用工具。 + 结果会被缓存,避免重复加载。 + + Args: + auto_mcp: 是否自动加载MCP工具(来自调用方参数) + force_refresh: 是否强制刷新缓存 + + Returns: + - None: 无可用工具(未配置/未启用/加载失败) + - List[Dict]: OpenAI格式的工具列表 + """ + # 前置条件检查 + if not self._enable_mcp: + logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)") + # 即使有缓存也清理掉,确保不使用 + self._cached_tools = None + self._tools_loaded = False + return None + + if not auto_mcp: + logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载") + # 即使有缓存也清理掉,确保不使用 + self._cached_tools = None + self._tools_loaded = False + return None + + if not self.user_id: + logger.debug(f"🔧 MCP工具加载跳过: user_id未设置") + return None + + if not self.db_session: + logger.debug(f"🔧 MCP工具加载跳过: db_session未设置") + return None + + # 使用缓存(只有 enable_mcp=True 时才使用缓存) + if self._tools_loaded and not force_refresh: + if self._cached_tools: + logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)") + return self._cached_tools + + try: + from app.services.mcp_tools_loader import mcp_tools_loader + + self._cached_tools = await mcp_tools_loader.get_user_tools( + user_id=self.user_id, + db_session=self.db_session, + use_cache=True, + force_refresh=force_refresh + ) + self._tools_loaded = True + + if self._cached_tools: + logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具") + else: + logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具") + + return self._cached_tools + + except Exception as e: + logger.warning(f"⚠️ 加载MCP工具失败: {e}") + self._tools_loaded = True + self._cached_tools = None + return None + + async def _handle_tool_calls( + self, + original_prompt: str, + response: Dict[str, Any], + max_rounds: int = 2, + **kwargs + ) -> Dict[str, Any]: + """ + 处理AI返回的工具调用 + + Args: + original_prompt: 原始提示词 + response: AI响应(包含tool_calls) + max_rounds: 最大工具调用轮数 + **kwargs: 传递给generate_text的其他参数 + + Returns: + 最终的AI响应 + """ + from app.mcp import mcp_client + + tool_calls = response.get("tool_calls", []) + if not tool_calls or not self.user_id: + return response + + result = { + "content": response.get("content", ""), + "tool_calls_made": 0, + "tools_used": [], + "finish_reason": response.get("finish_reason", ""), + "mcp_enhanced": True + } + + prompt = original_prompt + + for round_num in range(max_rounds): + logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具") + + try: + # 批量执行工具调用 + tool_results = await mcp_client.batch_call_tools( + user_id=self.user_id, + tool_calls=tool_calls + ) + + # 记录使用的工具 + for tc in tool_calls: + name = tc["function"]["name"] + if name not in result["tools_used"]: + result["tools_used"].append(name) + result["tool_calls_made"] += len(tool_calls) + + # 构建工具上下文 + tool_context = mcp_client.build_tool_context(tool_results, format="markdown") + + # 更新提示词 + if round_num == max_rounds - 1: + # 最后一轮,强制要求回答 + prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。" + tool_choice = "none" + else: + prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。" + tool_choice = kwargs.get("tool_choice", "auto") + + # 继续调用AI + prov = self._get_provider(kwargs.get("provider")) + next_response = await prov.generate( + prompt=prompt, + model=kwargs.get("model") or self.default_model, + temperature=kwargs.get("temperature") or self.default_temperature, + max_tokens=kwargs.get("max_tokens") or self.default_max_tokens, + system_prompt=kwargs.get("system_prompt") or self.default_system_prompt, + tools=None if tool_choice == "none" else self._cached_tools, + tool_choice=tool_choice, + ) + + tool_calls = next_response.get("tool_calls", []) + + if not tool_calls: + # 没有更多工具调用,返回结果 + result["content"] = next_response.get("content", "") + result["finish_reason"] = next_response.get("finish_reason", "stop") + break + + except Exception as e: + logger.error(f"❌ 工具调用失败: {e}") + result["content"] = response.get("content", "") + result["finish_reason"] = "tool_error" + break + + return result + async def generate_text( self, prompt: str, @@ -89,10 +326,39 @@ class AIService: system_prompt: Optional[str] = None, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, + auto_mcp: bool = True, + handle_tool_calls: bool = True, + mcp_max_rounds: Optional[int] = None, ) -> Dict[str, Any]: - """生成文本""" + """ + 生成文本(自动支持MCP工具) + + Args: + prompt: 用户提示词 + provider: AI提供商 + model: 模型名称 + temperature: 温度 + max_tokens: 最大令牌数 + system_prompt: 系统提示词 + tools: 手动指定的工具列表(优先级高于自动加载) + tool_choice: 工具选择策略 + auto_mcp: 是否自动加载MCP工具(默认True) + handle_tool_calls: 是否自动处理工具调用(默认True) + mcp_max_rounds: 最大工具调用轮数(None使用默认值3) + + Returns: + 包含生成内容的字典 + """ + # 使用全局配置的MCP轮数(如果未指定) + if mcp_max_rounds is None: + mcp_max_rounds = app_settings.mcp_max_rounds + + # 自动加载MCP工具 + if auto_mcp and tools is None: + tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp) + prov = self._get_provider(provider) - return await prov.generate( + response = await prov.generate( prompt=prompt, model=model or self.default_model, temperature=temperature or self.default_temperature, @@ -101,6 +367,22 @@ class AIService: tools=tools, tool_choice=tool_choice, ) + + # 处理工具调用 + if handle_tool_calls and response.get("tool_calls"): + return await self._handle_tool_calls( + original_prompt=prompt, + response=response, + provider=provider, + model=model, + temperature=temperature, + max_tokens=max_tokens, + system_prompt=system_prompt, + tool_choice=tool_choice, + max_rounds=mcp_max_rounds, + ) + + return response async def generate_text_stream( self, @@ -110,15 +392,51 @@ class AIService: temperature: Optional[float] = None, max_tokens: Optional[int] = None, system_prompt: Optional[str] = None, + tool_choice: Optional[str] = None, + auto_mcp: bool = True, + mcp_max_rounds: Optional[int] = None, ) -> AsyncGenerator[str, None]: - """流式生成""" + """ + 流式生成文本(自动支持MCP工具) + + 工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。 + + Args: + prompt: 用户提示词 + provider: AI提供商 + model: 模型名称 + temperature: 温度 + max_tokens: 最大令牌数 + system_prompt: 系统提示词 + tool_choice: 工具选择策略("auto"/"none"/"required") + auto_mcp: 是否自动加载MCP工具 + mcp_max_rounds: 最大工具调用轮数(None使用默认值3) + + Yields: + 生成的文本块 + """ + logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}") + + tools_to_use = None + + # 加载MCP工具 + if auto_mcp: + tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp) + if tools_to_use: + logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具") + + # 流式生成(Provider 层处理工具调用) prov = self._get_provider(provider) + logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}") async for chunk in prov.generate_stream( prompt=prompt, model=model or self.default_model, temperature=temperature or self.default_temperature, max_tokens=max_tokens or self.default_max_tokens, system_prompt=system_prompt or self.default_system_prompt, + tools=tools_to_use, + tool_choice=tool_choice, + user_id=self.user_id, ): yield chunk @@ -132,8 +450,25 @@ class AIService: provider: Optional[str] = None, model: Optional[str] = None, expected_type: Optional[str] = None, + auto_mcp: bool = True, ) -> Union[Dict, List]: - """带重试的 JSON 调用""" + """ + 带重试的 JSON 调用(自动支持MCP工具) + + Args: + prompt: 用户提示词 + system_prompt: 系统提示词 + max_retries: 最大重试次数 + temperature: 温度 + max_tokens: 最大令牌数 + provider: AI提供商 + model: 模型名称 + expected_type: 期望的返回类型("object"或"array") + auto_mcp: 是否自动加载MCP工具 + + Returns: + 解析后的JSON数据 + """ last_response = "" for attempt in range(1, max_retries + 1): @@ -146,6 +481,8 @@ class AIService: temperature=temperature, max_tokens=max_tokens, system_prompt=system_prompt, + auto_mcp=auto_mcp, + handle_tool_calls=True, ) last_response = result.get("content", "") @@ -172,108 +509,6 @@ class AIService: """清洗 JSON 响应""" return clean_json_response(text) - async def generate_text_with_mcp( - self, - prompt: str, - user_id: str, - db_session, - enable_mcp: bool = True, - max_tool_rounds: int = 3, - tool_choice: str = "auto", - **kwargs - ) -> Dict[str, Any]: - """支持MCP工具的AI文本生成""" - from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError - - result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False} - tools = None - - if enable_mcp: - try: - tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session) - if tools: - result["mcp_enhanced"] = True - except MCPToolServiceError: - tools = None - - original_prompt = prompt # 保存原始提示词 - - for round_num in range(max_tool_rounds): - logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}轮") - logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}") - - ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs) - logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}") - - tool_calls = ai_response.get("tool_calls", []) - - if not tool_calls: - content = ai_response.get("content", "") - result["content"] = content - result["finish_reason"] = ai_response.get("finish_reason", "stop") - logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}") - - # 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案 - if not content.strip() and result["tool_calls_made"] > 0: - logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)") - prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。" - tools = None - tool_choice = "none" # 强制不使用工具 - continue - - break - - logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用") - for idx, tc in enumerate(tool_calls): - logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}") - - try: - logger.debug(f" 开始执行工具调用...") - tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session) - logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}") - - # 🔍 检查工具结果 - for idx, tr in enumerate(tool_results): - success = tr.get("success", False) - content_preview = tr.get("content", "")[:200] if tr.get("content") else "None" - logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}") - - for tc in tool_calls: - name = tc["function"]["name"] - if name not in result["tools_used"]: - result["tools_used"].append(name) - result["tool_calls_made"] += len(tool_calls) - - tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown") - logger.debug(f" 工具上下文长度: {len(tool_context)}") - logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}") - - # 🔧 改进:在最后一轮时,明确要求AI给出完整答案 - if round_num == max_tool_rounds - 1: - logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案") - prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。" - tool_choice = "none" - else: - prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。" - logger.debug(f" 新prompt长度: {len(prompt)}") - - tools = None # 工具调用后禁用工具列表,避免重复调用 - logger.debug(f" ✅ 工具调用成功,准备下一轮") - - except Exception as tool_error: - logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True) - logger.error(f" 错误类型: {type(tool_error).__name__}") - logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}") - result["content"] = ai_response.get("content", "") - result["finish_reason"] = "tool_error" - break - - return result - - -# 全局实例 -ai_service = AIService() - def create_user_ai_service( api_provider: str, @@ -284,7 +519,7 @@ def create_user_ai_service( max_tokens: int, system_prompt: Optional[str] = None, ) -> AIService: - """创建用户 AI 服务""" + """创建用户 AI 服务(不带MCP支持)""" return AIService( api_provider=api_provider, api_key=api_key, @@ -293,4 +528,48 @@ def create_user_ai_service( default_temperature=temperature, default_max_tokens=max_tokens, default_system_prompt=system_prompt, + ) + + +def create_user_ai_service_with_mcp( + api_provider: str, + api_key: str, + api_base_url: str, + model_name: str, + temperature: float, + max_tokens: int, + user_id: str, + db_session, + system_prompt: Optional[str] = None, + enable_mcp: bool = True, +) -> AIService: + """ + 创建支持MCP的用户AI服务 + + Args: + api_provider: AI提供商 + api_key: API密钥 + api_base_url: API基础URL + model_name: 模型名称 + temperature: 温度 + max_tokens: 最大令牌数 + user_id: 用户ID(用于加载MCP工具) + db_session: 数据库会话 + system_prompt: 系统提示词 + enable_mcp: 是否启用MCP工具 + + Returns: + 配置好的AIService实例 + """ + return AIService( + api_provider=api_provider, + api_key=api_key, + api_base_url=api_base_url, + default_model=model_name, + default_temperature=temperature, + default_max_tokens=max_tokens, + default_system_prompt=system_prompt, + user_id=user_id, + db_session=db_session, + enable_mcp=enable_mcp, ) \ No newline at end of file diff --git a/backend/app/services/auto_character_service.py b/backend/app/services/auto_character_service.py index 128c3f8..8e6e827 100644 --- a/backend/app/services/auto_character_service.py +++ b/backend/app/services/auto_character_service.py @@ -269,25 +269,11 @@ class AutoCharacterService: ) try: - # 调用AI分析(使用统一的JSON调用方法) - if enable_mcp and user_id: - result = await self.ai_service.generate_text_with_mcp( - prompt=prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2 - ) - content = result.get("content", "") - # 使用统一的JSON清洗方法 - cleaned = self.ai_service._clean_json_response(content) - analysis = json.loads(cleaned) - else: - # 非MCP调用:使用带自动重试的JSON调用 - analysis = await self.ai_service.call_with_json_retry( - prompt=prompt, - max_retries=3 - ) + # 使用统一的JSON调用方法(支持自动MCP工具加载) + analysis = await self.ai_service.call_with_json_retry( + prompt=prompt, + max_retries=3, + ) logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}") return analysis @@ -364,16 +350,16 @@ class AutoCharacterService: existing_characters=existing_chars_summary + careers_info, plot_context="根据剧情需要引入的新角色", character_specification=json.dumps(spec, ensure_ascii=False, indent=2), - mcp_references="" # 暂时不使用MCP增强 + mcp_references="" # MCP工具通过AI服务自动加载 ) - # 调用AI生成(禁用MCP,避免累积超时导致卡死) + logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}") + + # 调用AI生成 try: - # 🔧 优化:角色详情生成不使用MCP,只在分析阶段使用MCP - # 这样可以减少大量的外部工具调用,避免超时和卡死 character_data = await self.ai_service.call_with_json_retry( prompt=prompt, - max_retries=2 # 减少重试次数以加快速度 + max_retries=2, # 减少重试次数以加快速度 ) char_name = character_data.get('name', '未知') diff --git a/backend/app/services/auto_organization_service.py b/backend/app/services/auto_organization_service.py index 8a16612..395604c 100644 --- a/backend/app/services/auto_organization_service.py +++ b/backend/app/services/auto_organization_service.py @@ -292,25 +292,11 @@ class AutoOrganizationService: ) try: - # 调用AI分析(使用统一的JSON调用方法) - if enable_mcp and user_id: - result = await self.ai_service.generate_text_with_mcp( - prompt=prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2 - ) - content = result.get("content", "") - # 使用统一的JSON清洗方法 - cleaned = self.ai_service._clean_json_response(content) - analysis = json.loads(cleaned) - else: - # 非MCP调用:使用带自动重试的JSON调用 - analysis = await self.ai_service.call_with_json_retry( - prompt=prompt, - max_retries=3 - ) + # 使用统一的JSON调用方法(支持自动MCP工具加载) + analysis = await self.ai_service.call_with_json_retry( + prompt=prompt, + max_retries=3, + ) logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}") return analysis @@ -362,24 +348,11 @@ class AutoOrganizationService: # 调用AI生成(使用统一的JSON调用方法) try: - if enable_mcp and user_id: - result = await self.ai_service.generate_text_with_mcp( - prompt=prompt, - user_id=user_id, - db_session=db, - enable_mcp=True, - max_tool_rounds=2 - ) - content = result.get("content", "") - # 使用统一的JSON清洗方法 - cleaned = self.ai_service._clean_json_response(content) - organization_data = json.loads(cleaned) - else: - # 非MCP调用:使用带自动重试的JSON调用 - organization_data = await self.ai_service.call_with_json_retry( - prompt=prompt, - max_retries=3 - ) + # 使用统一的JSON调用方法(支持自动MCP工具加载) + organization_data = await self.ai_service.call_with_json_retry( + prompt=prompt, + max_retries=3, + ) org_name = organization_data.get('name', '未知') logger.info(f" ✅ 组织详情生成成功: {org_name}") diff --git a/backend/app/services/mcp_test_service.py b/backend/app/services/mcp_test_service.py index f061bba..50bb406 100644 --- a/backend/app/services/mcp_test_service.py +++ b/backend/app/services/mcp_test_service.py @@ -1,4 +1,7 @@ -"""MCP插件测试服务 - 专门处理插件测试逻辑""" +"""MCP插件测试服务 - 专门处理插件测试逻辑 + +重构后使用统一的MCPClientFacade门面来管理所有MCP操作。 +""" import time import json @@ -10,7 +13,7 @@ from sqlalchemy import select from app.models.mcp_plugin import MCPPlugin from app.models.settings import Settings as UserSettings -from app.mcp.registry import mcp_registry +from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面 from app.services.ai_service import create_user_ai_service from app.schemas.mcp_plugin import MCPTestResult from app.services.prompt_service import prompt_service @@ -21,7 +24,32 @@ logger = get_logger(__name__) class MCPTestService: - """MCP插件测试服务(分离的测试逻辑)""" + """MCP插件测试服务(使用统一门面重构)""" + + async def _ensure_plugin_registered( + self, + plugin: MCPPlugin, + user_id: str + ) -> bool: + """ + 确保插件已注册到统一门面 + + Args: + plugin: 插件配置 + user_id: 用户ID + + Returns: + 是否成功 + """ + if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url: + return await mcp_client.ensure_registered( + user_id=user_id, + plugin_name=plugin.plugin_name, + url=plugin.server_url, + plugin_type=plugin.plugin_type, + headers=plugin.headers + ) + return False async def test_plugin_connection( self, @@ -41,19 +69,18 @@ class MCPTestService: start_time = time.time() try: - # 确保插件已加载 - if not mcp_registry.get_client(user_id, plugin.plugin_name): - success = await mcp_registry.load_plugin(plugin) - if not success: - return MCPTestResult( - success=False, - message="插件加载失败", - error="无法创建MCP客户端", - suggestions=["请检查插件配置", "请确认服务器URL正确"] - ) + # 确保插件已注册 + registered = await self._ensure_plugin_registered(plugin, user_id) + if not registered: + return MCPTestResult( + success=False, + message="插件注册失败", + error="无法创建MCP客户端", + suggestions=["请检查插件配置", "请确认服务器URL正确"] + ) - # 测试连接并获取工具列表 - test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name) + # 使用统一门面测试连接 + test_result = await mcp_client.test_connection(user_id, plugin.plugin_name) end_time = time.time() response_time = round((end_time - start_time) * 1000, 2) @@ -70,7 +97,18 @@ class MCPTestService: ] ) else: - return MCPTestResult(**test_result) + return MCPTestResult( + success=False, + message="❌ 连接测试失败", + response_time_ms=response_time, + error=test_result.get("message", "未知错误"), + error_type=test_result.get("error_type"), + suggestions=[ + "请检查服务器是否在线", + "请确认配置正确", + "请检查API Key是否有效" + ] + ) except Exception as e: end_time = time.time() @@ -117,8 +155,8 @@ class MCPTestService: if not connection_result.success: return connection_result - # 2. 获取工具列表 - tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name) + # 2. 使用统一门面获取工具列表 + tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name) if not tools: return MCPTestResult( @@ -162,8 +200,8 @@ class MCPTestService: max_tokens=1000 ) - # 转换为OpenAI Function Calling格式 - openai_tools = self._convert_tools_to_openai_format(tools) + # 使用统一门面转换为OpenAI Function Calling格式 + openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name) logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}") logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}") @@ -175,26 +213,16 @@ class MCPTestService: db=db_session ) - # 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下 - # AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks - # 因此这里需要特殊处理 - accumulated_text = "" - tool_calls = None - - async for chunk in ai_service.generate_text_stream( + # 使用 generate_text 进行 Function Calling(非流式) + ai_response = await ai_service.generate_text( prompt=prompts["user"], system_prompt=prompts["system"], tools=openai_tools, - tool_choice="required" - ): - # 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls - if isinstance(chunk, dict): - if "tool_calls" in chunk: - tool_calls = chunk["tool_calls"] - if "content" in chunk: - accumulated_text += chunk.get("content", "") - else: - accumulated_text += chunk + tool_choice="auto" + ) + + accumulated_text = ai_response.get("content", "") + tool_calls = ai_response.get("tool_calls") # 5. 检查AI是否返回工具调用 if not tool_calls: @@ -214,7 +242,7 @@ class MCPTestService: # 6. 解析工具调用 tool_call = tool_calls[0] function = tool_call["function"] - tool_name = function["name"] + tool_name_with_prefix = function["name"] test_arguments = function["arguments"] if isinstance(test_arguments, str): @@ -231,17 +259,23 @@ class MCPTestService: tools_count=len(tools) ) + # 解析插件名和工具名 + try: + _, tool_name = mcp_client.parse_function_name(tool_name_with_prefix) + except ValueError: + tool_name = tool_name_with_prefix + logger.info(f"🤖 AI选择的工具: {tool_name}") logger.info(f"📝 AI生成的参数: {test_arguments}") - # 7. 调用MCP工具 + # 7. 使用统一门面调用MCP工具 call_start = time.time() try: - tool_result = await mcp_registry.call_tool( - user.user_id, - plugin.plugin_name, - tool_name, - test_arguments + tool_result = await mcp_client.call_tool( + user_id=user.user_id, + plugin_name=plugin.plugin_name, + tool_name=tool_name, + arguments=test_arguments ) call_end = time.time() @@ -307,22 +341,6 @@ class MCPTestService: "请检查API Key是否有效" ] ) - - def _convert_tools_to_openai_format(self, tools: list) -> list: - """将MCP工具格式转换为OpenAI Function Calling格式""" - openai_tools = [] - for tool in tools: - openai_tool = { - "type": "function", - "function": { - "name": tool["name"], - "description": tool.get("description", ""), - } - } - if "inputSchema" in tool: - openai_tool["function"]["parameters"] = tool["inputSchema"] - openai_tools.append(openai_tool) - return openai_tools # 全局单例 diff --git a/backend/app/services/mcp_tool_service.py b/backend/app/services/mcp_tool_service.py deleted file mode 100644 index 6d8c315..0000000 --- a/backend/app/services/mcp_tool_service.py +++ /dev/null @@ -1,691 +0,0 @@ -"""MCP工具服务 - 统一管理MCP工具的注入和执行""" - -from typing import List, Dict, Any, Optional -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select -import asyncio -import json -import time -from datetime import datetime, timedelta -from dataclasses import dataclass, field -from collections import defaultdict - -from app.models.mcp_plugin import MCPPlugin -from app.mcp.registry import mcp_registry -from app.mcp.config import mcp_config -from app.logger import get_logger - -logger = get_logger(__name__) - - -@dataclass -class ToolMetrics: - """工具调用指标""" - total_calls: int = 0 - success_calls: int = 0 - failed_calls: int = 0 - total_duration_ms: float = 0.0 - avg_duration_ms: float = 0.0 - last_call_time: Optional[datetime] = None - - def update_success(self, duration_ms: float): - """更新成功调用指标""" - self.total_calls += 1 - self.success_calls += 1 - self.total_duration_ms += duration_ms - self.avg_duration_ms = self.total_duration_ms / self.total_calls - self.last_call_time = datetime.now() - - def update_failure(self, duration_ms: float): - """更新失败调用指标""" - self.total_calls += 1 - self.failed_calls += 1 - self.total_duration_ms += duration_ms - self.avg_duration_ms = self.total_duration_ms / self.total_calls - self.last_call_time = datetime.now() - - @property - def success_rate(self) -> float: - """成功率""" - if self.total_calls == 0: - return 0.0 - return self.success_calls / self.total_calls - - -@dataclass -class ToolCacheEntry: - """工具缓存条目""" - tools: List[Dict[str, Any]] - expire_time: datetime - hit_count: int = 0 - - -class MCPToolServiceError(Exception): - """MCP工具服务异常""" - pass - - -class MCPToolService: - """MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)""" - - def __init__( - self, - cache_ttl_minutes: Optional[int] = None, - max_retries: Optional[int] = None - ): - """ - 初始化MCP工具服务 - - Args: - cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置) - max_retries: 最大重试次数(默认使用配置) - """ - # 工具定义缓存: {cache_key: ToolCacheEntry} - self._tool_cache: Dict[str, ToolCacheEntry] = {} - self._cache_ttl = timedelta( - minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES - ) - - # 调用指标: {tool_key: ToolMetrics} - self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics) - - # 重试配置(使用配置常量) - self._max_retries = max_retries or mcp_config.MAX_RETRIES - self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS - self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS - - logger.info( - f"✅ MCPToolService初始化完成 " - f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, " - f"最大重试={self._max_retries}次)" - ) - - async def get_user_enabled_tools( - self, - user_id: str, - db_session: AsyncSession, - category: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - 获取用户启用的MCP工具列表 - - Args: - user_id: 用户ID - db_session: 数据库会话 - category: 工具类别筛选(search/analysis/filesystem等) - - Returns: - 工具定义列表,格式符合OpenAI Function Calling规范 - """ - try: - # 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active) - # 因为新启用的插件status可能还是inactive,需要给它机会被调用 - query = select(MCPPlugin).where( - MCPPlugin.user_id == user_id, - MCPPlugin.enabled == True - ) - - if category: - query = query.where(MCPPlugin.category == category) - - result = await db_session.execute(query) - plugins = result.scalars().all() - - if not plugins: - logger.info(f"用户 {user_id} 没有启用的MCP插件") - return [] - - # 2. 获取所有工具定义(使用缓存) - all_tools = [] - for plugin in plugins: - try: - # 确保插件已加载到注册表 - if not mcp_registry.get_client(user_id, plugin.plugin_name): - logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...") - success = await mcp_registry.load_plugin(plugin) - if not success: - logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过") - continue - - # ✅ 使用缓存获取工具列表 - plugin_tools = await self._get_plugin_tools_cached( - user_id=user_id, - plugin_name=plugin.plugin_name - ) - - # 格式化为Function Calling格式 - formatted_tools = self._format_tools_for_ai( - plugin_tools, - plugin.plugin_name - ) - all_tools.extend(formatted_tools) - - logger.info( - f"从插件 {plugin.plugin_name} 加载了 " - f"{len(formatted_tools)} 个工具" - ) - - except Exception as e: - logger.error( - f"获取插件 {plugin.plugin_name} 的工具失败: {e}", - exc_info=True - ) - continue - - logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具") - return all_tools - - except Exception as e: - logger.error(f"获取用户MCP工具失败: {e}", exc_info=True) - raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}") - - def _format_tools_for_ai( - self, - plugin_tools: List[Dict[str, Any]], - plugin_name: str - ) -> List[Dict[str, Any]]: - """ - 将MCP工具定义格式化为AI Function Calling格式 - - Args: - plugin_tools: MCP插件的工具列表 - plugin_name: 插件名称 - - Returns: - 格式化后的工具列表 - """ - formatted_tools = [] - - for tool in plugin_tools: - formatted_tool = { - "type": "function", - "function": { - "name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突 - "description": tool.get("description", ""), - "parameters": tool.get("inputSchema", { - "type": "object", - "properties": {}, - "required": [] - }) - } - } - formatted_tools.append(formatted_tool) - - return formatted_tools - - async def _get_plugin_tools_cached( - self, - user_id: str, - plugin_name: str - ) -> List[Dict[str, Any]]: - """ - 带缓存的工具列表获取 - - Args: - user_id: 用户ID - plugin_name: 插件名称 - - Returns: - 工具列表 - """ - cache_key = f"{user_id}:{plugin_name}" - now = datetime.now() - - # 检查缓存 - if cache_key in self._tool_cache: - entry = self._tool_cache[cache_key] - if now < entry.expire_time: - entry.hit_count += 1 - logger.debug( - f"🎯 工具缓存命中: {cache_key} " - f"(命中次数: {entry.hit_count})" - ) - return entry.tools - else: - logger.debug(f"⏰ 工具缓存过期: {cache_key}") - del self._tool_cache[cache_key] - - # 缓存未命中,从MCP获取 - logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}") - tools = await mcp_registry.get_plugin_tools(user_id, plugin_name) - - # 更新缓存 - self._tool_cache[cache_key] = ToolCacheEntry( - tools=tools, - expire_time=now + self._cache_ttl, - hit_count=0 - ) - - return tools - - def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None): - """ - 清理缓存 - - Args: - user_id: 用户ID(可选,清理特定用户的缓存) - plugin_name: 插件名称(可选,清理特定插件的缓存) - """ - if user_id is None and plugin_name is None: - # 清理所有缓存 - self._tool_cache.clear() - logger.info("🧹 已清理所有工具缓存") - elif user_id and plugin_name: - # 清理特定插件缓存 - cache_key = f"{user_id}:{plugin_name}" - if cache_key in self._tool_cache: - del self._tool_cache[cache_key] - logger.info(f"🧹 已清理缓存: {cache_key}") - elif user_id: - # 清理用户所有缓存 - keys_to_delete = [ - key for key in self._tool_cache.keys() - if key.startswith(f"{user_id}:") - ] - for key in keys_to_delete: - del self._tool_cache[key] - logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)") - - async def execute_tool_calls( - self, - user_id: str, - tool_calls: List[Dict[str, Any]], - db_session: AsyncSession, - timeout: Optional[float] = None, - max_concurrent: int = 2 - ) -> List[Dict[str, Any]]: - """ - 批量执行AI请求的工具调用(限制并发数,避免超时) - - Args: - user_id: 用户ID - tool_calls: AI返回的工具调用列表 - db_session: 数据库会话 - timeout: 单个工具调用的超时时间(秒,默认使用配置) - max_concurrent: 最大并发工具调用数(默认2) - - Returns: - 工具调用结果列表 - """ - if not tool_calls: - return [] - - # 使用配置的默认超时 - actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS - - logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s, 最大并发={max_concurrent})") - - # ✅ 分批执行,每批最多max_concurrent个 - all_results = [] - for i in range(0, len(tool_calls), max_concurrent): - batch = tool_calls[i:i+max_concurrent] - batch_num = i // max_concurrent + 1 - total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent - - logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}") - - # 创建当前批次的异步任务 - tasks = [ - self._execute_single_tool( - user_id=user_id, - tool_call=tool_call, - db_session=db_session, - timeout=actual_timeout - ) - for tool_call in batch - ] - - # 并行执行当前批次 - batch_results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理批次结果 - for j, result in enumerate(batch_results): - tool_call = batch[j] - - if isinstance(result, Exception): - # 工具调用异常 - all_results.append({ - "tool_call_id": tool_call.get("id", f"call_{i+j}"), - "role": "tool", - "name": tool_call["function"]["name"], - "content": f"工具调用失败: {str(result)}", - "success": False, - "error": str(result) - }) - else: - all_results.append(result) - - # 批次间增加短暂延迟,避免API限流 - if i + max_concurrent < len(tool_calls): - await asyncio.sleep(0.5) - logger.debug(f"批次间延迟 0.5 秒...") - - return all_results - - async def _execute_single_tool( - self, - user_id: str, - tool_call: Dict[str, Any], - db_session: AsyncSession, - timeout: float - ) -> Dict[str, Any]: - """ - 执行单个工具调用 - - Args: - user_id: 用户ID - tool_call: 工具调用信息 - db_session: 数据库会话 - timeout: 超时时间 - - Returns: - 工具调用结果 - """ - tool_call_id = tool_call.get("id", "unknown") - function_name = tool_call["function"]["name"] - - try: - # 解析插件名和工具名 - logger.debug(f"🔍 解析工具名称: {function_name}") - if "_" in function_name: - plugin_name, tool_name = function_name.split("_", 1) - logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}") - else: - raise ValueError(f"无效的工具名称格式: {function_name}") - - # 解析参数 - arguments_str = tool_call["function"]["arguments"] - logger.debug(f"🔍 解析参数:") - logger.debug(f" 原始类型: {type(arguments_str)}") - logger.debug(f" 原始内容: {arguments_str}") - - if isinstance(arguments_str, str): - try: - arguments = json.loads(arguments_str) - logger.debug(f" ✅ JSON解析成功: {arguments}") - except json.JSONDecodeError as je: - logger.error(f" ❌ JSON解析失败: {je}") - logger.error(f" 原始字符串: '{arguments_str}'") - raise ValueError(f"参数JSON解析失败: {je}") - else: - arguments = arguments_str - logger.debug(f" 直接使用dict类型参数") - - logger.info( - f"执行工具: {plugin_name}.{tool_name}, " - f"参数: {arguments}" - ) - - # ✅ 使用带重试的调用 - tool_key = f"{plugin_name}.{tool_name}" - start_time = time.time() - - try: - result = await self._call_tool_with_retry( - user_id=user_id, - plugin_name=plugin_name, - tool_name=tool_name, - arguments=arguments, - timeout=timeout - ) - - # 记录成功指标 - duration_ms = (time.time() - start_time) * 1000 - self._metrics[tool_key].update_success(duration_ms) - - logger.info( - f"✅ 工具调用成功: {tool_key} " - f"(耗时: {duration_ms:.2f}ms)" - ) - - # 成功返回 - return { - "tool_call_id": tool_call_id, - "role": "tool", - "name": function_name, - "content": json.dumps(result, ensure_ascii=False), - "success": True, - "error": None - } - - except asyncio.TimeoutError: - # 记录失败指标 - duration_ms = (time.time() - start_time) * 1000 - self._metrics[tool_key].update_failure(duration_ms) - raise MCPToolServiceError( - f"工具调用超时(>{timeout}秒)" - ) - - except Exception as e: - # 记录失败指标 - tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name - duration_ms = (time.time() - start_time) * 1000 - self._metrics[tool_key].update_failure(duration_ms) - - logger.error( - f"❌ 工具 {function_name} 调用失败: {e}", - exc_info=True - ) - return { - "tool_call_id": tool_call_id, - "role": "tool", - "name": function_name, - "content": f"工具调用失败: {str(e)}", - "success": False, - "error": str(e) - } - - async def _call_tool_with_retry( - self, - user_id: str, - plugin_name: str, - tool_name: str, - arguments: Dict[str, Any], - timeout: float - ) -> Any: - """ - 带指数退避重试的工具调用 - - Args: - user_id: 用户ID - plugin_name: 插件名称 - tool_name: 工具名称 - arguments: 工具参数 - timeout: 超时时间 - - Returns: - 工具执行结果 - - Raises: - MCPToolServiceError: 工具调用失败 - asyncio.TimeoutError: 调用超时 - """ - last_exception = None - - for attempt in range(self._max_retries): - try: - # 尝试调用工具 - result = await asyncio.wait_for( - mcp_registry.call_tool( - user_id=user_id, - plugin_name=plugin_name, - tool_name=tool_name, - arguments=arguments - ), - timeout=timeout - ) - - # 成功则返回 - if attempt > 0: - logger.info( - f"✅ 重试成功: {plugin_name}.{tool_name} " - f"(第{attempt + 1}次尝试)" - ) - return result - - except asyncio.TimeoutError: - # 超时不重试,直接抛出 - raise - - except Exception as e: - last_exception = e - - # 最后一次尝试失败 - if attempt == self._max_retries - 1: - logger.error( - f"❌ 重试失败: {plugin_name}.{tool_name} " - f"(已尝试{self._max_retries}次): {e}" - ) - raise MCPToolServiceError( - f"工具调用失败(已重试{self._max_retries}次): {str(e)}" - ) - - # 计算指数退避延迟 - delay = min( - self._base_retry_delay * (2 ** attempt), - self._max_retry_delay - ) - - logger.warning( - f"⚠️ 工具调用失败,{delay:.1f}秒后重试 " - f"(第{attempt + 1}/{self._max_retries}次): " - f"{plugin_name}.{tool_name} - {e}" - ) - - await asyncio.sleep(delay) - - # 理论上不会到这里,但为了类型安全 - raise MCPToolServiceError(f"工具调用失败: {last_exception}") - - def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]: - """ - 获取工具调用指标 - - Args: - tool_name: 工具名称(可选,获取特定工具的指标) - - Returns: - 指标字典 - """ - if tool_name: - if tool_name in self._metrics: - metric = self._metrics[tool_name] - return { - tool_name: { - "total_calls": metric.total_calls, - "success_calls": metric.success_calls, - "failed_calls": metric.failed_calls, - "success_rate": metric.success_rate, - "avg_duration_ms": round(metric.avg_duration_ms, 2), - "last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None - } - } - return {} - - # 返回所有工具的指标 - result = {} - for tool_key, metric in self._metrics.items(): - result[tool_key] = { - "total_calls": metric.total_calls, - "success_calls": metric.success_calls, - "failed_calls": metric.failed_calls, - "success_rate": round(metric.success_rate, 3), - "avg_duration_ms": round(metric.avg_duration_ms, 2), - "last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None - } - return result - - def get_cache_stats(self) -> Dict[str, Any]: - """获取缓存统计信息""" - total_entries = len(self._tool_cache) - total_hits = sum(entry.hit_count for entry in self._tool_cache.values()) - - return { - "total_entries": total_entries, - "total_hits": total_hits, - "cache_ttl_minutes": self._cache_ttl.total_seconds() / 60, - "entries": [ - { - "key": key, - "tools_count": len(entry.tools), - "hit_count": entry.hit_count, - "expire_time": entry.expire_time.isoformat() - } - for key, entry in self._tool_cache.items() - ] - } - - async def build_tool_context( - self, - tool_results: List[Dict[str, Any]], - format: str = "markdown" - ) -> str: - """ - 将工具调用结果格式化为上下文文本 - - Args: - tool_results: 工具调用结果列表 - format: 输出格式(markdown/json/plain) - - Returns: - 格式化的上下文字符串 - """ - if not tool_results: - return "" - - if format == "markdown": - return self._build_markdown_context(tool_results) - elif format == "json": - return json.dumps(tool_results, ensure_ascii=False, indent=2) - else: # plain - return self._build_plain_context(tool_results) - - def _build_markdown_context( - self, - tool_results: List[Dict[str, Any]] - ) -> str: - """构建Markdown格式的工具上下文""" - lines = ["## 🔧 工具调用结果\n"] - - for i, result in enumerate(tool_results, 1): - tool_name = result.get("name", "unknown") - success = result.get("success", False) - content = result.get("content", "") - - status_emoji = "✅" if success else "❌" - lines.append(f"### {status_emoji} {i}. {tool_name}\n") - - if success: - # 尝试美化JSON内容 - try: - content_obj = json.loads(content) - content = json.dumps(content_obj, ensure_ascii=False, indent=2) - except: - pass - lines.append(f"```json\n{content}\n```\n") - else: - lines.append(f"**错误**: {content}\n") - - return "\n".join(lines) - - def _build_plain_context( - self, - tool_results: List[Dict[str, Any]] - ) -> str: - """构建纯文本格式的工具上下文""" - lines = ["=== 工具调用结果 ===\n"] - - for i, result in enumerate(tool_results, 1): - tool_name = result.get("name", "unknown") - success = result.get("success", False) - content = result.get("content", "") - - status = "成功" if success else "失败" - lines.append(f"{i}. {tool_name} - {status}") - lines.append(f" 结果: {content}\n") - - return "\n".join(lines) - - -# 全局单例 -mcp_tool_service = MCPToolService() \ No newline at end of file diff --git a/backend/app/services/mcp_tools_loader.py b/backend/app/services/mcp_tools_loader.py new file mode 100644 index 0000000..7167ca7 --- /dev/null +++ b/backend/app/services/mcp_tools_loader.py @@ -0,0 +1,235 @@ +"""MCP工具加载器 - 统一的工具获取入口 + +在AI请求之前,自动检查用户MCP配置并加载可用工具。 +""" +from typing import Optional, List, Dict, Any +from datetime import datetime, timedelta +from dataclasses import dataclass, field + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.logger import get_logger +from app.models.mcp_plugin import MCPPlugin +from app.mcp import mcp_client + +logger = get_logger(__name__) + + +@dataclass +class UserToolsCache: + """用户工具缓存条目""" + tools: Optional[List[Dict[str, Any]]] + expire_time: datetime + hit_count: int = 0 + + +class MCPToolsLoader: + """ + MCP工具加载器 + + 负责: + 1. 检查用户是否配置并启用了MCP插件 + 2. 从各个启用的插件加载工具列表 + 3. 将工具转换为OpenAI Function Calling格式 + 4. 缓存结果以提升性能 + """ + + _instance: Optional['MCPToolsLoader'] = None + + def __new__(cls): + """单例模式""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # 用户工具缓存: user_id -> UserToolsCache + self._cache: Dict[str, UserToolsCache] = {} + + # 缓存TTL(5分钟) + self._cache_ttl = timedelta(minutes=5) + + self._initialized = True + logger.info("✅ MCPToolsLoader 初始化完成") + + async def has_enabled_plugins( + self, + user_id: str, + db_session: AsyncSession + ) -> bool: + """ + 检查用户是否有启用的MCP插件 + + Args: + user_id: 用户ID + db_session: 数据库会话 + + Returns: + 是否有启用的插件 + """ + try: + query = select(MCPPlugin.id).where( + MCPPlugin.user_id == user_id, + MCPPlugin.enabled == True, + MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"]) + ).limit(1) + + result = await db_session.execute(query) + return result.scalar() is not None + + except Exception as e: + logger.warning(f"检查用户MCP插件失败: {e}") + return False + + async def get_user_tools( + self, + user_id: str, + db_session: AsyncSession, + use_cache: bool = True, + force_refresh: bool = False + ) -> Optional[List[Dict[str, Any]]]: + """ + 获取用户的MCP工具列表(OpenAI格式) + + Args: + user_id: 用户ID + db_session: 数据库会话 + use_cache: 是否使用缓存 + force_refresh: 是否强制刷新 + + Returns: + - None: 用户未配置或未启用任何MCP插件 + - []: 有配置但没有可用工具 + - List[Dict]: OpenAI Function Calling格式的工具列表 + """ + now = datetime.now() + + # 检查缓存 + if use_cache and not force_refresh and user_id in self._cache: + cache_entry = self._cache[user_id] + if now < cache_entry.expire_time: + cache_entry.hit_count += 1 + logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})") + return cache_entry.tools + else: + del self._cache[user_id] + logger.debug(f"⏰ 用户工具缓存过期: {user_id}") + + # 从数据库加载 + try: + tools = await self._load_user_tools(user_id, db_session) + + # 更新缓存 + self._cache[user_id] = UserToolsCache( + tools=tools, + expire_time=now + self._cache_ttl + ) + + if tools: + logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具") + else: + logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具") + + return tools + + except Exception as e: + logger.error(f"❌ 加载用户MCP工具失败: {e}") + return None + + async def _load_user_tools( + self, + user_id: str, + db_session: AsyncSession + ) -> Optional[List[Dict[str, Any]]]: + """ + 从数据库加载用户启用的MCP插件并获取工具 + """ + # 查询启用的插件 + query = select(MCPPlugin).where( + MCPPlugin.user_id == user_id, + MCPPlugin.enabled == True, + MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"]) + ).order_by(MCPPlugin.sort_order) + + result = await db_session.execute(query) + plugins = result.scalars().all() + + if not plugins: + return None + + all_tools = [] + + for plugin in plugins: + try: + # 确定插件类型 + plugin_type = plugin.plugin_type + if plugin_type == "http": + plugin_type = "streamable_http" # 默认使用streamable_http + + # 确保插件已注册到MCP客户端 + await mcp_client.ensure_registered( + user_id=user_id, + plugin_name=plugin.plugin_name, + url=plugin.server_url, + plugin_type=plugin_type, + headers=plugin.headers + ) + + # 获取工具列表 + plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name) + + # 转换为OpenAI格式 + formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name) + all_tools.extend(formatted) + + logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具") + + except Exception as e: + logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}") + continue + + return all_tools if all_tools else None + + def invalidate_cache(self, user_id: Optional[str] = None): + """ + 使缓存失效 + + Args: + user_id: 用户ID,为None时清空所有缓存 + """ + if user_id: + if user_id in self._cache: + del self._cache[user_id] + logger.debug(f"🧹 清理用户工具缓存: {user_id}") + else: + count = len(self._cache) + self._cache.clear() + logger.info(f"🧹 清理所有用户工具缓存 ({count}个)") + + def get_cache_stats(self) -> Dict[str, Any]: + """获取缓存统计""" + now = datetime.now() + return { + "total_entries": len(self._cache), + "total_hits": sum(e.hit_count for e in self._cache.values()), + "cache_ttl_minutes": self._cache_ttl.total_seconds() / 60, + "entries": [ + { + "user_id": uid, + "tools_count": len(e.tools) if e.tools else 0, + "hit_count": e.hit_count, + "expired": now >= e.expire_time, + "expire_time": e.expire_time.isoformat() + } + for uid, e in self._cache.items() + ] + } + + +# 全局单例 +mcp_tools_loader = MCPToolsLoader() \ No newline at end of file diff --git a/backend/app/utils/sse_response.py b/backend/app/utils/sse_response.py index c97a08c..e624ca2 100644 --- a/backend/app/utils/sse_response.py +++ b/backend/app/utils/sse_response.py @@ -1,13 +1,231 @@ """Server-Sent Events (SSE) 响应工具类""" import json import asyncio -from typing import AsyncGenerator, Dict, Any, Optional +from enum import Enum +from typing import AsyncGenerator, Dict, Any, Optional, Callable +from dataclasses import dataclass from fastapi.responses import StreamingResponse from app.logger import get_logger logger = get_logger(__name__) +class ProgressStage(Enum): + """标准化进度阶段枚举""" + # 初始化阶段 (0-5%) + INIT = "init" + # 加载数据阶段 (5-15%) + LOADING = "loading" + # 准备提示词阶段 (15-20%) + PREPARING = "preparing" + # AI生成阶段 (20-85%) + GENERATING = "generating" + # 解析数据阶段 (85-92%) + PARSING = "parsing" + # 保存数据阶段 (92-98%) + SAVING = "saving" + # 完成阶段 (100%) + COMPLETE = "complete" + + +@dataclass +class StageConfig: + """阶段配置""" + start: int # 起始进度 + end: int # 结束进度 + default_message: str # 默认消息 + + +# 标准进度阶段配置 +STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = { + ProgressStage.INIT: StageConfig(0, 5, "开始处理..."), + ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."), + ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."), + ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."), + ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."), + ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."), + ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"), +} + + +class WizardProgressTracker: + """ + 向导进度追踪器 - 标准化管理SSE进度推送 + + 使用示例: + tracker = WizardProgressTracker("世界观") + yield await tracker.start() + yield await tracker.loading("加载项目信息") + yield await tracker.preparing() + async for chunk in ai_stream: + yield await tracker.generating_chunk(chunk, len(accumulated)) + yield await tracker.parsing() + yield await tracker.saving("保存世界观数据") + yield await tracker.complete() + """ + + def __init__(self, task_name: str = "任务"): + """ + 初始化进度追踪器 + + Args: + task_name: 任务名称,用于消息前缀 + """ + self.task_name = task_name + self.current_stage = ProgressStage.INIT + self.current_progress = 0 + self._last_generating_progress = 20 # 生成阶段的最后进度值 + + def _get_stage_progress( + self, + stage: ProgressStage, + sub_progress: float = 0.0 + ) -> int: + """ + 计算阶段内的进度值 + + Args: + stage: 当前阶段 + sub_progress: 阶段内子进度 (0.0-1.0) + + Returns: + 总进度值 (0-100) + """ + config = STAGE_CONFIGS[stage] + if sub_progress <= 0: + return config.start + if sub_progress >= 1: + return config.end + return config.start + int((config.end - config.start) * sub_progress) + + async def start(self, message: str = None) -> str: + """开始阶段""" + self.current_stage = ProgressStage.INIT + self.current_progress = 0 + msg = message or f"开始生成{self.task_name}..." + return await SSEResponse.send_progress(msg, 0, "processing") + + async def loading(self, message: str = None, sub_progress: float = 0.5) -> str: + """加载数据阶段""" + self.current_stage = ProgressStage.LOADING + progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress) + self.current_progress = progress + msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message + return await SSEResponse.send_progress(msg, progress, "processing") + + async def preparing(self, message: str = None) -> str: + """准备提示词阶段""" + self.current_stage = ProgressStage.PREPARING + progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5) + self.current_progress = progress + msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message + return await SSEResponse.send_progress(msg, progress, "processing") + + async def generating( + self, + current_chars: int = 0, + estimated_total: int = 5000, + message: str = None, + retry_count: int = 0, + max_retries: int = 3 + ) -> str: + """ + AI生成阶段进度更新 + + Args: + current_chars: 当前已生成字符数 + estimated_total: 预估总字符数 + message: 自定义消息 + retry_count: 当前重试次数 + max_retries: 最大重试次数 + """ + self.current_stage = ProgressStage.GENERATING + + # 计算生成进度 (0.0-1.0) + sub_progress = min(current_chars / max(estimated_total, 1), 1.0) + progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress) + + # 确保进度单调递增 + if progress < self._last_generating_progress: + progress = self._last_generating_progress + else: + self._last_generating_progress = progress + + self.current_progress = progress + + # 构建消息 + retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else "" + if message: + msg = f"{message}{retry_suffix}" + else: + msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}" + + return await SSEResponse.send_progress(msg, progress, "processing") + + async def generating_chunk(self, chunk: str) -> str: + """发送生成的内容块""" + return await SSEResponse.send_chunk(chunk) + + async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str: + """解析数据阶段""" + self.current_stage = ProgressStage.PARSING + progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress) + self.current_progress = progress + msg = message or f"解析{self.task_name}数据..." + return await SSEResponse.send_progress(msg, progress, "processing") + + async def saving(self, message: str = None, sub_progress: float = 0.5) -> str: + """保存数据阶段""" + self.current_stage = ProgressStage.SAVING + progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress) + self.current_progress = progress + msg = message or f"保存{self.task_name}到数据库..." + return await SSEResponse.send_progress(msg, progress, "processing") + + async def complete(self, message: str = None) -> str: + """完成阶段""" + self.current_stage = ProgressStage.COMPLETE + self.current_progress = 100 + msg = message or f"{self.task_name}生成完成!" + return await SSEResponse.send_progress(msg, 100, "success") + + async def warning(self, message: str) -> str: + """发送警告消息(保持当前进度)""" + return await SSEResponse.send_progress( + f"⚠️ {message}", + self.current_progress, + "warning" + ) + + async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str: + """发送重试消息""" + return await SSEResponse.send_progress( + f"⚠️ {reason}... ({retry_count}/{max_retries})", + self.current_progress, + "warning" + ) + + async def error(self, error_message: str, code: int = 500) -> str: + """发送错误消息""" + return await SSEResponse.send_error(error_message, code) + + async def result(self, data: Dict[str, Any]) -> str: + """发送结果数据""" + return await SSEResponse.send_result(data) + + async def done(self) -> str: + """发送完成信号""" + return await SSEResponse.send_done() + + async def heartbeat(self) -> str: + """发送心跳""" + return await SSEResponse.send_heartbeat() + + def reset_generating_progress(self): + """重置生成阶段进度(用于重试时)""" + self._last_generating_progress = 20 + + class SSEResponse: """SSE响应构建器""" diff --git a/backend/requirements.txt b/backend/requirements.txt index 8025cf9..d66e59f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -19,10 +19,11 @@ anthropic==0.72.0 # 工具库 httpx==0.28.1 -python-dotenv==1.0.0 +python-dotenv==1.1.0 psutil==6.1.1 # MCP官方库(Model Context Protocol Python SDK) -mcp==1.21.0 +mcp==1.22.0 +fastmcp==2.13.3 # NumPy版本锁定(兼容性要求) diff --git a/frontend/src/components/AIProjectGenerator.tsx b/frontend/src/components/AIProjectGenerator.tsx index 68f7d8f..c1b5538 100644 --- a/frontend/src/components/AIProjectGenerator.tsx +++ b/frontend/src/components/AIProjectGenerator.tsx @@ -37,6 +37,14 @@ interface GenerationSteps { outline: GenerationStep; } +interface WorldBuildingResult { + project_id: string; + time_period: string; + location: string; + atmosphere: string; + rules: string; +} + export const AIProjectGenerator: React.FC = ({ config, storagePrefix, @@ -64,7 +72,7 @@ export const AIProjectGenerator: React.FC = ({ // 保存生成数据,用于重试 const [generationData, setGenerationData] = useState(null); // 保存世界观生成结果,用于后续步骤 - const [worldBuildingResult, setWorldBuildingResult] = useState(null); + const [worldBuildingResult, setWorldBuildingResult] = useState(null); // LocalStorage 键名 const storageKeys = { @@ -102,6 +110,7 @@ export const AIProjectGenerator: React.FC = ({ handleAutoGenerate(config); } } + // eslint-disable-next-line react-hooks/exhaustive-deps }, [config, resumeProjectId]); // 恢复未完成项目的生成 @@ -125,33 +134,40 @@ export const AIProjectGenerator: React.FC = ({ const wizardStep = project.wizard_step || 0; // 根据wizard_step判断从哪里继续 + // wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成 + // 获取世界观数据(用于后续步骤) + const worldResult = { + project_id: projectIdParam, + time_period: project.world_time_period || '', + location: project.world_location || '', + atmosphere: project.world_atmosphere || '', + rules: project.world_rules || '' + }; + if (wizardStep === 0) { // 从世界观开始 message.info('从世界观步骤开始生成...'); setGenerationSteps({ worldBuilding: 'processing', careers: 'pending', characters: 'pending', outline: 'pending' }); await resumeFromWorldBuilding(data); } else if (wizardStep === 1) { - // 世界观已完成,从角色开始 - message.info('世界观已完成,从角色步骤继续...'); - setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' }); - - // 获取世界观数据 - const worldResult = { - project_id: projectIdParam, - time_period: project.world_time_period || '', - location: project.world_location || '', - atmosphere: project.world_atmosphere || '', - rules: project.world_rules || '' - }; + // 世界观已完成,从职业体系开始 + message.info('世界观已完成,从职业体系步骤继续...'); + setGenerationSteps({ worldBuilding: 'completed', careers: 'processing', characters: 'pending', outline: 'pending' }); setWorldBuildingResult(worldResult); - setProgress(33); - - await resumeFromCharacters(data, worldResult); + setProgress(20); + await resumeFromCareers(data, worldResult); } else if (wizardStep === 2) { - // 世界观和角色已完成,从大纲开始 - message.info('世界观和角色已完成,从大纲步骤继续...'); + // 职业体系已完成,从角色开始 + message.info('职业体系已完成,从角色步骤继续...'); + setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' }); + setWorldBuildingResult(worldResult); + setProgress(40); + await resumeFromCharacters(data, worldResult); + } else if (wizardStep === 3) { + // 角色已完成,从大纲开始 + message.info('角色已完成,从大纲步骤继续...'); setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'completed', outline: 'processing' }); - setProgress(66); + setProgress(70); await resumeFromOutline(data, projectIdParam); } else { // 已全部完成 @@ -211,11 +227,47 @@ export const AIProjectGenerator: React.FC = ({ } ); + await resumeFromCareers(data, worldResult); + }; + + // 恢复:从职业体系步骤继续 + const resumeFromCareers = async (data: GenerationConfig, worldResult: WorldBuildingResult) => { + const pid = projectId || worldResult.project_id; + + setGenerationSteps(prev => ({ ...prev, careers: 'processing' })); + setProgressMessage('正在生成职业体系...'); + + await wizardStreamApi.generateCareerSystemStream( + { + project_id: pid, + }, + { + onProgress: (msg, prog) => { + setProgress(prog); + setProgressMessage(msg); + }, + onResult: (result) => { + console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`); + setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); + }, + onError: (error) => { + console.error('职业体系生成失败:', error); + setErrorDetails(`职业体系生成失败: ${error}`); + setGenerationSteps(prev => ({ ...prev, careers: 'error' })); + setLoading(false); + throw new Error(error); + }, + onComplete: () => { + console.log('职业体系生成完成'); + } + } + ); + await resumeFromCharacters(data, worldResult); }; // 恢复:从角色步骤继续 - const resumeFromCharacters = async (data: GenerationConfig, worldResult: any) => { + const resumeFromCharacters = async (data: GenerationConfig, worldResult: WorldBuildingResult) => { const genreString = Array.isArray(data.genre) ? data.genre.join('、') : data.genre; const pid = projectId || worldResult.project_id; @@ -342,26 +394,11 @@ export const AIProjectGenerator: React.FC = ({ // 直接使用后端返回的进度值 setProgress(prog); setProgressMessage(msg); - - // 检测职业体系生成阶段 - if (msg.includes('职业体系')) { - if (msg.includes('开始') || msg.includes('生成')) { - setGenerationSteps(prev => ({ - ...prev, - worldBuilding: 'completed', - careers: 'processing' - })); - } - if (msg.includes('完成') || msg.includes('✅')) { - setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); - } - } }, onResult: (result) => { setProjectId(result.project_id); setWorldBuildingResult(result); setGenerationSteps(prev => ({ ...prev, worldBuilding: 'completed' })); - // 职业体系状态已在onProgress中更新 }, onError: (error) => { console.error('世界观生成失败:', error); @@ -385,7 +422,37 @@ export const AIProjectGenerator: React.FC = ({ setWorldBuildingResult(worldResult); saveProgress(createdProjectId, data, 'generating'); - // 步骤2: 生成角色 + // 步骤2: 生成职业体系 + setGenerationSteps(prev => ({ ...prev, careers: 'processing' })); + setProgressMessage('正在生成职业体系...'); + + await wizardStreamApi.generateCareerSystemStream( + { + project_id: createdProjectId, + }, + { + onProgress: (msg, prog) => { + setProgress(prog); + setProgressMessage(msg); + }, + onResult: (result) => { + console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`); + setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); + }, + onError: (error) => { + console.error('职业体系生成失败:', error); + setErrorDetails(`职业体系生成失败: ${error}`); + setGenerationSteps(prev => ({ ...prev, careers: 'error' })); + setLoading(false); + throw new Error(error); + }, + onComplete: () => { + console.log('职业体系生成完成'); + } + } + ); + + // 步骤3: 生成角色 setGenerationSteps(prev => ({ ...prev, characters: 'processing' })); setProgressMessage('正在生成角色...'); @@ -497,6 +564,9 @@ export const AIProjectGenerator: React.FC = ({ if (generationSteps.worldBuilding === 'error') { message.info('从世界观步骤开始重新生成...'); await retryFromWorldBuilding(); + } else if (generationSteps.careers === 'error') { + message.info('从职业体系步骤继续生成...'); + await retryFromCareers(); } else if (generationSteps.characters === 'error') { message.info('从角色步骤继续生成...'); await retryFromCharacters(); @@ -504,9 +574,10 @@ export const AIProjectGenerator: React.FC = ({ message.info('从大纲步骤继续生成...'); await retryFromOutline(); } - } catch (error: any) { + } catch (error) { console.error('智能重试失败:', error); - message.error('重试失败:' + (error.message || '未知错误')); + const errorMessage = error instanceof Error ? error.message : '未知错误'; + message.error('重试失败:' + errorMessage); setLoading(false); } }; @@ -537,20 +608,6 @@ export const AIProjectGenerator: React.FC = ({ // 直接使用后端返回的进度值 setProgress(prog); setProgressMessage(msg); - - // 检测职业体系生成阶段 - if (msg.includes('职业体系')) { - if (msg.includes('开始') || msg.includes('生成')) { - setGenerationSteps(prev => ({ - ...prev, - worldBuilding: 'completed', - careers: 'processing' - })); - } - if (msg.includes('完成') || msg.includes('✅')) { - setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); - } - } }, onResult: (result) => { setProjectId(result.project_id); @@ -574,17 +631,72 @@ export const AIProjectGenerator: React.FC = ({ throw new Error('项目创建失败:未获取到项目ID'); } - await continueFromCharacters(worldResult); + await continueFromCareers(worldResult); + }; + + // 从职业体系步骤继续 + const retryFromCareers = async () => { + if (!worldBuildingResult) { + message.warning('缺少必要数据,无法从职业体系步骤继续'); + setLoading(false); + return; + } + + const pid = worldBuildingResult.project_id || projectId; + if (!pid) { + message.warning('缺少项目ID,无法从职业体系步骤继续'); + setLoading(false); + return; + } + + setGenerationSteps(prev => ({ ...prev, careers: 'processing' })); + setProgressMessage('重新生成职业体系...'); + + await wizardStreamApi.generateCareerSystemStream( + { + project_id: pid, + }, + { + onProgress: (msg, prog) => { + setProgress(prog); + setProgressMessage(msg); + }, + onResult: (result) => { + console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`); + setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); + }, + onError: (error) => { + console.error('职业体系生成失败:', error); + setErrorDetails(`职业体系生成失败: ${error}`); + setGenerationSteps(prev => ({ ...prev, careers: 'error' })); + setLoading(false); + throw new Error(error); + }, + onComplete: () => { + console.log('职业体系重新生成完成'); + } + } + ); + + await continueFromCharacters(worldBuildingResult); }; // 从角色步骤继续 const retryFromCharacters = async () => { - if (!generationData || !projectId || !worldBuildingResult) { + if (!generationData || !worldBuildingResult) { message.warning('缺少必要数据,无法从角色步骤继续'); setLoading(false); return; } + // 优先使用 worldBuildingResult 中的 project_id,因为重试可能创建了新项目 + const pid = worldBuildingResult.project_id || projectId; + if (!pid) { + message.warning('缺少项目ID,无法从角色步骤继续'); + setLoading(false); + return; + } + setGenerationSteps(prev => ({ ...prev, characters: 'processing' })); setProgressMessage('重新生成角色...'); @@ -592,7 +704,7 @@ export const AIProjectGenerator: React.FC = ({ await wizardStreamApi.generateCharactersStream( { - project_id: projectId, + project_id: pid, count: generationData.character_count, world_context: { time_period: worldBuildingResult.time_period || '', @@ -626,23 +738,31 @@ export const AIProjectGenerator: React.FC = ({ } ); - await continueFromOutline(); + await continueFromOutline(pid); }; // 从大纲步骤继续 const retryFromOutline = async () => { - if (!generationData || !projectId) { + if (!generationData) { message.warning('缺少必要数据,无法从大纲步骤继续'); setLoading(false); return; } + // 优先使用 worldBuildingResult 中的 project_id,fallback 到状态中的 projectId + const pid = (worldBuildingResult?.project_id) || projectId; + if (!pid) { + message.warning('缺少项目ID,无法从大纲步骤继续'); + setLoading(false); + return; + } + setGenerationSteps(prev => ({ ...prev, outline: 'processing' })); setProgressMessage('重新生成大纲...'); await wizardStreamApi.generateCompleteOutlineStream( { - project_id: projectId, + project_id: pid, chapter_count: generationData.chapter_count, narrative_perspective: generationData.narrative_perspective, target_words: generationData.target_words, @@ -676,20 +796,59 @@ export const AIProjectGenerator: React.FC = ({ setLoading(false); // 调用完成回调 - if (projectId) { - onComplete(projectId); + if (pid) { + onComplete(pid); // 延迟1秒后自动跳转到项目详情页 setTimeout(() => { - navigate(`/project/${projectId}`); + navigate(`/project/${pid}`); }, 1000); } }; - // 从角色步骤开始的完整流程 - const continueFromCharacters = async (worldResult: any) => { + // 从职业体系步骤开始的完整流程 + const continueFromCareers = async (worldResult: WorldBuildingResult) => { if (!generationData || !worldResult?.project_id) return; + const pid = worldResult.project_id; + + setGenerationSteps(prev => ({ ...prev, careers: 'processing' })); + setProgressMessage('正在生成职业体系...'); + + await wizardStreamApi.generateCareerSystemStream( + { + project_id: pid, + }, + { + onProgress: (msg, prog) => { + setProgress(prog); + setProgressMessage(msg); + }, + onResult: (result) => { + console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}个`); + setGenerationSteps(prev => ({ ...prev, careers: 'completed' })); + }, + onError: (error) => { + console.error('职业体系生成失败:', error); + setErrorDetails(`职业体系生成失败: ${error}`); + setGenerationSteps(prev => ({ ...prev, careers: 'error' })); + setLoading(false); + throw new Error(error); + }, + onComplete: () => { + console.log('职业体系生成完成'); + } + } + ); + + await continueFromCharacters(worldResult); + }; + + // 从角色步骤开始的完整流程 + const continueFromCharacters = async (worldResult: WorldBuildingResult) => { + if (!generationData || !worldResult?.project_id) return; + + const pid = worldResult.project_id; const genreString = Array.isArray(generationData.genre) ? generationData.genre.join('、') : generationData.genre; setGenerationSteps(prev => ({ ...prev, characters: 'processing' })); @@ -697,7 +856,7 @@ export const AIProjectGenerator: React.FC = ({ await wizardStreamApi.generateCharactersStream( { - project_id: worldResult.project_id, + project_id: pid, count: generationData.character_count, world_context: { time_period: worldResult.time_period || '', @@ -731,19 +890,19 @@ export const AIProjectGenerator: React.FC = ({ } ); - await continueFromOutline(); + await continueFromOutline(pid); }; // 从大纲步骤开始的完整流程 - const continueFromOutline = async () => { - if (!generationData || !projectId) return; + const continueFromOutline = async (pid: string) => { + if (!generationData || !pid) return; setGenerationSteps(prev => ({ ...prev, outline: 'processing' })); setProgressMessage('正在生成大纲...'); await wizardStreamApi.generateCompleteOutlineStream( { - project_id: projectId, + project_id: pid, chapter_count: generationData.chapter_count, narrative_perspective: generationData.narrative_perspective, target_words: generationData.target_words, @@ -777,12 +936,12 @@ export const AIProjectGenerator: React.FC = ({ setLoading(false); // 调用完成回调 - if (projectId) { - onComplete(projectId); + if (pid) { + onComplete(pid); // 延迟1秒后自动跳转到项目详情页 setTimeout(() => { - navigate(`/project/${projectId}`); + navigate(`/project/${pid}`); }, 1000); } }; diff --git a/frontend/src/pages/MCPPlugins.tsx b/frontend/src/pages/MCPPlugins.tsx index b1348c8..4d17914 100644 --- a/frontend/src/pages/MCPPlugins.tsx +++ b/frontend/src/pages/MCPPlugins.tsx @@ -28,8 +28,11 @@ import { InfoCircleOutlined, ToolOutlined, ArrowLeftOutlined, + ApiOutlined, + QuestionCircleOutlined, + WarningOutlined, } from '@ant-design/icons'; -import { mcpPluginApi } from '../services/api'; +import { mcpPluginApi, settingsApi } from '../services/api'; import type { MCPPlugin, MCPTool } from '../types'; const { Paragraph, Text, Title } = Typography; @@ -46,24 +49,112 @@ export default function MCPPluginsPage() { const [editingPlugin, setEditingPlugin] = useState(null); const [testingPluginId, setTestingPluginId] = useState(null); const [viewingTools, setViewingTools] = useState<{ pluginId: string; tools: MCPTool[] } | null>(null); + const [checkingFunctionCalling, setCheckingFunctionCalling] = useState(false); + const [modelSupportStatus, setModelSupportStatus] = useState<'unknown' | 'supported' | 'unsupported'>('unknown'); useEffect(() => { - loadPlugins(); - }, []); + const initPage = async () => { + setLoading(true); + try { + // 1. 并行获取插件列表和当前设置 + const [pluginsData, settings] = await Promise.all([ + mcpPluginApi.getPlugins(), + settingsApi.getSettings() + ]); + + setPlugins(pluginsData); + + // 2. 检查配置一致性 + const verifiedConfigStr = localStorage.getItem('mcp_verified_config'); + if (verifiedConfigStr) { + try { + const verifiedConfig = JSON.parse(verifiedConfigStr); + const currentConfig = { + provider: settings.api_provider, + baseUrl: settings.api_base_url, + model: settings.llm_model + }; + + // 比较关键配置是否发生变更 + const isConfigChanged = + verifiedConfig.provider !== currentConfig.provider || + verifiedConfig.baseUrl !== currentConfig.baseUrl || + verifiedConfig.model !== currentConfig.model; + + if (isConfigChanged) { + // 配置已变更 + setModelSupportStatus('unknown'); + + // 检查是否有正在运行的插件 + const activePlugins = pluginsData.filter(p => p.enabled); + if (activePlugins.length > 0) { + // 自动禁用所有插件 + message.loading({ content: '检测到模型配置变更,正在为了安全自动禁用插件...', key: 'auto_disable' }); + + await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false))); + + // 重新加载插件列表状态 + const updatedPlugins = await mcpPluginApi.getPlugins(); + setPlugins(updatedPlugins); + + message.success({ content: '已自动禁用所有插件,请重新检测模型能力', key: 'auto_disable' }); + + modal.warning({ + title: '配置变更提醒', + centered: true, + content: '检测到您更换了 AI 模型或接口地址。为了防止错误调用,系统已自动暂停所有 MCP 插件。请重新进行"模型能力检查",确认新模型支持 Function Calling 后再启用插件。', + okText: '知道了', + }); + } else { + // 没有运行中的插件,仅提示 + message.info('检测到模型配置已变更,请重新检测模型能力'); + } + + // 清除旧的验证状态 + localStorage.removeItem('mcp_verified_config'); + } else { + // 配置未变更,恢复验证状态(根据缓存的状态恢复) + const cachedStatus = verifiedConfig.status || 'supported'; + setModelSupportStatus(cachedStatus as 'unknown' | 'supported' | 'unsupported'); + } + } catch (e) { + console.error('Failed to parse verified config:', e); + localStorage.removeItem('mcp_verified_config'); + } + } + } catch (error) { + console.error('Init page failed:', error); + message.error('页面初始化失败'); + } finally { + setLoading(false); + } + }; + initPage(); + }, [modal]); const loadPlugins = async () => { - setLoading(true); try { const data = await mcpPluginApi.getPlugins(); setPlugins(data); } catch (error) { + console.error('Load plugins failed:', error); message.error('加载插件列表失败'); - } finally { - setLoading(false); } }; const handleCreate = () => { + if (modelSupportStatus !== 'supported') { + modal.confirm({ + title: '模型能力检查', + centered: true, + icon: , + content: '为了确保 MCP 插件正常工作,您当前使用的 AI 模型必须支持 Function Calling(工具调用)能力。请先进行模型支持检测。', + okText: '去检测', + cancelText: '取消', + onOk: handleCheckFunctionCalling, + }); + return; + } setEditingPlugin(null); form.resetFields(); form.setFieldsValue({ @@ -86,7 +177,7 @@ export default function MCPPluginsPage() { setEditingPlugin(plugin); // 重构为标准MCP配置格式 - const mcpConfig: any = { + const mcpConfig: Record>> = { mcpServers: { [plugin.plugin_name]: { type: plugin.plugin_type || 'http' @@ -94,7 +185,7 @@ export default function MCPPluginsPage() { } }; - if (plugin.plugin_type === 'http') { + if (plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') { mcpConfig.mcpServers[plugin.plugin_name].url = plugin.server_url; mcpConfig.mcpServers[plugin.plugin_name].headers = plugin.headers || {}; } else { @@ -125,6 +216,7 @@ export default function MCPPluginsPage() { message.success('插件已删除'); loadPlugins(); } catch (error) { + console.error('Delete plugin failed:', error); message.error('删除插件失败'); } }, @@ -137,6 +229,7 @@ export default function MCPPluginsPage() { message.success(enabled ? '插件已启用' : '插件已禁用'); loadPlugins(); } catch (error) { + console.error('Toggle plugin failed:', error); message.error('切换插件状态失败'); } }; @@ -150,45 +243,62 @@ export default function MCPPluginsPage() { await loadPlugins(); if (result.success) { + const suggestions = result.suggestions || []; + const aiChoice = suggestions.find((s: string) => s.startsWith('🤖'))?.replace('🤖 AI选择: ', '') || ''; + const paramsStr = suggestions.find((s: string) => s.startsWith('📝'))?.replace('📝 参数: ', '') || ''; + const callTime = suggestions.find((s: string) => s.startsWith('⏱️'))?.replace('⏱️ 耗时: ', '') || ''; + const resultStr = suggestions.find((s: string) => s.startsWith('📊'))?.replace('📊 结果:\n', '') || ''; + modal.success({ - title: '测试成功', + title: '🎉 测试成功', centered: true, - width: isMobile ? '90%' : 600, + width: isMobile ? '95%' : 700, content: (
-
- +
+ ✓ {result.message}
- {(result.tools_count !== undefined || result.response_time_ms !== undefined) && ( -
- {result.tools_count !== undefined && ( -
- 可用工具数: - {result.tools_count} -
- )} - {result.response_time_ms !== undefined && ( -
- 响应时间: - {result.response_time_ms}ms -
- )} +
+
+ 可用工具数 +
{result.tools_count || 0}
+
+
+ 总响应时间 +
{result.response_time_ms?.toFixed(0) || 0}ms
+
+
+ + {aiChoice && ( +
+ 🤖 AI选择的工具 + {aiChoice} + {callTime && {callTime}}
)} - + {paramsStr && ( +
+ 📝 调用参数 +
+                    {(() => { try { return JSON.stringify(JSON.parse(paramsStr), null, 2); } catch { return paramsStr; } })()}
+                  
+
+ )} + + {resultStr && ( +
+ 📊 返回结果预览 +
+                    {resultStr}
+                  
+
+ )} + +
), }); @@ -248,7 +358,7 @@ export default function MCPPluginsPage() { ), }); } - } catch (error: any) { + } catch { message.error('测试插件失败'); } finally { setTestingPluginId(null); @@ -260,17 +370,181 @@ export default function MCPPluginsPage() { const result = await mcpPluginApi.getPluginTools(pluginId); setViewingTools({ pluginId, tools: result.tools }); } catch (error) { + console.error('Get tools failed:', error); message.error('获取工具列表失败'); } }; - const handleSubmit = async (values: any) => { + const handleCheckFunctionCalling = async () => { + // 从设置中获取当前配置 + setCheckingFunctionCalling(true); + try { + const settings = await settingsApi.getSettings(); + + if (!settings.api_key || !settings.llm_model) { + message.warning('请先在设置页面配置 API Key 和模型'); + return; + } + + const result = await settingsApi.checkFunctionCalling({ + api_key: settings.api_key, + api_base_url: settings.api_base_url || '', + provider: settings.api_provider || 'openai', + llm_model: settings.llm_model, + }); + + // 无论成功失败,都缓存当前测试的配置和状态 + const configToCache = { + provider: settings.api_provider, + baseUrl: settings.api_base_url, + model: settings.llm_model, + status: result.success && result.supported ? 'supported' : 'unsupported', + testedAt: new Date().toISOString() + }; + localStorage.setItem('mcp_verified_config', JSON.stringify(configToCache)); + + if (result.success && result.supported) { + setModelSupportStatus('supported'); + + modal.success({ + title: '✅ Function Calling 支持检测', + centered: true, + width: isMobile ? '95%' : 700, + content: ( +
+
+ + ✓ {result.message} + +
+ +
+
+ API 提供商 +
{result.provider}
+
+
+ 响应时间 +
{result.response_time_ms?.toFixed(0) || 0}ms
+
+
+ +
+ 🔧 模型信息 + {result.model} + {result.details?.finish_reason && ( + finish_reason: {result.details.finish_reason} + )} +
+ + {result.details && ( +
+ 📊 检测详情 +
+
✓ 工具调用数量: {result.details.tool_call_count || 0}
+
✓ 测试工具: {result.details.test_tool || 'N/A'}
+
✓ 响应类型: {result.details.response_type || 'N/A'}
+
+
+ )} + + {result.tool_calls && result.tool_calls.length > 0 && ( +
+ 🔨 工具调用示例 +
+                    {JSON.stringify(result.tool_calls[0], null, 2)}
+                  
+
+ )} + + {result.suggestions && result.suggestions.length > 0 && ( +
+ 💡 建议 +
    + {result.suggestions.map((s: string, i: number) => ( +
  • {s}
  • + ))} +
+
+ )} +
+ ), + }); + } else { + setModelSupportStatus('unsupported'); + modal.warning({ + title: '❌ Function Calling 支持检测', + centered: true, + width: isMobile ? '95%' : 700, + content: ( +
+
+ +
+ + {result.error && ( +
+ 错误信息: + + {result.error} + +
+ )} + + {result.response_preview && ( +
+ 📝 模型返回内容(前200字符) +
+                    {result.response_preview}
+                  
+
+ )} + + {result.suggestions && result.suggestions.length > 0 && ( +
+ 💡 建议: +
    + {result.suggestions.map((s: string, i: number) => ( +
  • {s}
  • + ))} +
+
+ )} +
+ ), + }); + } + } catch (error) { + console.error('Check function calling failed:', error); + message.error('检测失败,请稍后重试'); + setModelSupportStatus('unsupported'); + } finally { + setCheckingFunctionCalling(false); + } + }; + + const handleSubmit = async (values: { config_json: string; enabled: boolean; category?: string }) => { setLoading(true); try { // 验证JSON格式 try { JSON.parse(values.config_json); - } catch (e) { + } catch { message.error('配置JSON格式错误,请检查'); setLoading(false); return; @@ -289,8 +563,9 @@ export default function MCPPluginsPage() { setModalVisible(false); form.resetFields(); loadPlugins(); - } catch (error: any) { - const errorMsg = error?.response?.data?.detail || '操作失败'; + } catch (error: unknown) { + const err = error as { response?: { data?: { detail?: string } } }; + const errorMsg = err?.response?.data?.detail || '操作失败'; message.error(errorMsg); } finally { setLoading(false); @@ -407,38 +682,104 @@ export default function MCPPluginsPage() { - {/* 使用提示 */} - - - 什么是 MCP 插件? - - } - description={ -
- - • MCP (Model Context Protocol) 是一个标准化的协议,允许 AI 调用外部工具获取数据。 - - - • 通过添加 MCP 插件,AI 可以访问搜索引擎、数据库、API 等外部服务,增强创作能力。 - +
+ +
+ +
+ {modelSupportStatus === 'supported' ? ( + + ) : modelSupportStatus === 'unsupported' ? ( + + ) : ( + + )} +
+
+ 模型能力检查 + + {modelSupportStatus === 'supported' + ? '当前模型支持 Function Calling,可正常使用 MCP 插件' + : modelSupportStatus === 'unsupported' + ? '当前模型不支持 Function Calling,无法使用 MCP 插件' + : '请先检测模型是否支持 Function Calling 能力'} + +
+
+
- } - type="info" - showIcon={false} - style={{ - marginTop: isMobile ? 16 : 24, - borderRadius: 12, - background: 'rgba(230, 247, 255, 0.6)', - border: '1px solid rgba(145, 213, 255, 0.6)', - backdropFilter: 'blur(5px)' - }} - /> +
+ + + + +
+ 什么是 MCP 插件? + + MCP (Model Context Protocol) 协议允许 AI 调用外部工具获取数据。通过添加插件,AI 可以访问搜索引擎、数据库、API 等服务,大幅增强创作能力。 + +
+
+
+
{/* 主内容区 */}
+ {/* 模型能力未验证时的警告提示 */} + {modelSupportStatus !== 'supported' && plugins.length > 0 && ( + : } + style={{ marginBottom: 16, borderRadius: 8 }} + action={ + + } + /> + )} {/* 插件列表 */} @@ -479,7 +820,7 @@ export default function MCPPluginsPage() { {plugin.display_name || plugin.plugin_name} {getStatusTag(plugin)} - + {plugin.plugin_type?.toUpperCase() || 'UNKNOWN'} {plugin.category && plugin.category !== 'general' && ( @@ -500,7 +841,7 @@ export default function MCPPluginsPage() { )} {/* 只显示有值的URL或命令,脱敏处理敏感信息 */} - {plugin.plugin_type === 'http' && plugin.server_url && ( + {(plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') && plugin.server_url && (
{(() => { @@ -551,9 +892,10 @@ export default function MCPPluginsPage() { handleToggle(plugin, checked)} + disabled={modelSupportStatus !== 'supported'} size={isMobile ? 'small' : 'default'} style={{ flexShrink: 0, @@ -563,30 +905,33 @@ export default function MCPPluginsPage() { }} />