feat: 重构MCP功能和AI服务提供者架构

This commit is contained in:
xiamuceer-j
2026-01-09 17:13:19 +08:00
parent f3c224261d
commit 77c5489ff8
49 changed files with 4763 additions and 4307 deletions
+118 -164
View File
@@ -45,7 +45,7 @@ from app.services.memory_service import memory_service
from app.services.chapter_regenerator import ChapterRegenerator
from app.logger import get_logger
from app.api.settings import get_user_ai_service
from app.utils.sse_response import create_sse_response
from app.utils.sse_response import SSEResponse, create_sse_response
router = APIRouter(prefix="/chapters", tags=["章节管理"])
logger = get_logger(__name__)
@@ -1172,7 +1172,6 @@ async def generate_chapter_content_stream(
"""
style_id = generate_request.style_id
target_word_count = generate_request.target_word_count or 3000
enable_mcp = generate_request.enable_mcp if hasattr(generate_request, 'enable_mcp') else True
custom_model = generate_request.model if hasattr(generate_request, 'model') else None
temp_narrative_perspective = generate_request.narrative_perspective if hasattr(generate_request, 'narrative_perspective') else None
# 预先验证章节存在性(使用临时会话)
@@ -1211,25 +1210,36 @@ async def generate_chapter_content_stream(
# 获取当前用户ID(在生成器外部就需要)
current_user_id = getattr(request.state, "user_id", "system")
# 初始化标准进度追踪器
from app.utils.sse_response import WizardProgressTracker
tracker = WizardProgressTracker("章节")
try:
yield await tracker.start()
# 创建新的数据库会话
async for db_session in get_db(request):
# === 加载阶段 ===
yield await tracker.loading("加载章节信息...", 0.2)
# 重新获取章节信息
chapter_result = await db_session.execute(
select(Chapter).where(Chapter.id == chapter_id)
)
current_chapter = chapter_result.scalar_one_or_none()
if not current_chapter:
yield f"data: {json.dumps({'type': 'error', 'error': '章节不存在'}, ensure_ascii=False)}\n\n"
yield await tracker.error("章节不存在", 404)
return
yield await tracker.loading("加载项目信息...", 0.4)
# 获取项目信息
project_result = await db_session.execute(
select(Project).where(Project.id == current_chapter.project_id)
)
project = project_result.scalar_one_or_none()
if not project:
yield f"data: {json.dumps({'type': 'error', 'error': '项目不存在'}, ensure_ascii=False)}\n\n"
yield await tracker.error("项目不存在", 404)
return
# 获取项目的大纲模式
@@ -1333,80 +1343,7 @@ async def generate_chapter_content_stream(
logger.info(f" - 相关记忆: {chapter_context.context_stats.get('memory_count', 0)}")
logger.info(f" - 总上下文长度: {chapter_context.context_stats.get('total_length', 0)} 字符")
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始AI创作...'}, ensure_ascii=False)}\n\n"
# 发送初始进度0%
yield f"data: {json.dumps({'type': 'progress', 'progress': 0, 'message': '准备生成...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🔧 MCP工具增强:收集章节参考资料(优化版)
mcp_reference_materials = ""
if enable_mcp and current_user_id:
try:
# 1️⃣ 静默检查工具可用性
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=current_user_id,
db_session=db_session
)
# 2️⃣ 只在有工具时才显示消息和调用
if available_tools:
yield f"data: {json.dumps({'type': 'progress', 'message': '🔍 使用MCP工具收集参考资料...', 'progress': 28}, ensure_ascii=False)}\n\n"
# 构建资料收集提示词
planning_prompt = f"""你正在为小说《{project.title}》创作第{current_chapter.chapter_number}章《{current_chapter.title}》。
【章节大纲】
{outline.content if outline else current_chapter.summary or '暂无大纲'}
【小说信息】
- 题材:{project.genre or '未设定'}
- 主题:{project.theme or '未设定'}
- 时代背景:{project.world_time_period or '未设定'}
- 地理位置:{project.world_location or '未设定'}
【任务】
请使用可用工具搜索相关背景资料,帮助创作更真实、更有深度的章节内容。
你可以查询:
1. 该章节涉及的历史事件或时代背景
2. 地理环境和场景描写参考
3. 相关领域的专业知识(如武术、科技、魔法等)
4. 文化习俗和生活细节
请根据章节内容,有针对性地查询1-2个最关键的问题。"""
# 调用MCP增强的AI(非流式,限制1轮避免超时)
planning_result = await user_ai_service.generate_text_with_mcp(
prompt=planning_prompt,
user_id=current_user_id,
db_session=db_session,
enable_mcp=True,
max_tool_rounds=2, # ✅ 减少为1轮,避免超时
tool_choice="auto",
provider=None,
model=None
)
# 3️⃣ 提取参考资料并显示结果
if planning_result.get("tool_calls_made", 0) > 0:
tool_count = planning_result["tool_calls_made"]
yield f"data: {json.dumps({'type': 'progress', 'message': f'✅ MCP工具调用成功({tool_count}次)', 'progress': 32}, ensure_ascii=False)}\n\n"
mcp_reference_materials = planning_result.get("content", "")
logger.info(f"📚 MCP工具收集参考资料:{len(mcp_reference_materials)} 字符")
else:
yield f"data: {json.dumps({'type': 'progress', 'message': '️ MCP未使用工具,继续', 'progress': 32}, ensure_ascii=False)}\n\n"
else:
logger.debug(f"用户 {current_user_id} 未启用MCP工具,跳过MCP增强")
# 未启用MCP时也发送进度,保持连贯性
yield f"data: {json.dumps({'type': 'progress', 'message': '准备生成内容...', 'progress': 10}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式: {str(e)}")
yield f"data: {json.dumps({'type': 'progress', 'message': '⚠️ MCP工具暂时不可用,使用基础模式', 'progress': 10}, ensure_ascii=False)}\n\n"
else:
# 如果未启用MCP,也发送基础进度
yield f"data: {json.dumps({'type': 'progress', 'message': '开始构建创作上下文...', 'progress': 10}, ensure_ascii=False)}\n\n"
yield await tracker.loading("上下文构建完成", 0.8)
# 🎭 确定使用的叙事人称(临时指定 > 项目默认 > 系统默认)
chapter_perspective = (
@@ -1496,26 +1433,17 @@ async def generate_chapter_content_stream(
characters_info=characters_info or '暂无角色信息'
)
# 添加 MCP 参考资料(如果有)
if mcp_reference_materials:
mcp_section = f"\n\n<mcp_reference>\n{mcp_reference_materials}\n</mcp_reference>"
base_prompt = base_prompt.replace("</task>", f"{mcp_section}\n</task>")
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)")
# 应用写作风格
if style_content:
prompt = WritingStyleManager.apply_style_to_prompt(base_prompt, style_content)
else:
prompt = base_prompt
if mcp_reference_materials:
logger.info(f"📖 已整合MCP参考资料({len(mcp_reference_materials)}字符)到章节生成提示词")
# === 准备阶段 ===
yield await tracker.preparing("准备AI提示词...")
logger.info(f"开始AI流式创作章节 {chapter_id}")
# 发送开始生成的进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 10, 'message': '开始AI创作...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# 🎨 方案一:将写作风格注入到系统提示词(最高优先级)
system_prompt_with_style = None
if style_content:
@@ -1530,7 +1458,8 @@ async def generate_chapter_content_stream(
# 准备生成参数
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
"system_prompt": system_prompt_with_style,
"tool_choice": "required"
}
if custom_model:
logger.info(f" 使用自定义模型: {custom_model}")
@@ -1538,47 +1467,38 @@ async def generate_chapter_content_stream(
# 注意:这里使用用户配置的AI服务,模型参数会覆盖默认模型
# 如果需要切换provider,需要在前端传递provider参数
# 流式生成内容
# === 生成阶段 ===
full_content = ""
chunk_count = 0
last_progress = 0
yield await tracker.generating(
current_chars=0,
estimated_total=target_word_count
)
async for chunk in user_ai_service.generate_text_stream(**generate_kwargs):
full_content += chunk
chunk_count += 1
# 发送内容块
yield f"data: {json.dumps({'type': 'content', 'content': chunk}, ensure_ascii=False)}\n\n"
yield await tracker.generating_chunk(chunk)
# 每5个chunk发送一次进度更新10-95%,更平滑)
# 每5个chunk发送一次进度更新
if chunk_count % 5 == 0:
current_word_count = len(full_content)
# 优化进度计算:使用更平滑的递增方式
# 基于chunk数量和字数的混合计算,避免大幅跳跃
chunk_progress = min(40, chunk_count // 5) # chunk贡献最多40%
word_progress = min(45, int((current_word_count / target_word_count) * 45)) # 字数贡献最多45%
estimated_progress = min(95, 10 + chunk_progress + word_progress)
# 只在进度变化时发送
if estimated_progress > last_progress:
progress_data = {
'type': 'progress',
'progress': estimated_progress,
'message': f'正在创作中... 已生成 {current_word_count}',
'word_count': current_word_count,
'status': 'processing'
}
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
last_progress = estimated_progress
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=target_word_count,
message=f'正在创作中... 已生成 {len(full_content)}'
)
# 每20个chunk发送心跳
if chunk_count % 20 == 0:
yield f"data: {json.dumps({'type': 'heartbeat'}, ensure_ascii=False)}\n\n"
yield await tracker.heartbeat()
await asyncio.sleep(0) # 让出控制权
# 发送保存进度
yield f"data: {json.dumps({'type': 'progress', 'progress': 97, 'message': '正在保存章节...', 'status': 'processing'}, ensure_ascii=False)}\n\n"
# === 保存阶段 ===
yield await tracker.saving("正在保存章节...", 0.3)
# 更新章节内容到数据库
old_word_count = current_chapter.word_count or 0
@@ -1634,25 +1554,28 @@ async def generate_chapter_content_stream(
ai_service=user_ai_service
)
# 发送最终进度100%
yield f"data: {json.dumps({'type': 'progress', 'progress': 99, 'message': '创作完成!', 'word_count': new_word_count, 'status': 'success'}, ensure_ascii=False)}\n\n"
yield await tracker.saving("章节保存完成", 0.8)
# 发送完成事件(包含分析任务ID
completion_data = {
'type': 'done',
'message': '创作完成',
# === 完成阶段 ===
yield await tracker.complete("创作完成!")
# 发送结果数据
yield await tracker.result({
'word_count': new_word_count,
'analysis_task_id': task_id
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
})
# 发送分析开始事件
analysis_started_data = {
'type': 'analysis_started',
'task_id': task_id,
'message': '章节分析已开始'
}
yield f"data: {json.dumps(analysis_started_data, ensure_ascii=False)}\n\n"
# 发送分析开始事件(使用自定义事件)
yield await SSEResponse.send_event(
event='analysis_started',
data={
'task_id': task_id,
'message': '章节分析已开始'
}
)
# 发送完成信号
yield await tracker.done()
break # 退出async for db_session循环
@@ -1675,7 +1598,7 @@ async def generate_chapter_content_stream(
logger.info("章节生成事务已回滚(异常)")
except Exception as rollback_error:
logger.error(f"回滚失败: {str(rollback_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
yield await tracker.error(str(e))
finally:
# 确保数据库会话被正确关闭
if db_session:
@@ -2813,7 +2736,8 @@ async def generate_single_chapter_for_batch(
# 准备生成参数
generate_kwargs = {
"prompt": prompt,
"system_prompt": system_prompt_with_style # 🔑 关键:使用系统提示词传递风格
"system_prompt": system_prompt_with_style,
"tool_choice": "required"
}
# 如果传入了自定义模型,使用指定的模型
if custom_model:
@@ -3029,11 +2953,16 @@ async def regenerate_chapter_stream(
db_session = None
db_committed = False
# 初始化标准进度追踪器
from app.utils.sse_response import WizardProgressTracker
tracker = WizardProgressTracker("章节重新生成")
try:
yield await tracker.start()
# 创建独立数据库会话
async for db_session in get_db(request):
# 发送开始事件
yield f"data: {json.dumps({'type': 'start', 'message': '开始重新生成章节...'}, ensure_ascii=False)}\n\n"
yield await tracker.loading("加载章节信息...", 0.5)
# 创建重新生成任务
regen_task = RegenerationTask(
@@ -3062,13 +2991,25 @@ async def regenerate_chapter_stream(
task_id = regen_task.id
logger.info(f"📝 创建重新生成任务: {task_id}")
yield f"data: {json.dumps({'type': 'task_created', 'task_id': task_id}, ensure_ascii=False)}\n\n"
yield await tracker.preparing("准备重新生成...")
yield await SSEResponse.send_event(
event='task_created',
data={'task_id': task_id}
)
# 初始化重新生成器
regenerator = ChapterRegenerator(user_ai_service)
# 流式生成新内容
# === 生成阶段 ===
full_content = ""
estimated_total = regenerate_request.target_word_count or len(chapter.content)
yield await tracker.generating(
current_chars=0,
estimated_total=estimated_total
)
async for event in regenerator.regenerate_with_feedback(
chapter=chapter,
analysis=analysis,
@@ -3083,19 +3024,35 @@ async def regenerate_chapter_stream(
# 内容块
chunk = event['content']
full_content += chunk
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
yield await tracker.generating_chunk(chunk)
# 定期更新进度
if len(full_content) % 500 == 0:
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=estimated_total,
message=f'重新生成中... 已生成 {len(full_content)}'
)
elif event['type'] == 'progress':
# 进度更新
progress_data = {
'type': 'progress',
'progress': event.get('progress', 0),
'message': event.get('message', ''),
'word_count': event.get('word_count', 0)
}
yield f"data: {json.dumps(progress_data, ensure_ascii=False)}\n\n"
# 进度更新 - 映射到对应阶段
progress = event.get('progress', 0)
message = event.get('message', '')
if progress < 20:
yield await tracker.preparing(message)
elif progress < 85:
yield await tracker.generating(
current_chars=len(full_content),
estimated_total=estimated_total,
message=message
)
else:
yield await tracker.parsing(message)
await asyncio.sleep(0)
# === 保存阶段 ===
yield await tracker.saving("保存重新生成的内容...", 0.5)
# 更新任务状态
regen_task.status = 'completed'
regen_task.regenerated_content = full_content
@@ -3108,25 +3065,22 @@ async def regenerate_chapter_stream(
await db_session.commit()
db_committed = True
# 先发送结果数据
result_data = {
'type': 'result',
'data': {
'task_id': task_id,
'word_count': len(full_content),
'version_number': regen_task.version_number,
'auto_applied': regenerate_request.auto_apply,
'diff_stats': diff_stats
}
}
yield f"data: {json.dumps(result_data, ensure_ascii=False)}\n\n"
yield await tracker.saving("保存完成", 0.9)
# 再发送完成事件
completion_data = {
'type': 'done',
'message': '重新生成完成'
}
yield f"data: {json.dumps(completion_data, ensure_ascii=False)}\n\n"
# === 完成阶段 ===
yield await tracker.complete("重新生成完成!")
# 发送结果数据
yield await tracker.result({
'task_id': task_id,
'word_count': len(full_content),
'version_number': regen_task.version_number,
'auto_applied': regenerate_request.auto_apply,
'diff_stats': diff_stats
})
# 发送完成信号
yield await tracker.done()
logger.info(f"✅ 章节重新生成完成: {chapter_id}, 任务: {task_id}")
@@ -3151,7 +3105,7 @@ async def regenerate_chapter_stream(
except Exception as update_error:
logger.error(f"更新任务失败状态失败: {str(update_error)}")
yield f"data: {json.dumps({'type': 'error', 'error': str(e)}, ensure_ascii=False)}\n\n"
yield await tracker.error(str(e))
finally:
if db_session:
+28 -130
View File
@@ -775,148 +775,46 @@ async def generate_character_stream(
logger.info(f"🎯 开始为项目 {request.project_id} 生成角色(SSE流式)")
try:
# 🔧 MCP工具增强:静默检查并收集参考资料
# 直接使用 AIService 流式生成
ai_response = ""
chunk_count = 0
if user_id:
try:
from app.services.mcp_tool_service import mcp_tool_service
available_tools = await mcp_tool_service.get_user_enabled_tools(
user_id=user_id,
db_session=db
)
# 只在有工具时才调用
if available_tools:
logger.info(f"🔍 检测到可用MCP工具,尝试收集参考资料...")
result = await user_ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2,
tool_choice="auto",
provider=None,
model=None
)
if isinstance(result, dict):
ai_response = result.get('content', '')
finish_reason = result.get('finish_reason', '')
tool_calls_made = result.get('tool_calls_made', 0)
# 🔧 修复:检查工具调用是否真正成功
if tool_calls_made > 0:
if finish_reason == 'tool_error':
logger.warning(f"⚠️ MCP工具调用失败,降级为基础模式")
# 工具调用失败,重新用基础模式生成
ai_response = ""
elif not ai_response.strip():
logger.warning(f"⚠️ MCP工具调用后返回空响应,降级为基础模式")
# 工具调用成功但返回空内容,重新生成
ai_response = ""
else:
logger.info(f"✅ MCP工具调用成功({tool_calls_made}次),内容长度: {len(ai_response)}")
# MCP成功且有内容,模拟流式输出(分块发送)
chunk_size = 50
for i in range(0, len(ai_response), chunk_size):
chunk = ai_response[i:i+chunk_size]
chunk_count += 1
yield await SSEResponse.send_chunk(chunk)
if chunk_count % 3 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({i+len(chunk)}/{len(ai_response)}字符)",
10 + min(85 * (i+len(chunk)) // len(ai_response), 85)
)
# 跳过后续的流式生成
ai_response = result.get('content', '')
else:
ai_response = result
# 如果MCP调用失败或返回空,继续走流式生成
if not ai_response or not ai_response.strip():
logger.info(f"🔄 开始流式生成...")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"用户 {user_id} 未启用MCP工具,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as mcp_error:
logger.warning(f"⚠️ MCP工具调用异常,降级为流式基础模式: {str(mcp_error)}")
ai_response = ""
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
# 定期更新进度
if chunk_count % 5 == 0:
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
else:
logger.debug(f"未登录用户,使用流式基础模式")
async for chunk in user_ai_service.generate_text_stream(prompt=prompt):
chunk_count += 1
ai_response += chunk
logger.info(f"🎯 开始生成角色(流式模式)...")
yield await SSEResponse.send_progress("🎯 开始生成角色...", 15)
async for chunk in user_ai_service.generate_text_stream(
prompt=prompt,
tool_choice="required",
):
# chunk 现在可能是 dict 或 str,提取 content 字段
if isinstance(chunk, dict):
content = chunk.get("content", "")
else:
content = chunk
if content:
ai_response += content
# 发送内容块
yield await SSEResponse.send_chunk(chunk)
yield await SSEResponse.send_chunk(content)
# 定期更新进度
if chunk_count % 5 == 0:
# 定期更新进度(每收到约500字符更新一次,避免过于频繁)
current_len = len(ai_response)
if current_len >= chunk_count * 500:
chunk_count += 1
# 使用实际字符数量计算进度,上限85%(留15%给后续解析和保存)
# 估算最终字符数约为提示词的8倍,最少3000字符
estimated_total = max(3000, len(prompt) * 8)
progress = min(15 + int(current_len / estimated_total * 70), 85)
yield await SSEResponse.send_progress(
f"AI生成角色中... ({len(ai_response)}字符)",
10 + min(chunk_count // 2, 85)
f"AI生成角色中... ({current_len}字符)",
progress
)
# 心跳
if chunk_count % 20 == 0:
yield await SSEResponse.send_heartbeat()
except Exception as ai_error:
logger.error(f"❌ AI服务调用异常:{str(ai_error)}")
yield await SSEResponse.send_error(f"AI服务调用失败:{str(ai_error)}")
+210 -73
View File
@@ -1,4 +1,7 @@
"""MCP插件管理API"""
"""MCP插件管理API
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
"""
from fastapi import APIRouter, HTTPException, Depends, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
@@ -17,9 +20,8 @@ from app.schemas.mcp_plugin import (
)
import json
from app.user_manager import User
from app.mcp.registry import mcp_registry
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
from app.services.mcp_test_service import mcp_test_service
from app.services.mcp_tool_service import mcp_tool_service
from app.logger import get_logger
logger = get_logger(__name__)
@@ -34,6 +36,31 @@ def require_login(request: Request) -> User:
return request.state.user
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
"""
将插件注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否注册成功
"""
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.register(MCPPluginConfig(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers,
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
))
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
return False
@router.get("", response_model=List[MCPPluginResponse])
async def list_plugins(
enabled_only: bool = Query(False, description="只返回启用的插件"),
@@ -99,9 +126,9 @@ async def create_plugin(
await db.commit()
await db.refresh(plugin)
# 如果启用,加载到注册表
# 如果启用,注册到统一门面
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
success = await _register_plugin_to_facade(plugin, user.user_id)
if success:
plugin.status = "active"
else:
@@ -153,7 +180,7 @@ async def create_plugin_simple(
# 提取配置
server_type = server_config.get("type", "http")
if server_type not in ["http", "stdio"]:
if server_type not in ["http", "stdio", "streamable_http", "sse"]:
raise HTTPException(status_code=400, detail=f"不支持的服务器类型: {server_type}")
# 检查插件名是否已存在
@@ -175,12 +202,12 @@ async def create_plugin_simple(
"sort_order": 0
}
if server_type == "http":
if server_type in ["http", "streamable_http", "sse"]:
plugin_data["server_url"] = server_config.get("url")
plugin_data["headers"] = server_config.get("headers", {})
if not plugin_data["server_url"]:
raise HTTPException(status_code=400, detail="HTTP类型插件必须提供url字段")
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
elif server_type == "stdio":
plugin_data["command"] = server_config.get("command")
@@ -194,9 +221,9 @@ async def create_plugin_simple(
# 更新现有插件
logger.info(f"插件 {plugin_name} 已存在,执行更新操作")
# 先卸载旧插件
if existing.enabled:
await mcp_registry.unload_plugin(user.user_id, existing.plugin_name)
# 保存旧状态
old_enabled = existing.enabled
old_plugin_name = existing.plugin_name
# 更新字段
for key, value in plugin_data.items():
@@ -206,17 +233,24 @@ async def create_plugin_simple(
await db.commit()
await db.refresh(plugin)
# 如果启用,重新加载
# 数据库完成后进行MCP操作
if old_enabled:
try:
await mcp_client.unregister(user.user_id, old_plugin_name)
except Exception as e:
logger.warning(f"注销旧插件出错: {e}")
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
plugin.last_error = None
else:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = "加载失败"
await db.commit()
await db.refresh(plugin)
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
else:
@@ -230,16 +264,18 @@ async def create_plugin_simple(
await db.commit()
await db.refresh(plugin)
# 如果启用,加载到注册表
# 数据库完成后进行MCP操作
if plugin.enabled:
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
else:
try:
success = await _register_plugin_to_facade(plugin, user.user_id)
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
except Exception as e:
logger.error(f"注册插件失败: {e}")
plugin.status = "error"
plugin.last_error = "加载失败"
await db.commit()
await db.refresh(plugin)
plugin.last_error = str(e)
await db.commit()
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
@@ -306,9 +342,10 @@ async def update_plugin(
await db.commit()
await db.refresh(plugin)
# 如果插件已启用,重新加载
# 如果插件已启用,重新注册
if plugin.enabled:
await mcp_registry.reload_plugin(plugin)
await mcp_client.unregister(user.user_id, plugin.plugin_name)
await _register_plugin_to_facade(plugin, user.user_id)
logger.info(f"用户 {user.user_id} 更新插件: {plugin.plugin_name}")
return plugin
@@ -334,8 +371,8 @@ async def delete_plugin(
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
# 从注册表卸载
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
# 从统一门面注销
await mcp_client.unregister(user.user_id, plugin.plugin_name)
# 删除数据库记录
await db.delete(plugin)
@@ -366,27 +403,57 @@ async def toggle_plugin(
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
plugin.enabled = enabled
# 保存插件信息用于后续MCP操作
plugin_name = plugin.plugin_name
plugin_type = plugin.plugin_type
server_url = plugin.server_url
headers = plugin.headers
config = plugin.config
if enabled:
# 启用:加载到注册表
success = await mcp_registry.load_plugin(plugin)
if success:
plugin.status = "active"
plugin.last_error = None
else:
plugin.status = "error"
plugin.last_error = "加载失败"
else:
# 禁用:从注册表卸载
await mcp_registry.unload_plugin(user.user_id, plugin.plugin_name)
# 先更新数据库状态
plugin.enabled = enabled
if not enabled:
plugin.status = "inactive"
await db.commit()
await db.refresh(plugin)
# 数据库操作完成后,再进行MCP操作
if enabled:
# 启用:注册到统一门面
try:
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
success = await mcp_client.register(MCPPluginConfig(
user_id=user.user_id,
plugin_name=plugin_name,
url=server_url,
plugin_type=plugin_type,
headers=headers,
timeout=config.get('timeout', 60.0) if config else 60.0
))
else:
success = False
# 更新状态
plugin.status = "active" if success else "error"
plugin.last_error = None if success else "加载失败"
await db.commit()
await db.refresh(plugin)
except Exception as e:
logger.error(f"注册插件失败: {plugin_name}, 错误: {e}")
plugin.status = "error"
plugin.last_error = str(e)
await db.commit()
await db.refresh(plugin)
else:
# 禁用:从统一门面注销(不影响数据库状态)
try:
await mcp_client.unregister(user.user_id, plugin_name)
except Exception as e:
logger.warning(f"注销插件时出错(可忽略): {plugin_name}, 错误: {e}")
action = "启用" if enabled else "禁用"
logger.info(f"用户 {user.user_id} {action}插件: {plugin.plugin_name}")
logger.info(f"用户 {user.user_id} {action}插件: {plugin_name}")
return plugin
@@ -399,7 +466,7 @@ async def test_plugin(
"""
测试插件连接并调用工具验证功能
使用新的MCPTestService进行测试
使用MCPTestService进行测试
"""
result = await db.execute(
@@ -421,7 +488,7 @@ async def test_plugin(
suggestions=["点击开关按钮启用插件"]
)
# 使用新的测试服务
# 使用测试服务
try:
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
@@ -447,32 +514,77 @@ async def test_plugin(
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
async def _ensure_plugin_loaded(
async def _ensure_plugin_registered(
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已加载(共享逻辑)
确保插件已注册到统一门面
Args:
plugin: 插件对象
user_id: 用户ID
Returns:
是否加载成功
是否成功
Raises:
HTTPException: 加载失败
HTTPException: 注册失败
"""
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
success = await mcp_registry.load_plugin(plugin)
try:
# 使用ensure_registered方法,它会检查是否已注册
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
except ValueError as e:
logger.info(f"插件 {plugin.plugin_name} 未注册,自动注册中...")
success = await _register_plugin_to_facade(plugin, user_id)
if not success:
raise HTTPException(
status_code=500,
detail=f"插件加载失败: {plugin.plugin_name}"
detail=f"插件注册失败: {plugin.plugin_name}"
)
return True
return True
@router.get("/{plugin_id}/status")
async def get_plugin_status(
plugin_id: str,
user: User = Depends(require_login),
db: AsyncSession = Depends(get_db)
):
"""获取插件的实时状态(包括内存中的会话状态)"""
result = await db.execute(
select(MCPPlugin).where(
MCPPlugin.id == plugin_id,
MCPPlugin.user_id == user.user_id
)
)
plugin = result.scalar_one_or_none()
if not plugin:
raise HTTPException(status_code=404, detail="插件不存在")
session_stats = mcp_client.get_session_stats()
session_key = f"{user.user_id}:{plugin.plugin_name}"
session_info = next((s for s in session_stats.get("sessions", []) if s["key"] == session_key), None)
return {
"plugin_id": plugin_id,
"plugin_name": plugin.plugin_name,
"db_status": plugin.status,
"session_status": session_info["status"] if session_info else None,
"is_registered": session_info is not None,
"error_rate": session_info["error_rate"] if session_info else 0,
"in_sync": (plugin.status == session_info["status"]) if session_info else (plugin.status == "inactive"),
"timestamp": datetime.now().isoformat()
}
@router.get("/metrics")
@@ -495,7 +607,8 @@ async def get_metrics(
- avg_duration_ms: 平均耗时(毫秒)
- last_call_time: 最后调用时间
"""
metrics = mcp_tool_service.get_metrics(tool_name)
# 使用统一门面获取指标
metrics = mcp_client.get_metrics(tool_name)
return {
"metrics": metrics,
@@ -518,7 +631,8 @@ async def get_cache_stats(
- cache_ttl_minutes: 缓存TTL(分钟)
- entries: 各缓存条目详情
"""
stats = mcp_tool_service.get_cache_stats()
# 使用统一门面获取缓存统计
stats = mcp_client.get_cache_stats()
return {
"cache_stats": stats,
@@ -526,6 +640,27 @@ async def get_cache_stats(
}
@router.get("/sessions/stats")
async def get_session_stats(
user: User = Depends(require_login)
):
"""
获取MCP会话统计信息
Returns:
会话统计信息,包含:
- total_sessions: 会话总数
- sessions: 各会话详情
"""
# 使用统一门面获取会话统计
stats = mcp_client.get_session_stats()
return {
"session_stats": stats,
"timestamp": datetime.now().isoformat()
}
@router.post("/cache/clear")
async def clear_cache(
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
@@ -551,7 +686,8 @@ async def clear_cache(
# 如果没有指定user_id,使用当前用户
target_user_id = user_id or user.user_id
mcp_tool_service.clear_cache(target_user_id, plugin_name)
# 使用统一门面清理缓存
mcp_client.clear_cache(target_user_id, plugin_name)
message = "已清理"
if plugin_name:
@@ -594,12 +730,13 @@ async def get_plugin_tools(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
# 更新缓存
# 更新数据库中的工具缓存
plugin.tools = tools
await db.commit()
@@ -640,22 +777,22 @@ async def call_mcp_tool(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 确保插件已注册
await _ensure_plugin_registered(plugin, user.user_id)
# 调用工具
result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
data.tool_name,
data.arguments
# 使用统一门面调用工具
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=data.tool_name,
arguments=data.arguments
)
return {
"success": True,
"plugin_name": plugin.plugin_name,
"tool_name": data.tool_name,
"result": result
"result": tool_result
}
except HTTPException:
raise
File diff suppressed because it is too large Load Diff
+253 -7
View File
@@ -22,7 +22,7 @@ from app.schemas.settings import (
from app.user_manager import User
from app.logger import get_logger
from app.config import settings as app_settings, PROJECT_ROOT
from app.services.ai_service import AIService, create_user_ai_service
from app.services.ai_service import AIService, create_user_ai_service, create_user_ai_service_with_mcp
logger = get_logger(__name__)
@@ -53,9 +53,14 @@ async def get_user_ai_service(
db: AsyncSession = Depends(get_db)
) -> AIService:
"""
依赖:获取当前用户的AI服务实例
从数据库读取用户设置并创建对应的AI服务
依赖:获取当前用户的AI服务实例(支持MCP工具自动加载)
从数据库读取用户设置并创建对应的AI服务。
自动传递 user_id 和 db_session,使得 AIService 能够加载用户配置的MCP工具。
根据用户的所有MCP插件状态决定是否启用MCP:如果有启用的插件则启用,否则禁用。
"""
from app.models.mcp_plugin import MCPPlugin
result = await db.execute(
select(Settings).where(Settings.user_id == user.user_id)
)
@@ -73,15 +78,34 @@ async def get_user_ai_service(
await db.refresh(settings)
logger.info(f"用户 {user.user_id} 首次使用AI服务,已从.env同步设置到数据库")
# 使用用户设置创建AI服务实例(包括系统提示词)
return create_user_ai_service(
# 查询用户的所有MCP插件状态
mcp_result = await db.execute(
select(MCPPlugin).where(MCPPlugin.user_id == user.user_id)
)
mcp_plugins = mcp_result.scalars().all()
# 检查是否有启用的MCP插件
enable_mcp = any(plugin.enabled for plugin in mcp_plugins) if mcp_plugins else False
if mcp_plugins:
enabled_count = sum(1 for p in mcp_plugins if p.enabled)
logger.info(f"用户 {user.user_id}{len(mcp_plugins)} 个MCP插件,{enabled_count} 个启用,{enable_mcp} 决定使用MCP")
else:
logger.debug(f"用户 {user.user_id} 没有配置MCP插件,禁用MCP")
# ✅ 使用支持MCP的工厂函数创建AI服务实例
# 传递 user_id 和 db_session,使得 AIService 能够自动加载用户配置的MCP工具
return create_user_ai_service_with_mcp(
api_provider=settings.api_provider,
api_key=settings.api_key,
api_base_url=settings.api_base_url or "",
model_name=settings.llm_model,
temperature=settings.temperature,
max_tokens=settings.max_tokens,
system_prompt=settings.system_prompt # 传递系统提示词
user_id=user.user_id, # ✅ 传递 user_id
db_session=db, # ✅ 传递 db_session
system_prompt=settings.system_prompt,
enable_mcp=enable_mcp, # 根据MCP插件状态动态决定
)
@@ -327,6 +351,227 @@ class ApiTestRequest(BaseModel):
llm_model: str
@router.post("/check-function-calling")
async def check_function_calling_support(data: ApiTestRequest):
"""
检查模型是否支持 Function Calling(工具调用)
基于业界最佳实践的测试方法:
1. 发送包含工具定义的请求
2. 检查响应的 finish_reason 是否为 "tool_calls"
3. 验证响应中是否包含有效的 tool_calls 数据
Args:
data: 包含 API 配置的请求数据
Returns:
检测结果包含支持状态、详细信息和建议
"""
api_key = data.api_key
api_base_url = data.api_base_url
provider = data.provider
llm_model = data.llm_model
try:
start_time = time.time()
# 定义一个简单的测试工具(天气查询)
test_tools = [{
"type": "function",
"function": {
"name": "get_weather",
"description": "获取指定城市的当前天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,例如:北京、上海、深圳"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "温度单位"
}
},
"required": ["city"]
}
}
}]
# 测试提示:故意设计一个需要调用工具的问题
test_prompt = "请告诉我北京现在的天气情况如何?"
logger.info(f"🧪 开始检测 Function Calling 支持")
logger.info(f" - 提供商: {provider}")
logger.info(f" - 模型: {llm_model}")
logger.info(f" - 测试工具: get_weather")
# 创建临时 AI 服务实例进行测试
test_service = AIService(
api_provider=provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=llm_model,
default_temperature=0.3, # 使用较低温度以获得更确定的行为
default_max_tokens=200
)
# 发送带工具的测试请求
response = await test_service.generate_text(
prompt=test_prompt,
provider=provider,
model=llm_model,
temperature=0.3,
max_tokens=200,
tools=test_tools,
tool_choice="auto", # 让模型自动决定是否使用工具
auto_mcp=False # 禁用 MCP 自动加载
)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
# 分析响应以确定是否支持 Function Calling
supported = False
finish_reason = None
tool_calls = None
response_content = None
if isinstance(response, dict):
# 检查 finish_reasonOpenAI 标准)
finish_reason = response.get("finish_reason")
# 检查是否有 tool_calls
if "tool_calls" in response and response["tool_calls"]:
supported = True
tool_calls = response["tool_calls"]
logger.info(f"✅ 检测到工具调用: {len(tool_calls)}")
# 记录返回的内容(如果有)
if "content" in response:
response_content = response["content"]
elif isinstance(response, str):
# 如果只返回字符串,说明不支持工具调用
response_content = response
logger.info(f" - 响应时间: {response_time}ms")
logger.info(f" - finish_reason: {finish_reason}")
logger.info(f" - 支持状态: {'✅ 支持' if supported else '❌ 不支持'}")
# 构建详细的返回信息
result = {
"success": True,
"supported": supported,
"message": "✅ 模型支持 Function Calling" if supported else "❌ 模型不支持 Function Calling",
"response_time_ms": response_time,
"provider": provider,
"model": llm_model,
"details": {
"finish_reason": finish_reason,
"has_tool_calls": bool(tool_calls),
"tool_call_count": len(tool_calls) if tool_calls else 0,
"test_tool": "get_weather",
"test_prompt": test_prompt,
"response_type": "tool_calls" if supported else "text"
}
}
# 添加工具调用详情
if tool_calls:
result["tool_calls"] = tool_calls
result["suggestions"] = [
"✅ 该模型支持 Function Calling,可以正常使用 MCP 插件",
"建议:启用需要的 MCP 插件以扩展 AI 能力",
"提示:测试成功检测到工具调用,模型能够正确解析和使用外部工具"
]
else:
result["response_preview"] = response_content[:200] if response_content else None
result["suggestions"] = [
"❌ 该模型不支持 Function Calling,无法使用 MCP 插件功能",
"建议:更换支持工具调用的模型",
"推荐模型:GPT-4 系列、GPT-4-turbo、Claude 3 Opus/Sonnet、Gemini 1.5 Pro 等",
"说明:模型返回了文本回复而非工具调用,表明不支持该功能"
]
return result
except ValueError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测配置错误: {error_msg}")
return {
"success": False,
"supported": False,
"message": "配置错误",
"error": error_msg,
"error_type": "ConfigurationError",
"suggestions": [
"请检查 API Key 是否正确",
"请确认 API Base URL 格式是否正确",
"请验证所选提供商与配置是否匹配"
]
}
except TimeoutError as e:
error_msg = str(e)
logger.error(f"❌ Function Calling 检测超时: {error_msg}")
return {
"success": False,
"supported": None,
"message": "检测超时",
"error": error_msg,
"error_type": "TimeoutError",
"suggestions": [
"请检查网络连接是否正常",
"请确认 API 服务是否可访问",
"建议:稍后重试或使用其他网络环境"
]
}
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
logger.error(f"❌ Function Calling 检测失败: {error_msg}")
logger.error(f" - 错误类型: {error_type}")
# 智能分析错误原因
suggestions = []
if "tool" in error_msg.lower() or "function" in error_msg.lower():
suggestions = [
"该模型可能不支持 Function Calling 功能",
"API 返回了与工具调用相关的错误",
"建议:更换支持工具调用的模型或联系 API 提供商"
]
elif "unauthorized" in error_msg.lower() or "401" in error_msg:
suggestions = [
"API Key 认证失败",
"请检查 API Key 是否正确且有效",
"请确认 API Key 是否有足够的权限"
]
elif "not found" in error_msg.lower() or "404" in error_msg:
suggestions = [
"模型不存在或不可用",
"请检查模型名称是否正确",
"请确认该模型在当前 API 中是否可用"
]
else:
suggestions = [
"检测过程中遇到未知错误",
"建议:检查所有配置参数是否正确",
"提示:查看详细错误信息以获取更多线索"
]
return {
"success": False,
"supported": False,
"message": "Function Calling 检测失败",
"error": error_msg,
"error_type": error_type,
"suggestions": suggestions
}
@router.post("/test")
async def test_api_connection(data: ApiTestRequest):
"""
@@ -370,7 +615,8 @@ async def test_api_connection(data: ApiTestRequest):
provider=provider,
model=llm_model,
temperature=0.7,
max_tokens=8000
max_tokens=8000,
auto_mcp=False # 测试时不加载MCP工具
)
end_time = time.time()
File diff suppressed because it is too large Load Diff
+2 -4
View File
@@ -77,10 +77,8 @@ class Settings(BaseSettings):
default_temperature: float = 0.7
default_max_tokens: int = 32000
# MCP适配器配置
enable_mcp_adapter: bool = True # 是否启用MCP适配器(自动检测API能力
mcp_adapter_cache_ttl_hours: int = 24 # API能力检测缓存时长(小时)
mcp_adapter_auto_fallback: bool = True # 是否启用自动降级(FC失败时切换到提示词注入)
# MCP配置
mcp_max_rounds: int = 3 # MCP工具调用最大轮数(全局统一控制
# LinuxDO OAuth2 配置
LINUXDO_CLIENT_ID: Optional[str] = None
+1 -1
View File
@@ -167,7 +167,7 @@ async def get_db(request: Request):
_session_stats["created"] += 1
_session_stats["active"] += 1
logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
# logger.debug(f"📊 会话创建 [User:{user_id}][ID:{session_id}] - 活跃:{_session_stats['active']}, 总创建:{_session_stats['created']}, 总关闭:{_session_stats['closed']}")
try:
yield session
+5 -1
View File
@@ -130,11 +130,15 @@ def _configure_third_party_loggers():
logging.getLogger('sqlalchemy.dialects').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.orm').setLevel(logging.WARNING)
# aiosqlite - 异步SQLite,禁用DEBUG日志
logging.getLogger('aiosqlite').setLevel(logging.WARNING)
# Watchfiles - 开发时的文件监控,降低级别
logging.getLogger('watchfiles').setLevel(logging.WARNING)
# httpx - HTTP客户端
# httpx/httpcore - HTTP客户端,禁用DEBUG日志
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('httpcore').setLevel(logging.WARNING)
# openai/anthropic - AI客户端库
logging.getLogger('openai').setLevel(logging.WARNING)
+5 -2
View File
@@ -12,7 +12,7 @@ from app.database import close_db, _session_stats
from app.logger import setup_logging, get_logger
from app.middleware import RequestIDMiddleware
from app.middleware.auth_middleware import AuthMiddleware
from app.mcp.registry import mcp_registry
from app.mcp import mcp_client, register_status_sync
setup_logging(
level=config_settings.log_level,
@@ -27,12 +27,15 @@ logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 注册MCP状态同步服务
register_status_sync()
logger.info("应用启动完成")
yield
# 清理MCP插件
await mcp_registry.cleanup_all()
await mcp_client.cleanup()
# 清理HTTP客户端池
from app.services.ai_service import cleanup_http_clients
+35 -3
View File
@@ -1,4 +1,36 @@
"""MCP插件系统"""
from .registry import mcp_registry
"""MCP模块 - 统一的MCP客户端管理
__all__ = ["mcp_registry"]
本模块提供MCPModel Context Protocol客户端的统一管理接口
推荐使用方式
from app.mcp import mcp_client, MCPPluginConfig
# 注册插件
await mcp_client.register(MCPPluginConfig(
user_id="user123",
plugin_name="exa-search",
url="http://localhost:8000/mcp"
))
# 获取工具
tools = await mcp_client.get_tools("user123", "exa-search")
# 调用工具
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
# 注册状态变更回调
from app.mcp.status_sync import register_status_sync
register_status_sync()
"""
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
from .status_sync import register_status_sync
__all__ = [
"mcp_client",
"MCPClientFacade",
"MCPPluginConfig",
"MCPError",
"PluginStatus",
"register_status_sync",
]
-14
View File
@@ -1,14 +0,0 @@
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
from .base import BaseMCPAdapter, AdapterType
from .prompt_injection import PromptInjectionAdapter
from .function_calling import FunctionCallingAdapter
from .universal import UniversalMCPAdapter
__all__ = [
"BaseMCPAdapter",
"AdapterType",
"PromptInjectionAdapter",
"FunctionCallingAdapter",
"UniversalMCPAdapter",
]
-89
View File
@@ -1,89 +0,0 @@
"""MCP适配器基类"""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
class AdapterType(Enum):
"""适配器类型"""
FUNCTION_CALLING = "function_calling" # 标准Function Calling
PROMPT_INJECTION = "prompt_injection" # 提示词注入
REACT = "react" # ReAct模式
XML = "xml" # XML标记
@dataclass
class ToolCallResult:
"""工具调用结果"""
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
raw_response: str # 原始AI响应
has_tool_calls: bool # 是否包含工具调用
needs_continuation: bool = False # 是否需要继续对话
class BaseMCPAdapter(ABC):
"""MCP适配器基类"""
def __init__(self):
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
@abstractmethod
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""
将工具列表格式化为提示词
Args:
tools: MCP工具列表
user_message: 用户消息
Returns:
格式化后的提示词
"""
pass
@abstractmethod
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
"""
从AI响应中解析工具调用
Args:
ai_response: AI的原始响应
Returns:
解析结果
"""
pass
@abstractmethod
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""
构建包含工具结果的继续对话提示词
Args:
original_message: 原始用户消息
ai_response: AI响应
tool_results: 工具执行结果
Returns:
继续对话的提示词
"""
pass
def supports_native_tools(self) -> bool:
"""是否支持原生工具调用(如Function Calling"""
return False
def get_adapter_type(self) -> AdapterType:
"""获取适配器类型"""
return self.adapter_type
@@ -1,171 +0,0 @@
"""Function Calling适配器 - 支持原生Function Calling的API"""
import json
from typing import Dict, Any, List
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.logger import get_logger
logger = get_logger(__name__)
class FunctionCallingAdapter(BaseMCPAdapter):
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI"""
def __init__(self):
super().__init__()
self.adapter_type = AdapterType.FUNCTION_CALLING
def supports_native_tools(self) -> bool:
"""支持原生工具调用"""
return True
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""
Function Calling模式下工具通过API参数传递不需要修改提示词
Returns:
原始用户消息
"""
return user_message
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
获取适用于API的工具格式
Args:
tools: MCP工具列表
Returns:
适用于OpenAI Function Calling的工具格式
"""
return tools
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
"""
从AI响应中解析工具调用Function Calling格式
Args:
ai_response: AI响应对象通常是OpenAI的ChatCompletion对象
Returns:
解析结果
"""
try:
# 处理不同类型的响应
if isinstance(ai_response, dict):
# 字典格式(OpenAI API响应)
message = ai_response.get("choices", [{}])[0].get("message", {})
tool_calls = message.get("tool_calls", [])
content = message.get("content", "")
elif hasattr(ai_response, "choices"):
# 对象格式(OpenAI SDK响应)
message = ai_response.choices[0].message
tool_calls = getattr(message, "tool_calls", None) or []
content = getattr(message, "content", "") or ""
# 转换为字典格式
if tool_calls:
tool_calls = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
for tc in tool_calls
]
else:
# 字符串格式(降级为文本响应)
return ToolCallResult(
tool_calls=[],
raw_response=str(ai_response),
has_tool_calls=False
)
has_tool_calls = len(tool_calls) > 0
if has_tool_calls:
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
for tc in tool_calls:
logger.info(f" - {tc['function']['name']}")
return ToolCallResult(
tool_calls=tool_calls,
raw_response=content or "",
has_tool_calls=has_tool_calls,
needs_continuation=has_tool_calls
)
except Exception as e:
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
return ToolCallResult(
tool_calls=[],
raw_response=str(ai_response),
has_tool_calls=False
)
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""
构建包含工具结果的继续对话提示词
在Function Calling模式下这通常不需要因为工具结果会作为消息历史的一部分
"""
# Function Calling模式下通常通过消息历史传递工具结果
# 这里提供一个降级方案
results_text = "\n\n".join([
f"工具 {r['name']} 的结果:\n{r['content']}"
for r in tool_results
])
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
def build_messages_with_tool_results(
self,
messages: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
tool_results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
构建包含工具结果的消息历史Function Calling标准格式
Args:
messages: 原始消息历史
tool_calls: AI的工具调用
tool_results: 工具执行结果
Returns:
更新后的消息历史
"""
new_messages = messages.copy()
# 添加助手的工具调用消息
new_messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls
})
# 添加工具结果消息
for result in tool_results:
new_messages.append({
"role": "tool",
"tool_call_id": result.get("tool_call_id", ""),
"name": result.get("name", ""),
"content": result.get("content", "")
})
return new_messages
@@ -1,274 +0,0 @@
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
import re
import json
from typing import Dict, Any, List
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.logger import get_logger
logger = get_logger(__name__)
class PromptInjectionAdapter(BaseMCPAdapter):
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
def __init__(self):
super().__init__()
self.adapter_type = AdapterType.PROMPT_INJECTION
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""将工具列表注入到提示词中"""
if not tools:
return user_message
# 格式化工具描述
tool_descriptions = self._format_tools_as_text(tools)
# 构建增强的提示词
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
## 可用工具
{tool_descriptions}
## 工具使用说明
当你需要使用工具时请按以下XML格式输出可以一次调用多个工具
<tool_calls>
<tool_call>
<tool_name>工具名称</tool_name>
<arguments>
{{
"参数名1": "参数值1",
"参数名2": "参数值2"
}}
</arguments>
</tool_call>
</tool_calls>
## 重要提示
1. 只有在确实需要使用工具时才调用工具
2. 参数必须是有效的JSON格式
3. 仔细检查参数是否符合工具的要求
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
5. 调用工具后你会收到工具的执行结果然后需要基于结果继续回答
---
用户问题{user_message}
请分析问题判断是否需要使用工具如果需要先输出工具调用然后等待结果如果不需要直接回答问题"""
return enhanced_prompt
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
"""将工具格式化为可读的文本描述"""
lines = []
for i, tool in enumerate(tools, 1):
func = tool.get("function", {})
name = func.get("name", "unknown")
description = func.get("description", "无描述")
parameters = func.get("parameters", {})
lines.append(f"### {i}. {name}")
lines.append(f"**描述**: {description}")
lines.append("")
# 格式化参数信息
if parameters and "properties" in parameters:
lines.append("**参数**:")
properties = parameters.get("properties", {})
required = parameters.get("required", [])
for param_name, param_info in properties.items():
param_type = param_info.get("type", "string")
param_desc = param_info.get("description", "")
is_required = "必填" if param_name in required else "可选"
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
lines.append("")
# 添加示例
if "example" in func:
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def parse_tool_calls(self, ai_response) -> ToolCallResult:
"""从AI响应中解析工具调用"""
tool_calls = []
try:
# 处理不同类型的响应
if isinstance(ai_response, dict):
# 如果是字典,提取content字段
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
if not ai_response:
return ToolCallResult(
tool_calls=[],
raw_response="",
has_tool_calls=False
)
elif not isinstance(ai_response, str):
# 转换为字符串
ai_response = str(ai_response)
# 使用正则提取 <tool_calls> 标签内容
tool_calls_match = re.search(
r'<tool_calls>(.*?)</tool_calls>',
ai_response,
re.DOTALL | re.IGNORECASE
)
if not tool_calls_match:
# 没有找到工具调用
return ToolCallResult(
tool_calls=[],
raw_response=ai_response,
has_tool_calls=False
)
tool_calls_content = tool_calls_match.group(1)
# 提取所有 <tool_call> 标签
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
tool_call_matches = re.findall(
tool_call_pattern,
tool_calls_content,
re.DOTALL | re.IGNORECASE
)
for i, tool_call_content in enumerate(tool_call_matches):
# 提取工具名称
name_match = re.search(
r'<tool_name>(.*?)</tool_name>',
tool_call_content,
re.DOTALL | re.IGNORECASE
)
# 提取参数
args_match = re.search(
r'<arguments>(.*?)</arguments>',
tool_call_content,
re.DOTALL | re.IGNORECASE
)
if name_match and args_match:
tool_name = name_match.group(1).strip()
arguments_str = args_match.group(1).strip()
try:
# 解析JSON参数
arguments = json.loads(arguments_str)
# 构建标准格式的工具调用
tool_calls.append({
"id": f"call_{i}",
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(arguments, ensure_ascii=False)
}
})
logger.info(f"✅ 解析工具调用: {tool_name}")
except json.JSONDecodeError as e:
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
continue
has_tool_calls = len(tool_calls) > 0
if has_tool_calls:
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
return ToolCallResult(
tool_calls=tool_calls,
raw_response=ai_response,
has_tool_calls=has_tool_calls,
needs_continuation=has_tool_calls
)
except Exception as e:
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
return ToolCallResult(
tool_calls=[],
raw_response=ai_response,
has_tool_calls=False
)
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建包含工具结果的继续对话提示词"""
# 格式化工具结果
results_text = self._format_tool_results(tool_results)
continuation = f"""你之前尝试使用工具来回答用户的问题。
原始问题{original_message}
你的工具调用
{self._extract_tool_calls_text(ai_response)}
工具执行结果
{results_text}
现在请基于这些工具的执行结果给出完整详细的回答不要重复调用工具直接使用已有的结果来回答用户的问题"""
return continuation
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
"""格式化工具结果为可读文本"""
lines = []
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status = "✅ 成功" if success else "❌ 失败"
lines.append(f"{i}. {tool_name} - {status}")
if success:
# 尝试美化JSON内容
try:
if isinstance(content, str):
content_obj = json.loads(content)
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
except:
pass
lines.append(f"```\n{content}\n```")
else:
error = result.get("error", "未知错误")
lines.append(f"错误信息: {error}")
lines.append("")
return "\n".join(lines)
def _extract_tool_calls_text(self, ai_response: str) -> str:
"""从AI响应中提取工具调用部分的文本"""
match = re.search(
r'<tool_calls>(.*?)</tool_calls>',
ai_response,
re.DOTALL | re.IGNORECASE
)
if match:
return match.group(0)
return "(未找到工具调用)"
-353
View File
@@ -1,353 +0,0 @@
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
import time
import asyncio
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
from app.mcp.adapters.function_calling import FunctionCallingAdapter
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class APICapability:
"""API能力检测结果"""
supports_function_calling: bool
tested_at: datetime
test_duration_ms: float
error_message: Optional[str] = None
class UniversalMCPAdapter:
"""
通用MCP适配器管理器
功能
1. 自动检测API是否支持Function Calling
2. 缓存检测结果
3. 自动降级策略FC失败时切换到提示词注入
4. 提供统一接口
"""
def __init__(
self,
cache_ttl_hours: int = 24,
enable_auto_fallback: bool = True
):
"""
初始化通用适配器
Args:
cache_ttl_hours: 能力检测缓存时长小时
enable_auto_fallback: 是否启用自动降级
"""
# 适配器实例
self.adapters = {
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
}
# API能力缓存: {api_identifier: APICapability}
self._capability_cache: Dict[str, APICapability] = {}
self._cache_ttl = timedelta(hours=cache_ttl_hours)
self._cache_lock = asyncio.Lock()
# 配置
self._enable_auto_fallback = enable_auto_fallback
logger.info(
f"✅ UniversalMCPAdapter初始化完成 "
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
)
async def get_adapter(
self,
api_identifier: str,
test_function: Optional[callable] = None
) -> BaseMCPAdapter:
"""
获取适合当前API的适配器
Args:
api_identifier: API标识符"openai_official", "azure_openai"
test_function: 可选的测试函数用于检测API能力
Returns:
最适合的适配器实例
"""
# 检查缓存
capability = await self._get_cached_capability(api_identifier)
if capability is None and test_function:
# 缓存未命中,执行检测
capability = await self._detect_capability(api_identifier, test_function)
# 选择适配器
if capability and capability.supports_function_calling:
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
return self.adapters[AdapterType.FUNCTION_CALLING]
else:
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
return self.adapters[AdapterType.PROMPT_INJECTION]
async def _get_cached_capability(
self,
api_identifier: str
) -> Optional[APICapability]:
"""获取缓存的能力检测结果"""
async with self._cache_lock:
if api_identifier not in self._capability_cache:
return None
capability = self._capability_cache[api_identifier]
# 检查是否过期
if datetime.now() - capability.tested_at > self._cache_ttl:
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
del self._capability_cache[api_identifier]
return None
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
return capability
async def _detect_capability(
self,
api_identifier: str,
test_function: callable
) -> APICapability:
"""
检测API能力
Args:
api_identifier: API标识符
test_function: 测试函数应该尝试使用Function Calling
Returns:
能力检测结果
"""
logger.info(f"🔍 开始检测API能力: {api_identifier}")
start_time = time.time()
try:
# 调用测试函数
result = await test_function()
# 判断是否成功
supports_fc = self._is_function_calling_response(result)
duration_ms = (time.time() - start_time) * 1000
capability = APICapability(
supports_function_calling=supports_fc,
tested_at=datetime.now(),
test_duration_ms=duration_ms
)
# 缓存结果
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
status = "✅ 支持" if supports_fc else "❌ 不支持"
logger.info(
f"{status} Function Calling: {api_identifier} "
f"(耗时: {duration_ms:.2f}ms)"
)
return capability
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.warning(
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
f"将使用提示词注入模式"
)
capability = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=duration_ms,
error_message=str(e)
)
# 缓存失败结果(避免重复测试)
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
return capability
def _is_function_calling_response(self, response: Any) -> bool:
"""
判断响应是否是Function Calling格式
Args:
response: API响应
Returns:
是否支持Function Calling
"""
try:
# 检查字典格式
if isinstance(response, dict):
message = response.get("choices", [{}])[0].get("message", {})
return "tool_calls" in message or "function_call" in message
# 检查对象格式(OpenAI SDK
if hasattr(response, "choices"):
message = response.choices[0].message
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
return False
except Exception:
return False
async def call_with_fallback(
self,
api_identifier: str,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
test_function: Optional[callable] = None
) -> ToolCallResult:
"""
带降级策略的工具调用
Args:
api_identifier: API标识符
tools: MCP工具列表
user_message: 用户消息
call_function: 实际调用API的函数
test_function: 可选的测试函数
Returns:
工具调用结果
"""
# 获取适配器
adapter = await self.get_adapter(api_identifier, test_function)
# 首次尝试
try:
if adapter.supports_native_tools():
# Function Calling模式
logger.info("🚀 尝试使用Function Calling模式")
result = await self._try_function_calling(
tools, user_message, call_function, adapter
)
else:
# 提示词注入模式
logger.info("🚀 使用提示词注入模式")
result = await self._try_prompt_injection(
tools, user_message, call_function, adapter
)
return result
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
# 自动降级
if self._enable_auto_fallback and adapter.supports_native_tools():
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
# 更新缓存,标记为不支持
async with self._cache_lock:
self._capability_cache[api_identifier] = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=0,
error_message=str(e)
)
# 使用提示词注入重试
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
return await self._try_prompt_injection(
tools, user_message, call_function, fallback_adapter
)
raise
async def _try_function_calling(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: FunctionCallingAdapter
) -> ToolCallResult:
"""尝试Function Calling模式"""
# Function Calling不需要修改提示词
response = await call_function(
message=user_message,
tools_param=tools,
tool_choice_param="auto"
)
return adapter.parse_tool_calls(response)
async def _try_prompt_injection(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: PromptInjectionAdapter
) -> ToolCallResult:
"""尝试提示词注入模式"""
# 注入工具到提示词
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
# 调用API(不传tools参数)
response = await call_function(
message=enhanced_prompt,
tools_param=None,
tool_choice_param=None
)
# 从文本响应中解析工具调用
return adapter.parse_tool_calls(response)
def clear_cache(self, api_identifier: Optional[str] = None):
"""
清理能力缓存
Args:
api_identifier: 可选只清理特定API的缓存
"""
if api_identifier:
if api_identifier in self._capability_cache:
del self._capability_cache[api_identifier]
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
else:
self._capability_cache.clear()
logger.info("🧹 已清理所有API能力缓存")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"total_cached": len(self._capability_cache),
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
"cached_apis": [
{
"api_identifier": api_id,
"supports_fc": cap.supports_function_calling,
"tested_at": cap.tested_at.isoformat(),
"test_duration_ms": cap.test_duration_ms
}
for api_id, cap in self._capability_cache.items()
]
}
# 全局单例
universal_mcp_adapter = UniversalMCPAdapter()
File diff suppressed because it is too large Load Diff
-385
View File
@@ -1,385 +0,0 @@
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
import asyncio
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from anyio import ClosedResourceError
from app.logger import get_logger
logger = get_logger(__name__)
class MCPError(Exception):
"""MCP错误"""
pass
class HTTPMCPClient:
"""HTTP模式MCP客户端(基于官方 MCP Python SDK"""
def __init__(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
初始化HTTP MCP客户端
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量用于API Key等
timeout: 超时时间
"""
self.url = url.rstrip('/')
self.headers = headers or {}
self.env = env or {}
self.timeout = timeout
# 如果env中有API Key,添加到headers
if 'API_KEY' in self.env:
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
self._session: Optional[ClientSession] = None
self._context_stack = [] # 保存上下文管理器栈
self._initialized = False
self._lock = asyncio.Lock()
async def _ensure_connected(self):
"""确保连接已建立"""
async with self._lock:
if self._session is None:
try:
logger.info(f"🔗 连接到MCP服务器: {self.url}")
# 使用官方 SDK 的 streamable_http_client
# 保存上下文管理器以便后续正确清理
stream_context = streamablehttp_client(self.url)
read_stream, write_stream, _ = await stream_context.__aenter__()
self._context_stack.append(('stream', stream_context))
# 创建客户端会话
self._session = ClientSession(read_stream, write_stream)
session_context = self._session
await session_context.__aenter__()
self._context_stack.append(('session', session_context))
# 初始化会话
await self._session.initialize()
self._initialized = True
logger.info(f"✅ MCP会话初始化成功")
except Exception as e:
logger.error(f"❌ MCP连接失败: {e}")
await self._cleanup()
raise MCPError(f"连接MCP服务器失败: {str(e)}")
async def _cleanup(self):
"""清理连接资源(按照进入的相反顺序退出)"""
# 按照LIFO顺序清理上下文
while self._context_stack:
ctx_type, ctx = self._context_stack.pop()
try:
await ctx.__aexit__(None, None, None)
except RuntimeError as e:
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
else:
logger.error(f"清理{ctx_type}上下文失败: {e}")
except Exception as e:
logger.error(f"清理{ctx_type}上下文失败: {e}")
self._session = None
self._initialized = False
async def initialize(self) -> Dict[str, Any]:
"""
初始化MCP会话
Returns:
初始化响应
"""
await self._ensure_connected()
return {"status": "initialized"}
async def list_tools(self) -> List[Dict[str, Any]]:
"""
列举可用工具
Returns:
工具列表
"""
try:
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
logger.info(f"获取到 {len(tools)} 个工具")
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise MCPError(f"获取工具列表失败: {str(e)}")
async def call_tool(
self,
tool_name: str,
arguments: Dict[str, Any],
max_reconnect_attempts: int = 2
) -> Any:
"""
调用工具带自动重连
Args:
tool_name: 工具名称
arguments: 工具参数
max_reconnect_attempts: 最大重连尝试次数
Returns:
工具执行结果
"""
for attempt in range(max_reconnect_attempts + 1):
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f" 参数类型: {type(arguments)}")
logger.debug(f" 参数内容: {arguments}")
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
result = await self._session.call_tool(tool_name, arguments)
logger.debug(f" 工具返回类型: {type(result)}")
logger.debug(f" 返回内容: {result}")
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
logger.debug(f" 返回content数量: {len(result.content)}")
# 提取第一个content的文本
for idx, content in enumerate(result.content):
logger.debug(f" content[{idx}]类型: {type(content)}")
if isinstance(content, types.TextContent):
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
return content.text
elif isinstance(content, types.ImageContent):
logger.debug(f" ✅ 返回ImageContent")
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
logger.debug(f" ⚠️ 返回原始content[0]")
return result.content[0] if result.content else None
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
logger.debug(f" ✅ 返回structuredContent")
return result.structuredContent
logger.warning(f" ⚠️ 工具返回为None")
return None
except ClosedResourceError as e:
# 连接已关闭,尝试重连
if attempt < max_reconnect_attempts:
logger.warning(
f"⚠️ MCP连接已关闭,尝试重新连接 "
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
)
await self._cleanup()
await asyncio.sleep(0.5) # 短暂延迟后重连
continue
else:
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
raise MCPError(error_msg)
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
logger.error(f" 参数: {arguments}")
logger.error(f" 错误类型: {type(e).__name__}")
logger.error(f" 错误详情: {repr(e)}")
logger.error(f" 错误字符串: '{str(e)}'")
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
raise MCPError(f"调用工具失败: {error_msg}")
# 理论上不会到这里
raise MCPError(f"工具调用失败: 未知错误")
async def list_resources(self) -> List[Dict[str, Any]]:
"""
列举可用资源
Returns:
资源列表
"""
try:
await self._ensure_connected()
result = await self._session.list_resources()
# 转换为字典格式
resources = []
for resource in result.resources:
resource_dict = {
"uri": str(resource.uri),
"name": resource.name,
"description": resource.description or "",
"mimeType": resource.mimeType or ""
}
resources.append(resource_dict)
logger.info(f"获取到 {len(resources)} 个资源")
return resources
except Exception as e:
logger.error(f"获取资源列表失败: {e}")
raise MCPError(f"获取资源列表失败: {str(e)}")
async def read_resource(self, uri: str) -> Any:
"""
读取资源
Args:
uri: 资源URI
Returns:
资源内容
"""
try:
await self._ensure_connected()
result = await self._session.read_resource(AnyUrl(uri))
# 提取资源内容
if result.contents:
content = result.contents[0]
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
elif isinstance(content, types.BlobResourceContents):
return {
"type": "blob",
"blob": content.blob,
"mimeType": content.mimeType
}
return None
except Exception as e:
logger.error(f"读取资源失败: {uri}, 错误: {e}")
raise MCPError(f"读取资源失败: {str(e)}")
async def test_connection(self) -> Dict[str, Any]:
"""
测试连接
Returns:
测试结果
"""
import time
start_time = time.time()
try:
# 尝试连接并列举工具(直接调用SDK,避免重复日志)
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
return {
"success": True,
"message": "连接测试成功",
"response_time_ms": response_time,
"tools_count": len(tools),
"tools": tools
}
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
return {
"success": False,
"message": "连接测试失败",
"response_time_ms": response_time,
"error": str(e),
"error_type": type(e).__name__,
"suggestions": [
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接",
"请确认MCP服务器是否在线"
]
}
async def close(self):
"""关闭客户端连接"""
logger.info(f"关闭MCP客户端: {self.url}")
await self._cleanup()
@asynccontextmanager
async def create_mcp_client(
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
创建MCP客户端的上下文管理器
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量
timeout: 超时时间
Yields:
HTTPMCPClient实例
"""
client = HTTPMCPClient(url, headers, env, timeout)
try:
await client.initialize()
yield client
finally:
await client.close()
-527
View File
@@ -1,527 +0,0 @@
"""MCP插件注册表 - 管理运行时插件实例"""
import asyncio
import time
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from datetime import datetime
from app.mcp.http_client import HTTPMCPClient, MCPError
from app.mcp.config import mcp_config
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class SessionInfo:
"""会话信息"""
client: HTTPMCPClient
created_at: float
last_access: float
request_count: int = 0
error_count: int = 0
status: str = "active" # active, degraded, error
class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
def __init__(
self,
max_clients: Optional[int] = None,
client_ttl: Optional[int] = None
):
"""
初始化注册表
Args:
max_clients: 最大缓存客户端数量默认使用配置
client_ttl: 客户端过期时间默认使用配置
"""
# 存储格式: {plugin_id: SessionInfo}
self._sessions: Dict[str, SessionInfo] = {}
# 全局锁用于保护会话字典
self._sessions_lock = asyncio.Lock()
# 细粒度锁:每个用户一个锁
self._user_locks: Dict[str, asyncio.Lock] = {}
self._locks_lock = asyncio.Lock() # 保护locks字典本身
# 配置参数(使用配置常量)
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
# 启动后台清理任务
self._cleanup_task = None
self._health_check_task = None
self._tasks_started = False
def _ensure_background_tasks(self):
"""确保后台任务已启动(延迟初始化)"""
if not self._tasks_started:
try:
# 检查是否有运行中的事件循环
loop = asyncio.get_running_loop()
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ MCP插件注册表后台清理任务已启动")
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("✅ MCP会话健康检查任务已启动")
self._tasks_started = True
except RuntimeError:
# 没有运行中的事件循环,稍后再试
pass
async def _cleanup_loop(self):
"""后台清理过期客户端"""
while True:
try:
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理任务异常: {e}")
async def _health_check_loop(self):
"""后台健康检查"""
while True:
try:
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
await self._check_session_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查任务异常: {e}")
async def _cleanup_expired_sessions(self):
"""清理过期的会话"""
now = time.time()
expired_ids = []
async with self._sessions_lock:
# 收集过期的plugin_id
for plugin_id, session in list(self._sessions.items()):
if now - session.last_access > self._client_ttl:
expired_ids.append(plugin_id)
if expired_ids:
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
for plugin_id in expired_ids:
# 提取user_id来获取对应的锁
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
async def _check_session_health(self):
"""增强的会话健康检查"""
async with self._sessions_lock:
for plugin_id, session in list(self._sessions.items()):
# 计算错误率
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
error_rate = session.error_count / session.request_count
# 动态调整状态(使用配置常量)
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
if session.status != "error":
session.status = "error"
logger.error(
f"❌ 会话 {plugin_id} 错误率过高 "
f"({error_rate:.1%}), 标记为error"
)
elif error_rate > mcp_config.ERROR_RATE_WARNING:
if session.status == "active":
session.status = "degraded"
logger.warning(
f"⚠️ 会话 {plugin_id} 健康状况下降 "
f"(错误率: {error_rate:.1%})"
)
elif session.status == "degraded":
# 错误率降低,恢复正常
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查即将过期的会话(最后1分钟提醒)
idle_time = time.time() - session.last_access
time_until_expiry = self._client_ttl - idle_time
# 仅在最后1分钟(60秒)内提醒一次
if 0 < time_until_expiry <= 60:
# 使用会话属性避免重复提醒
if not hasattr(session, '_expiry_warned') or not session._expiry_warned:
logger.warning(
f"⏰ 会话 {plugin_id} 即将过期 "
f"(剩余 {time_until_expiry:.0f} 秒)"
)
session._expiry_warned = True
elif time_until_expiry > 60:
# 重置警告标志(如果会话被重新使用)
if hasattr(session, '_expiry_warned'):
session._expiry_warned = False
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
"""
获取用户专属的锁细粒度锁
Args:
user_id: 用户ID
Returns:
该用户的锁对象
"""
async with self._locks_lock:
if user_id not in self._user_locks:
self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id]
def _touch_session(self, plugin_id: str):
"""
更新会话的最后访问时间需要在锁内调用
Args:
plugin_id: 插件ID
"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
session.last_access = time.time()
session.request_count += 1
async def _evict_lru_session(self):
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
if len(self._sessions) >= self._max_clients:
# 找到最旧的会话
oldest_id = None
oldest_time = float('inf')
for plugin_id, session in self._sessions.items():
if session.last_access < oldest_time:
oldest_time = session.last_access
oldest_id = plugin_id
if oldest_id:
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def load_plugin(self, plugin: MCPPlugin) -> bool:
"""
从配置加载插件
Args:
plugin: 插件配置
Returns:
是否加载成功
"""
# 确保后台任务已启动
self._ensure_background_tasks()
# 使用细粒度锁(只锁定当前用户)
user_lock = await self._get_user_lock(plugin.user_id)
async with user_lock:
try:
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
# 如果已加载,先卸载
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU会话
await self._evict_lru_session()
# 目前只支持HTTP类型
if plugin.plugin_type == "http":
if not plugin.server_url:
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
return False
# 为每个插件创建独立的HTTP客户端
client = HTTPMCPClient(
url=plugin.server_url,
headers=plugin.headers or {},
env=plugin.env or {},
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
)
# 创建会话信息
now = time.time()
session = SessionInfo(
client=client,
created_at=now,
last_access=now,
request_count=0,
error_count=0,
status="active"
)
# 存储会话
async with self._sessions_lock:
self._sessions[plugin_id] = session
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
return True
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
return False
except Exception as e:
logger.error(f"加载插件失败 {plugin.plugin_name}: {e}")
return False
async def unload_plugin(self, user_id: str, plugin_name: str):
"""
卸载插件
Args:
user_id: 用户ID
plugin_name: 插件名称
"""
# 使用细粒度锁(只锁定当前用户)
user_lock = await self._get_user_lock(user_id)
async with user_lock:
plugin_id = f"{user_id}:{plugin_name}"
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
async def _unload_plugin_unsafe(self, plugin_id: str):
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用)"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
try:
await session.client.close()
except Exception as e:
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
del self._sessions[plugin_id]
logger.info(f"卸载MCP插件: {plugin_id}")
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
"""
重新加载插件
Args:
plugin: 插件配置
Returns:
是否重载成功
"""
await self.unload_plugin(plugin.user_id, plugin.plugin_name)
return await self.load_plugin(plugin)
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
"""
获取插件客户端线程安全支持访问时间更新
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
客户端实例或None
"""
plugin_id = f"{user_id}:{plugin_name}"
session = self._sessions.get(plugin_id)
if session:
# 检查会话状态
if session.status == "error":
logger.warning(
f"⚠️ 会话 {plugin_id} 处于错误状态,"
f"建议调用者重新加载插件"
)
# 不返回错误状态的客户端
return None
# ✅ 使用锁保护状态更新,避免并发问题
# 注意:这里使用原子操作更新简单字段,不需要异步锁
session.last_access = time.time()
session.request_count += 1
return session.client
return None
async def get_or_reconnect_client(
self,
user_id: str,
plugin_name: str,
plugin: MCPPlugin
) -> HTTPMCPClient:
"""
获取或重连客户端自动处理错误状态
Args:
user_id: 用户ID
plugin_name: 插件名称
plugin: 插件配置对象
Returns:
客户端实例
Raises:
ValueError: 插件加载失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取用户锁
user_lock = await self._get_user_lock(user_id)
async with user_lock:
session = self._sessions.get(plugin_id)
# 检查会话健康状态
if session and session.status == "error":
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
session = None
# 如果没有会话,加载插件
if not session:
success = await self.load_plugin(plugin)
if not success:
raise ValueError(f"插件加载失败: {plugin_name}")
session = self._sessions[plugin_id]
return session.client
async def call_tool(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any]
) -> Any:
"""
调用插件工具带错误计数和状态管理
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
Returns:
工具执行结果
Raises:
ValueError: 插件不存在或未启用
MCPError: 工具调用失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取会话
session = self._sessions.get(plugin_id)
if not session:
raise ValueError(f"插件未加载: {plugin_name}")
try:
result = await session.client.call_tool(tool_name, arguments)
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
# 调用成功,重置状态(如果之前是degraded)
if session.status == "degraded":
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
return result
except Exception as e:
# 增加错误计数
session.error_count += 1
# 根据错误率更新状态
if session.request_count > 0:
error_rate = session.error_count / session.request_count
if error_rate > 0.5:
session.status = "error"
elif error_rate > 0.3:
session.status = "degraded"
logger.error(
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
)
raise
async def get_plugin_tools(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
获取插件的工具列表
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
client = self.get_client(user_id, plugin_name)
if not client:
raise ValueError(f"插件未加载: {plugin_name}")
try:
tools = await client.list_tools()
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {plugin_name}, 错误: {e}")
raise
async def test_plugin(
self,
user_id: str,
plugin_name: str
) -> Dict[str, Any]:
"""
测试插件连接
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
测试结果
"""
client = self.get_client(user_id, plugin_name)
if not client:
raise ValueError(f"插件未加载: {plugin_name}")
return await client.test_connection()
async def cleanup_all(self):
"""清理所有插件和资源"""
# 停止后台任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
# 清理所有会话
async with self._sessions_lock:
plugin_ids = list(self._sessions.keys())
for plugin_id in plugin_ids:
await self._unload_plugin_unsafe(plugin_id)
logger.info("✅ 已清理所有MCP插件和资源")
# 全局注册表实例
mcp_registry = MCPPluginRegistry()
+50
View File
@@ -0,0 +1,50 @@
"""MCP插件状态同步服务
将内存中的会话状态变更同步到数据库确保状态一致性
"""
from typing import Dict, Any
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
async def sync_status_to_db(event: Dict[str, Any]):
"""
状态变更回调 - 同步到数据库
"""
user_id = event["user_id"]
plugin_name = event["plugin_name"]
new_status = event["new_status"]
reason = event.get("reason", "")
try:
from app.database import get_engine
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with AsyncSessionLocal() as db:
stmt = (
update(MCPPlugin)
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
.values(status=new_status, last_error=reason if new_status == "error" else None)
)
await db.execute(stmt)
await db.commit()
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
except Exception as e:
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
def register_status_sync():
"""注册状态同步回调到MCP客户端"""
from app.mcp import mcp_client
mcp_client.register_status_callback(sync_status_to_db)
logger.info("✅ MCP状态同步服务已注册")
+2 -3
View File
@@ -1,5 +1,5 @@
"""职业相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -63,8 +63,7 @@ class CareerResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class CareerListResponse(BaseModel):
+4 -6
View File
@@ -1,5 +1,5 @@
"""章节相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -58,8 +58,7 @@ class ChapterResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ChapterListResponse(BaseModel):
@@ -142,8 +141,7 @@ class ExpansionPlanUpdate(BaseModel):
estimated_words: Optional[int] = Field(None, description="预估字数", ge=500, le=10000)
scenes: Optional[List[SceneData]] = Field(None, description="场景列表")
class Config:
json_schema_extra = {
model_config = ConfigDict(json_schema_extra={
"example": {
"key_events": ["主角遇到挑战", "关键决策时刻"],
"character_focus": ["张三", "李四"],
@@ -159,7 +157,7 @@ class ExpansionPlanUpdate(BaseModel):
}
]
}
}
})
class ExpansionPlanResponse(BaseModel):
+2 -3
View File
@@ -1,5 +1,5 @@
"""角色相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -98,8 +98,7 @@ class CharacterResponse(CharacterBase):
main_career_stage: Optional[int] = Field(None, description="主职业阶段")
sub_careers: Optional[List[Dict[str, Any]]] = Field(None, description="副职业列表")
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class CharacterGenerateRequest(BaseModel):
+2 -3
View File
@@ -1,5 +1,5 @@
"""MCP插件Pydantic模式"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Dict, Any, List
from datetime import datetime
@@ -82,8 +82,7 @@ class MCPPluginResponse(BaseModel):
# 时间戳
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class MCPToolCall(BaseModel):
+2 -3
View File
@@ -1,5 +1,5 @@
"""大纲相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List, Dict, Any
from datetime import datetime
@@ -103,8 +103,7 @@ class OutlineResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OutlineGenerateRequest(BaseModel):
+2 -3
View File
@@ -1,5 +1,5 @@
"""项目相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, Literal
from datetime import datetime
@@ -59,8 +59,7 @@ class ProjectResponse(ProjectBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ProjectListResponse(BaseModel):
+5 -9
View File
@@ -1,5 +1,5 @@
"""关系管理相关的Pydantic模型"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional, List
from datetime import datetime
@@ -17,8 +17,7 @@ class RelationshipTypeResponse(BaseModel):
description: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ============ 角色关系相关 ============
@@ -62,8 +61,7 @@ class CharacterRelationshipResponse(CharacterRelationshipBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class RelationshipGraphNode(BaseModel):
@@ -127,8 +125,7 @@ class OrganizationResponse(OrganizationBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OrganizationDetailResponse(BaseModel):
@@ -185,8 +182,7 @@ class OrganizationMemberResponse(OrganizationMemberBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class OrganizationMemberDetailResponse(BaseModel):
+2 -3
View File
@@ -1,5 +1,5 @@
"""写作风格 Schema"""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
from datetime import datetime
@@ -48,8 +48,7 @@ class WritingStyleResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class WritingStyleListResponse(BaseModel):
@@ -71,7 +71,18 @@ class AnthropicClient:
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
kwargs = {
"model": model,
"max_tokens": max_tokens,
@@ -80,12 +91,42 @@ class AnthropicClient:
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
try:
async with self.client.messages.stream(**kwargs) as stream:
try:
async for text in stream.text_stream:
yield text
tool_calls = []
async for chunk in stream:
# 处理不同类型的块
if chunk.type == "text_delta":
yield {"content": chunk.text}
elif chunk.type == "tool_use_delta":
# 工具调用增量
if not tool_calls or tool_calls[-1].get("id") != chunk.id:
tool_calls.append({
"id": chunk.id,
"type": "function",
"function": {
"name": chunk.name,
"arguments": ""
}
})
# 追加参数
if tool_calls[-1]["function"]["arguments"] is None:
tool_calls[-1]["function"]["arguments"] = ""
tool_calls[-1]["function"]["arguments"] += chunk.input_gets_new_text or ""
elif chunk.type == "message_delta":
if chunk.stop_reason:
# 流结束
if tool_calls:
yield {"tool_calls": tool_calls}
yield {"done": True, "finish_reason": chunk.stop_reason}
except GeneratorExit:
# 生成器被关闭,这是正常的清理过程
logger.debug("Anthropic 流式响应生成器被关闭(GeneratorExit)")
@@ -111,7 +111,18 @@ class GeminiClient:
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
@@ -125,6 +136,8 @@ class GeminiClient:
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
try:
async with self.client.stream("POST", url, json=payload) as response:
@@ -139,9 +152,26 @@ class GeminiClient:
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = parts[0].get("text", "")
text = ""
function_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
function_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {
"name": fc["name"],
"arguments": fc.get("args", {})
}
})
if text:
yield text
yield {"content": text}
if function_calls:
yield {"tool_calls": function_calls}
except json.JSONDecodeError:
continue
except GeneratorExit:
@@ -86,8 +86,21 @@ class OpenAIClient(BaseAIClient):
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式生成支持工具调用
Yields:
Dict with keys:
- content: str - 文本内容块
- tool_calls: list - 工具调用列表如果有
- done: bool - 是否结束
"""
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True)
tool_calls_buffer = {} # 收集工具调用块
try:
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
@@ -97,14 +110,38 @@ class OpenAIClient(BaseAIClient):
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
# 流结束,检查是否有工具调用需要处理
if tool_calls_buffer:
yield {"tool_calls": list(tool_calls_buffer.values()), "done": True}
yield {"done": True}
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices and len(choices) > 0:
content = choices[0].get("delta", {}).get("content", "")
delta = choices[0].get("delta", {})
content = delta.get("content", "")
# 检查工具调用
tc_list = delta.get("tool_calls")
if tc_list:
for tc in tc_list:
index = tc.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = tc
else:
existing = tool_calls_buffer[index]
# 合并 function.arguments
if "function" in tc and "function" in existing:
if tc["function"].get("arguments"):
existing["function"]["arguments"] = (
existing["function"].get("arguments", "") +
tc["function"]["arguments"]
)
if content:
yield content
yield {"content": content}
except json.JSONDecodeError:
continue
except GeneratorExit:
@@ -1,9 +1,12 @@
"""Anthropic Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.anthropic_client import AnthropicClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class AnthropicProvider(BaseAIProvider):
"""Anthropic 提供商"""
@@ -39,7 +42,62 @@ class AnthropicProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 AnthropicProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
@@ -48,4 +106,56 @@ class AnthropicProvider(BaseAIProvider):
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -28,6 +28,9 @@ class BaseAIProvider(ABC):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
pass
@@ -1,8 +1,12 @@
"""Gemini Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.gemini_client import GeminiClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class GeminiProvider(BaseAIProvider):
def __init__(self, client: GeminiClient):
@@ -36,7 +40,62 @@ class GeminiProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 GeminiProvider: 有 {len(tools)} 个工具,使用流式处理")
messages = [{"role": "user", "content": prompt}]
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = [{"role": "user", "content": final_prompt}]
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
messages = [{"role": "user", "content": prompt}]
async for chunk in self.client.chat_completion_stream(
messages=messages,
@@ -45,4 +104,56 @@ class GeminiProvider(BaseAIProvider):
max_tokens=max_tokens,
system_prompt=system_prompt,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: list = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成"""
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 _generate_with_tools 收到工具调用: {len(chunk['tool_calls'])}")
if chunk.get("done"):
if tool_calls_buffer:
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, system_prompt, tools, user_id
):
yield final_chunk
break
if chunk.get("content"):
yield chunk["content"]
@@ -1,9 +1,12 @@
"""OpenAI Provider"""
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.logger import get_logger
from app.services.ai_clients.openai_client import OpenAIClient
from .base_provider import BaseAIProvider
logger = get_logger(__name__)
class OpenAIProvider(BaseAIProvider):
"""OpenAI 提供商"""
@@ -42,16 +45,117 @@ class OpenAIProvider(BaseAIProvider):
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# 如果有工具,使用真正的流式工具调用
if tools:
logger.debug(f"🔧 OpenAIProvider: 有 {len(tools)} 个工具,使用流式处理")
actual_tool_choice = tool_choice if tool_choice else "auto"
tool_calls_buffer = []
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice=actual_tool_choice,
):
# 检查是否有工具调用
if chunk.get("tool_calls"):
tool_calls_buffer.extend(chunk["tool_calls"])
logger.debug(f"🔧 收到工具调用: {len(chunk['tool_calls'])}")
# 检查是否结束
if chunk.get("done"):
if tool_calls_buffer:
logger.info(f"🔧 流式结束,处理 {len(tool_calls_buffer)} 个工具调用")
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=tool_calls_buffer
)
# 将工具结果注入到上下文中
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 构建最终提示词,要求AI基于工具结果回答
final_prompt = f"{prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"
final_messages = messages.copy()
final_messages.append({"role": "user", "content": final_prompt})
# 递归调用生成最终结果
async for final_chunk in self._generate_with_tools(
final_messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
# 输出文本内容
if chunk.get("content"):
yield chunk["content"]
return
# 无工具时普通流式生成
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
yield chunk
# 确保只 yield 字符串内容,避免 yield 字典导致类型错误
if isinstance(chunk, dict):
if chunk.get("content"):
yield chunk["content"]
else:
yield chunk
async def _generate_with_tools(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: list,
user_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""辅助方法:带工具的流式生成(无tool_choice,AI自由决定)"""
async for chunk in self.client.chat_completion_stream(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
):
if chunk.get("tool_calls"):
from app.mcp import mcp_client
actual_user_id = user_id or ""
tool_results = await mcp_client.batch_call_tools(
user_id=actual_user_id,
tool_calls=chunk["tool_calls"]
)
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 再次调用获取最终回答
messages.append({"role": "user", "content": f"{tool_context}\n\n请基于以上工具查询结果,给出完整详细的回答。"})
async for final_chunk in self._generate_with_tools(
messages, model, temperature, max_tokens, tools, user_id
):
yield final_chunk
break
if chunk.get("done"):
break
if chunk.get("content"):
yield chunk["content"]
+391 -112
View File
@@ -1,4 +1,10 @@
"""AI服务封装 - 统一的AI接口"""
"""AI服务封装 - 统一的AI接口
重构后支持自动MCP工具加载
- 所有AI方法在请求前自动检查用户MCP配置
- 如果有启用的MCP插件且有可用工具自动发送tools
- 通过 auto_mcp 参数控制是否启用自动工具加载
"""
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
from app.config import settings as app_settings
@@ -13,7 +19,6 @@ from app.services.ai_providers.anthropic_provider import AnthropicProvider
from app.services.ai_providers.gemini_provider import GeminiProvider
from app.services.ai_providers.base_provider import BaseAIProvider
from app.services.json_helper import clean_json_response, parse_json
from app.mcp.adapters.universal import universal_mcp_adapter
# 导出清理函数
cleanup_http_clients = cleanup_all_clients
@@ -22,7 +27,41 @@ logger = get_logger(__name__)
class AIService:
"""AI服务统一接口"""
"""
AI服务统一接口
MCP工具支持
- 在创建服务时传入 user_id db_session
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
- 如果有任意一个MCP插件启用则加载并使用工具
- 如果所有插件都关闭则不使用任何MCP工具
- 通过 auto_mcp=False 可临时禁用自动工具加载
- 通过 mcp_max_rounds 控制工具调用轮数
- 通过 clear_mcp_cache() 可清理MCP工具缓存
MCP启用逻辑backend/app/api/settings.py 中的 get_user_ai_service
- 查询用户的所有MCP插件
- 如果有启用的插件 (enabled=True) enable_mcp=True
- 如果所有插件都关闭或没有插件 enable_mcp=False
使用示例
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
ai_service = create_user_ai_service_with_mcp(
api_provider="openai",
api_key="...",
user_id="user123",
db_session=db
)
# 自动加载MCP工具(如果有启用的插件)
result = await ai_service.generate_text(prompt="...")
# 临时禁用MCP工具
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
# 自定义轮数
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
"""
def __init__(
self,
@@ -33,8 +72,11 @@ class AIService:
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None,
default_system_prompt: Optional[str] = None,
enable_mcp_adapter: bool = True,
config: Optional[AIClientConfig] = None,
# MCP支持参数
user_id: Optional[str] = None,
db_session: Optional[Any] = None,
enable_mcp: bool = True,
):
self.api_provider = api_provider or app_settings.default_ai_provider
self.default_model = default_model or app_settings.default_model
@@ -43,7 +85,12 @@ class AIService:
self.default_system_prompt = default_system_prompt
self.config = config or default_config
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
# MCP配置
self.user_id = user_id
self.db_session = db_session
self._enable_mcp = enable_mcp
self._cached_tools: Optional[List[Dict]] = None
self._tools_loaded = False
self._openai_provider: Optional[OpenAIProvider] = None
self._anthropic_provider: Optional[AnthropicProvider] = None
@@ -68,6 +115,36 @@ class AIService:
client = GeminiClient(api_key, api_base_url, self.config)
self._gemini_provider = GeminiProvider(client)
@property
def enable_mcp(self) -> bool:
"""是否启用MCP工具"""
return self._enable_mcp
@enable_mcp.setter
def enable_mcp(self, value: bool):
"""设置MCP启用状态,如果禁用则清理缓存"""
if value is False and self._enable_mcp is True:
# 从启用变为禁用,清理缓存
self.clear_mcp_cache()
self._enable_mcp = value
def clear_mcp_cache(self):
"""
清理MCP工具缓存
当禁用MCP时调用此方法确保后续AI调用不会使用缓存的工具
同时更新 _tools_loaded 状态使下次调用时重新检查
"""
if self._cached_tools is not None:
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
self._cached_tools = None
else:
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
# 更新加载状态,确保下次调用会重新检查
self._tools_loaded = False
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
"""获取对应的 Provider"""
p = provider or self.api_provider
@@ -79,6 +156,166 @@ class AIService:
return self._gemini_provider
raise ValueError(f"Provider {p} 未初始化")
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
"""
预处理MCP工具
检查用户MCP配置并加载可用工具
结果会被缓存避免重复加载
Args:
auto_mcp: 是否自动加载MCP工具来自调用方参数
force_refresh: 是否强制刷新缓存
Returns:
- None: 无可用工具未配置/未启用/加载失败
- List[Dict]: OpenAI格式的工具列表
"""
# 前置条件检查
if not self._enable_mcp:
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not auto_mcp:
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not self.user_id:
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
return None
if not self.db_session:
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
return None
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
if self._tools_loaded and not force_refresh:
if self._cached_tools:
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
return self._cached_tools
try:
from app.services.mcp_tools_loader import mcp_tools_loader
self._cached_tools = await mcp_tools_loader.get_user_tools(
user_id=self.user_id,
db_session=self.db_session,
use_cache=True,
force_refresh=force_refresh
)
self._tools_loaded = True
if self._cached_tools:
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
return self._cached_tools
except Exception as e:
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
self._tools_loaded = True
self._cached_tools = None
return None
async def _handle_tool_calls(
self,
original_prompt: str,
response: Dict[str, Any],
max_rounds: int = 2,
**kwargs
) -> Dict[str, Any]:
"""
处理AI返回的工具调用
Args:
original_prompt: 原始提示词
response: AI响应包含tool_calls
max_rounds: 最大工具调用轮数
**kwargs: 传递给generate_text的其他参数
Returns:
最终的AI响应
"""
from app.mcp import mcp_client
tool_calls = response.get("tool_calls", [])
if not tool_calls or not self.user_id:
return response
result = {
"content": response.get("content", ""),
"tool_calls_made": 0,
"tools_used": [],
"finish_reason": response.get("finish_reason", ""),
"mcp_enhanced": True
}
prompt = original_prompt
for round_num in range(max_rounds):
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
try:
# 批量执行工具调用
tool_results = await mcp_client.batch_call_tools(
user_id=self.user_id,
tool_calls=tool_calls
)
# 记录使用的工具
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
# 构建工具上下文
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 更新提示词
if round_num == max_rounds - 1:
# 最后一轮,强制要求回答
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
tool_choice = kwargs.get("tool_choice", "auto")
# 继续调用AI
prov = self._get_provider(kwargs.get("provider"))
next_response = await prov.generate(
prompt=prompt,
model=kwargs.get("model") or self.default_model,
temperature=kwargs.get("temperature") or self.default_temperature,
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
tools=None if tool_choice == "none" else self._cached_tools,
tool_choice=tool_choice,
)
tool_calls = next_response.get("tool_calls", [])
if not tool_calls:
# 没有更多工具调用,返回结果
result["content"] = next_response.get("content", "")
result["finish_reason"] = next_response.get("finish_reason", "stop")
break
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
result["content"] = response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
async def generate_text(
self,
prompt: str,
@@ -89,10 +326,39 @@ class AIService:
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
handle_tool_calls: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> Dict[str, Any]:
"""生成文本"""
"""
生成文本自动支持MCP工具
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tools: 手动指定的工具列表优先级高于自动加载
tool_choice: 工具选择策略
auto_mcp: 是否自动加载MCP工具默认True
handle_tool_calls: 是否自动处理工具调用默认True
mcp_max_rounds: 最大工具调用轮数None使用默认值3
Returns:
包含生成内容的字典
"""
# 使用全局配置的MCP轮数(如果未指定)
if mcp_max_rounds is None:
mcp_max_rounds = app_settings.mcp_max_rounds
# 自动加载MCP工具
if auto_mcp and tools is None:
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
prov = self._get_provider(provider)
return await prov.generate(
response = await prov.generate(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
@@ -101,6 +367,22 @@ class AIService:
tools=tools,
tool_choice=tool_choice,
)
# 处理工具调用
if handle_tool_calls and response.get("tool_calls"):
return await self._handle_tool_calls(
original_prompt=prompt,
response=response,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tool_choice=tool_choice,
max_rounds=mcp_max_rounds,
)
return response
async def generate_text_stream(
self,
@@ -110,15 +392,51 @@ class AIService:
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
"""
流式生成文本自动支持MCP工具
工具调用在 Provider 层通过流式方式处理支持真正的流式工具调用
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tool_choice: 工具选择策略"auto"/"none"/"required"
auto_mcp: 是否自动加载MCP工具
mcp_max_rounds: 最大工具调用轮数None使用默认值3
Yields:
生成的文本块
"""
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
tools_to_use = None
# 加载MCP工具
if auto_mcp:
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
if tools_to_use:
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
# 流式生成(Provider 层处理工具调用)
prov = self._get_provider(provider)
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
async for chunk in prov.generate_stream(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools_to_use,
tool_choice=tool_choice,
user_id=self.user_id,
):
yield chunk
@@ -132,8 +450,25 @@ class AIService:
provider: Optional[str] = None,
model: Optional[str] = None,
expected_type: Optional[str] = None,
auto_mcp: bool = True,
) -> Union[Dict, List]:
"""带重试的 JSON 调用"""
"""
带重试的 JSON 调用自动支持MCP工具
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
temperature: 温度
max_tokens: 最大令牌数
provider: AI提供商
model: 模型名称
expected_type: 期望的返回类型"object""array"
auto_mcp: 是否自动加载MCP工具
Returns:
解析后的JSON数据
"""
last_response = ""
for attempt in range(1, max_retries + 1):
@@ -146,6 +481,8 @@ class AIService:
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
auto_mcp=auto_mcp,
handle_tool_calls=True,
)
last_response = result.get("content", "")
@@ -172,108 +509,6 @@ class AIService:
"""清洗 JSON 响应"""
return clean_json_response(text)
async def generate_text_with_mcp(
self,
prompt: str,
user_id: str,
db_session,
enable_mcp: bool = True,
max_tool_rounds: int = 3,
tool_choice: str = "auto",
**kwargs
) -> Dict[str, Any]:
"""支持MCP工具的AI文本生成"""
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
tools = None
if enable_mcp:
try:
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
if tools:
result["mcp_enhanced"] = True
except MCPToolServiceError:
tools = None
original_prompt = prompt # 保存原始提示词
for round_num in range(max_tool_rounds):
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}")
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
tool_calls = ai_response.get("tool_calls", [])
if not tool_calls:
content = ai_response.get("content", "")
result["content"] = content
result["finish_reason"] = ai_response.get("finish_reason", "stop")
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
if not content.strip() and result["tool_calls_made"] > 0:
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
tools = None
tool_choice = "none" # 强制不使用工具
continue
break
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
for idx, tc in enumerate(tool_calls):
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
try:
logger.debug(f" 开始执行工具调用...")
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
# 🔍 检查工具结果
for idx, tr in enumerate(tool_results):
success = tr.get("success", False)
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
logger.debug(f" 工具上下文长度: {len(tool_context)}")
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
if round_num == max_tool_rounds - 1:
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
logger.debug(f" 新prompt长度: {len(prompt)}")
tools = None # 工具调用后禁用工具列表,避免重复调用
logger.debug(f" ✅ 工具调用成功,准备下一轮")
except Exception as tool_error:
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
logger.error(f" 错误类型: {type(tool_error).__name__}")
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
result["content"] = ai_response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
# 全局实例
ai_service = AIService()
def create_user_ai_service(
api_provider: str,
@@ -284,7 +519,7 @@ def create_user_ai_service(
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AIService:
"""创建用户 AI 服务"""
"""创建用户 AI 服务(不带MCP支持)"""
return AIService(
api_provider=api_provider,
api_key=api_key,
@@ -293,4 +528,48 @@ def create_user_ai_service(
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
)
def create_user_ai_service_with_mcp(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
user_id: str,
db_session,
system_prompt: Optional[str] = None,
enable_mcp: bool = True,
) -> AIService:
"""
创建支持MCP的用户AI服务
Args:
api_provider: AI提供商
api_key: API密钥
api_base_url: API基础URL
model_name: 模型名称
temperature: 温度
max_tokens: 最大令牌数
user_id: 用户ID用于加载MCP工具
db_session: 数据库会话
system_prompt: 系统提示词
enable_mcp: 是否启用MCP工具
Returns:
配置好的AIService实例
"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
user_id=user_id,
db_session=db_session,
enable_mcp=enable_mcp,
)
+10 -24
View File
@@ -269,25 +269,11 @@ class AutoCharacterService:
)
try:
# 调用AI分析(使用统一的JSON调用方法)
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
analysis = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_characters={analysis.get('needs_new_characters')}")
return analysis
@@ -364,16 +350,16 @@ class AutoCharacterService:
existing_characters=existing_chars_summary + careers_info,
plot_context="根据剧情需要引入的新角色",
character_specification=json.dumps(spec, ensure_ascii=False, indent=2),
mcp_references="" # 暂时不使用MCP增强
mcp_references="" # MCP工具通过AI服务自动加载
)
# 调用AI生成(禁用MCP,避免累积超时导致卡死)
logger.info(f"🔧 角色详情生成: enable_mcp={enable_mcp}")
# 调用AI生成
try:
# 🔧 优化:角色详情生成不使用MCP,只在分析阶段使用MCP
# 这样可以减少大量的外部工具调用,避免超时和卡死
character_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=2 # 减少重试次数以加快速度
max_retries=2, # 减少重试次数以加快速度
)
char_name = character_data.get('name', '未知')
@@ -292,25 +292,11 @@ class AutoOrganizationService:
)
try:
# 调用AI分析(使用统一的JSON调用方法)
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
analysis = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载
analysis = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
logger.info(f" ✅ AI分析完成: needs_new_organizations={analysis.get('needs_new_organizations')}")
return analysis
@@ -362,24 +348,11 @@ class AutoOrganizationService:
# 调用AI生成(使用统一的JSON调用方法)
try:
if enable_mcp and user_id:
result = await self.ai_service.generate_text_with_mcp(
prompt=prompt,
user_id=user_id,
db_session=db,
enable_mcp=True,
max_tool_rounds=2
)
content = result.get("content", "")
# 使用统一的JSON清洗方法
cleaned = self.ai_service._clean_json_response(content)
organization_data = json.loads(cleaned)
else:
# 非MCP调用:使用带自动重试的JSON调用
organization_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3
)
# 使用统一的JSON调用方法(支持自动MCP工具加载)
organization_data = await self.ai_service.call_with_json_retry(
prompt=prompt,
max_retries=3,
)
org_name = organization_data.get('name', '未知')
logger.info(f" ✅ 组织详情生成成功: {org_name}")
+78 -60
View File
@@ -1,4 +1,7 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
"""MCP插件测试服务 - 专门处理插件测试逻辑
重构后使用统一的MCPClientFacade门面来管理所有MCP操作
"""
import time
import json
@@ -10,7 +13,7 @@ from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp.registry import mcp_registry
from app.mcp import mcp_client, MCPPluginConfig # 使用新的统一门面
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.services.prompt_service import prompt_service
@@ -21,7 +24,32 @@ logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(分离的测试逻辑"""
"""MCP插件测试服务(使用统一门面重构"""
async def _ensure_plugin_registered(
self,
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
确保插件已注册到统一门面
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
是否成功
"""
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
return await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin.plugin_type,
headers=plugin.headers
)
return False
async def test_plugin_connection(
self,
@@ -41,19 +69,18 @@ class MCPTestService:
start_time = time.time()
try:
# 确保插件已加载
if not mcp_registry.get_client(user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 确保插件已注册
registered = await self._ensure_plugin_registered(plugin, user_id)
if not registered:
return MCPTestResult(
success=False,
message="插件注册失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
# 使用统一门面测试连接
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
@@ -70,7 +97,18 @@ class MCPTestService:
]
)
else:
return MCPTestResult(**test_result)
return MCPTestResult(
success=False,
message="❌ 连接测试失败",
response_time_ms=response_time,
error=test_result.get("message", "未知错误"),
error_type=test_result.get("error_type"),
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
except Exception as e:
end_time = time.time()
@@ -117,8 +155,8 @@ class MCPTestService:
if not connection_result.success:
return connection_result
# 2. 获取工具列表
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 2. 使用统一门面获取工具列表
tools = await mcp_client.get_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
@@ -162,8 +200,8 @@ class MCPTestService:
max_tokens=1000
)
# 转换为OpenAI Function Calling格式
openai_tools = self._convert_tools_to_openai_format(tools)
# 使用统一门面转换为OpenAI Function Calling格式
openai_tools = mcp_client.format_tools_for_openai(tools, plugin.plugin_name)
logger.info(f"📋 转换后的OpenAI工具数量: {len(openai_tools)}")
logger.debug(f"📋 OpenAI工具列表: {[t['function']['name'] for t in openai_tools]}")
@@ -175,26 +213,16 @@ class MCPTestService:
db=db_session
)
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
# 因此这里需要特殊处理
accumulated_text = ""
tool_calls = None
async for chunk in ai_service.generate_text_stream(
# 使用 generate_text 进行 Function Calling(非流式)
ai_response = await ai_service.generate_text(
prompt=prompts["user"],
system_prompt=prompts["system"],
tools=openai_tools,
tool_choice="required"
):
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
if isinstance(chunk, dict):
if "tool_calls" in chunk:
tool_calls = chunk["tool_calls"]
if "content" in chunk:
accumulated_text += chunk.get("content", "")
else:
accumulated_text += chunk
tool_choice="auto"
)
accumulated_text = ai_response.get("content", "")
tool_calls = ai_response.get("tool_calls")
# 5. 检查AI是否返回工具调用
if not tool_calls:
@@ -214,7 +242,7 @@ class MCPTestService:
# 6. 解析工具调用
tool_call = tool_calls[0]
function = tool_call["function"]
tool_name = function["name"]
tool_name_with_prefix = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
@@ -231,17 +259,23 @@ class MCPTestService:
tools_count=len(tools)
)
# 解析插件名和工具名
try:
_, tool_name = mcp_client.parse_function_name(tool_name_with_prefix)
except ValueError:
tool_name = tool_name_with_prefix
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 调用MCP工具
# 7. 使用统一门面调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
tool_result = await mcp_client.call_tool(
user_id=user.user_id,
plugin_name=plugin.plugin_name,
tool_name=tool_name,
arguments=test_arguments
)
call_end = time.time()
@@ -307,22 +341,6 @@ class MCPTestService:
"请检查API Key是否有效"
]
)
def _convert_tools_to_openai_format(self, tools: list) -> list:
"""将MCP工具格式转换为OpenAI Function Calling格式"""
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
return openai_tools
# 全局单例
-691
View File
@@ -1,691 +0,0 @@
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
from typing import List, Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import json
import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict
from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry
from app.mcp.config import mcp_config
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ToolMetrics:
"""工具调用指标"""
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
avg_duration_ms: float = 0.0
last_call_time: Optional[datetime] = None
def update_success(self, duration_ms: float):
"""更新成功调用指标"""
self.total_calls += 1
self.success_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
def update_failure(self, duration_ms: float):
"""更新失败调用指标"""
self.total_calls += 1
self.failed_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_calls == 0:
return 0.0
return self.success_calls / self.total_calls
@dataclass
class ToolCacheEntry:
"""工具缓存条目"""
tools: List[Dict[str, Any]]
expire_time: datetime
hit_count: int = 0
class MCPToolServiceError(Exception):
"""MCP工具服务异常"""
pass
class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
def __init__(
self,
cache_ttl_minutes: Optional[int] = None,
max_retries: Optional[int] = None
):
"""
初始化MCP工具服务
Args:
cache_ttl_minutes: 工具缓存TTL分钟默认使用配置
max_retries: 最大重试次数默认使用配置
"""
# 工具定义缓存: {cache_key: ToolCacheEntry}
self._tool_cache: Dict[str, ToolCacheEntry] = {}
self._cache_ttl = timedelta(
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
)
# 调用指标: {tool_key: ToolMetrics}
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
# 重试配置(使用配置常量)
self._max_retries = max_retries or mcp_config.MAX_RETRIES
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
logger.info(
f"✅ MCPToolService初始化完成 "
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
f"最大重试={self._max_retries}次)"
)
async def get_user_enabled_tools(
self,
user_id: str,
db_session: AsyncSession,
category: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
获取用户启用的MCP工具列表
Args:
user_id: 用户ID
db_session: 数据库会话
category: 工具类别筛选search/analysis/filesystem等
Returns:
工具定义列表格式符合OpenAI Function Calling规范
"""
try:
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True
)
if category:
query = query.where(MCPPlugin.category == category)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
logger.info(f"用户 {user_id} 没有启用的MCP插件")
return []
# 2. 获取所有工具定义(使用缓存)
all_tools = []
for plugin in plugins:
try:
# 确保插件已加载到注册表
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
success = await mcp_registry.load_plugin(plugin)
if not success:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue
# ✅ 使用缓存获取工具列表
plugin_tools = await self._get_plugin_tools_cached(
user_id=user_id,
plugin_name=plugin.plugin_name
)
# 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai(
plugin_tools,
plugin.plugin_name
)
all_tools.extend(formatted_tools)
logger.info(
f"从插件 {plugin.plugin_name} 加载了 "
f"{len(formatted_tools)} 个工具"
)
except Exception as e:
logger.error(
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
exc_info=True
)
continue
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
return all_tools
except Exception as e:
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
def _format_tools_for_ai(
self,
plugin_tools: List[Dict[str, Any]],
plugin_name: str
) -> List[Dict[str, Any]]:
"""
将MCP工具定义格式化为AI Function Calling格式
Args:
plugin_tools: MCP插件的工具列表
plugin_name: 插件名称
Returns:
格式化后的工具列表
"""
formatted_tools = []
for tool in plugin_tools:
formatted_tool = {
"type": "function",
"function": {
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
"description": tool.get("description", ""),
"parameters": tool.get("inputSchema", {
"type": "object",
"properties": {},
"required": []
})
}
}
formatted_tools.append(formatted_tool)
return formatted_tools
async def _get_plugin_tools_cached(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
带缓存的工具列表获取
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
cache_key = f"{user_id}:{plugin_name}"
now = datetime.now()
# 检查缓存
if cache_key in self._tool_cache:
entry = self._tool_cache[cache_key]
if now < entry.expire_time:
entry.hit_count += 1
logger.debug(
f"🎯 工具缓存命中: {cache_key} "
f"(命中次数: {entry.hit_count})"
)
return entry.tools
else:
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
del self._tool_cache[cache_key]
# 缓存未命中,从MCP获取
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
# 更新缓存
self._tool_cache[cache_key] = ToolCacheEntry(
tools=tools,
expire_time=now + self._cache_ttl,
hit_count=0
)
return tools
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
"""
清理缓存
Args:
user_id: 用户ID可选清理特定用户的缓存
plugin_name: 插件名称可选清理特定插件的缓存
"""
if user_id is None and plugin_name is None:
# 清理所有缓存
self._tool_cache.clear()
logger.info("🧹 已清理所有工具缓存")
elif user_id and plugin_name:
# 清理特定插件缓存
cache_key = f"{user_id}:{plugin_name}"
if cache_key in self._tool_cache:
del self._tool_cache[cache_key]
logger.info(f"🧹 已清理缓存: {cache_key}")
elif user_id:
# 清理用户所有缓存
keys_to_delete = [
key for key in self._tool_cache.keys()
if key.startswith(f"{user_id}:")
]
for key in keys_to_delete:
del self._tool_cache[key]
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
async def execute_tool_calls(
self,
user_id: str,
tool_calls: List[Dict[str, Any]],
db_session: AsyncSession,
timeout: Optional[float] = None,
max_concurrent: int = 2
) -> List[Dict[str, Any]]:
"""
批量执行AI请求的工具调用限制并发数避免超时
Args:
user_id: 用户ID
tool_calls: AI返回的工具调用列表
db_session: 数据库会话
timeout: 单个工具调用的超时时间默认使用配置
max_concurrent: 最大并发工具调用数默认2
Returns:
工具调用结果列表
"""
if not tool_calls:
return []
# 使用配置的默认超时
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s, 最大并发={max_concurrent})")
# ✅ 分批执行,每批最多max_concurrent个
all_results = []
for i in range(0, len(tool_calls), max_concurrent):
batch = tool_calls[i:i+max_concurrent]
batch_num = i // max_concurrent + 1
total_batches = (len(tool_calls) + max_concurrent - 1) // max_concurrent
logger.info(f"执行工具批次 {batch_num}/{total_batches}, 数量: {len(batch)}")
# 创建当前批次的异步任务
tasks = [
self._execute_single_tool(
user_id=user_id,
tool_call=tool_call,
db_session=db_session,
timeout=actual_timeout
)
for tool_call in batch
]
# 并行执行当前批次
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理批次结果
for j, result in enumerate(batch_results):
tool_call = batch[j]
if isinstance(result, Exception):
# 工具调用异常
all_results.append({
"tool_call_id": tool_call.get("id", f"call_{i+j}"),
"role": "tool",
"name": tool_call["function"]["name"],
"content": f"工具调用失败: {str(result)}",
"success": False,
"error": str(result)
})
else:
all_results.append(result)
# 批次间增加短暂延迟,避免API限流
if i + max_concurrent < len(tool_calls):
await asyncio.sleep(0.5)
logger.debug(f"批次间延迟 0.5 秒...")
return all_results
async def _execute_single_tool(
self,
user_id: str,
tool_call: Dict[str, Any],
db_session: AsyncSession,
timeout: float
) -> Dict[str, Any]:
"""
执行单个工具调用
Args:
user_id: 用户ID
tool_call: 工具调用信息
db_session: 数据库会话
timeout: 超时时间
Returns:
工具调用结果
"""
tool_call_id = tool_call.get("id", "unknown")
function_name = tool_call["function"]["name"]
try:
# 解析插件名和工具名
logger.debug(f"🔍 解析工具名称: {function_name}")
if "_" in function_name:
plugin_name, tool_name = function_name.split("_", 1)
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
else:
raise ValueError(f"无效的工具名称格式: {function_name}")
# 解析参数
arguments_str = tool_call["function"]["arguments"]
logger.debug(f"🔍 解析参数:")
logger.debug(f" 原始类型: {type(arguments_str)}")
logger.debug(f" 原始内容: {arguments_str}")
if isinstance(arguments_str, str):
try:
arguments = json.loads(arguments_str)
logger.debug(f" ✅ JSON解析成功: {arguments}")
except json.JSONDecodeError as je:
logger.error(f" ❌ JSON解析失败: {je}")
logger.error(f" 原始字符串: '{arguments_str}'")
raise ValueError(f"参数JSON解析失败: {je}")
else:
arguments = arguments_str
logger.debug(f" 直接使用dict类型参数")
logger.info(
f"执行工具: {plugin_name}.{tool_name}, "
f"参数: {arguments}"
)
# ✅ 使用带重试的调用
tool_key = f"{plugin_name}.{tool_name}"
start_time = time.time()
try:
result = await self._call_tool_with_retry(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments,
timeout=timeout
)
# 记录成功指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_success(duration_ms)
logger.info(
f"✅ 工具调用成功: {tool_key} "
f"(耗时: {duration_ms:.2f}ms)"
)
# 成功返回
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": json.dumps(result, ensure_ascii=False),
"success": True,
"error": None
}
except asyncio.TimeoutError:
# 记录失败指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)"
)
except Exception as e:
# 记录失败指标
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
logger.error(
f"❌ 工具 {function_name} 调用失败: {e}",
exc_info=True
)
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": f"工具调用失败: {str(e)}",
"success": False,
"error": str(e)
}
async def _call_tool_with_retry(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: float
) -> Any:
"""
带指数退避重试的工具调用
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
工具执行结果
Raises:
MCPToolServiceError: 工具调用失败
asyncio.TimeoutError: 调用超时
"""
last_exception = None
for attempt in range(self._max_retries):
try:
# 尝试调用工具
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功则返回
if attempt > 0:
logger.info(
f"✅ 重试成功: {plugin_name}.{tool_name} "
f"(第{attempt + 1}次尝试)"
)
return result
except asyncio.TimeoutError:
# 超时不重试,直接抛出
raise
except Exception as e:
last_exception = e
# 最后一次尝试失败
if attempt == self._max_retries - 1:
logger.error(
f"❌ 重试失败: {plugin_name}.{tool_name} "
f"(已尝试{self._max_retries}次): {e}"
)
raise MCPToolServiceError(
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
)
# 计算指数退避延迟
delay = min(
self._base_retry_delay * (2 ** attempt),
self._max_retry_delay
)
logger.warning(
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
f"(第{attempt + 1}/{self._max_retries}次): "
f"{plugin_name}.{tool_name} - {e}"
)
await asyncio.sleep(delay)
# 理论上不会到这里,但为了类型安全
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
获取工具调用指标
Args:
tool_name: 工具名称可选获取特定工具的指标
Returns:
指标字典
"""
if tool_name:
if tool_name in self._metrics:
metric = self._metrics[tool_name]
return {
tool_name: {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": metric.success_rate,
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
}
return {}
# 返回所有工具的指标
result = {}
for tool_key, metric in self._metrics.items():
result[tool_key] = {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": round(metric.success_rate, 3),
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
return result
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_entries = len(self._tool_cache)
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
return {
"total_entries": total_entries,
"total_hits": total_hits,
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"key": key,
"tools_count": len(entry.tools),
"hit_count": entry.hit_count,
"expire_time": entry.expire_time.isoformat()
}
for key, entry in self._tool_cache.items()
]
}
async def build_tool_context(
self,
tool_results: List[Dict[str, Any]],
format: str = "markdown"
) -> str:
"""
将工具调用结果格式化为上下文文本
Args:
tool_results: 工具调用结果列表
format: 输出格式markdown/json/plain
Returns:
格式化的上下文字符串
"""
if not tool_results:
return ""
if format == "markdown":
return self._build_markdown_context(tool_results)
elif format == "json":
return json.dumps(tool_results, ensure_ascii=False, indent=2)
else: # plain
return self._build_plain_context(tool_results)
def _build_markdown_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建Markdown格式的工具上下文"""
lines = ["## 🔧 工具调用结果\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status_emoji = "" if success else ""
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
if success:
# 尝试美化JSON内容
try:
content_obj = json.loads(content)
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
except:
pass
lines.append(f"```json\n{content}\n```\n")
else:
lines.append(f"**错误**: {content}\n")
return "\n".join(lines)
def _build_plain_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建纯文本格式的工具上下文"""
lines = ["=== 工具调用结果 ===\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status = "成功" if success else "失败"
lines.append(f"{i}. {tool_name} - {status}")
lines.append(f" 结果: {content}\n")
return "\n".join(lines)
# 全局单例
mcp_tool_service = MCPToolService()
+235
View File
@@ -0,0 +1,235 @@
"""MCP工具加载器 - 统一的工具获取入口
在AI请求之前自动检查用户MCP配置并加载可用工具
"""
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.logger import get_logger
from app.models.mcp_plugin import MCPPlugin
from app.mcp import mcp_client
logger = get_logger(__name__)
@dataclass
class UserToolsCache:
"""用户工具缓存条目"""
tools: Optional[List[Dict[str, Any]]]
expire_time: datetime
hit_count: int = 0
class MCPToolsLoader:
"""
MCP工具加载器
负责
1. 检查用户是否配置并启用了MCP插件
2. 从各个启用的插件加载工具列表
3. 将工具转换为OpenAI Function Calling格式
4. 缓存结果以提升性能
"""
_instance: Optional['MCPToolsLoader'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 用户工具缓存: user_id -> UserToolsCache
self._cache: Dict[str, UserToolsCache] = {}
# 缓存TTL5分钟)
self._cache_ttl = timedelta(minutes=5)
self._initialized = True
logger.info("✅ MCPToolsLoader 初始化完成")
async def has_enabled_plugins(
self,
user_id: str,
db_session: AsyncSession
) -> bool:
"""
检查用户是否有启用的MCP插件
Args:
user_id: 用户ID
db_session: 数据库会话
Returns:
是否有启用的插件
"""
try:
query = select(MCPPlugin.id).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).limit(1)
result = await db_session.execute(query)
return result.scalar() is not None
except Exception as e:
logger.warning(f"检查用户MCP插件失败: {e}")
return False
async def get_user_tools(
self,
user_id: str,
db_session: AsyncSession,
use_cache: bool = True,
force_refresh: bool = False
) -> Optional[List[Dict[str, Any]]]:
"""
获取用户的MCP工具列表OpenAI格式
Args:
user_id: 用户ID
db_session: 数据库会话
use_cache: 是否使用缓存
force_refresh: 是否强制刷新
Returns:
- None: 用户未配置或未启用任何MCP插件
- []: 有配置但没有可用工具
- List[Dict]: OpenAI Function Calling格式的工具列表
"""
now = datetime.now()
# 检查缓存
if use_cache and not force_refresh and user_id in self._cache:
cache_entry = self._cache[user_id]
if now < cache_entry.expire_time:
cache_entry.hit_count += 1
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
return cache_entry.tools
else:
del self._cache[user_id]
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
# 从数据库加载
try:
tools = await self._load_user_tools(user_id, db_session)
# 更新缓存
self._cache[user_id] = UserToolsCache(
tools=tools,
expire_time=now + self._cache_ttl
)
if tools:
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
return tools
except Exception as e:
logger.error(f"❌ 加载用户MCP工具失败: {e}")
return None
async def _load_user_tools(
self,
user_id: str,
db_session: AsyncSession
) -> Optional[List[Dict[str, Any]]]:
"""
从数据库加载用户启用的MCP插件并获取工具
"""
# 查询启用的插件
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).order_by(MCPPlugin.sort_order)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
return None
all_tools = []
for plugin in plugins:
try:
# 确定插件类型
plugin_type = plugin.plugin_type
if plugin_type == "http":
plugin_type = "streamable_http" # 默认使用streamable_http
# 确保插件已注册到MCP客户端
await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin_type,
headers=plugin.headers
)
# 获取工具列表
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
# 转换为OpenAI格式
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
all_tools.extend(formatted)
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
except Exception as e:
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
continue
return all_tools if all_tools else None
def invalidate_cache(self, user_id: Optional[str] = None):
"""
使缓存失效
Args:
user_id: 用户ID为None时清空所有缓存
"""
if user_id:
if user_id in self._cache:
del self._cache[user_id]
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
else:
count = len(self._cache)
self._cache.clear()
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
now = datetime.now()
return {
"total_entries": len(self._cache),
"total_hits": sum(e.hit_count for e in self._cache.values()),
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"user_id": uid,
"tools_count": len(e.tools) if e.tools else 0,
"hit_count": e.hit_count,
"expired": now >= e.expire_time,
"expire_time": e.expire_time.isoformat()
}
for uid, e in self._cache.items()
]
}
# 全局单例
mcp_tools_loader = MCPToolsLoader()
+219 -1
View File
@@ -1,13 +1,231 @@
"""Server-Sent Events (SSE) 响应工具类"""
import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional
from enum import Enum
from typing import AsyncGenerator, Dict, Any, Optional, Callable
from dataclasses import dataclass
from fastapi.responses import StreamingResponse
from app.logger import get_logger
logger = get_logger(__name__)
class ProgressStage(Enum):
"""标准化进度阶段枚举"""
# 初始化阶段 (0-5%)
INIT = "init"
# 加载数据阶段 (5-15%)
LOADING = "loading"
# 准备提示词阶段 (15-20%)
PREPARING = "preparing"
# AI生成阶段 (20-85%)
GENERATING = "generating"
# 解析数据阶段 (85-92%)
PARSING = "parsing"
# 保存数据阶段 (92-98%)
SAVING = "saving"
# 完成阶段 (100%)
COMPLETE = "complete"
@dataclass
class StageConfig:
"""阶段配置"""
start: int # 起始进度
end: int # 结束进度
default_message: str # 默认消息
# 标准进度阶段配置
STAGE_CONFIGS: Dict[ProgressStage, StageConfig] = {
ProgressStage.INIT: StageConfig(0, 5, "开始处理..."),
ProgressStage.LOADING: StageConfig(5, 15, "加载数据中..."),
ProgressStage.PREPARING: StageConfig(15, 20, "准备AI提示词..."),
ProgressStage.GENERATING: StageConfig(20, 85, "AI生成中..."),
ProgressStage.PARSING: StageConfig(85, 92, "解析数据..."),
ProgressStage.SAVING: StageConfig(92, 98, "保存到数据库..."),
ProgressStage.COMPLETE: StageConfig(100, 100, "完成!"),
}
class WizardProgressTracker:
"""
向导进度追踪器 - 标准化管理SSE进度推送
使用示例:
tracker = WizardProgressTracker("世界观")
yield await tracker.start()
yield await tracker.loading("加载项目信息")
yield await tracker.preparing()
async for chunk in ai_stream:
yield await tracker.generating_chunk(chunk, len(accumulated))
yield await tracker.parsing()
yield await tracker.saving("保存世界观数据")
yield await tracker.complete()
"""
def __init__(self, task_name: str = "任务"):
"""
初始化进度追踪器
Args:
task_name: 任务名称用于消息前缀
"""
self.task_name = task_name
self.current_stage = ProgressStage.INIT
self.current_progress = 0
self._last_generating_progress = 20 # 生成阶段的最后进度值
def _get_stage_progress(
self,
stage: ProgressStage,
sub_progress: float = 0.0
) -> int:
"""
计算阶段内的进度值
Args:
stage: 当前阶段
sub_progress: 阶段内子进度 (0.0-1.0)
Returns:
总进度值 (0-100)
"""
config = STAGE_CONFIGS[stage]
if sub_progress <= 0:
return config.start
if sub_progress >= 1:
return config.end
return config.start + int((config.end - config.start) * sub_progress)
async def start(self, message: str = None) -> str:
"""开始阶段"""
self.current_stage = ProgressStage.INIT
self.current_progress = 0
msg = message or f"开始生成{self.task_name}..."
return await SSEResponse.send_progress(msg, 0, "processing")
async def loading(self, message: str = None, sub_progress: float = 0.5) -> str:
"""加载数据阶段"""
self.current_stage = ProgressStage.LOADING
progress = self._get_stage_progress(ProgressStage.LOADING, sub_progress)
self.current_progress = progress
msg = message or STAGE_CONFIGS[ProgressStage.LOADING].default_message
return await SSEResponse.send_progress(msg, progress, "processing")
async def preparing(self, message: str = None) -> str:
"""准备提示词阶段"""
self.current_stage = ProgressStage.PREPARING
progress = self._get_stage_progress(ProgressStage.PREPARING, 0.5)
self.current_progress = progress
msg = message or STAGE_CONFIGS[ProgressStage.PREPARING].default_message
return await SSEResponse.send_progress(msg, progress, "processing")
async def generating(
self,
current_chars: int = 0,
estimated_total: int = 5000,
message: str = None,
retry_count: int = 0,
max_retries: int = 3
) -> str:
"""
AI生成阶段进度更新
Args:
current_chars: 当前已生成字符数
estimated_total: 预估总字符数
message: 自定义消息
retry_count: 当前重试次数
max_retries: 最大重试次数
"""
self.current_stage = ProgressStage.GENERATING
# 计算生成进度 (0.0-1.0)
sub_progress = min(current_chars / max(estimated_total, 1), 1.0)
progress = self._get_stage_progress(ProgressStage.GENERATING, sub_progress)
# 确保进度单调递增
if progress < self._last_generating_progress:
progress = self._last_generating_progress
else:
self._last_generating_progress = progress
self.current_progress = progress
# 构建消息
retry_suffix = f" (重试 {retry_count}/{max_retries})" if retry_count > 0 else ""
if message:
msg = f"{message}{retry_suffix}"
else:
msg = f"生成{self.task_name}中... ({current_chars}字符){retry_suffix}"
return await SSEResponse.send_progress(msg, progress, "processing")
async def generating_chunk(self, chunk: str) -> str:
"""发送生成的内容块"""
return await SSEResponse.send_chunk(chunk)
async def parsing(self, message: str = None, sub_progress: float = 0.5) -> str:
"""解析数据阶段"""
self.current_stage = ProgressStage.PARSING
progress = self._get_stage_progress(ProgressStage.PARSING, sub_progress)
self.current_progress = progress
msg = message or f"解析{self.task_name}数据..."
return await SSEResponse.send_progress(msg, progress, "processing")
async def saving(self, message: str = None, sub_progress: float = 0.5) -> str:
"""保存数据阶段"""
self.current_stage = ProgressStage.SAVING
progress = self._get_stage_progress(ProgressStage.SAVING, sub_progress)
self.current_progress = progress
msg = message or f"保存{self.task_name}到数据库..."
return await SSEResponse.send_progress(msg, progress, "processing")
async def complete(self, message: str = None) -> str:
"""完成阶段"""
self.current_stage = ProgressStage.COMPLETE
self.current_progress = 100
msg = message or f"{self.task_name}生成完成!"
return await SSEResponse.send_progress(msg, 100, "success")
async def warning(self, message: str) -> str:
"""发送警告消息(保持当前进度)"""
return await SSEResponse.send_progress(
f"⚠️ {message}",
self.current_progress,
"warning"
)
async def retry(self, retry_count: int, max_retries: int, reason: str = "准备重试") -> str:
"""发送重试消息"""
return await SSEResponse.send_progress(
f"⚠️ {reason}... ({retry_count}/{max_retries})",
self.current_progress,
"warning"
)
async def error(self, error_message: str, code: int = 500) -> str:
"""发送错误消息"""
return await SSEResponse.send_error(error_message, code)
async def result(self, data: Dict[str, Any]) -> str:
"""发送结果数据"""
return await SSEResponse.send_result(data)
async def done(self) -> str:
"""发送完成信号"""
return await SSEResponse.send_done()
async def heartbeat(self) -> str:
"""发送心跳"""
return await SSEResponse.send_heartbeat()
def reset_generating_progress(self):
"""重置生成阶段进度(用于重试时)"""
self._last_generating_progress = 20
class SSEResponse:
"""SSE响应构建器"""
+3 -2
View File
@@ -19,10 +19,11 @@ anthropic==0.72.0
# 工具库
httpx==0.28.1
python-dotenv==1.0.0
python-dotenv==1.1.0
psutil==6.1.1
# MCP官方库(Model Context Protocol Python SDK
mcp==1.21.0
mcp==1.22.0
fastmcp==2.13.3
# NumPy版本锁定(兼容性要求)
+230 -71
View File
@@ -37,6 +37,14 @@ interface GenerationSteps {
outline: GenerationStep;
}
interface WorldBuildingResult {
project_id: string;
time_period: string;
location: string;
atmosphere: string;
rules: string;
}
export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
config,
storagePrefix,
@@ -64,7 +72,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
// 保存生成数据,用于重试
const [generationData, setGenerationData] = useState<GenerationConfig | null>(null);
// 保存世界观生成结果,用于后续步骤
const [worldBuildingResult, setWorldBuildingResult] = useState<any>(null);
const [worldBuildingResult, setWorldBuildingResult] = useState<WorldBuildingResult | null>(null);
// LocalStorage 键名
const storageKeys = {
@@ -102,6 +110,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
handleAutoGenerate(config);
}
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [config, resumeProjectId]);
// 恢复未完成项目的生成
@@ -125,33 +134,40 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
const wizardStep = project.wizard_step || 0;
// 根据wizard_step判断从哪里继续
// wizard_step: 0=未开始, 1=世界观已完成, 2=职业体系已完成, 3=角色已完成, 4=大纲已完成
// 获取世界观数据(用于后续步骤)
const worldResult = {
project_id: projectIdParam,
time_period: project.world_time_period || '',
location: project.world_location || '',
atmosphere: project.world_atmosphere || '',
rules: project.world_rules || ''
};
if (wizardStep === 0) {
// 从世界观开始
message.info('从世界观步骤开始生成...');
setGenerationSteps({ worldBuilding: 'processing', careers: 'pending', characters: 'pending', outline: 'pending' });
await resumeFromWorldBuilding(data);
} else if (wizardStep === 1) {
// 世界观已完成,从角色开始
message.info('世界观已完成,从角色步骤继续...');
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' });
// 获取世界观数据
const worldResult = {
project_id: projectIdParam,
time_period: project.world_time_period || '',
location: project.world_location || '',
atmosphere: project.world_atmosphere || '',
rules: project.world_rules || ''
};
// 世界观已完成,从职业体系开始
message.info('世界观已完成,从职业体系步骤继续...');
setGenerationSteps({ worldBuilding: 'completed', careers: 'processing', characters: 'pending', outline: 'pending' });
setWorldBuildingResult(worldResult);
setProgress(33);
await resumeFromCharacters(data, worldResult);
setProgress(20);
await resumeFromCareers(data, worldResult);
} else if (wizardStep === 2) {
// 世界观和角色已完成,从大纲开始
message.info('世界观和角色已完成,从大纲步骤继续...');
// 职业体系已完成,从角色开始
message.info('职业体系已完成,从角色步骤继续...');
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'processing', outline: 'pending' });
setWorldBuildingResult(worldResult);
setProgress(40);
await resumeFromCharacters(data, worldResult);
} else if (wizardStep === 3) {
// 角色已完成,从大纲开始
message.info('角色已完成,从大纲步骤继续...');
setGenerationSteps({ worldBuilding: 'completed', careers: 'completed', characters: 'completed', outline: 'processing' });
setProgress(66);
setProgress(70);
await resumeFromOutline(data, projectIdParam);
} else {
// 已全部完成
@@ -211,11 +227,47 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
}
);
await resumeFromCareers(data, worldResult);
};
// 恢复:从职业体系步骤继续
const resumeFromCareers = async (data: GenerationConfig, worldResult: WorldBuildingResult) => {
const pid = projectId || worldResult.project_id;
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
setProgressMessage('正在生成职业体系...');
await wizardStreamApi.generateCareerSystemStream(
{
project_id: pid,
},
{
onProgress: (msg, prog) => {
setProgress(prog);
setProgressMessage(msg);
},
onResult: (result) => {
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}`);
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
},
onError: (error) => {
console.error('职业体系生成失败:', error);
setErrorDetails(`职业体系生成失败: ${error}`);
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
setLoading(false);
throw new Error(error);
},
onComplete: () => {
console.log('职业体系生成完成');
}
}
);
await resumeFromCharacters(data, worldResult);
};
// 恢复:从角色步骤继续
const resumeFromCharacters = async (data: GenerationConfig, worldResult: any) => {
const resumeFromCharacters = async (data: GenerationConfig, worldResult: WorldBuildingResult) => {
const genreString = Array.isArray(data.genre) ? data.genre.join('、') : data.genre;
const pid = projectId || worldResult.project_id;
@@ -342,26 +394,11 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
// 直接使用后端返回的进度值
setProgress(prog);
setProgressMessage(msg);
// 检测职业体系生成阶段
if (msg.includes('职业体系')) {
if (msg.includes('开始') || msg.includes('生成')) {
setGenerationSteps(prev => ({
...prev,
worldBuilding: 'completed',
careers: 'processing'
}));
}
if (msg.includes('完成') || msg.includes('✅')) {
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
}
}
},
onResult: (result) => {
setProjectId(result.project_id);
setWorldBuildingResult(result);
setGenerationSteps(prev => ({ ...prev, worldBuilding: 'completed' }));
// 职业体系状态已在onProgress中更新
},
onError: (error) => {
console.error('世界观生成失败:', error);
@@ -385,7 +422,37 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
setWorldBuildingResult(worldResult);
saveProgress(createdProjectId, data, 'generating');
// 步骤2: 生成角色
// 步骤2: 生成职业体系
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
setProgressMessage('正在生成职业体系...');
await wizardStreamApi.generateCareerSystemStream(
{
project_id: createdProjectId,
},
{
onProgress: (msg, prog) => {
setProgress(prog);
setProgressMessage(msg);
},
onResult: (result) => {
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}`);
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
},
onError: (error) => {
console.error('职业体系生成失败:', error);
setErrorDetails(`职业体系生成失败: ${error}`);
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
setLoading(false);
throw new Error(error);
},
onComplete: () => {
console.log('职业体系生成完成');
}
}
);
// 步骤3: 生成角色
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
setProgressMessage('正在生成角色...');
@@ -497,6 +564,9 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
if (generationSteps.worldBuilding === 'error') {
message.info('从世界观步骤开始重新生成...');
await retryFromWorldBuilding();
} else if (generationSteps.careers === 'error') {
message.info('从职业体系步骤继续生成...');
await retryFromCareers();
} else if (generationSteps.characters === 'error') {
message.info('从角色步骤继续生成...');
await retryFromCharacters();
@@ -504,9 +574,10 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
message.info('从大纲步骤继续生成...');
await retryFromOutline();
}
} catch (error: any) {
} catch (error) {
console.error('智能重试失败:', error);
message.error('重试失败:' + (error.message || '未知错误'));
const errorMessage = error instanceof Error ? error.message : '未知错误';
message.error('重试失败:' + errorMessage);
setLoading(false);
}
};
@@ -537,20 +608,6 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
// 直接使用后端返回的进度值
setProgress(prog);
setProgressMessage(msg);
// 检测职业体系生成阶段
if (msg.includes('职业体系')) {
if (msg.includes('开始') || msg.includes('生成')) {
setGenerationSteps(prev => ({
...prev,
worldBuilding: 'completed',
careers: 'processing'
}));
}
if (msg.includes('完成') || msg.includes('✅')) {
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
}
}
},
onResult: (result) => {
setProjectId(result.project_id);
@@ -574,17 +631,72 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
throw new Error('项目创建失败:未获取到项目ID');
}
await continueFromCharacters(worldResult);
await continueFromCareers(worldResult);
};
// 从职业体系步骤继续
const retryFromCareers = async () => {
if (!worldBuildingResult) {
message.warning('缺少必要数据,无法从职业体系步骤继续');
setLoading(false);
return;
}
const pid = worldBuildingResult.project_id || projectId;
if (!pid) {
message.warning('缺少项目ID,无法从职业体系步骤继续');
setLoading(false);
return;
}
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
setProgressMessage('重新生成职业体系...');
await wizardStreamApi.generateCareerSystemStream(
{
project_id: pid,
},
{
onProgress: (msg, prog) => {
setProgress(prog);
setProgressMessage(msg);
},
onResult: (result) => {
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}`);
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
},
onError: (error) => {
console.error('职业体系生成失败:', error);
setErrorDetails(`职业体系生成失败: ${error}`);
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
setLoading(false);
throw new Error(error);
},
onComplete: () => {
console.log('职业体系重新生成完成');
}
}
);
await continueFromCharacters(worldBuildingResult);
};
// 从角色步骤继续
const retryFromCharacters = async () => {
if (!generationData || !projectId || !worldBuildingResult) {
if (!generationData || !worldBuildingResult) {
message.warning('缺少必要数据,无法从角色步骤继续');
setLoading(false);
return;
}
// 优先使用 worldBuildingResult 中的 project_id,因为重试可能创建了新项目
const pid = worldBuildingResult.project_id || projectId;
if (!pid) {
message.warning('缺少项目ID,无法从角色步骤继续');
setLoading(false);
return;
}
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
setProgressMessage('重新生成角色...');
@@ -592,7 +704,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
await wizardStreamApi.generateCharactersStream(
{
project_id: projectId,
project_id: pid,
count: generationData.character_count,
world_context: {
time_period: worldBuildingResult.time_period || '',
@@ -626,23 +738,31 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
}
);
await continueFromOutline();
await continueFromOutline(pid);
};
// 从大纲步骤继续
const retryFromOutline = async () => {
if (!generationData || !projectId) {
if (!generationData) {
message.warning('缺少必要数据,无法从大纲步骤继续');
setLoading(false);
return;
}
// 优先使用 worldBuildingResult 中的 project_idfallback 到状态中的 projectId
const pid = (worldBuildingResult?.project_id) || projectId;
if (!pid) {
message.warning('缺少项目ID,无法从大纲步骤继续');
setLoading(false);
return;
}
setGenerationSteps(prev => ({ ...prev, outline: 'processing' }));
setProgressMessage('重新生成大纲...');
await wizardStreamApi.generateCompleteOutlineStream(
{
project_id: projectId,
project_id: pid,
chapter_count: generationData.chapter_count,
narrative_perspective: generationData.narrative_perspective,
target_words: generationData.target_words,
@@ -676,20 +796,59 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
setLoading(false);
// 调用完成回调
if (projectId) {
onComplete(projectId);
if (pid) {
onComplete(pid);
// 延迟1秒后自动跳转到项目详情页
setTimeout(() => {
navigate(`/project/${projectId}`);
navigate(`/project/${pid}`);
}, 1000);
}
};
// 从角色步骤开始的完整流程
const continueFromCharacters = async (worldResult: any) => {
// 从职业体系步骤开始的完整流程
const continueFromCareers = async (worldResult: WorldBuildingResult) => {
if (!generationData || !worldResult?.project_id) return;
const pid = worldResult.project_id;
setGenerationSteps(prev => ({ ...prev, careers: 'processing' }));
setProgressMessage('正在生成职业体系...');
await wizardStreamApi.generateCareerSystemStream(
{
project_id: pid,
},
{
onProgress: (msg, prog) => {
setProgress(prog);
setProgressMessage(msg);
},
onResult: (result) => {
console.log(`成功生成职业体系:主职业${result.main_careers_count}个,副职业${result.sub_careers_count}`);
setGenerationSteps(prev => ({ ...prev, careers: 'completed' }));
},
onError: (error) => {
console.error('职业体系生成失败:', error);
setErrorDetails(`职业体系生成失败: ${error}`);
setGenerationSteps(prev => ({ ...prev, careers: 'error' }));
setLoading(false);
throw new Error(error);
},
onComplete: () => {
console.log('职业体系生成完成');
}
}
);
await continueFromCharacters(worldResult);
};
// 从角色步骤开始的完整流程
const continueFromCharacters = async (worldResult: WorldBuildingResult) => {
if (!generationData || !worldResult?.project_id) return;
const pid = worldResult.project_id;
const genreString = Array.isArray(generationData.genre) ? generationData.genre.join('、') : generationData.genre;
setGenerationSteps(prev => ({ ...prev, characters: 'processing' }));
@@ -697,7 +856,7 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
await wizardStreamApi.generateCharactersStream(
{
project_id: worldResult.project_id,
project_id: pid,
count: generationData.character_count,
world_context: {
time_period: worldResult.time_period || '',
@@ -731,19 +890,19 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
}
);
await continueFromOutline();
await continueFromOutline(pid);
};
// 从大纲步骤开始的完整流程
const continueFromOutline = async () => {
if (!generationData || !projectId) return;
const continueFromOutline = async (pid: string) => {
if (!generationData || !pid) return;
setGenerationSteps(prev => ({ ...prev, outline: 'processing' }));
setProgressMessage('正在生成大纲...');
await wizardStreamApi.generateCompleteOutlineStream(
{
project_id: projectId,
project_id: pid,
chapter_count: generationData.chapter_count,
narrative_perspective: generationData.narrative_perspective,
target_words: generationData.target_words,
@@ -777,12 +936,12 @@ export const AIProjectGenerator: React.FC<AIProjectGeneratorProps> = ({
setLoading(false);
// 调用完成回调
if (projectId) {
onComplete(projectId);
if (pid) {
onComplete(pid);
// 延迟1秒后自动跳转到项目详情页
setTimeout(() => {
navigate(`/project/${projectId}`);
navigate(`/project/${pid}`);
}, 1000);
}
};
+422 -77
View File
@@ -28,8 +28,11 @@ import {
InfoCircleOutlined,
ToolOutlined,
ArrowLeftOutlined,
ApiOutlined,
QuestionCircleOutlined,
WarningOutlined,
} from '@ant-design/icons';
import { mcpPluginApi } from '../services/api';
import { mcpPluginApi, settingsApi } from '../services/api';
import type { MCPPlugin, MCPTool } from '../types';
const { Paragraph, Text, Title } = Typography;
@@ -46,24 +49,112 @@ export default function MCPPluginsPage() {
const [editingPlugin, setEditingPlugin] = useState<MCPPlugin | null>(null);
const [testingPluginId, setTestingPluginId] = useState<string | null>(null);
const [viewingTools, setViewingTools] = useState<{ pluginId: string; tools: MCPTool[] } | null>(null);
const [checkingFunctionCalling, setCheckingFunctionCalling] = useState(false);
const [modelSupportStatus, setModelSupportStatus] = useState<'unknown' | 'supported' | 'unsupported'>('unknown');
useEffect(() => {
loadPlugins();
}, []);
const initPage = async () => {
setLoading(true);
try {
// 1. 并行获取插件列表和当前设置
const [pluginsData, settings] = await Promise.all([
mcpPluginApi.getPlugins(),
settingsApi.getSettings()
]);
setPlugins(pluginsData);
// 2. 检查配置一致性
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
if (verifiedConfigStr) {
try {
const verifiedConfig = JSON.parse(verifiedConfigStr);
const currentConfig = {
provider: settings.api_provider,
baseUrl: settings.api_base_url,
model: settings.llm_model
};
// 比较关键配置是否发生变更
const isConfigChanged =
verifiedConfig.provider !== currentConfig.provider ||
verifiedConfig.baseUrl !== currentConfig.baseUrl ||
verifiedConfig.model !== currentConfig.model;
if (isConfigChanged) {
// 配置已变更
setModelSupportStatus('unknown');
// 检查是否有正在运行的插件
const activePlugins = pluginsData.filter(p => p.enabled);
if (activePlugins.length > 0) {
// 自动禁用所有插件
message.loading({ content: '检测到模型配置变更,正在为了安全自动禁用插件...', key: 'auto_disable' });
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
// 重新加载插件列表状态
const updatedPlugins = await mcpPluginApi.getPlugins();
setPlugins(updatedPlugins);
message.success({ content: '已自动禁用所有插件,请重新检测模型能力', key: 'auto_disable' });
modal.warning({
title: '配置变更提醒',
centered: true,
content: '检测到您更换了 AI 模型或接口地址。为了防止错误调用,系统已自动暂停所有 MCP 插件。请重新进行"模型能力检查",确认新模型支持 Function Calling 后再启用插件。',
okText: '知道了',
});
} else {
// 没有运行中的插件,仅提示
message.info('检测到模型配置已变更,请重新检测模型能力');
}
// 清除旧的验证状态
localStorage.removeItem('mcp_verified_config');
} else {
// 配置未变更,恢复验证状态(根据缓存的状态恢复)
const cachedStatus = verifiedConfig.status || 'supported';
setModelSupportStatus(cachedStatus as 'unknown' | 'supported' | 'unsupported');
}
} catch (e) {
console.error('Failed to parse verified config:', e);
localStorage.removeItem('mcp_verified_config');
}
}
} catch (error) {
console.error('Init page failed:', error);
message.error('页面初始化失败');
} finally {
setLoading(false);
}
};
initPage();
}, [modal]);
const loadPlugins = async () => {
setLoading(true);
try {
const data = await mcpPluginApi.getPlugins();
setPlugins(data);
} catch (error) {
console.error('Load plugins failed:', error);
message.error('加载插件列表失败');
} finally {
setLoading(false);
}
};
const handleCreate = () => {
if (modelSupportStatus !== 'supported') {
modal.confirm({
title: '模型能力检查',
centered: true,
icon: <WarningOutlined />,
content: '为了确保 MCP 插件正常工作,您当前使用的 AI 模型必须支持 Function Calling(工具调用)能力。请先进行模型支持检测。',
okText: '去检测',
cancelText: '取消',
onOk: handleCheckFunctionCalling,
});
return;
}
setEditingPlugin(null);
form.resetFields();
form.setFieldsValue({
@@ -86,7 +177,7 @@ export default function MCPPluginsPage() {
setEditingPlugin(plugin);
// 重构为标准MCP配置格式
const mcpConfig: any = {
const mcpConfig: Record<string, Record<string, Record<string, unknown>>> = {
mcpServers: {
[plugin.plugin_name]: {
type: plugin.plugin_type || 'http'
@@ -94,7 +185,7 @@ export default function MCPPluginsPage() {
}
};
if (plugin.plugin_type === 'http') {
if (plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') {
mcpConfig.mcpServers[plugin.plugin_name].url = plugin.server_url;
mcpConfig.mcpServers[plugin.plugin_name].headers = plugin.headers || {};
} else {
@@ -125,6 +216,7 @@ export default function MCPPluginsPage() {
message.success('插件已删除');
loadPlugins();
} catch (error) {
console.error('Delete plugin failed:', error);
message.error('删除插件失败');
}
},
@@ -137,6 +229,7 @@ export default function MCPPluginsPage() {
message.success(enabled ? '插件已启用' : '插件已禁用');
loadPlugins();
} catch (error) {
console.error('Toggle plugin failed:', error);
message.error('切换插件状态失败');
}
};
@@ -150,45 +243,62 @@ export default function MCPPluginsPage() {
await loadPlugins();
if (result.success) {
const suggestions = result.suggestions || [];
const aiChoice = suggestions.find((s: string) => s.startsWith('🤖'))?.replace('🤖 AI选择: ', '') || '';
const paramsStr = suggestions.find((s: string) => s.startsWith('📝'))?.replace('📝 参数: ', '') || '';
const callTime = suggestions.find((s: string) => s.startsWith('⏱️'))?.replace('⏱️ 耗时: ', '') || '';
const resultStr = suggestions.find((s: string) => s.startsWith('📊'))?.replace('📊 结果:\n', '') || '';
modal.success({
title: '测试成功',
title: '🎉 测试成功',
centered: true,
width: isMobile ? '90%' : 600,
width: isMobile ? '95%' : 700,
content: (
<div style={{ padding: '8px 0' }}>
<div style={{ marginBottom: 24, padding: 16, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
<Typography.Text strong style={{ color: 'var(--color-success)' }}>
<div style={{ marginBottom: 16, padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
<Typography.Text strong style={{ color: 'var(--color-success)', fontSize: 14 }}>
{result.message}
</Typography.Text>
</div>
{(result.tools_count !== undefined || result.response_time_ms !== undefined) && (
<div style={{
padding: 16,
background: 'var(--color-bg-layout)',
borderRadius: 8,
marginBottom: 16
}}>
{result.tools_count !== undefined && (
<div style={{ marginBottom: 8, fontSize: 14 }}>
<Text type="secondary"></Text>
<Text strong>{result.tools_count}</Text>
</div>
)}
{result.response_time_ms !== undefined && (
<div style={{ fontSize: 14 }}>
<Text type="secondary"></Text>
<Text strong>{result.response_time_ms}ms</Text>
</div>
)}
<div style={{ display: 'grid', gridTemplateColumns: isMobile ? '1fr' : '1fr 1fr', gap: 12, marginBottom: 16 }}>
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
<Text type="secondary" style={{ fontSize: 12 }}></Text>
<div><Text strong style={{ fontSize: 20 }}>{result.tools_count || 0}</Text></div>
</div>
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
<Text type="secondary" style={{ fontSize: 12 }}></Text>
<div><Text strong style={{ fontSize: 20 }}>{result.response_time_ms?.toFixed(0) || 0}ms</Text></div>
</div>
</div>
{aiChoice && (
<div style={{ marginBottom: 12, padding: 12, background: 'var(--color-info-bg)', borderRadius: 8, border: '1px solid var(--color-info-border)' }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🤖 AI选择的工具</Text>
<Text code strong>{aiChoice}</Text>
{callTime && <Tag color="blue" style={{ marginLeft: 8 }}>{callTime}</Tag>}
</div>
)}
<Alert
message='插件状态已自动更新为"运行中"'
type="success"
showIcon
/>
{paramsStr && (
<div style={{ marginBottom: 12 }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📝 </Text>
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 12, overflow: 'auto', maxHeight: 100 }}>
{(() => { try { return JSON.stringify(JSON.parse(paramsStr), null, 2); } catch { return paramsStr; } })()}
</pre>
</div>
)}
{resultStr && (
<div style={{ marginBottom: 12 }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📊 </Text>
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 150, whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>
{resultStr}
</pre>
</div>
)}
<Alert message='插件状态已自动更新为"运行中"' type="success" showIcon />
</div>
),
});
@@ -248,7 +358,7 @@ export default function MCPPluginsPage() {
),
});
}
} catch (error: any) {
} catch {
message.error('测试插件失败');
} finally {
setTestingPluginId(null);
@@ -260,17 +370,181 @@ export default function MCPPluginsPage() {
const result = await mcpPluginApi.getPluginTools(pluginId);
setViewingTools({ pluginId, tools: result.tools });
} catch (error) {
console.error('Get tools failed:', error);
message.error('获取工具列表失败');
}
};
const handleSubmit = async (values: any) => {
const handleCheckFunctionCalling = async () => {
// 从设置中获取当前配置
setCheckingFunctionCalling(true);
try {
const settings = await settingsApi.getSettings();
if (!settings.api_key || !settings.llm_model) {
message.warning('请先在设置页面配置 API Key 和模型');
return;
}
const result = await settingsApi.checkFunctionCalling({
api_key: settings.api_key,
api_base_url: settings.api_base_url || '',
provider: settings.api_provider || 'openai',
llm_model: settings.llm_model,
});
// 无论成功失败,都缓存当前测试的配置和状态
const configToCache = {
provider: settings.api_provider,
baseUrl: settings.api_base_url,
model: settings.llm_model,
status: result.success && result.supported ? 'supported' : 'unsupported',
testedAt: new Date().toISOString()
};
localStorage.setItem('mcp_verified_config', JSON.stringify(configToCache));
if (result.success && result.supported) {
setModelSupportStatus('supported');
modal.success({
title: '✅ Function Calling 支持检测',
centered: true,
width: isMobile ? '95%' : 700,
content: (
<div style={{ padding: '8px 0' }}>
<div style={{ marginBottom: 16, padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
<Typography.Text strong style={{ color: 'var(--color-success)', fontSize: 14 }}>
{result.message}
</Typography.Text>
</div>
<div style={{ display: 'grid', gridTemplateColumns: isMobile ? '1fr' : '1fr 1fr', gap: 12, marginBottom: 16 }}>
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
<Text type="secondary" style={{ fontSize: 12 }}>API </Text>
<div><Text strong style={{ fontSize: 16 }}>{result.provider}</Text></div>
</div>
<div style={{ padding: 12, background: 'var(--color-bg-layout)', borderRadius: 8 }}>
<Text type="secondary" style={{ fontSize: 12 }}></Text>
<div><Text strong style={{ fontSize: 16 }}>{result.response_time_ms?.toFixed(0) || 0}ms</Text></div>
</div>
</div>
<div style={{ marginBottom: 12, padding: 12, background: 'var(--color-info-bg)', borderRadius: 8, border: '1px solid var(--color-info-border)' }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🔧 </Text>
<Text code strong>{result.model}</Text>
{result.details?.finish_reason && (
<Tag color="green" style={{ marginLeft: 8 }}>finish_reason: {result.details.finish_reason}</Tag>
)}
</div>
{result.details && (
<div style={{ marginBottom: 12 }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📊 </Text>
<div style={{ padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 12 }}>
<div> : {result.details.tool_call_count || 0}</div>
<div> : {result.details.test_tool || 'N/A'}</div>
<div> : {result.details.response_type || 'N/A'}</div>
</div>
</div>
)}
{result.tool_calls && result.tool_calls.length > 0 && (
<div style={{ marginBottom: 12 }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>🔨 </Text>
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 150 }}>
{JSON.stringify(result.tool_calls[0], null, 2)}
</pre>
</div>
)}
{result.suggestions && result.suggestions.length > 0 && (
<div style={{ padding: 12, background: 'var(--color-success-bg)', border: '1px solid var(--color-success-border)', borderRadius: 8 }}>
<Text strong style={{ fontSize: 13, display: 'block', marginBottom: 8 }}>💡 </Text>
<ul style={{ margin: 0, paddingLeft: 20, fontSize: 12 }}>
{result.suggestions.map((s: string, i: number) => (
<li key={i} style={{ marginBottom: 4 }}>{s}</li>
))}
</ul>
</div>
)}
</div>
),
});
} else {
setModelSupportStatus('unsupported');
modal.warning({
title: '❌ Function Calling 支持检测',
centered: true,
width: isMobile ? '95%' : 700,
content: (
<div style={{ padding: '8px 0' }}>
<div style={{ marginBottom: 16 }}>
<Alert
message={result.message || '模型不支持 Function Calling'}
type="warning"
showIcon
/>
</div>
{result.error && (
<div style={{
padding: 16,
background: 'var(--color-warning-bg)',
border: '1px solid var(--color-warning-border)',
borderRadius: 8,
marginBottom: 16
}}>
<Text strong style={{ fontSize: 14, display: 'block', marginBottom: 8 }}>:</Text>
<Text style={{ fontSize: 13, fontFamily: 'monospace' }}>
{result.error}
</Text>
</div>
)}
{result.response_preview && (
<div style={{ marginBottom: 12 }}>
<Text type="secondary" style={{ fontSize: 12, display: 'block', marginBottom: 4 }}>📝 200</Text>
<pre style={{ margin: 0, padding: 8, background: 'var(--color-bg-layout)', borderRadius: 4, fontSize: 11, overflow: 'auto', maxHeight: 100, whiteSpace: 'pre-wrap' }}>
{result.response_preview}
</pre>
</div>
)}
{result.suggestions && result.suggestions.length > 0 && (
<div style={{
padding: 16,
background: 'var(--color-info-bg)',
border: '1px solid var(--color-info-border)',
borderRadius: 8
}}>
<Text strong style={{ fontSize: 14, display: 'block', marginBottom: 8 }}>💡 :</Text>
<ul style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
{result.suggestions.map((s: string, i: number) => (
<li key={i} style={{ marginBottom: 4 }}>{s}</li>
))}
</ul>
</div>
)}
</div>
),
});
}
} catch (error) {
console.error('Check function calling failed:', error);
message.error('检测失败,请稍后重试');
setModelSupportStatus('unsupported');
} finally {
setCheckingFunctionCalling(false);
}
};
const handleSubmit = async (values: { config_json: string; enabled: boolean; category?: string }) => {
setLoading(true);
try {
// 验证JSON格式
try {
JSON.parse(values.config_json);
} catch (e) {
} catch {
message.error('配置JSON格式错误,请检查');
setLoading(false);
return;
@@ -289,8 +563,9 @@ export default function MCPPluginsPage() {
setModalVisible(false);
form.resetFields();
loadPlugins();
} catch (error: any) {
const errorMsg = error?.response?.data?.detail || '操作失败';
} catch (error: unknown) {
const err = error as { response?: { data?: { detail?: string } } };
const errorMsg = err?.response?.data?.detail || '操作失败';
message.error(errorMsg);
} finally {
setLoading(false);
@@ -407,38 +682,104 @@ export default function MCPPluginsPage() {
</Col>
</Row>
{/* 使用提示 */}
<Alert
message={
<Space align="center">
<InfoCircleOutlined style={{ fontSize: 16, color: 'var(--color-primary)' }} />
<Text strong style={{ fontSize: isMobile ? 13 : 14, color: 'var(--color-text-primary)' }}> MCP </Text>
</Space>
}
description={
<div>
<Text style={{ fontSize: isMobile ? 12 : 13, display: 'block', marginBottom: 8 }}>
<strong>MCP (Model Context Protocol)</strong> AI
</Text>
<Text style={{ fontSize: isMobile ? 12 : 13, display: 'block' }}>
MCP AI 访API
</Text>
<div style={{ marginTop: isMobile ? 16 : 24, display: 'flex', gap: 16, flexDirection: isMobile ? 'column' : 'row' }}>
<Card
variant="borderless"
style={{
flex: 1,
borderRadius: 12,
background: 'rgba(255, 255, 255, 0.9)',
border: '1px solid rgba(255, 255, 255, 0.6)',
backdropFilter: 'blur(10px)',
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.03)'
}}
bodyStyle={{ padding: 20 }}
>
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center' }}>
<Space align="start">
<div style={{
width: 40, height: 40, borderRadius: '50%',
background: modelSupportStatus === 'supported' ? 'var(--color-success-bg)' : modelSupportStatus === 'unsupported' ? 'var(--color-error-bg)' : 'var(--color-info-bg)',
display: 'flex', alignItems: 'center', justifyContent: 'center',
border: `1px solid ${modelSupportStatus === 'supported' ? 'var(--color-success-border)' : modelSupportStatus === 'unsupported' ? 'var(--color-error-border)' : 'var(--color-info-border)'}`
}}>
{modelSupportStatus === 'supported' ? (
<CheckCircleOutlined style={{ fontSize: 20, color: 'var(--color-success)' }} />
) : modelSupportStatus === 'unsupported' ? (
<CloseCircleOutlined style={{ fontSize: 20, color: 'var(--color-error)' }} />
) : (
<QuestionCircleOutlined style={{ fontSize: 20, color: 'var(--color-info)' }} />
)}
</div>
<div>
<Text strong style={{ fontSize: 16, display: 'block', color: 'var(--color-text-primary)' }}></Text>
<Text type="secondary" style={{ fontSize: 13 }}>
{modelSupportStatus === 'supported'
? '当前模型支持 Function Calling,可正常使用 MCP 插件'
: modelSupportStatus === 'unsupported'
? '当前模型不支持 Function Calling,无法使用 MCP 插件'
: '请先检测模型是否支持 Function Calling 能力'}
</Text>
</div>
</Space>
<Button
type={modelSupportStatus === 'supported' ? 'default' : 'primary'}
icon={<ApiOutlined />}
onClick={handleCheckFunctionCalling}
loading={checkingFunctionCalling}
style={{ borderRadius: 8 }}
>
{modelSupportStatus === 'unknown' ? '开始检测' : '重新检测'}
</Button>
</div>
}
type="info"
showIcon={false}
style={{
marginTop: isMobile ? 16 : 24,
borderRadius: 12,
background: 'rgba(230, 247, 255, 0.6)',
border: '1px solid rgba(145, 213, 255, 0.6)',
backdropFilter: 'blur(5px)'
}}
/>
</Card>
<Card
variant="borderless"
style={{
flex: 1,
borderRadius: 12,
background: 'rgba(230, 247, 255, 0.6)',
border: '1px solid rgba(145, 213, 255, 0.6)',
backdropFilter: 'blur(10px)',
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.03)'
}}
bodyStyle={{ padding: 20 }}
>
<Space align="start">
<InfoCircleOutlined style={{ fontSize: 20, color: 'var(--color-primary)', marginTop: 4 }} />
<div>
<Text strong style={{ fontSize: 16, display: 'block', color: 'var(--color-text-primary)', marginBottom: 4 }}> MCP </Text>
<Text style={{ fontSize: 13, display: 'block', color: 'var(--color-text-secondary)', lineHeight: 1.6 }}>
MCP (Model Context Protocol) AI AI 访API
</Text>
</div>
</Space>
</Card>
</div>
</Card>
{/* 主内容区 */}
<div style={{ flex: 1 }}>
{/* 模型能力未验证时的警告提示 */}
{modelSupportStatus !== 'supported' && plugins.length > 0 && (
<Alert
message={
modelSupportStatus === 'unsupported'
? '当前模型不支持 Function Calling,所有插件操作已禁用'
: '请先完成模型能力检查,才能操作插件'
}
type={modelSupportStatus === 'unsupported' ? 'error' : 'warning'}
showIcon
icon={modelSupportStatus === 'unsupported' ? <CloseCircleOutlined /> : <WarningOutlined />}
style={{ marginBottom: 16, borderRadius: 8 }}
action={
<Button size="small" type="primary" onClick={handleCheckFunctionCalling} loading={checkingFunctionCalling}>
{modelSupportStatus === 'unknown' ? '开始检测' : '重新检测'}
</Button>
}
/>
)}
{/* 插件列表 */}
<Spin spinning={loading}>
@@ -479,7 +820,7 @@ export default function MCPPluginsPage() {
{plugin.display_name || plugin.plugin_name}
</Text>
{getStatusTag(plugin)}
<Tag color={plugin.plugin_type === 'http' ? 'blue' : 'cyan'}>
<Tag color={plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse' ? 'blue' : 'cyan'}>
{plugin.plugin_type?.toUpperCase() || 'UNKNOWN'}
</Tag>
{plugin.category && plugin.category !== 'general' && (
@@ -500,7 +841,7 @@ export default function MCPPluginsPage() {
)}
{/* 只显示有值的URL或命令,脱敏处理敏感信息 */}
{plugin.plugin_type === 'http' && plugin.server_url && (
{(plugin.plugin_type === 'http' || plugin.plugin_type === 'streamable_http' || plugin.plugin_type === 'sse') && plugin.server_url && (
<div style={{ fontSize: isMobile ? '11px' : '12px' }}>
<Text type="secondary" code>
{(() => {
@@ -551,9 +892,10 @@ export default function MCPPluginsPage() {
<Space size="small" wrap>
<Switch
title={plugin.enabled ? '禁用插件' : '启用插件'}
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : (plugin.enabled ? '禁用插件' : '启用插件')}
checked={plugin.enabled}
onChange={(checked) => handleToggle(plugin, checked)}
disabled={modelSupportStatus !== 'supported'}
size={isMobile ? 'small' : 'default'}
style={{
flexShrink: 0,
@@ -563,30 +905,33 @@ export default function MCPPluginsPage() {
}}
/>
<Button
title="测试连接"
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '测试连接'}
icon={<ThunderboltOutlined />}
onClick={() => handleTest(plugin.id)}
loading={testingPluginId === plugin.id}
disabled={modelSupportStatus !== 'supported'}
size={isMobile ? 'small' : 'middle'}
/>
<Button
title="查看工具"
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '查看工具'}
icon={<ToolOutlined />}
onClick={() => handleViewTools(plugin.id)}
disabled={!plugin.enabled || plugin.status !== 'active'}
disabled={modelSupportStatus !== 'supported' || !plugin.enabled || plugin.status !== 'active'}
size={isMobile ? 'small' : 'middle'}
/>
<Button
title="编辑"
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '编辑'}
icon={<EditOutlined />}
onClick={() => handleEdit(plugin)}
disabled={modelSupportStatus !== 'supported'}
size={isMobile ? 'small' : 'middle'}
/>
<Button
title="删除"
title={modelSupportStatus !== 'supported' ? '请先完成模型能力检查' : '删除'}
danger
icon={<DeleteOutlined />}
onClick={() => handleDelete(plugin)}
disabled={modelSupportStatus !== 'supported'}
size={isMobile ? 'small' : 'middle'}
/>
</Space>
@@ -627,7 +972,7 @@ export default function MCPPluginsPage() {
{
"mcpServers": {
"exa": {
"type": "http",
"type": "streamable_http",
"url": "https://mcp.exa.ai/mcp?exaApiKey=YOUR_API_KEY",
"headers": {}
}
+162 -2
View File
@@ -1,8 +1,8 @@
import { useState, useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { Card, Form, Input, Button, Select, Slider, InputNumber, message, Space, Typography, Spin, Modal, Alert, Grid, Tabs, List, Tag, Popconfirm, Empty, Row, Col } from 'antd';
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined, CheckCircleOutlined, CloseCircleOutlined, ThunderboltOutlined, PlusOutlined, EditOutlined, CopyOutlined } from '@ant-design/icons';
import { settingsApi } from '../services/api';
import { SettingOutlined, SaveOutlined, DeleteOutlined, ReloadOutlined, ArrowLeftOutlined, InfoCircleOutlined, CheckCircleOutlined, CloseCircleOutlined, ThunderboltOutlined, PlusOutlined, EditOutlined, CopyOutlined, WarningOutlined } from '@ant-design/icons';
import { settingsApi, mcpPluginApi } from '../services/api';
import type { SettingsUpdate, APIKeyPreset, PresetCreateRequest, APIKeyPresetConfig } from '../types';
const { Title, Text } = Typography;
@@ -95,10 +95,86 @@ export default function SettingsPage() {
const handleSave = async (values: SettingsUpdate) => {
setLoading(true);
try {
// 检查是否与 MCP 缓存的配置不一致
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
let configChanged = false;
if (verifiedConfigStr) {
try {
const verifiedConfig = JSON.parse(verifiedConfigStr);
configChanged =
verifiedConfig.provider !== values.api_provider ||
verifiedConfig.baseUrl !== values.api_base_url ||
verifiedConfig.model !== values.llm_model;
} catch (e) {
console.error('Failed to parse verified config:', e);
}
}
await settingsApi.saveSettings(values);
message.success('设置已保存');
setHasSettings(true);
setIsDefaultSettings(false);
// 如果配置发生变化,需要处理 MCP 插件
if (configChanged) {
// 清除 MCP 验证缓存
localStorage.removeItem('mcp_verified_config');
// 检查并禁用所有 MCP 插件
try {
const plugins = await mcpPluginApi.getPlugins();
const activePlugins = plugins.filter(p => p.enabled);
if (activePlugins.length > 0) {
// 禁用所有插件
message.loading({ content: '正在禁用 MCP 插件...', key: 'disable_mcp' });
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
message.success({ content: '已禁用所有 MCP 插件', key: 'disable_mcp' });
// 显示提示弹窗
modal.warning({
title: (
<Space>
<WarningOutlined style={{ color: '#faad14' }} />
<span>API </span>
</Space>
),
centered: true,
content: (
<div style={{ padding: '8px 0' }}>
<Alert
message="检测到您修改了 API 配置(提供商、地址或模型),为确保 MCP 插件正常工作,系统已自动禁用所有插件。"
type="warning"
showIcon
style={{ marginBottom: 16 }}
/>
<div style={{
padding: 12,
background: 'var(--color-info-bg)',
border: '1px solid var(--color-info-border)',
borderRadius: 8
}}>
<Text strong style={{ display: 'block', marginBottom: 8 }}></Text>
<ol style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
<li> MCP </li>
<li>"模型能力检查"</li>
<li> Function Calling </li>
</ol>
</div>
</div>
),
okText: '前往 MCP 页面',
cancelText: '稍后处理',
onOk: () => {
navigate('/mcp-plugins');
},
});
}
} catch (err) {
console.error('Failed to disable MCP plugins:', err);
}
}
} catch (error) {
message.error('保存设置失败');
} finally {
@@ -348,10 +424,94 @@ export default function SettingsPage() {
const handlePresetActivate = async (presetId: string, presetName: string) => {
try {
// 获取预设配置用于比较
const preset = presets.find(p => p.id === presetId);
await settingsApi.activatePreset(presetId);
message.success(`已激活预设: ${presetName}`);
loadPresets();
loadSettings(); // 重新加载当前配置
// 检查是否与 MCP 缓存的配置不一致
if (preset) {
const verifiedConfigStr = localStorage.getItem('mcp_verified_config');
let configChanged = false;
if (verifiedConfigStr) {
try {
const verifiedConfig = JSON.parse(verifiedConfigStr);
configChanged =
verifiedConfig.provider !== preset.config.api_provider ||
verifiedConfig.baseUrl !== preset.config.api_base_url ||
verifiedConfig.model !== preset.config.llm_model;
} catch (e) {
console.error('Failed to parse verified config:', e);
configChanged = true; // 解析失败也视为配置变化
}
} else {
// 没有缓存的配置,如果有启用的插件也需要处理
configChanged = true;
}
if (configChanged) {
// 清除 MCP 验证缓存
localStorage.removeItem('mcp_verified_config');
// 检查并禁用所有 MCP 插件
try {
const plugins = await mcpPluginApi.getPlugins();
const activePlugins = plugins.filter(p => p.enabled);
if (activePlugins.length > 0) {
// 禁用所有插件
message.loading({ content: '正在禁用 MCP 插件...', key: 'disable_mcp' });
await Promise.all(activePlugins.map(p => mcpPluginApi.togglePlugin(p.id, false)));
message.success({ content: '已禁用所有 MCP 插件', key: 'disable_mcp' });
// 显示提示弹窗
modal.warning({
title: (
<Space>
<WarningOutlined style={{ color: '#faad14' }} />
<span>API </span>
</Space>
),
centered: true,
content: (
<div style={{ padding: '8px 0' }}>
<Alert
message={`切换到预设「${presetName}」后,API 配置发生了变化。为确保 MCP 插件正常工作,系统已自动禁用所有插件。`}
type="warning"
showIcon
style={{ marginBottom: 16 }}
/>
<div style={{
padding: 12,
background: 'var(--color-info-bg)',
border: '1px solid var(--color-info-border)',
borderRadius: 8
}}>
<Text strong style={{ display: 'block', marginBottom: 8 }}></Text>
<ol style={{ margin: 0, paddingLeft: 20, fontSize: 13 }}>
<li> MCP </li>
<li>"模型能力检查"</li>
<li> Function Calling </li>
</ol>
</div>
</div>
),
okText: '前往 MCP 页面',
cancelText: '稍后处理',
onOk: () => {
navigate('/mcp-plugins');
},
});
}
} catch (err) {
console.error('Failed to disable MCP plugins:', err);
}
}
}
} catch (error) {
message.error('激活失败');
console.error(error);
+56 -6
View File
@@ -1,9 +1,4 @@
import axios from 'axios';
interface MCPPluginSimpleCreate {
config_json: string;
enabled: boolean;
}
import { message } from 'antd';
import { ssePost } from '../utils/sseClient';
import type { SSEClientOptions } from '../utils/sseClient';
@@ -50,8 +45,14 @@ import type {
PresetCreateRequest,
PresetUpdateRequest,
PresetListResponse,
ChapterPlanItem,
} from '../types';
interface MCPPluginSimpleCreate {
config_json: string;
enabled: boolean;
}
const api = axios.create({
baseURL: '/api',
timeout: 120000,
@@ -205,6 +206,36 @@ export const settingsApi = {
suggestions?: string[];
}>('/settings/test', params),
checkFunctionCalling: (params: { api_key: string; api_base_url: string; provider: string; llm_model: string }) =>
api.post<unknown, {
success: boolean;
supported: boolean;
message: string;
response_time_ms?: number;
provider?: string;
model?: string;
details?: {
finish_reason?: string;
has_tool_calls?: boolean;
tool_call_count?: number;
test_tool?: string;
test_prompt?: string;
response_type?: string;
};
tool_calls?: Array<{
id?: string;
type?: string;
function?: {
name: string;
arguments: string;
};
}>;
response_preview?: string;
error?: string;
error_type?: string;
suggestions?: string[];
}>('/settings/check-function-calling', params),
// API配置预设管理
getPresets: () =>
api.get<unknown, PresetListResponse>('/settings/presets'),
@@ -410,7 +441,7 @@ export const outlineApi = {
api.post<unknown, OutlineExpansionResponse>(`/outlines/${outlineId}/expand`, data),
// 根据已有规划创建章节(避免重复AI调用)
createChaptersFromPlans: (outlineId: string, chapterPlans: any[]) =>
createChaptersFromPlans: (outlineId: string, chapterPlans: ChapterPlanItem[]) =>
api.post<unknown, {
outline_id: string;
outline_title: string;
@@ -711,6 +742,25 @@ export const wizardStreamApi = {
options
),
generateCareerSystemStream: (
data: {
project_id: string;
provider?: string;
model?: string;
},
options?: SSEClientOptions
) => ssePost<{
project_id: string;
main_careers_count: number;
sub_careers_count: number;
main_careers: string[];
sub_careers: string[];
}>(
'/api/wizard-stream/career-system',
data,
options
),
generateCompleteOutlineStream: (
data: {
project_id: string;
+1 -1
View File
@@ -356,7 +356,7 @@ export function useChapterSync() {
message.progress || 0
);
}
} else if (message.type === 'content' && message.content) {
} else if ((message.type === 'content' || message.type === 'chunk') && message.content) {
fullContent += message.content;
if (onProgress) {
onProgress(fullContent);
+2 -2
View File
@@ -667,7 +667,7 @@ export interface MCPPlugin {
plugin_name: string;
display_name: string;
description?: string;
plugin_type: 'http' | 'stdio';
plugin_type: 'http' | 'stdio' | 'streamable_http' | 'sse';
category: string;
// HTTP类型字段
@@ -693,7 +693,7 @@ export interface MCPPluginCreate {
plugin_name: string;
display_name?: string;
description?: string;
server_type: 'http' | 'stdio';
server_type: 'http' | 'stdio' | 'streamable_http' | 'sse';
server_url?: string;
command?: string;
args?: string[];